Spaces:
Sleeping
Sleeping
| # file: train_demo_models.py | |
| from __future__ import annotations | |
| import os | |
| import pickle | |
| import numpy as np | |
| from typing import Dict, Tuple, List | |
| import nltk | |
| from nltk.corpus import twitter_samples, stopwords | |
| from sklearn.ensemble import RandomForestClassifier | |
| from xgboost import XGBClassifier | |
| from lightgbm import LGBMClassifier | |
| from sklearn.svm import SVC | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.naive_bayes import GaussianNB | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import accuracy_score, log_loss | |
| from feature_extract import build_freqs, extract_features_2, extract_features_6 | |
| def _ensure_nltk(): | |
| try: | |
| twitter_samples.fileids() | |
| except LookupError: | |
| nltk.download("twitter_samples", quiet=True) | |
| try: | |
| stopwords.words("english") | |
| except LookupError: | |
| nltk.download("stopwords", quiet=True) | |
| def load_twitter_data() -> Tuple[List[str], np.ndarray]: | |
| pos = twitter_samples.strings("positive_tweets.json") | |
| neg = twitter_samples.strings("negative_tweets.json") | |
| tweets = pos + neg | |
| y = np.array([1] * len(pos) + [0] * len(neg)) | |
| return tweets, y | |
| def vectorize(tweets: List[str], freqs: Dict[Tuple[str, float], float], mode: str = "2f") -> np.ndarray: | |
| feat_fn = extract_features_2 if mode == "2f" else extract_features_6 | |
| rows = [feat_fn(t, freqs) for t in tweets] | |
| return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6)) | |
| ALL_MODEL_SPECS: Dict[str, object] = { | |
| "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42), | |
| "XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"), | |
| "LightGBM": LGBMClassifier(random_state=42), | |
| "SVM": SVC(kernel="linear", probability=True, random_state=42), | |
| "Decision Tree": DecisionTreeClassifier(random_state=42), | |
| "Naive Bayes": GaussianNB(), | |
| "Logistic Regression": LogisticRegression(solver="liblinear", random_state=42), | |
| } | |
| def make_models(include: List[str] | None = None) -> Dict[str, object]: | |
| if include is None: | |
| return {k: v for k, v in ALL_MODEL_SPECS.items()} | |
| return {k: ALL_MODEL_SPECS[k] for k in include} | |
| def _fit_and_log(name: str, clf, X: np.ndarray, y: np.ndarray): | |
| clf.fit(X, y.ravel()) | |
| y_pred = clf.predict(X) | |
| acc = accuracy_score(y, y_pred) | |
| try: | |
| y_proba = clf.predict_proba(X) | |
| loss = log_loss(y, y_proba) | |
| print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}") | |
| except Exception: | |
| print(f"[{name}] Accuracy: {acc:.4f} | (no predict_proba)") | |
| return clf | |
| def train_models(X: np.ndarray, y: np.ndarray, include: List[str] | None = None) -> Dict[str, object]: | |
| specs = make_models(include) | |
| trained: Dict[str, object] = {} | |
| for name, clf in specs.items(): | |
| trained[name] = _fit_and_log(name, clf, X, y) | |
| return trained | |
| def ensure_logreg_only(save_path: str = "demo_models.pkl"): | |
| _ensure_nltk() | |
| tweets, y = load_twitter_data() | |
| if os.path.exists(save_path): | |
| with open(save_path, "rb") as f: | |
| data = pickle.load(f) | |
| freqs = data.get("freqs") | |
| models_2f: Dict[str, object] = data.get("2f", {}) | |
| models_6f: Dict[str, object] = data.get("6f", {}) | |
| else: | |
| freqs = build_freqs(tweets, y.reshape(-1, 1)) | |
| models_2f, models_6f = {}, {} | |
| X2 = vectorize(tweets, freqs, mode="2f") | |
| X6 = vectorize(tweets, freqs, mode="6f") | |
| if "Logistic Regression" not in models_2f: | |
| new_models_2f = train_models(X2, y, include=["Logistic Regression"]) | |
| models_2f.update(new_models_2f) | |
| if "Logistic Regression" not in models_6f: | |
| new_models_6f = train_models(X6, y, include=["Logistic Regression"]) | |
| models_6f.update(new_models_6f) | |
| data_to_save = {"freqs": freqs, "2f": models_2f, "6f": models_6f} | |
| with open(save_path, "wb") as f: | |
| pickle.dump(data_to_save, f) | |
| return data_to_save | |
| def load_demo_models(save_path: str = "demo_models.pkl"): | |
| with open(save_path, "rb") as f: | |
| data = pickle.load(f) | |
| return data | |
| if __name__ == "__main__": | |
| models = ensure_logreg_only() | |
| print("Các mô hình 2f:", list(models["2f"].keys())) | |
| print("Các mô hình 6f:", list(models["6f"].keys())) | |