NhatHuy1110 commited on
Commit
3f14ab3
·
verified ·
1 Parent(s): a1bbbd5

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +200 -0
train.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ import os
3
+ import json
4
+ import joblib
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pathlib import Path
8
+ from datasets import load_dataset
9
+ from sentence_transformers import SentenceTransformer
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.metrics import accuracy_score, classification_report
12
+ from sklearn.neighbors import NearestNeighbors
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ import lightgbm as lgb
15
+ import re
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+
20
+ ARTIFACTS = Path("artifacts")
21
+ ARTIFACTS.mkdir(parents=True, exist_ok=True)
22
+
23
+ # ------------------------------
24
+ # 1) Load & filter data
25
+ # ------------------------------
26
+ def clean_text(s: str) -> str:
27
+ s = s.replace("\n", " ")
28
+ s = re.sub(r"[^\w\s]", " ", s)
29
+ s = re.sub(r"\d+", " ", s)
30
+ s = re.sub(r"\s+", " ", s).strip().lower()
31
+ return s
32
+
33
+ def load_arxiv_subset(max_docs_per_class=600, seed=42):
34
+ ds = load_dataset("UniverseTBD/arxiv-abstracts-large", split="train")
35
+ print("Available columns:", ds.column_names[:15]) # <-- debug xem tên cột
36
+
37
+ wanted = ["astro-ph", "cond-mat", "cs", "math", "physics"]
38
+
39
+ # Cột abstract có thể khác tên (vd. 'abs' hoặc 'text')
40
+ abstract_field = None
41
+ for cand in ["abstract", "abs", "text", "summary", "content"]:
42
+ if cand in ds.column_names:
43
+ abstract_field = cand
44
+ break
45
+ if not abstract_field:
46
+ raise ValueError("❌ Không tìm thấy cột chứa abstract trong dataset.")
47
+
48
+ rows = []
49
+ per_class_cnt = {k: 0 for k in wanted}
50
+ for r in ds:
51
+ labs = r.get("categories", []) or []
52
+ # Kiểm tra categories có dạng list hay string
53
+ if isinstance(labs, str):
54
+ labs = [labs]
55
+
56
+ labs = [c for c in labs if c in wanted]
57
+ if len(labs) != 1:
58
+ continue
59
+ lab = labs[0]
60
+
61
+ if per_class_cnt[lab] >= max_docs_per_class:
62
+ continue
63
+
64
+ abs_text = (r.get("abstract") or "").strip()
65
+ if len(abs_text) < 40:
66
+ continue
67
+
68
+ rows.append({
69
+ "title": r.get("title", ""),
70
+ "abstract": abs_text,
71
+ "label": lab,
72
+ })
73
+ per_class_cnt[lab] += 1
74
+
75
+ if all(v >= max_docs_per_class for v in per_class_cnt.values()):
76
+ break
77
+
78
+ # ✅ Kiểm tra kết quả
79
+ if not rows:
80
+ raise ValueError("❌ Không lấy được mẫu nào! Kiểm tra giá trị trong cột 'categories' có trùng với wanted không.")
81
+
82
+ df = pd.DataFrame(rows)
83
+ print("✅ Sample rows:")
84
+ print(df.head())
85
+
86
+ df["abstract_clean"] = df["abstract"].apply(clean_text)
87
+ print(f"✅ Loaded {len(df)} samples.")
88
+ return df
89
+
90
+ # ------------------------------
91
+ # 2) Embedding model
92
+ # ------------------------------
93
+ EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
94
+ def encode_texts(model, texts, batch_size=64, normalize=True):
95
+ prompts = [f"passage: {t}" for t in texts]
96
+ emb = model.encode(
97
+ prompts,
98
+ batch_size=batch_size,
99
+ show_progress_bar=True,
100
+ normalize_embeddings=normalize,
101
+ )
102
+ return np.array(emb, dtype=np.float32)
103
+
104
+ # ------------------------------
105
+ # 3) Train & export
106
+ # ------------------------------
107
+ def main():
108
+ print("Loading data ...")
109
+ df = load_arxiv_subset(max_docs_per_class=600) # tổng ~3k mẫu
110
+ label_names = sorted(df["label"].unique())
111
+ label2id = {lb: i for i, lb in enumerate(label_names)}
112
+ y_full = df["label"].map(label2id).values
113
+ X_full = df["abstract_clean"].values
114
+
115
+ X_train_txt, X_test_txt, y_train, y_test, meta_train, meta_test = train_test_split(
116
+ X_full, y_full, df[["title", "abstract", "label"]].values,
117
+ test_size=0.2, stratify=y_full, random_state=42
118
+ )
119
+
120
+ print("Loading embedding model ...")
121
+ emb_model = SentenceTransformer(EMB_MODEL_NAME)
122
+
123
+ print("Encoding train/test ...")
124
+ X_train = encode_texts(emb_model, list(X_train_txt))
125
+ X_test = encode_texts(emb_model, list(X_test_txt))
126
+
127
+ print("Training LightGBM ...")
128
+ clf = lgb.LGBMClassifier(
129
+ boosting_type="gbdt", # goss/dart cũng được
130
+ n_estimators=800,
131
+ learning_rate=0.05,
132
+ max_depth=-1,
133
+ subsample=0.9,
134
+ colsample_bytree=0.9,
135
+ random_state=42,
136
+ n_jobs=-1,
137
+ )
138
+ clf.fit(X_train, y_train)
139
+ preds = clf.predict(X_test)
140
+ acc = accuracy_score(y_test, preds)
141
+ print(f"Accuracy (embeddings + LGBM): {acc:.4f}")
142
+ print(classification_report(y_test, preds, target_names=label_names))
143
+
144
+ # --------------------------
145
+ # Similarity index (cosine)
146
+ # --------------------------
147
+ print("Fitting NearestNeighbors index ...")
148
+ nn = NearestNeighbors(n_neighbors=5, metric="cosine", n_jobs=-1)
149
+ nn.fit(X_train) # index trên embeddings train
150
+
151
+ # --------------------------
152
+ # Class keywords by TF-IDF
153
+ # --------------------------
154
+ print("Building class-wise TF-IDF keywords ...")
155
+ tfidf = TfidfVectorizer(
156
+ stop_words="english",
157
+ max_df=0.9,
158
+ min_df=3,
159
+ max_features=3000,
160
+ )
161
+ tfidf.fit(X_train_txt)
162
+
163
+ # top words mỗi class = từ có mean TF-IDF cao nhất trong class
164
+ class_keywords = {}
165
+ vocab = np.array(tfidf.get_feature_names_out())
166
+ X_tfidf_train = tfidf.transform(X_train_txt)
167
+ for lb, idx in label2id.items():
168
+ rows = (y_train == idx)
169
+ if rows.sum() == 0:
170
+ class_keywords[lb] = []
171
+ continue
172
+ mean_scores = np.asarray(X_tfidf_train[rows].mean(axis=0)).ravel()
173
+ top_idx = np.argsort(mean_scores)[-20:][::-1]
174
+ class_keywords[lb] = vocab[top_idx].tolist()
175
+
176
+ # --------------------------
177
+ # Export artifacts
178
+ # --------------------------
179
+ print("Saving artifacts ...")
180
+ joblib.dump(clf, ARTIFACTS/"lgbm_model.pkl")
181
+ (ARTIFACTS/"emb_model_name.txt").write_text(EMB_MODEL_NAME)
182
+ joblib.dump(nn, ARTIFACTS/"nn_index.pkl")
183
+ joblib.dump(tfidf, ARTIFACTS/"tfidf_explainer.pkl")
184
+ json.dump(label_names, open(ARTIFACTS/"label_names.json", "w"))
185
+ json.dump(
186
+ {
187
+ "train_titles": [t for t, a, l in meta_train],
188
+ "train_abstracts": [a for t, a, l in meta_train],
189
+ "train_labels": [str(l) for t, a, l in meta_train],
190
+ },
191
+ open(ARTIFACTS/"train_meta.json", "w"),
192
+ )
193
+ json.dump(class_keywords, open(ARTIFACTS/"class_keywords.json", "w"))
194
+ (ARTIFACTS/"readme.txt").write_text(
195
+ f"Accuracy: {acc:.4f}\nModel: LightGBM + {EMB_MODEL_NAME}\n"
196
+ )
197
+ print("Done.")
198
+
199
+ if __name__ == "__main__":
200
+ main()