davidtran999 commited on
Commit
980fef7
·
verified ·
1 Parent(s): f5ba315

Upload backend/chatbot/training/train_intent.py with huggingface_hub

Browse files
backend/chatbot/training/train_intent.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ import sys
6
+ import time
7
+ from datetime import datetime
8
+
9
+ import joblib
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.naive_bayes import MultinomialNB
15
+ from sklearn.pipeline import Pipeline
16
+
17
+
18
+ ROOT_DIR = Path(__file__).resolve().parents[2]
19
+ if str(ROOT_DIR) not in sys.path:
20
+ sys.path.insert(0, str(ROOT_DIR))
21
+
22
+
23
+ BASE_DIR = Path(__file__).resolve().parent
24
+ DEFAULT_DATASET = BASE_DIR / "intent_dataset.json"
25
+ GENERATED_QA_DIR = BASE_DIR / "generated_qa"
26
+ ARTIFACT_DIR = BASE_DIR / "artifacts"
27
+ LOG_DIR = ROOT_DIR / "logs" / "intent"
28
+ ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
29
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
30
+
31
+
32
+ def load_dataset(path: Path):
33
+ payload = json.loads(path.read_text(encoding="utf-8"))
34
+ texts = []
35
+ labels = []
36
+ for intent in payload.get("intents", []):
37
+ name = intent["name"]
38
+ for example in intent.get("examples", []):
39
+ texts.append(example)
40
+ labels.append(name)
41
+ return texts, labels, payload
42
+
43
+
44
+ def load_generated_qa(directory: Path):
45
+ """
46
+ Load generated QA questions as additional intent training samples.
47
+
48
+ Each JSON file is expected to contain a list of objects compatible
49
+ with `QAItem` from `generated_qa`, at minimum having:
50
+ - question: str
51
+ - intent: str
52
+ """
53
+ texts: list[str] = []
54
+ labels: list[str] = []
55
+
56
+ if not directory.exists():
57
+ return texts, labels
58
+
59
+ for path in sorted(directory.glob("*.json")):
60
+ try:
61
+ payload = json.loads(path.read_text(encoding="utf-8"))
62
+ except Exception:
63
+ # Skip malformed files but continue loading others
64
+ continue
65
+ if not isinstance(payload, list):
66
+ continue
67
+ for item in payload:
68
+ if not isinstance(item, dict):
69
+ continue
70
+ question = str(item.get("question") or "").strip()
71
+ intent = str(item.get("intent") or "").strip() or "search_legal"
72
+ if not question:
73
+ continue
74
+ texts.append(question)
75
+ labels.append(intent)
76
+ return texts, labels
77
+
78
+
79
+ def load_combined_dataset(path: Path, generated_dir: Path):
80
+ """
81
+ Load seed intent dataset and merge with generated QA questions.
82
+ """
83
+ texts, labels, meta = load_dataset(path)
84
+ gen_texts, gen_labels = load_generated_qa(generated_dir)
85
+
86
+ texts.extend(gen_texts)
87
+ labels.extend(gen_labels)
88
+ return texts, labels, meta
89
+
90
+
91
+ def build_pipelines():
92
+ vectorizer = TfidfVectorizer(
93
+ analyzer="word",
94
+ ngram_range=(1, 2),
95
+ lowercase=True,
96
+ token_pattern=r"\b\w+\b",
97
+ )
98
+
99
+ nb_pipeline = Pipeline([
100
+ ("tfidf", vectorizer),
101
+ ("clf", MultinomialNB()),
102
+ ])
103
+
104
+ logreg_pipeline = Pipeline([
105
+ ("tfidf", vectorizer),
106
+ ("clf", LogisticRegression(max_iter=1000, solver="lbfgs")),
107
+ ])
108
+
109
+ return {
110
+ "multinomial_nb": nb_pipeline,
111
+ "logistic_regression": logreg_pipeline,
112
+ }
113
+
114
+
115
+ def train(dataset_path: Path, test_size: float = 0.2, random_state: int = 42):
116
+ texts, labels, meta = load_combined_dataset(dataset_path, GENERATED_QA_DIR)
117
+ if not texts:
118
+ raise ValueError("Dataset rỗng, không thể huấn luyện")
119
+
120
+ X_train, X_test, y_train, y_test = train_test_split(
121
+ texts, labels, test_size=test_size, random_state=random_state, stratify=labels
122
+ )
123
+
124
+ pipelines = build_pipelines()
125
+ best_model = None
126
+ best_metrics = None
127
+
128
+ for name, pipeline in pipelines.items():
129
+ start = time.perf_counter()
130
+ pipeline.fit(X_train, y_train)
131
+ train_duration = time.perf_counter() - start
132
+
133
+ y_pred = pipeline.predict(X_test)
134
+ acc = accuracy_score(y_test, y_pred)
135
+ report = classification_report(y_test, y_pred, output_dict=True)
136
+ cm = confusion_matrix(y_test, y_pred, labels=sorted(set(labels)))
137
+
138
+ metrics = {
139
+ "model": name,
140
+ "accuracy": acc,
141
+ "train_duration_sec": train_duration,
142
+ "classification_report": report,
143
+ "confusion_matrix": cm.tolist(),
144
+ "labels": sorted(set(labels)),
145
+ "dataset_version": meta.get("version"),
146
+ "timestamp": datetime.utcnow().isoformat() + "Z",
147
+ "test_size": test_size,
148
+ "samples": len(texts),
149
+ }
150
+
151
+ if best_model is None or acc > best_metrics["accuracy"]:
152
+ best_model = pipeline
153
+ best_metrics = metrics
154
+
155
+ assert best_model is not None
156
+
157
+ model_path = ARTIFACT_DIR / "intent_model.joblib"
158
+ metrics_path = ARTIFACT_DIR / "metrics.json"
159
+ joblib.dump(best_model, model_path)
160
+ metrics_path.write_text(json.dumps(best_metrics, ensure_ascii=False, indent=2), encoding="utf-8")
161
+
162
+ log_entry = {
163
+ "event": "train_intent",
164
+ "model": best_metrics["model"],
165
+ "accuracy": best_metrics["accuracy"],
166
+ "timestamp": best_metrics["timestamp"],
167
+ "samples": best_metrics["samples"],
168
+ "dataset_version": best_metrics["dataset_version"],
169
+ "artifact": str(model_path.relative_to(ROOT_DIR)),
170
+ }
171
+
172
+ log_file = LOG_DIR / "train.log"
173
+ with log_file.open("a", encoding="utf-8") as fh:
174
+ fh.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
175
+
176
+ return model_path, metrics_path, best_metrics
177
+
178
+
179
+ def parse_args():
180
+ parser = argparse.ArgumentParser(description="Huấn luyện model intent cho chatbot")
181
+ parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET, help="Đường dẫn tới intent_dataset.json")
182
+ parser.add_argument("--test-size", type=float, default=0.2, help="Tỉ lệ dữ liệu test")
183
+ parser.add_argument("--seed", type=int, default=42, help="Giá trị random seed")
184
+ return parser.parse_args()
185
+
186
+
187
+ def main():
188
+ args = parse_args()
189
+ model_path, metrics_path, metrics = train(args.dataset, test_size=args.test_size, random_state=args.seed)
190
+ print("Huấn luyện hoàn tất:")
191
+ print(f" Model: {metrics['model']}")
192
+ print(f" Accuracy: {metrics['accuracy']:.4f}")
193
+ print(f" Model artifact: {model_path}")
194
+ print(f" Metrics: {metrics_path}")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()