Commit ·
62a3be1
1
Parent(s): bb37020
add working scripts
Browse files- README.md +5 -1
- data/samples/eval.json +18 -0
- data/samples/train.json +26 -0
- scripts/evaluate.py +97 -0
- scripts/seed_data.py +114 -0
README.md
CHANGED
|
@@ -22,10 +22,14 @@ pytest -v
|
|
| 22 |
Or manual smoke test in test_backend.py
|
| 23 |
|
| 24 |
|
| 25 |
-
### Train model
|
|
|
|
|
|
|
| 26 |
|
| 27 |
python scripts/train_model.py
|
| 28 |
|
|
|
|
|
|
|
| 29 |
## Initial struture
|
| 30 |
|
| 31 |
Context-aware NLP classification platform with MCP/
|
|
|
|
| 22 |
Or manual smoke test in test_backend.py
|
| 23 |
|
| 24 |
|
| 25 |
+
### Train-evaluate model
|
| 26 |
+
|
| 27 |
+
python scripts\seed_data.py
|
| 28 |
|
| 29 |
python scripts/train_model.py
|
| 30 |
|
| 31 |
+
python scripts\evaluate.py
|
| 32 |
+
|
| 33 |
## Initial struture
|
| 34 |
|
| 35 |
Context-aware NLP classification platform with MCP/
|
data/samples/eval.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"text": "Invoice for Q1 2025 total amount $15,200",
|
| 4 |
+
"label": "finance.invoice"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"text": "HR policy update regarding employee leave",
|
| 8 |
+
"label": "hr.policy"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"text": "Contract agreement between Company A and Company B",
|
| 12 |
+
"label": "legal.contract"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"text": "Invoice for Q4 2025 total amount $12,000",
|
| 16 |
+
"label": "finance.invoice"
|
| 17 |
+
}
|
| 18 |
+
]
|
data/samples/train.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"text": "Mandatory compliance training policy for all staff",
|
| 4 |
+
"label": "hr.policy"
|
| 5 |
+
},
|
| 6 |
+
{
|
| 7 |
+
"text": "Invoice for Q2 2025 total amount $8,450",
|
| 8 |
+
"label": "finance.invoice"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"text": "New guidelines for work-from-home policy",
|
| 12 |
+
"label": "hr.policy"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"text": "Invoice for Q3 2025 total amount $23,923",
|
| 16 |
+
"label": "finance.invoice"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"text": "Non-disclosure agreement for external partners",
|
| 20 |
+
"label": "legal.contract"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"text": "Service level agreement for client X",
|
| 24 |
+
"label": "legal.contract"
|
| 25 |
+
}
|
| 26 |
+
]
|
scripts/evaluate.py
CHANGED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import joblib
|
| 7 |
+
from sklearn.metrics import (
|
| 8 |
+
accuracy_score,
|
| 9 |
+
precision_recall_fscore_support,
|
| 10 |
+
classification_report
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
| 14 |
+
MODELS_DIR = BASE_DIR / "models"
|
| 15 |
+
DATA_DIR = BASE_DIR / "data"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_model():
|
| 19 |
+
model_path = MODELS_DIR / "trained_pipeline.joblib"
|
| 20 |
+
if not model_path.exists():
|
| 21 |
+
raise FileNotFoundError(f"Model not found: {model_path}")
|
| 22 |
+
return joblib.load(model_path)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_dataset(dataset_path: Path):
|
| 26 |
+
if not dataset_path.exists():
|
| 27 |
+
raise FileNotFoundError(f"Dataset not found: {dataset_path}")
|
| 28 |
+
|
| 29 |
+
# Hard guard: never evaluate on training data
|
| 30 |
+
if dataset_path.name in {"training_data.json", "train.json"}:
|
| 31 |
+
raise RuntimeError(
|
| 32 |
+
f"Refusing to evaluate on training dataset: {dataset_path.name}"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
with dataset_path.open("r", encoding="utf-8") as f:
|
| 36 |
+
raw = json.load(f)
|
| 37 |
+
|
| 38 |
+
if isinstance(raw, list):
|
| 39 |
+
samples = raw
|
| 40 |
+
elif isinstance(raw, dict) and "samples" in raw:
|
| 41 |
+
samples = raw["samples"]
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError("Unsupported JSON dataset format")
|
| 44 |
+
|
| 45 |
+
texts = []
|
| 46 |
+
labels = []
|
| 47 |
+
|
| 48 |
+
for i, item in enumerate(samples):
|
| 49 |
+
if "text" not in item or "label" not in item:
|
| 50 |
+
raise ValueError(f"Invalid sample at index {i}: {item}")
|
| 51 |
+
texts.append(item["text"])
|
| 52 |
+
labels.append(item["label"])
|
| 53 |
+
|
| 54 |
+
return texts, labels
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def evaluate(model, X, y):
|
| 58 |
+
y_pred = model.predict(X)
|
| 59 |
+
|
| 60 |
+
acc = accuracy_score(y, y_pred)
|
| 61 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 62 |
+
y, y_pred, average="weighted", zero_division=0
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
print("====================================")
|
| 66 |
+
print("Offline Evaluation Results")
|
| 67 |
+
print("====================================")
|
| 68 |
+
print(f"Samples : {len(y)}")
|
| 69 |
+
print(f"Accuracy : {acc:.4f}")
|
| 70 |
+
print(f"Precision: {precision:.4f}")
|
| 71 |
+
print(f"Recall : {recall:.4f}")
|
| 72 |
+
print(f"F1-score : {f1:.4f}")
|
| 73 |
+
print()
|
| 74 |
+
print("Detailed Classification Report")
|
| 75 |
+
print("------------------------------------")
|
| 76 |
+
print(classification_report(y, y_pred, zero_division=0))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
parser = argparse.ArgumentParser(
|
| 81 |
+
description="Offline evaluation using held-out JSON dataset"
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--data",
|
| 85 |
+
default=str(DATA_DIR / "samples" / "eval.json"),
|
| 86 |
+
help="Path to evaluation dataset (default: data/samples/eval.json)"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
args = parser.parse_args()
|
| 90 |
+
|
| 91 |
+
model = load_model()
|
| 92 |
+
X, y = load_dataset(Path(args.data))
|
| 93 |
+
evaluate(model, X, y)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main()
|
scripts/seed_data.py
CHANGED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Seed and split dataset for training and evaluation.
|
| 3 |
+
|
| 4 |
+
- Reads: data/samples/training_data.json
|
| 5 |
+
- Writes:
|
| 6 |
+
- data/samples/train.json
|
| 7 |
+
- data/samples/eval.json
|
| 8 |
+
|
| 9 |
+
This script enforces:
|
| 10 |
+
- Stratified split by label
|
| 11 |
+
- Deterministic output (fixed random seed)
|
| 12 |
+
- Basic data validation
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
|
| 20 |
+
# -------------------------
|
| 21 |
+
# Configuration
|
| 22 |
+
# -------------------------
|
| 23 |
+
RANDOM_SEED = 42
|
| 24 |
+
TRAIN_RATIO = 0.7
|
| 25 |
+
|
| 26 |
+
BASE_DIR = Path(__file__).resolve().parent.parent
|
| 27 |
+
SAMPLES_DIR = BASE_DIR / "data" / "samples"
|
| 28 |
+
|
| 29 |
+
SOURCE_FILE = SAMPLES_DIR / "training_data.json"
|
| 30 |
+
TRAIN_FILE = SAMPLES_DIR / "train.json"
|
| 31 |
+
EVAL_FILE = SAMPLES_DIR / "eval.json"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
if not SOURCE_FILE.exists():
|
| 36 |
+
raise FileNotFoundError(f"Source dataset not found: {SOURCE_FILE}")
|
| 37 |
+
|
| 38 |
+
with open(SOURCE_FILE, "r", encoding="utf-8") as f:
|
| 39 |
+
data = json.load(f)
|
| 40 |
+
|
| 41 |
+
if not isinstance(data, list) or len(data) == 0:
|
| 42 |
+
raise ValueError("Dataset must be a non-empty list")
|
| 43 |
+
|
| 44 |
+
# -------------------------
|
| 45 |
+
# Basic validation
|
| 46 |
+
# -------------------------
|
| 47 |
+
for i, item in enumerate(data):
|
| 48 |
+
if "text" not in item or "label" not in item:
|
| 49 |
+
raise ValueError(f"Invalid sample at index {i}: {item}")
|
| 50 |
+
|
| 51 |
+
# -------------------------
|
| 52 |
+
# Stratified split
|
| 53 |
+
# -------------------------
|
| 54 |
+
random.seed(RANDOM_SEED)
|
| 55 |
+
|
| 56 |
+
by_label = defaultdict(list)
|
| 57 |
+
for item in data:
|
| 58 |
+
by_label[item["label"]].append(item)
|
| 59 |
+
|
| 60 |
+
train_data = []
|
| 61 |
+
eval_data = []
|
| 62 |
+
|
| 63 |
+
for label, items in by_label.items():
|
| 64 |
+
random.shuffle(items)
|
| 65 |
+
|
| 66 |
+
split_idx = max(1, int(len(items) * TRAIN_RATIO))
|
| 67 |
+
|
| 68 |
+
train_data.extend(items[:split_idx])
|
| 69 |
+
eval_data.extend(items[split_idx:])
|
| 70 |
+
|
| 71 |
+
# Final shuffle (important)
|
| 72 |
+
random.shuffle(train_data)
|
| 73 |
+
random.shuffle(eval_data)
|
| 74 |
+
|
| 75 |
+
# -------------------------
|
| 76 |
+
# Write outputs
|
| 77 |
+
# -------------------------
|
| 78 |
+
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
|
| 79 |
+
|
| 80 |
+
with open(TRAIN_FILE, "w", encoding="utf-8") as f:
|
| 81 |
+
json.dump(train_data, f, indent=2, ensure_ascii=False)
|
| 82 |
+
|
| 83 |
+
with open(EVAL_FILE, "w", encoding="utf-8") as f:
|
| 84 |
+
json.dump(eval_data, f, indent=2, ensure_ascii=False)
|
| 85 |
+
|
| 86 |
+
# -------------------------
|
| 87 |
+
# Summary
|
| 88 |
+
# -------------------------
|
| 89 |
+
print("====================================")
|
| 90 |
+
print("Dataset seeding completed")
|
| 91 |
+
print("====================================")
|
| 92 |
+
print(f"Total samples : {len(data)}")
|
| 93 |
+
print(f"Train samples : {len(train_data)}")
|
| 94 |
+
print(f"Eval samples : {len(eval_data)}")
|
| 95 |
+
print()
|
| 96 |
+
|
| 97 |
+
print("Label distribution (train):")
|
| 98 |
+
_print_distribution(train_data)
|
| 99 |
+
|
| 100 |
+
print("\nLabel distribution (eval):")
|
| 101 |
+
_print_distribution(eval_data)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _print_distribution(dataset):
|
| 105 |
+
dist = defaultdict(int)
|
| 106 |
+
for item in dataset:
|
| 107 |
+
dist[item["label"]] += 1
|
| 108 |
+
|
| 109 |
+
for label, count in sorted(dist.items()):
|
| 110 |
+
print(f" {label:<20} {count}")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|