Spaces:
Sleeping
Sleeping
Commit ·
58e2d3b
1
Parent(s): 9d9f0fa
damn
Browse files- __pycache__/feature_extract.cpython-313.pyc +0 -0
- __pycache__/inference_demo.cpython-313.pyc +0 -0
- app.py +35 -26
- demo_models.pkl +2 -2
- inference_demo.py +48 -30
- training_model.py +54 -80
__pycache__/feature_extract.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/feature_extract.cpython-313.pyc and b/__pycache__/feature_extract.cpython-313.pyc differ
|
|
|
__pycache__/inference_demo.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/inference_demo.cpython-313.pyc and b/__pycache__/inference_demo.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -3,8 +3,10 @@ from llm_classification import get_answer
|
|
| 3 |
from inference_demo import (
|
| 4 |
predict_randomforest_2f, predict_xgboost_2f, predict_lightgbm_2f,
|
| 5 |
predict_svm_2f, predict_decisiontree_2f, predict_naivebayes_2f,
|
|
|
|
| 6 |
predict_randomforest_6f, predict_xgboost_6f, predict_lightgbm_6f,
|
| 7 |
predict_svm_6f, predict_decisiontree_6f, predict_naivebayes_6f,
|
|
|
|
| 8 |
)
|
| 9 |
|
| 10 |
PREDICT_FUNCS = {
|
|
@@ -14,12 +16,15 @@ PREDICT_FUNCS = {
|
|
| 14 |
("SVM", "2-feature"): predict_svm_2f,
|
| 15 |
("Decision Tree", "2-feature"): predict_decisiontree_2f,
|
| 16 |
("Naive Bayes", "2-feature"): predict_naivebayes_2f,
|
|
|
|
|
|
|
| 17 |
("Random Forest", "6-feature"): predict_randomforest_6f,
|
| 18 |
("XGBoost", "6-feature"): predict_xgboost_6f,
|
| 19 |
("LightGBM", "6-feature"): predict_lightgbm_6f,
|
| 20 |
("SVM", "6-feature"): predict_svm_6f,
|
| 21 |
("Decision Tree", "6-feature"): predict_decisiontree_6f,
|
| 22 |
("Naive Bayes", "6-feature"): predict_naivebayes_6f,
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
CLASSIFIERS = [
|
|
@@ -30,6 +35,7 @@ CLASSIFIERS = [
|
|
| 30 |
"📈 SVM",
|
| 31 |
"🌲 Decision Tree",
|
| 32 |
"📊 Naive Bayes",
|
|
|
|
| 33 |
"🤝 Ensemble"
|
| 34 |
]
|
| 35 |
FEATURE_VERSIONS = ["2-feature", "6-feature"]
|
|
@@ -63,56 +69,59 @@ def explain_features(version: str) -> str:
|
|
| 63 |
def infer(clf: str, version: str, text: str):
|
| 64 |
if not text.strip():
|
| 65 |
return {"⚠️ Please enter a sentence": 1.0}, ""
|
|
|
|
| 66 |
if clf == "🔮 Gemini":
|
| 67 |
y = get_answer(text)
|
| 68 |
-
if y == 1:
|
| 69 |
-
|
| 70 |
-
else:
|
| 71 |
-
label = {"Negative 😞": 1.0}
|
| 72 |
-
return label, ""
|
| 73 |
if clf == "🤝 Ensemble":
|
| 74 |
-
model_names = ["Random Forest", "XGBoost", "LightGBM", "SVM", "Decision Tree", "Naive Bayes"]
|
| 75 |
-
votes_detail = []
|
| 76 |
-
votes = []
|
| 77 |
for m in model_names:
|
| 78 |
func = PREDICT_FUNCS.get((m, version))
|
| 79 |
if func:
|
| 80 |
y = func(text)
|
| 81 |
votes.append(y)
|
| 82 |
votes_detail.append(f"- **{m}**: {'Positive 😀' if y == 1 else 'Negative 😞'}")
|
| 83 |
-
if
|
| 84 |
return {"No models available": 1.0}, ""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
label = {"Positive 😀": 1.0}
|
| 92 |
final = "### Final Ensemble Result: **Positive 😀**"
|
| 93 |
-
elif
|
| 94 |
label = {"Negative 😞": 1.0}
|
| 95 |
final = "### Final Ensemble Result: **Negative 😞**"
|
| 96 |
else:
|
| 97 |
label = {"Tie 🤔": 1.0}
|
| 98 |
final = "### Final Ensemble Result: **Tie 🤔**"
|
| 99 |
-
|
| 100 |
detail_md = (
|
| 101 |
f"{final}\n\n"
|
| 102 |
-
f"**Votes:** {
|
| 103 |
-
f"
|
| 104 |
-
f"**Individual model decisions:**\n{detail_text}"
|
| 105 |
)
|
| 106 |
return label, detail_md
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if func is None:
|
| 109 |
return {"Model not found": 1.0}, ""
|
| 110 |
y = func(text)
|
| 111 |
-
if y == 1:
|
| 112 |
-
label = {"Positive 😀": 1.0}
|
| 113 |
-
else:
|
| 114 |
-
label = {"Negative 😞": 1.0}
|
| 115 |
-
return label, ""
|
| 116 |
|
| 117 |
with gr.Blocks(
|
| 118 |
title="Sentiment Classifier Demo",
|
|
|
|
| 3 |
from inference_demo import (
|
| 4 |
predict_randomforest_2f, predict_xgboost_2f, predict_lightgbm_2f,
|
| 5 |
predict_svm_2f, predict_decisiontree_2f, predict_naivebayes_2f,
|
| 6 |
+
predict_logisticregression_2f,
|
| 7 |
predict_randomforest_6f, predict_xgboost_6f, predict_lightgbm_6f,
|
| 8 |
predict_svm_6f, predict_decisiontree_6f, predict_naivebayes_6f,
|
| 9 |
+
predict_logisticregression_6f,
|
| 10 |
)
|
| 11 |
|
| 12 |
PREDICT_FUNCS = {
|
|
|
|
| 16 |
("SVM", "2-feature"): predict_svm_2f,
|
| 17 |
("Decision Tree", "2-feature"): predict_decisiontree_2f,
|
| 18 |
("Naive Bayes", "2-feature"): predict_naivebayes_2f,
|
| 19 |
+
("Logistic Regression", "2-feature"): predict_logisticregression_2f,
|
| 20 |
+
|
| 21 |
("Random Forest", "6-feature"): predict_randomforest_6f,
|
| 22 |
("XGBoost", "6-feature"): predict_xgboost_6f,
|
| 23 |
("LightGBM", "6-feature"): predict_lightgbm_6f,
|
| 24 |
("SVM", "6-feature"): predict_svm_6f,
|
| 25 |
("Decision Tree", "6-feature"): predict_decisiontree_6f,
|
| 26 |
("Naive Bayes", "6-feature"): predict_naivebayes_6f,
|
| 27 |
+
("Logistic Regression", "6-feature"): predict_logisticregression_6f,
|
| 28 |
}
|
| 29 |
|
| 30 |
CLASSIFIERS = [
|
|
|
|
| 35 |
"📈 SVM",
|
| 36 |
"🌲 Decision Tree",
|
| 37 |
"📊 Naive Bayes",
|
| 38 |
+
"🧮 Logistic Regression",
|
| 39 |
"🤝 Ensemble"
|
| 40 |
]
|
| 41 |
FEATURE_VERSIONS = ["2-feature", "6-feature"]
|
|
|
|
| 69 |
def infer(clf: str, version: str, text: str):
|
| 70 |
if not text.strip():
|
| 71 |
return {"⚠️ Please enter a sentence": 1.0}, ""
|
| 72 |
+
|
| 73 |
if clf == "🔮 Gemini":
|
| 74 |
y = get_answer(text)
|
| 75 |
+
return ({"Positive 😀": 1.0} if y == 1 else {"Negative 😞": 1.0}), ""
|
| 76 |
+
|
|
|
|
|
|
|
|
|
|
| 77 |
if clf == "🤝 Ensemble":
|
| 78 |
+
model_names = ["Random Forest", "XGBoost", "LightGBM", "SVM", "Decision Tree", "Naive Bayes", "Logistic Regression"]
|
| 79 |
+
votes_detail, votes = [], []
|
|
|
|
| 80 |
for m in model_names:
|
| 81 |
func = PREDICT_FUNCS.get((m, version))
|
| 82 |
if func:
|
| 83 |
y = func(text)
|
| 84 |
votes.append(y)
|
| 85 |
votes_detail.append(f"- **{m}**: {'Positive 😀' if y == 1 else 'Negative 😞'}")
|
| 86 |
+
if not votes:
|
| 87 |
return {"No models available": 1.0}, ""
|
| 88 |
+
|
| 89 |
+
pos, total = sum(votes), len(votes)
|
| 90 |
+
neg = total - pos
|
| 91 |
+
pos_pct = 100 * pos / total
|
| 92 |
+
neg_pct = 100 * neg / total
|
| 93 |
+
|
| 94 |
+
if pos > neg:
|
| 95 |
label = {"Positive 😀": 1.0}
|
| 96 |
final = "### Final Ensemble Result: **Positive 😀**"
|
| 97 |
+
elif neg > pos:
|
| 98 |
label = {"Negative 😞": 1.0}
|
| 99 |
final = "### Final Ensemble Result: **Negative 😞**"
|
| 100 |
else:
|
| 101 |
label = {"Tie 🤔": 1.0}
|
| 102 |
final = "### Final Ensemble Result: **Tie 🤔**"
|
| 103 |
+
|
| 104 |
detail_md = (
|
| 105 |
f"{final}\n\n"
|
| 106 |
+
f"**Votes:** {pos} positive ({pos_pct:.1f}%) | {neg} negative ({neg_pct:.1f}%) out of {total} models.\n\n"
|
| 107 |
+
f"**Individual model decisions:**\n" + "\n".join(votes_detail)
|
|
|
|
| 108 |
)
|
| 109 |
return label, detail_md
|
| 110 |
+
|
| 111 |
+
base_name = (
|
| 112 |
+
clf.replace("🌳 ","")
|
| 113 |
+
.replace("⚡ ","")
|
| 114 |
+
.replace("💡 ","")
|
| 115 |
+
.replace("📈 ","")
|
| 116 |
+
.replace("🌲 ","")
|
| 117 |
+
.replace("📊 ","")
|
| 118 |
+
.replace("🧮 ","")
|
| 119 |
+
)
|
| 120 |
+
func = PREDICT_FUNCS.get((base_name, version))
|
| 121 |
if func is None:
|
| 122 |
return {"Model not found": 1.0}, ""
|
| 123 |
y = func(text)
|
| 124 |
+
return ({"Positive 😀": 1.0} if y == 1 else {"Negative 😞": 1.0}), ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
with gr.Blocks(
|
| 127 |
title="Sentiment Classifier Demo",
|
demo_models.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf799bce43df9e171189f86bee098ae0c9b4bb56be43a485f1146a319e78bc5a
|
| 3 |
+
size 4827607
|
inference_demo.py
CHANGED
|
@@ -1,46 +1,64 @@
|
|
| 1 |
import pickle
|
| 2 |
import numpy as np
|
| 3 |
-
from feature_extract import extract_features_2, extract_features_6
|
| 4 |
|
| 5 |
-
# ---- Load models + freqs ----
|
| 6 |
with open("demo_models.pkl", "rb") as f:
|
| 7 |
data = pickle.load(f)
|
| 8 |
|
| 9 |
freqs = data["freqs"]
|
| 10 |
-
models_2f = data
|
| 11 |
-
models_6f = data
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def _predict_2f(sentence: str, model_name: str) -> int:
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
return int(models_2f[model_name].predict(x)[0])
|
| 18 |
|
|
|
|
| 19 |
def _predict_6f(sentence: str, model_name: str) -> int:
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
return int(models_6f[model_name].predict(x)[0])
|
| 23 |
|
| 24 |
# 2-feature
|
| 25 |
-
def predict_randomforest_2f(sentence):
|
| 26 |
-
def predict_xgboost_2f(sentence):
|
| 27 |
-
def predict_lightgbm_2f(sentence):
|
| 28 |
-
def predict_svm_2f(sentence):
|
| 29 |
-
def predict_decisiontree_2f(sentence):
|
| 30 |
-
def predict_naivebayes_2f(sentence):
|
|
|
|
| 31 |
|
| 32 |
# 6-feature
|
| 33 |
-
def predict_randomforest_6f(sentence):
|
| 34 |
-
def predict_xgboost_6f(sentence):
|
| 35 |
-
def predict_lightgbm_6f(sentence):
|
| 36 |
-
def predict_svm_6f(sentence):
|
| 37 |
-
def predict_decisiontree_6f(sentence):
|
| 38 |
-
def predict_naivebayes_6f(sentence):
|
| 39 |
-
|
| 40 |
-
|
| 41 |
if __name__ == "__main__":
|
| 42 |
-
|
| 43 |
-
print("
|
| 44 |
-
print("
|
| 45 |
-
print("SVM 2f:", predict_svm_2f(
|
| 46 |
-
print("SVM 6f:", predict_svm_6f(
|
|
|
|
|
|
|
|
|
| 1 |
import pickle
|
| 2 |
import numpy as np
|
| 3 |
+
from feature_extract import extract_features_2, extract_features_6
|
| 4 |
|
|
|
|
| 5 |
with open("demo_models.pkl", "rb") as f:
|
| 6 |
data = pickle.load(f)
|
| 7 |
|
| 8 |
freqs = data["freqs"]
|
| 9 |
+
models_2f = data.get("2f", {})
|
| 10 |
+
models_6f = data.get("6f", {})
|
| 11 |
|
| 12 |
+
def _predict_with_model(sentence: str, model) -> int:
|
| 13 |
+
n = getattr(model, "n_features_in_", None)
|
| 14 |
+
if n == 6:
|
| 15 |
+
x = extract_features_6(sentence, freqs)
|
| 16 |
+
else:
|
| 17 |
+
x = extract_features_2(sentence, freqs)
|
| 18 |
+
return int(model.predict(x)[0])
|
| 19 |
+
|
| 20 |
+
def _smart_pick_model(model_name: str, prefer: str = "2f"):
|
| 21 |
+
if prefer == "6f":
|
| 22 |
+
model = models_6f.get(model_name) or models_2f.get(model_name)
|
| 23 |
+
else:
|
| 24 |
+
model = models_2f.get(model_name) or models_6f.get(model_name)
|
| 25 |
+
if model is None:
|
| 26 |
+
raise KeyError(f"Model '{model_name}' not found in saved models.")
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
# --- 2-feature API (sẽ tự xử lý nếu model thực tế là 6f) ---
|
| 30 |
def _predict_2f(sentence: str, model_name: str) -> int:
|
| 31 |
+
model = _smart_pick_model(model_name, prefer="2f")
|
| 32 |
+
return _predict_with_model(sentence, model)
|
|
|
|
| 33 |
|
| 34 |
+
# --- 6-feature API (sẽ tự xử lý nếu model thực tế là 2f) ---
|
| 35 |
def _predict_6f(sentence: str, model_name: str) -> int:
|
| 36 |
+
model = _smart_pick_model(model_name, prefer="6f")
|
| 37 |
+
return _predict_with_model(sentence, model)
|
|
|
|
| 38 |
|
| 39 |
# 2-feature
|
| 40 |
+
def predict_randomforest_2f(sentence): return _predict_2f(sentence, "Random Forest")
|
| 41 |
+
def predict_xgboost_2f(sentence): return _predict_2f(sentence, "XGBoost")
|
| 42 |
+
def predict_lightgbm_2f(sentence): return _predict_2f(sentence, "LightGBM")
|
| 43 |
+
def predict_svm_2f(sentence): return _predict_2f(sentence, "SVM")
|
| 44 |
+
def predict_decisiontree_2f(sentence): return _predict_2f(sentence, "Decision Tree")
|
| 45 |
+
def predict_naivebayes_2f(sentence): return _predict_2f(sentence, "Naive Bayes")
|
| 46 |
+
def predict_logisticregression_2f(sentence): return _predict_2f(sentence, "Logistic Regression")
|
| 47 |
|
| 48 |
# 6-feature
|
| 49 |
+
def predict_randomforest_6f(sentence): return _predict_6f(sentence, "Random Forest")
|
| 50 |
+
def predict_xgboost_6f(sentence): return _predict_6f(sentence, "XGBoost")
|
| 51 |
+
def predict_lightgbm_6f(sentence): return _predict_6f(sentence, "LightGBM")
|
| 52 |
+
def predict_svm_6f(sentence): return _predict_6f(sentence, "SVM")
|
| 53 |
+
def predict_decisiontree_6f(sentence): return _predict_6f(sentence, "Decision Tree")
|
| 54 |
+
def predict_naivebayes_6f(sentence): return _predict_6f(sentence, "Naive Bayes")
|
| 55 |
+
def predict_logisticregression_6f(sentence): return _predict_6f(sentence, "Logistic Regression")
|
| 56 |
+
|
| 57 |
if __name__ == "__main__":
|
| 58 |
+
s = "I love this new phone!"
|
| 59 |
+
print("RF 2f:", predict_randomforest_2f(s))
|
| 60 |
+
print("RF 6f:", predict_randomforest_6f(s))
|
| 61 |
+
print("SVM 2f:", predict_svm_2f(s))
|
| 62 |
+
print("SVM 6f:", predict_svm_6f(s))
|
| 63 |
+
print("LogReg 2f:", predict_logisticregression_2f(s))
|
| 64 |
+
print("LogReg 6f:", predict_logisticregression_6f(s))
|
training_model.py
CHANGED
|
@@ -1,25 +1,21 @@
|
|
| 1 |
# file: train_demo_models.py
|
| 2 |
from __future__ import annotations
|
| 3 |
-
|
| 4 |
import pickle
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Tuple, List
|
| 7 |
-
|
| 8 |
import nltk
|
| 9 |
from nltk.corpus import twitter_samples, stopwords
|
| 10 |
-
|
| 11 |
from sklearn.ensemble import RandomForestClassifier
|
| 12 |
from xgboost import XGBClassifier
|
| 13 |
from lightgbm import LGBMClassifier
|
| 14 |
from sklearn.svm import SVC
|
| 15 |
from sklearn.tree import DecisionTreeClassifier
|
| 16 |
from sklearn.naive_bayes import GaussianNB
|
| 17 |
-
|
| 18 |
from sklearn.metrics import accuracy_score, log_loss
|
|
|
|
| 19 |
|
| 20 |
-
from feature_extract import build_freqs, extract_features_2, extract_features_6
|
| 21 |
-
|
| 22 |
-
# -------------------- NLTK setup --------------------
|
| 23 |
def _ensure_nltk():
|
| 24 |
try:
|
| 25 |
twitter_samples.fileids()
|
|
@@ -30,7 +26,6 @@ def _ensure_nltk():
|
|
| 30 |
except LookupError:
|
| 31 |
nltk.download("stopwords", quiet=True)
|
| 32 |
|
| 33 |
-
# -------------------- Data prep --------------------
|
| 34 |
def load_twitter_data() -> Tuple[List[str], np.ndarray]:
|
| 35 |
pos = twitter_samples.strings("positive_tweets.json")
|
| 36 |
neg = twitter_samples.strings("negative_tweets.json")
|
|
@@ -38,97 +33,76 @@ def load_twitter_data() -> Tuple[List[str], np.ndarray]:
|
|
| 38 |
y = np.array([1] * len(pos) + [0] * len(neg))
|
| 39 |
return tweets, y
|
| 40 |
|
| 41 |
-
def vectorize(tweets: List[str],
|
| 42 |
-
freqs: Dict[Tuple[str, float], float],
|
| 43 |
-
mode: str = "2f") -> np.ndarray:
|
| 44 |
-
"""mode: '2f' -> extract_features_2, '6f' -> extract_features_6"""
|
| 45 |
feat_fn = extract_features_2 if mode == "2f" else extract_features_6
|
| 46 |
rows = [feat_fn(t, freqs) for t in tweets]
|
| 47 |
return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
acc
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
else:
|
| 82 |
-
print(f"[{name}] Accuracy: {acc:.4f} | (không có predict_proba để tính log_loss)")
|
| 83 |
-
print("=" * 60)
|
| 84 |
return trained
|
| 85 |
|
| 86 |
-
def
|
| 87 |
-
"""
|
| 88 |
-
Train và lưu mô hình + freqs ra file pickle.
|
| 89 |
-
Trả về:
|
| 90 |
-
{
|
| 91 |
-
'freqs': freqs,
|
| 92 |
-
'2f': {model_name: trained_model, ...},
|
| 93 |
-
'6f': {model_name: trained_model, ...}
|
| 94 |
-
}
|
| 95 |
-
"""
|
| 96 |
_ensure_nltk()
|
| 97 |
tweets, y = load_twitter_data()
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
X2 = vectorize(tweets, freqs, mode="2f")
|
| 102 |
X6 = vectorize(tweets, freqs, mode="6f")
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
data_to_save = {
|
| 111 |
-
"freqs": freqs,
|
| 112 |
-
"2f": models_2f,
|
| 113 |
-
"6f": models_6f,
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
# lưu file pickle
|
| 117 |
with open(save_path, "wb") as f:
|
| 118 |
pickle.dump(data_to_save, f)
|
| 119 |
-
|
| 120 |
-
print(f"\nĐã train và lưu mô hình + freqs vào file: {save_path}")
|
| 121 |
return data_to_save
|
| 122 |
|
| 123 |
-
# -------------------- Load --------------------
|
| 124 |
def load_demo_models(save_path: str = "demo_models.pkl"):
|
| 125 |
-
"""Load lại mô hình + freqs từ file pickle."""
|
| 126 |
with open(save_path, "rb") as f:
|
| 127 |
data = pickle.load(f)
|
| 128 |
return data
|
| 129 |
|
| 130 |
-
# -------------------- CLI --------------------
|
| 131 |
if __name__ == "__main__":
|
| 132 |
-
models =
|
| 133 |
print("Các mô hình 2f:", list(models["2f"].keys()))
|
| 134 |
print("Các mô hình 6f:", list(models["6f"].keys()))
|
|
|
|
| 1 |
# file: train_demo_models.py
|
| 2 |
from __future__ import annotations
|
| 3 |
+
import os
|
| 4 |
import pickle
|
| 5 |
import numpy as np
|
| 6 |
from typing import Dict, Tuple, List
|
|
|
|
| 7 |
import nltk
|
| 8 |
from nltk.corpus import twitter_samples, stopwords
|
|
|
|
| 9 |
from sklearn.ensemble import RandomForestClassifier
|
| 10 |
from xgboost import XGBClassifier
|
| 11 |
from lightgbm import LGBMClassifier
|
| 12 |
from sklearn.svm import SVC
|
| 13 |
from sklearn.tree import DecisionTreeClassifier
|
| 14 |
from sklearn.naive_bayes import GaussianNB
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
from sklearn.metrics import accuracy_score, log_loss
|
| 17 |
+
from feature_extract import build_freqs, extract_features_2, extract_features_6
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
def _ensure_nltk():
|
| 20 |
try:
|
| 21 |
twitter_samples.fileids()
|
|
|
|
| 26 |
except LookupError:
|
| 27 |
nltk.download("stopwords", quiet=True)
|
| 28 |
|
|
|
|
| 29 |
def load_twitter_data() -> Tuple[List[str], np.ndarray]:
|
| 30 |
pos = twitter_samples.strings("positive_tweets.json")
|
| 31 |
neg = twitter_samples.strings("negative_tweets.json")
|
|
|
|
| 33 |
y = np.array([1] * len(pos) + [0] * len(neg))
|
| 34 |
return tweets, y
|
| 35 |
|
| 36 |
+
def vectorize(tweets: List[str], freqs: Dict[Tuple[str, float], float], mode: str = "2f") -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
| 37 |
feat_fn = extract_features_2 if mode == "2f" else extract_features_6
|
| 38 |
rows = [feat_fn(t, freqs) for t in tweets]
|
| 39 |
return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
|
| 40 |
|
| 41 |
+
ALL_MODEL_SPECS: Dict[str, object] = {
|
| 42 |
+
"Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
|
| 43 |
+
"XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"),
|
| 44 |
+
"LightGBM": LGBMClassifier(random_state=42),
|
| 45 |
+
"SVM": SVC(kernel="linear", probability=True, random_state=42),
|
| 46 |
+
"Decision Tree": DecisionTreeClassifier(random_state=42),
|
| 47 |
+
"Naive Bayes": GaussianNB(),
|
| 48 |
+
"Logistic Regression": LogisticRegression(solver="liblinear", random_state=42),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def make_models(include: List[str] | None = None) -> Dict[str, object]:
|
| 52 |
+
if include is None:
|
| 53 |
+
return {k: v for k, v in ALL_MODEL_SPECS.items()}
|
| 54 |
+
return {k: ALL_MODEL_SPECS[k] for k in include}
|
| 55 |
+
|
| 56 |
+
def _fit_and_log(name: str, clf, X: np.ndarray, y: np.ndarray):
|
| 57 |
+
clf.fit(X, y.ravel())
|
| 58 |
+
y_pred = clf.predict(X)
|
| 59 |
+
acc = accuracy_score(y, y_pred)
|
| 60 |
+
try:
|
| 61 |
+
y_proba = clf.predict_proba(X)
|
| 62 |
+
loss = log_loss(y, y_proba)
|
| 63 |
+
print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}")
|
| 64 |
+
except Exception:
|
| 65 |
+
print(f"[{name}] Accuracy: {acc:.4f} | (no predict_proba)")
|
| 66 |
+
return clf
|
| 67 |
+
|
| 68 |
+
def train_models(X: np.ndarray, y: np.ndarray, include: List[str] | None = None) -> Dict[str, object]:
|
| 69 |
+
specs = make_models(include)
|
| 70 |
+
trained: Dict[str, object] = {}
|
| 71 |
+
for name, clf in specs.items():
|
| 72 |
+
trained[name] = _fit_and_log(name, clf, X, y)
|
|
|
|
|
|
|
|
|
|
| 73 |
return trained
|
| 74 |
|
| 75 |
+
def ensure_logreg_only(save_path: str = "demo_models.pkl"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
_ensure_nltk()
|
| 77 |
tweets, y = load_twitter_data()
|
| 78 |
+
if os.path.exists(save_path):
|
| 79 |
+
with open(save_path, "rb") as f:
|
| 80 |
+
data = pickle.load(f)
|
| 81 |
+
freqs = data.get("freqs")
|
| 82 |
+
models_2f: Dict[str, object] = data.get("2f", {})
|
| 83 |
+
models_6f: Dict[str, object] = data.get("6f", {})
|
| 84 |
+
else:
|
| 85 |
+
freqs = build_freqs(tweets, y.reshape(-1, 1))
|
| 86 |
+
models_2f, models_6f = {}, {}
|
| 87 |
X2 = vectorize(tweets, freqs, mode="2f")
|
| 88 |
X6 = vectorize(tweets, freqs, mode="6f")
|
| 89 |
+
if "Logistic Regression" not in models_2f:
|
| 90 |
+
new_models_2f = train_models(X2, y, include=["Logistic Regression"])
|
| 91 |
+
models_2f.update(new_models_2f)
|
| 92 |
+
if "Logistic Regression" not in models_6f:
|
| 93 |
+
new_models_6f = train_models(X6, y, include=["Logistic Regression"])
|
| 94 |
+
models_6f.update(new_models_6f)
|
| 95 |
+
data_to_save = {"freqs": freqs, "2f": models_2f, "6f": models_6f}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
with open(save_path, "wb") as f:
|
| 97 |
pickle.dump(data_to_save, f)
|
|
|
|
|
|
|
| 98 |
return data_to_save
|
| 99 |
|
|
|
|
| 100 |
def load_demo_models(save_path: str = "demo_models.pkl"):
|
|
|
|
| 101 |
with open(save_path, "rb") as f:
|
| 102 |
data = pickle.load(f)
|
| 103 |
return data
|
| 104 |
|
|
|
|
| 105 |
if __name__ == "__main__":
|
| 106 |
+
models = ensure_logreg_only()
|
| 107 |
print("Các mô hình 2f:", list(models["2f"].keys()))
|
| 108 |
print("Các mô hình 6f:", list(models["6f"].keys()))
|