Spaces:
Runtime error
Runtime error
| # train.py | |
| import os | |
| import json | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, classification_report | |
| from sklearn.neighbors import NearestNeighbors | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import lightgbm as lgb | |
| import re | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| ARTIFACTS = Path("artifacts") | |
| ARTIFACTS.mkdir(parents=True, exist_ok=True) | |
| # ------------------------------ | |
| # 1) Load & filter data | |
| # ------------------------------ | |
| def clean_text(s: str) -> str: | |
| s = s.replace("\n", " ") | |
| s = re.sub(r"[^\w\s]", " ", s) | |
| s = re.sub(r"\d+", " ", s) | |
| s = re.sub(r"\s+", " ", s).strip().lower() | |
| return s | |
| def load_arxiv_subset(max_docs_per_class=600, seed=42): | |
| ds = load_dataset("UniverseTBD/arxiv-abstracts-large", split="train") | |
| print("Available columns:", ds.column_names[:15]) # <-- debug xem tên cột | |
| wanted = ["astro-ph", "cond-mat", "cs", "math", "physics"] | |
| # Cột abstract có thể khác tên (vd. 'abs' hoặc 'text') | |
| abstract_field = None | |
| for cand in ["abstract", "abs", "text", "summary", "content"]: | |
| if cand in ds.column_names: | |
| abstract_field = cand | |
| break | |
| if not abstract_field: | |
| raise ValueError("❌ Không tìm thấy cột chứa abstract trong dataset.") | |
| rows = [] | |
| per_class_cnt = {k: 0 for k in wanted} | |
| for r in ds: | |
| labs = r.get("categories", []) or [] | |
| # Kiểm tra categories có dạng list hay string | |
| if isinstance(labs, str): | |
| labs = [labs] | |
| labs = [c for c in labs if c in wanted] | |
| if len(labs) != 1: | |
| continue | |
| lab = labs[0] | |
| if per_class_cnt[lab] >= max_docs_per_class: | |
| continue | |
| abs_text = (r.get("abstract") or "").strip() | |
| if len(abs_text) < 40: | |
| continue | |
| rows.append({ | |
| "title": r.get("title", ""), | |
| "abstract": abs_text, | |
| "label": lab, | |
| }) | |
| per_class_cnt[lab] += 1 | |
| if all(v >= max_docs_per_class for v in per_class_cnt.values()): | |
| break | |
| # ✅ Kiểm tra kết quả | |
| if not rows: | |
| raise ValueError("❌ Không lấy được mẫu nào! Kiểm tra giá trị trong cột 'categories' có trùng với wanted không.") | |
| df = pd.DataFrame(rows) | |
| print("✅ Sample rows:") | |
| print(df.head()) | |
| df["abstract_clean"] = df["abstract"].apply(clean_text) | |
| print(f"✅ Loaded {len(df)} samples.") | |
| return df | |
| # ------------------------------ | |
| # 2) Embedding model | |
| # ------------------------------ | |
| EMB_MODEL_NAME = "intfloat/multilingual-e5-base" | |
| def encode_texts(model, texts, batch_size=64, normalize=True): | |
| prompts = [f"passage: {t}" for t in texts] | |
| emb = model.encode( | |
| prompts, | |
| batch_size=batch_size, | |
| show_progress_bar=True, | |
| normalize_embeddings=normalize, | |
| ) | |
| return np.array(emb, dtype=np.float32) | |
| # ------------------------------ | |
| # 3) Train & export | |
| # ------------------------------ | |
| def main(): | |
| print("Loading data ...") | |
| df = load_arxiv_subset(max_docs_per_class=600) # tổng ~3k mẫu | |
| label_names = sorted(df["label"].unique()) | |
| label2id = {lb: i for i, lb in enumerate(label_names)} | |
| y_full = df["label"].map(label2id).values | |
| X_full = df["abstract_clean"].values | |
| X_train_txt, X_test_txt, y_train, y_test, meta_train, meta_test = train_test_split( | |
| X_full, y_full, df[["title", "abstract", "label"]].values, | |
| test_size=0.2, stratify=y_full, random_state=42 | |
| ) | |
| print("Loading embedding model ...") | |
| emb_model = SentenceTransformer(EMB_MODEL_NAME) | |
| print("Encoding train/test ...") | |
| X_train = encode_texts(emb_model, list(X_train_txt)) | |
| X_test = encode_texts(emb_model, list(X_test_txt)) | |
| print("Training LightGBM ...") | |
| clf = lgb.LGBMClassifier( | |
| boosting_type="gbdt", # goss/dart cũng được | |
| n_estimators=800, | |
| learning_rate=0.05, | |
| max_depth=-1, | |
| subsample=0.9, | |
| colsample_bytree=0.9, | |
| random_state=42, | |
| n_jobs=-1, | |
| ) | |
| clf.fit(X_train, y_train) | |
| preds = clf.predict(X_test) | |
| acc = accuracy_score(y_test, preds) | |
| print(f"Accuracy (embeddings + LGBM): {acc:.4f}") | |
| print(classification_report(y_test, preds, target_names=label_names)) | |
| # -------------------------- | |
| # Similarity index (cosine) | |
| # -------------------------- | |
| print("Fitting NearestNeighbors index ...") | |
| nn = NearestNeighbors(n_neighbors=5, metric="cosine", n_jobs=-1) | |
| nn.fit(X_train) # index trên embeddings train | |
| # -------------------------- | |
| # Class keywords by TF-IDF | |
| # -------------------------- | |
| print("Building class-wise TF-IDF keywords ...") | |
| tfidf = TfidfVectorizer( | |
| stop_words="english", | |
| max_df=0.9, | |
| min_df=3, | |
| max_features=3000, | |
| ) | |
| tfidf.fit(X_train_txt) | |
| # top words mỗi class = từ có mean TF-IDF cao nhất trong class | |
| class_keywords = {} | |
| vocab = np.array(tfidf.get_feature_names_out()) | |
| X_tfidf_train = tfidf.transform(X_train_txt) | |
| for lb, idx in label2id.items(): | |
| rows = (y_train == idx) | |
| if rows.sum() == 0: | |
| class_keywords[lb] = [] | |
| continue | |
| mean_scores = np.asarray(X_tfidf_train[rows].mean(axis=0)).ravel() | |
| top_idx = np.argsort(mean_scores)[-20:][::-1] | |
| class_keywords[lb] = vocab[top_idx].tolist() | |
| # -------------------------- | |
| # Export artifacts | |
| # -------------------------- | |
| print("Saving artifacts ...") | |
| joblib.dump(clf, ARTIFACTS/"lgbm_model.pkl") | |
| (ARTIFACTS/"emb_model_name.txt").write_text(EMB_MODEL_NAME) | |
| joblib.dump(nn, ARTIFACTS/"nn_index.pkl") | |
| joblib.dump(tfidf, ARTIFACTS/"tfidf_explainer.pkl") | |
| json.dump(label_names, open(ARTIFACTS/"label_names.json", "w")) | |
| json.dump( | |
| { | |
| "train_titles": [t for t, a, l in meta_train], | |
| "train_abstracts": [a for t, a, l in meta_train], | |
| "train_labels": [str(l) for t, a, l in meta_train], | |
| }, | |
| open(ARTIFACTS/"train_meta.json", "w"), | |
| ) | |
| json.dump(class_keywords, open(ARTIFACTS/"class_keywords.json", "w")) | |
| (ARTIFACTS/"readme.txt").write_text( | |
| f"Accuracy: {acc:.4f}\nModel: LightGBM + {EMB_MODEL_NAME}\n" | |
| ) | |
| print("Done.") | |
| if __name__ == "__main__": | |
| main() | |