LeonardoMdSA commited on
Commit
62a3be1
·
1 Parent(s): bb37020

add working scripts

Browse files
README.md CHANGED
@@ -22,10 +22,14 @@ pytest -v
22
  Or manual smoke test in test_backend.py
23
 
24
 
25
- ### Train model
 
 
26
 
27
  python scripts/train_model.py
28
 
 
 
29
  ## Initial struture
30
 
31
  Context-aware NLP classification platform with MCP/
 
22
  Or manual smoke test in test_backend.py
23
 
24
 
25
+ ### Train-evaluate model
26
+
27
+ python scripts\seed_data.py
28
 
29
  python scripts/train_model.py
30
 
31
+ python scripts\evaluate.py
32
+
33
  ## Initial struture
34
 
35
  Context-aware NLP classification platform with MCP/
data/samples/eval.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "text": "Invoice for Q1 2025 total amount $15,200",
4
+ "label": "finance.invoice"
5
+ },
6
+ {
7
+ "text": "HR policy update regarding employee leave",
8
+ "label": "hr.policy"
9
+ },
10
+ {
11
+ "text": "Contract agreement between Company A and Company B",
12
+ "label": "legal.contract"
13
+ },
14
+ {
15
+ "text": "Invoice for Q4 2025 total amount $12,000",
16
+ "label": "finance.invoice"
17
+ }
18
+ ]
data/samples/train.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "text": "Mandatory compliance training policy for all staff",
4
+ "label": "hr.policy"
5
+ },
6
+ {
7
+ "text": "Invoice for Q2 2025 total amount $8,450",
8
+ "label": "finance.invoice"
9
+ },
10
+ {
11
+ "text": "New guidelines for work-from-home policy",
12
+ "label": "hr.policy"
13
+ },
14
+ {
15
+ "text": "Invoice for Q3 2025 total amount $23,923",
16
+ "label": "finance.invoice"
17
+ },
18
+ {
19
+ "text": "Non-disclosure agreement for external partners",
20
+ "label": "legal.contract"
21
+ },
22
+ {
23
+ "text": "Service level agreement for client X",
24
+ "label": "legal.contract"
25
+ }
26
+ ]
scripts/evaluate.py CHANGED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import joblib
7
+ from sklearn.metrics import (
8
+ accuracy_score,
9
+ precision_recall_fscore_support,
10
+ classification_report
11
+ )
12
+
13
+ BASE_DIR = Path(__file__).resolve().parent.parent
14
+ MODELS_DIR = BASE_DIR / "models"
15
+ DATA_DIR = BASE_DIR / "data"
16
+
17
+
18
+ def load_model():
19
+ model_path = MODELS_DIR / "trained_pipeline.joblib"
20
+ if not model_path.exists():
21
+ raise FileNotFoundError(f"Model not found: {model_path}")
22
+ return joblib.load(model_path)
23
+
24
+
25
+ def load_dataset(dataset_path: Path):
26
+ if not dataset_path.exists():
27
+ raise FileNotFoundError(f"Dataset not found: {dataset_path}")
28
+
29
+ # Hard guard: never evaluate on training data
30
+ if dataset_path.name in {"training_data.json", "train.json"}:
31
+ raise RuntimeError(
32
+ f"Refusing to evaluate on training dataset: {dataset_path.name}"
33
+ )
34
+
35
+ with dataset_path.open("r", encoding="utf-8") as f:
36
+ raw = json.load(f)
37
+
38
+ if isinstance(raw, list):
39
+ samples = raw
40
+ elif isinstance(raw, dict) and "samples" in raw:
41
+ samples = raw["samples"]
42
+ else:
43
+ raise ValueError("Unsupported JSON dataset format")
44
+
45
+ texts = []
46
+ labels = []
47
+
48
+ for i, item in enumerate(samples):
49
+ if "text" not in item or "label" not in item:
50
+ raise ValueError(f"Invalid sample at index {i}: {item}")
51
+ texts.append(item["text"])
52
+ labels.append(item["label"])
53
+
54
+ return texts, labels
55
+
56
+
57
+ def evaluate(model, X, y):
58
+ y_pred = model.predict(X)
59
+
60
+ acc = accuracy_score(y, y_pred)
61
+ precision, recall, f1, _ = precision_recall_fscore_support(
62
+ y, y_pred, average="weighted", zero_division=0
63
+ )
64
+
65
+ print("====================================")
66
+ print("Offline Evaluation Results")
67
+ print("====================================")
68
+ print(f"Samples : {len(y)}")
69
+ print(f"Accuracy : {acc:.4f}")
70
+ print(f"Precision: {precision:.4f}")
71
+ print(f"Recall : {recall:.4f}")
72
+ print(f"F1-score : {f1:.4f}")
73
+ print()
74
+ print("Detailed Classification Report")
75
+ print("------------------------------------")
76
+ print(classification_report(y, y_pred, zero_division=0))
77
+
78
+
79
+ def main():
80
+ parser = argparse.ArgumentParser(
81
+ description="Offline evaluation using held-out JSON dataset"
82
+ )
83
+ parser.add_argument(
84
+ "--data",
85
+ default=str(DATA_DIR / "samples" / "eval.json"),
86
+ help="Path to evaluation dataset (default: data/samples/eval.json)"
87
+ )
88
+
89
+ args = parser.parse_args()
90
+
91
+ model = load_model()
92
+ X, y = load_dataset(Path(args.data))
93
+ evaluate(model, X, y)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
scripts/seed_data.py CHANGED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Seed and split dataset for training and evaluation.
3
+
4
+ - Reads: data/samples/training_data.json
5
+ - Writes:
6
+ - data/samples/train.json
7
+ - data/samples/eval.json
8
+
9
+ This script enforces:
10
+ - Stratified split by label
11
+ - Deterministic output (fixed random seed)
12
+ - Basic data validation
13
+ """
14
+
15
+ import json
16
+ import random
17
+ from pathlib import Path
18
+ from collections import defaultdict
19
+
20
+ # -------------------------
21
+ # Configuration
22
+ # -------------------------
23
+ RANDOM_SEED = 42
24
+ TRAIN_RATIO = 0.7
25
+
26
+ BASE_DIR = Path(__file__).resolve().parent.parent
27
+ SAMPLES_DIR = BASE_DIR / "data" / "samples"
28
+
29
+ SOURCE_FILE = SAMPLES_DIR / "training_data.json"
30
+ TRAIN_FILE = SAMPLES_DIR / "train.json"
31
+ EVAL_FILE = SAMPLES_DIR / "eval.json"
32
+
33
+
34
+ def main():
35
+ if not SOURCE_FILE.exists():
36
+ raise FileNotFoundError(f"Source dataset not found: {SOURCE_FILE}")
37
+
38
+ with open(SOURCE_FILE, "r", encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ if not isinstance(data, list) or len(data) == 0:
42
+ raise ValueError("Dataset must be a non-empty list")
43
+
44
+ # -------------------------
45
+ # Basic validation
46
+ # -------------------------
47
+ for i, item in enumerate(data):
48
+ if "text" not in item or "label" not in item:
49
+ raise ValueError(f"Invalid sample at index {i}: {item}")
50
+
51
+ # -------------------------
52
+ # Stratified split
53
+ # -------------------------
54
+ random.seed(RANDOM_SEED)
55
+
56
+ by_label = defaultdict(list)
57
+ for item in data:
58
+ by_label[item["label"]].append(item)
59
+
60
+ train_data = []
61
+ eval_data = []
62
+
63
+ for label, items in by_label.items():
64
+ random.shuffle(items)
65
+
66
+ split_idx = max(1, int(len(items) * TRAIN_RATIO))
67
+
68
+ train_data.extend(items[:split_idx])
69
+ eval_data.extend(items[split_idx:])
70
+
71
+ # Final shuffle (important)
72
+ random.shuffle(train_data)
73
+ random.shuffle(eval_data)
74
+
75
+ # -------------------------
76
+ # Write outputs
77
+ # -------------------------
78
+ SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
79
+
80
+ with open(TRAIN_FILE, "w", encoding="utf-8") as f:
81
+ json.dump(train_data, f, indent=2, ensure_ascii=False)
82
+
83
+ with open(EVAL_FILE, "w", encoding="utf-8") as f:
84
+ json.dump(eval_data, f, indent=2, ensure_ascii=False)
85
+
86
+ # -------------------------
87
+ # Summary
88
+ # -------------------------
89
+ print("====================================")
90
+ print("Dataset seeding completed")
91
+ print("====================================")
92
+ print(f"Total samples : {len(data)}")
93
+ print(f"Train samples : {len(train_data)}")
94
+ print(f"Eval samples : {len(eval_data)}")
95
+ print()
96
+
97
+ print("Label distribution (train):")
98
+ _print_distribution(train_data)
99
+
100
+ print("\nLabel distribution (eval):")
101
+ _print_distribution(eval_data)
102
+
103
+
104
+ def _print_distribution(dataset):
105
+ dist = defaultdict(int)
106
+ for item in dataset:
107
+ dist[item["label"]] += 1
108
+
109
+ for label, count in sorted(dist.items()):
110
+ print(f" {label:<20} {count}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()