ml_exercise / training_model.py
trantuan1701's picture
damn
58e2d3b
# file: train_demo_models.py
from __future__ import annotations
import os
import pickle
import numpy as np
from typing import Dict, Tuple, List
import nltk
from nltk.corpus import twitter_samples, stopwords
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, log_loss
from feature_extract import build_freqs, extract_features_2, extract_features_6
def _ensure_nltk():
try:
twitter_samples.fileids()
except LookupError:
nltk.download("twitter_samples", quiet=True)
try:
stopwords.words("english")
except LookupError:
nltk.download("stopwords", quiet=True)
def load_twitter_data() -> Tuple[List[str], np.ndarray]:
pos = twitter_samples.strings("positive_tweets.json")
neg = twitter_samples.strings("negative_tweets.json")
tweets = pos + neg
y = np.array([1] * len(pos) + [0] * len(neg))
return tweets, y
def vectorize(tweets: List[str], freqs: Dict[Tuple[str, float], float], mode: str = "2f") -> np.ndarray:
feat_fn = extract_features_2 if mode == "2f" else extract_features_6
rows = [feat_fn(t, freqs) for t in tweets]
return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
ALL_MODEL_SPECS: Dict[str, object] = {
"Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
"XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"),
"LightGBM": LGBMClassifier(random_state=42),
"SVM": SVC(kernel="linear", probability=True, random_state=42),
"Decision Tree": DecisionTreeClassifier(random_state=42),
"Naive Bayes": GaussianNB(),
"Logistic Regression": LogisticRegression(solver="liblinear", random_state=42),
}
def make_models(include: List[str] | None = None) -> Dict[str, object]:
if include is None:
return {k: v for k, v in ALL_MODEL_SPECS.items()}
return {k: ALL_MODEL_SPECS[k] for k in include}
def _fit_and_log(name: str, clf, X: np.ndarray, y: np.ndarray):
clf.fit(X, y.ravel())
y_pred = clf.predict(X)
acc = accuracy_score(y, y_pred)
try:
y_proba = clf.predict_proba(X)
loss = log_loss(y, y_proba)
print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}")
except Exception:
print(f"[{name}] Accuracy: {acc:.4f} | (no predict_proba)")
return clf
def train_models(X: np.ndarray, y: np.ndarray, include: List[str] | None = None) -> Dict[str, object]:
specs = make_models(include)
trained: Dict[str, object] = {}
for name, clf in specs.items():
trained[name] = _fit_and_log(name, clf, X, y)
return trained
def ensure_logreg_only(save_path: str = "demo_models.pkl"):
_ensure_nltk()
tweets, y = load_twitter_data()
if os.path.exists(save_path):
with open(save_path, "rb") as f:
data = pickle.load(f)
freqs = data.get("freqs")
models_2f: Dict[str, object] = data.get("2f", {})
models_6f: Dict[str, object] = data.get("6f", {})
else:
freqs = build_freqs(tweets, y.reshape(-1, 1))
models_2f, models_6f = {}, {}
X2 = vectorize(tweets, freqs, mode="2f")
X6 = vectorize(tweets, freqs, mode="6f")
if "Logistic Regression" not in models_2f:
new_models_2f = train_models(X2, y, include=["Logistic Regression"])
models_2f.update(new_models_2f)
if "Logistic Regression" not in models_6f:
new_models_6f = train_models(X6, y, include=["Logistic Regression"])
models_6f.update(new_models_6f)
data_to_save = {"freqs": freqs, "2f": models_2f, "6f": models_6f}
with open(save_path, "wb") as f:
pickle.dump(data_to_save, f)
return data_to_save
def load_demo_models(save_path: str = "demo_models.pkl"):
with open(save_path, "rb") as f:
data = pickle.load(f)
return data
if __name__ == "__main__":
models = ensure_logreg_only()
print("Các mô hình 2f:", list(models["2f"].keys()))
print("Các mô hình 6f:", list(models["6f"].keys()))