Spaces:
Sleeping
Sleeping
File size: 4,246 Bytes
9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb 58e2d3b 9d454eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | # file: train_demo_models.py
from __future__ import annotations
import os
import pickle
import numpy as np
from typing import Dict, Tuple, List
import nltk
from nltk.corpus import twitter_samples, stopwords
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, log_loss
from feature_extract import build_freqs, extract_features_2, extract_features_6
def _ensure_nltk():
try:
twitter_samples.fileids()
except LookupError:
nltk.download("twitter_samples", quiet=True)
try:
stopwords.words("english")
except LookupError:
nltk.download("stopwords", quiet=True)
def load_twitter_data() -> Tuple[List[str], np.ndarray]:
pos = twitter_samples.strings("positive_tweets.json")
neg = twitter_samples.strings("negative_tweets.json")
tweets = pos + neg
y = np.array([1] * len(pos) + [0] * len(neg))
return tweets, y
def vectorize(tweets: List[str], freqs: Dict[Tuple[str, float], float], mode: str = "2f") -> np.ndarray:
feat_fn = extract_features_2 if mode == "2f" else extract_features_6
rows = [feat_fn(t, freqs) for t in tweets]
return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
ALL_MODEL_SPECS: Dict[str, object] = {
"Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
"XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"),
"LightGBM": LGBMClassifier(random_state=42),
"SVM": SVC(kernel="linear", probability=True, random_state=42),
"Decision Tree": DecisionTreeClassifier(random_state=42),
"Naive Bayes": GaussianNB(),
"Logistic Regression": LogisticRegression(solver="liblinear", random_state=42),
}
def make_models(include: List[str] | None = None) -> Dict[str, object]:
if include is None:
return {k: v for k, v in ALL_MODEL_SPECS.items()}
return {k: ALL_MODEL_SPECS[k] for k in include}
def _fit_and_log(name: str, clf, X: np.ndarray, y: np.ndarray):
clf.fit(X, y.ravel())
y_pred = clf.predict(X)
acc = accuracy_score(y, y_pred)
try:
y_proba = clf.predict_proba(X)
loss = log_loss(y, y_proba)
print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}")
except Exception:
print(f"[{name}] Accuracy: {acc:.4f} | (no predict_proba)")
return clf
def train_models(X: np.ndarray, y: np.ndarray, include: List[str] | None = None) -> Dict[str, object]:
specs = make_models(include)
trained: Dict[str, object] = {}
for name, clf in specs.items():
trained[name] = _fit_and_log(name, clf, X, y)
return trained
def ensure_logreg_only(save_path: str = "demo_models.pkl"):
_ensure_nltk()
tweets, y = load_twitter_data()
if os.path.exists(save_path):
with open(save_path, "rb") as f:
data = pickle.load(f)
freqs = data.get("freqs")
models_2f: Dict[str, object] = data.get("2f", {})
models_6f: Dict[str, object] = data.get("6f", {})
else:
freqs = build_freqs(tweets, y.reshape(-1, 1))
models_2f, models_6f = {}, {}
X2 = vectorize(tweets, freqs, mode="2f")
X6 = vectorize(tweets, freqs, mode="6f")
if "Logistic Regression" not in models_2f:
new_models_2f = train_models(X2, y, include=["Logistic Regression"])
models_2f.update(new_models_2f)
if "Logistic Regression" not in models_6f:
new_models_6f = train_models(X6, y, include=["Logistic Regression"])
models_6f.update(new_models_6f)
data_to_save = {"freqs": freqs, "2f": models_2f, "6f": models_6f}
with open(save_path, "wb") as f:
pickle.dump(data_to_save, f)
return data_to_save
def load_demo_models(save_path: str = "demo_models.pkl"):
with open(save_path, "rb") as f:
data = pickle.load(f)
return data
if __name__ == "__main__":
models = ensure_logreg_only()
print("Các mô hình 2f:", list(models["2f"].keys()))
print("Các mô hình 6f:", list(models["6f"].keys()))
|