trantuan1701 commited on
Commit
58e2d3b
·
1 Parent(s): 9d9f0fa
__pycache__/feature_extract.cpython-313.pyc CHANGED
Binary files a/__pycache__/feature_extract.cpython-313.pyc and b/__pycache__/feature_extract.cpython-313.pyc differ
 
__pycache__/inference_demo.cpython-313.pyc CHANGED
Binary files a/__pycache__/inference_demo.cpython-313.pyc and b/__pycache__/inference_demo.cpython-313.pyc differ
 
app.py CHANGED
@@ -3,8 +3,10 @@ from llm_classification import get_answer
3
  from inference_demo import (
4
  predict_randomforest_2f, predict_xgboost_2f, predict_lightgbm_2f,
5
  predict_svm_2f, predict_decisiontree_2f, predict_naivebayes_2f,
 
6
  predict_randomforest_6f, predict_xgboost_6f, predict_lightgbm_6f,
7
  predict_svm_6f, predict_decisiontree_6f, predict_naivebayes_6f,
 
8
  )
9
 
10
  PREDICT_FUNCS = {
@@ -14,12 +16,15 @@ PREDICT_FUNCS = {
14
  ("SVM", "2-feature"): predict_svm_2f,
15
  ("Decision Tree", "2-feature"): predict_decisiontree_2f,
16
  ("Naive Bayes", "2-feature"): predict_naivebayes_2f,
 
 
17
  ("Random Forest", "6-feature"): predict_randomforest_6f,
18
  ("XGBoost", "6-feature"): predict_xgboost_6f,
19
  ("LightGBM", "6-feature"): predict_lightgbm_6f,
20
  ("SVM", "6-feature"): predict_svm_6f,
21
  ("Decision Tree", "6-feature"): predict_decisiontree_6f,
22
  ("Naive Bayes", "6-feature"): predict_naivebayes_6f,
 
23
  }
24
 
25
  CLASSIFIERS = [
@@ -30,6 +35,7 @@ CLASSIFIERS = [
30
  "📈 SVM",
31
  "🌲 Decision Tree",
32
  "📊 Naive Bayes",
 
33
  "🤝 Ensemble"
34
  ]
35
  FEATURE_VERSIONS = ["2-feature", "6-feature"]
@@ -63,56 +69,59 @@ def explain_features(version: str) -> str:
63
  def infer(clf: str, version: str, text: str):
64
  if not text.strip():
65
  return {"⚠️ Please enter a sentence": 1.0}, ""
 
66
  if clf == "🔮 Gemini":
67
  y = get_answer(text)
68
- if y == 1:
69
- label = {"Positive 😀": 1.0}
70
- else:
71
- label = {"Negative 😞": 1.0}
72
- return label, ""
73
  if clf == "🤝 Ensemble":
74
- model_names = ["Random Forest", "XGBoost", "LightGBM", "SVM", "Decision Tree", "Naive Bayes"]
75
- votes_detail = []
76
- votes = []
77
  for m in model_names:
78
  func = PREDICT_FUNCS.get((m, version))
79
  if func:
80
  y = func(text)
81
  votes.append(y)
82
  votes_detail.append(f"- **{m}**: {'Positive 😀' if y == 1 else 'Negative 😞'}")
83
- if len(votes) == 0:
84
  return {"No models available": 1.0}, ""
85
- positive_votes = sum(votes)
86
- negative_votes = len(votes) - positive_votes
87
- total = len(votes)
88
- positive_pct = 100 * positive_votes / total
89
- negative_pct = 100 * negative_votes / total
90
- if positive_votes > negative_votes:
 
91
  label = {"Positive 😀": 1.0}
92
  final = "### Final Ensemble Result: **Positive 😀**"
93
- elif negative_votes > positive_votes:
94
  label = {"Negative 😞": 1.0}
95
  final = "### Final Ensemble Result: **Negative 😞**"
96
  else:
97
  label = {"Tie 🤔": 1.0}
98
  final = "### Final Ensemble Result: **Tie 🤔**"
99
- detail_text = "\n".join(votes_detail)
100
  detail_md = (
101
  f"{final}\n\n"
102
- f"**Votes:** {positive_votes} positive ({positive_pct:.1f}%) | "
103
- f"{negative_votes} negative ({negative_pct:.1f}%) out of {total} models.\n\n"
104
- f"**Individual model decisions:**\n{detail_text}"
105
  )
106
  return label, detail_md
107
- func = PREDICT_FUNCS.get((clf.replace("🌳 ","").replace("⚡ ","").replace("💡 ","").replace("📈 ","").replace("🌲 ","").replace("📊 ",""), version))
 
 
 
 
 
 
 
 
 
 
108
  if func is None:
109
  return {"Model not found": 1.0}, ""
110
  y = func(text)
111
- if y == 1:
112
- label = {"Positive 😀": 1.0}
113
- else:
114
- label = {"Negative 😞": 1.0}
115
- return label, ""
116
 
117
  with gr.Blocks(
118
  title="Sentiment Classifier Demo",
 
3
  from inference_demo import (
4
  predict_randomforest_2f, predict_xgboost_2f, predict_lightgbm_2f,
5
  predict_svm_2f, predict_decisiontree_2f, predict_naivebayes_2f,
6
+ predict_logisticregression_2f,
7
  predict_randomforest_6f, predict_xgboost_6f, predict_lightgbm_6f,
8
  predict_svm_6f, predict_decisiontree_6f, predict_naivebayes_6f,
9
+ predict_logisticregression_6f,
10
  )
11
 
12
  PREDICT_FUNCS = {
 
16
  ("SVM", "2-feature"): predict_svm_2f,
17
  ("Decision Tree", "2-feature"): predict_decisiontree_2f,
18
  ("Naive Bayes", "2-feature"): predict_naivebayes_2f,
19
+ ("Logistic Regression", "2-feature"): predict_logisticregression_2f,
20
+
21
  ("Random Forest", "6-feature"): predict_randomforest_6f,
22
  ("XGBoost", "6-feature"): predict_xgboost_6f,
23
  ("LightGBM", "6-feature"): predict_lightgbm_6f,
24
  ("SVM", "6-feature"): predict_svm_6f,
25
  ("Decision Tree", "6-feature"): predict_decisiontree_6f,
26
  ("Naive Bayes", "6-feature"): predict_naivebayes_6f,
27
+ ("Logistic Regression", "6-feature"): predict_logisticregression_6f,
28
  }
29
 
30
  CLASSIFIERS = [
 
35
  "📈 SVM",
36
  "🌲 Decision Tree",
37
  "📊 Naive Bayes",
38
+ "🧮 Logistic Regression",
39
  "🤝 Ensemble"
40
  ]
41
  FEATURE_VERSIONS = ["2-feature", "6-feature"]
 
69
  def infer(clf: str, version: str, text: str):
70
  if not text.strip():
71
  return {"⚠️ Please enter a sentence": 1.0}, ""
72
+
73
  if clf == "🔮 Gemini":
74
  y = get_answer(text)
75
+ return ({"Positive 😀": 1.0} if y == 1 else {"Negative 😞": 1.0}), ""
76
+
 
 
 
77
  if clf == "🤝 Ensemble":
78
+ model_names = ["Random Forest", "XGBoost", "LightGBM", "SVM", "Decision Tree", "Naive Bayes", "Logistic Regression"]
79
+ votes_detail, votes = [], []
 
80
  for m in model_names:
81
  func = PREDICT_FUNCS.get((m, version))
82
  if func:
83
  y = func(text)
84
  votes.append(y)
85
  votes_detail.append(f"- **{m}**: {'Positive 😀' if y == 1 else 'Negative 😞'}")
86
+ if not votes:
87
  return {"No models available": 1.0}, ""
88
+
89
+ pos, total = sum(votes), len(votes)
90
+ neg = total - pos
91
+ pos_pct = 100 * pos / total
92
+ neg_pct = 100 * neg / total
93
+
94
+ if pos > neg:
95
  label = {"Positive 😀": 1.0}
96
  final = "### Final Ensemble Result: **Positive 😀**"
97
+ elif neg > pos:
98
  label = {"Negative 😞": 1.0}
99
  final = "### Final Ensemble Result: **Negative 😞**"
100
  else:
101
  label = {"Tie 🤔": 1.0}
102
  final = "### Final Ensemble Result: **Tie 🤔**"
103
+
104
  detail_md = (
105
  f"{final}\n\n"
106
+ f"**Votes:** {pos} positive ({pos_pct:.1f}%) | {neg} negative ({neg_pct:.1f}%) out of {total} models.\n\n"
107
+ f"**Individual model decisions:**\n" + "\n".join(votes_detail)
 
108
  )
109
  return label, detail_md
110
+
111
+ base_name = (
112
+ clf.replace("🌳 ","")
113
+ .replace("⚡ ","")
114
+ .replace("💡 ","")
115
+ .replace("📈 ","")
116
+ .replace("🌲 ","")
117
+ .replace("📊 ","")
118
+ .replace("🧮 ","")
119
+ )
120
+ func = PREDICT_FUNCS.get((base_name, version))
121
  if func is None:
122
  return {"Model not found": 1.0}, ""
123
  y = func(text)
124
+ return ({"Positive 😀": 1.0} if y == 1 else {"Negative 😞": 1.0}), ""
 
 
 
 
125
 
126
  with gr.Blocks(
127
  title="Sentiment Classifier Demo",
demo_models.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cf5cd9e9f927d6467888e9d249a99a086812f0c0a228a0b57407c2fe9eeb323d
3
- size 4826559
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf799bce43df9e171189f86bee098ae0c9b4bb56be43a485f1146a319e78bc5a
3
+ size 4827607
inference_demo.py CHANGED
@@ -1,46 +1,64 @@
1
  import pickle
2
  import numpy as np
3
- from feature_extract import extract_features_2, extract_features_6
4
 
5
- # ---- Load models + freqs ----
6
  with open("demo_models.pkl", "rb") as f:
7
  data = pickle.load(f)
8
 
9
  freqs = data["freqs"]
10
- models_2f = data["2f"]
11
- models_6f = data["6f"]
12
 
13
- # ---- Helper functions ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _predict_2f(sentence: str, model_name: str) -> int:
15
- """Trích 2-feature predict 0/1."""
16
- x = extract_features_2(sentence, freqs)
17
- return int(models_2f[model_name].predict(x)[0])
18
 
 
19
  def _predict_6f(sentence: str, model_name: str) -> int:
20
- """Trích 6-feature predict 0/1."""
21
- x = extract_features_6(sentence, freqs)
22
- return int(models_6f[model_name].predict(x)[0])
23
 
24
  # 2-feature
25
- def predict_randomforest_2f(sentence): return _predict_2f(sentence, "Random Forest")
26
- def predict_xgboost_2f(sentence): return _predict_2f(sentence, "XGBoost")
27
- def predict_lightgbm_2f(sentence): return _predict_2f(sentence, "LightGBM")
28
- def predict_svm_2f(sentence): return _predict_2f(sentence, "SVM")
29
- def predict_decisiontree_2f(sentence): return _predict_2f(sentence, "Decision Tree")
30
- def predict_naivebayes_2f(sentence): return _predict_2f(sentence, "Naive Bayes")
 
31
 
32
  # 6-feature
33
- def predict_randomforest_6f(sentence): return _predict_6f(sentence, "Random Forest")
34
- def predict_xgboost_6f(sentence): return _predict_6f(sentence, "XGBoost")
35
- def predict_lightgbm_6f(sentence): return _predict_6f(sentence, "LightGBM")
36
- def predict_svm_6f(sentence): return _predict_6f(sentence, "SVM")
37
- def predict_decisiontree_6f(sentence): return _predict_6f(sentence, "Decision Tree")
38
- def predict_naivebayes_6f(sentence): return _predict_6f(sentence, "Naive Bayes")
39
-
40
- # ---- Test nhanh ----
41
  if __name__ == "__main__":
42
- test_sentence = "I love this new phone!"
43
- print("RandomForest 2f:", predict_randomforest_2f(test_sentence))
44
- print("RandomForest 6f:", predict_randomforest_6f(test_sentence))
45
- print("SVM 2f:", predict_svm_2f(test_sentence))
46
- print("SVM 6f:", predict_svm_6f(test_sentence))
 
 
 
1
  import pickle
2
  import numpy as np
3
+ from feature_extract import extract_features_2, extract_features_6
4
 
 
5
  with open("demo_models.pkl", "rb") as f:
6
  data = pickle.load(f)
7
 
8
  freqs = data["freqs"]
9
+ models_2f = data.get("2f", {})
10
+ models_6f = data.get("6f", {})
11
 
12
+ def _predict_with_model(sentence: str, model) -> int:
13
+ n = getattr(model, "n_features_in_", None)
14
+ if n == 6:
15
+ x = extract_features_6(sentence, freqs)
16
+ else:
17
+ x = extract_features_2(sentence, freqs)
18
+ return int(model.predict(x)[0])
19
+
20
+ def _smart_pick_model(model_name: str, prefer: str = "2f"):
21
+ if prefer == "6f":
22
+ model = models_6f.get(model_name) or models_2f.get(model_name)
23
+ else:
24
+ model = models_2f.get(model_name) or models_6f.get(model_name)
25
+ if model is None:
26
+ raise KeyError(f"Model '{model_name}' not found in saved models.")
27
+ return model
28
+
29
+ # --- 2-feature API (sẽ tự xử lý nếu model thực tế là 6f) ---
30
  def _predict_2f(sentence: str, model_name: str) -> int:
31
+ model = _smart_pick_model(model_name, prefer="2f")
32
+ return _predict_with_model(sentence, model)
 
33
 
34
+ # --- 6-feature API (sẽ tự xử lý nếu model thực tế là 2f) ---
35
  def _predict_6f(sentence: str, model_name: str) -> int:
36
+ model = _smart_pick_model(model_name, prefer="6f")
37
+ return _predict_with_model(sentence, model)
 
38
 
39
  # 2-feature
40
+ def predict_randomforest_2f(sentence): return _predict_2f(sentence, "Random Forest")
41
+ def predict_xgboost_2f(sentence): return _predict_2f(sentence, "XGBoost")
42
+ def predict_lightgbm_2f(sentence): return _predict_2f(sentence, "LightGBM")
43
+ def predict_svm_2f(sentence): return _predict_2f(sentence, "SVM")
44
+ def predict_decisiontree_2f(sentence): return _predict_2f(sentence, "Decision Tree")
45
+ def predict_naivebayes_2f(sentence): return _predict_2f(sentence, "Naive Bayes")
46
+ def predict_logisticregression_2f(sentence): return _predict_2f(sentence, "Logistic Regression")
47
 
48
  # 6-feature
49
+ def predict_randomforest_6f(sentence): return _predict_6f(sentence, "Random Forest")
50
+ def predict_xgboost_6f(sentence): return _predict_6f(sentence, "XGBoost")
51
+ def predict_lightgbm_6f(sentence): return _predict_6f(sentence, "LightGBM")
52
+ def predict_svm_6f(sentence): return _predict_6f(sentence, "SVM")
53
+ def predict_decisiontree_6f(sentence): return _predict_6f(sentence, "Decision Tree")
54
+ def predict_naivebayes_6f(sentence): return _predict_6f(sentence, "Naive Bayes")
55
+ def predict_logisticregression_6f(sentence): return _predict_6f(sentence, "Logistic Regression")
56
+
57
  if __name__ == "__main__":
58
+ s = "I love this new phone!"
59
+ print("RF 2f:", predict_randomforest_2f(s))
60
+ print("RF 6f:", predict_randomforest_6f(s))
61
+ print("SVM 2f:", predict_svm_2f(s))
62
+ print("SVM 6f:", predict_svm_6f(s))
63
+ print("LogReg 2f:", predict_logisticregression_2f(s))
64
+ print("LogReg 6f:", predict_logisticregression_6f(s))
training_model.py CHANGED
@@ -1,25 +1,21 @@
1
  # file: train_demo_models.py
2
  from __future__ import annotations
3
-
4
  import pickle
5
  import numpy as np
6
  from typing import Dict, Tuple, List
7
-
8
  import nltk
9
  from nltk.corpus import twitter_samples, stopwords
10
-
11
  from sklearn.ensemble import RandomForestClassifier
12
  from xgboost import XGBClassifier
13
  from lightgbm import LGBMClassifier
14
  from sklearn.svm import SVC
15
  from sklearn.tree import DecisionTreeClassifier
16
  from sklearn.naive_bayes import GaussianNB
17
-
18
  from sklearn.metrics import accuracy_score, log_loss
 
19
 
20
- from feature_extract import build_freqs, extract_features_2, extract_features_6
21
-
22
- # -------------------- NLTK setup --------------------
23
  def _ensure_nltk():
24
  try:
25
  twitter_samples.fileids()
@@ -30,7 +26,6 @@ def _ensure_nltk():
30
  except LookupError:
31
  nltk.download("stopwords", quiet=True)
32
 
33
- # -------------------- Data prep --------------------
34
  def load_twitter_data() -> Tuple[List[str], np.ndarray]:
35
  pos = twitter_samples.strings("positive_tweets.json")
36
  neg = twitter_samples.strings("negative_tweets.json")
@@ -38,97 +33,76 @@ def load_twitter_data() -> Tuple[List[str], np.ndarray]:
38
  y = np.array([1] * len(pos) + [0] * len(neg))
39
  return tweets, y
40
 
41
- def vectorize(tweets: List[str],
42
- freqs: Dict[Tuple[str, float], float],
43
- mode: str = "2f") -> np.ndarray:
44
- """mode: '2f' -> extract_features_2, '6f' -> extract_features_6"""
45
  feat_fn = extract_features_2 if mode == "2f" else extract_features_6
46
  rows = [feat_fn(t, freqs) for t in tweets]
47
  return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
48
 
49
- # -------------------- Models --------------------
50
- def make_models() -> Dict[str, object]:
51
- return {
52
- "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
53
- "XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"),
54
- "LightGBM": LGBMClassifier(random_state=42),
55
- "SVM": SVC(kernel="linear", probability=True, random_state=42),
56
- "Decision Tree": DecisionTreeClassifier(random_state=42),
57
- "Naive Bayes": GaussianNB(),
58
- }
59
-
60
- # -------------------- Train --------------------
61
- def train_models(X: np.ndarray, y: np.ndarray) -> Dict[str, object]:
62
- models = make_models()
63
- trained = {}
64
- print("Đang train các hình:")
65
- for name, clf in models.items():
66
- clf.fit(X, y.ravel())
67
- trained[name] = clf
68
-
69
- # --- ghi log sau train ---
70
- y_pred = clf.predict(X)
71
- acc = accuracy_score(y, y_pred)
72
- # log_loss cần probability
73
- try:
74
- y_proba = clf.predict_proba(X)
75
- loss = log_loss(y, y_proba)
76
- except Exception:
77
- loss = None
78
-
79
- if loss is not None:
80
- print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}")
81
- else:
82
- print(f"[{name}] Accuracy: {acc:.4f} | (không có predict_proba để tính log_loss)")
83
- print("=" * 60)
84
  return trained
85
 
86
- def train_all_versions(save_path: str = "demo_models.pkl"):
87
- """
88
- Train và lưu mô hình + freqs ra file pickle.
89
- Trả về:
90
- {
91
- 'freqs': freqs,
92
- '2f': {model_name: trained_model, ...},
93
- '6f': {model_name: trained_model, ...}
94
- }
95
- """
96
  _ensure_nltk()
97
  tweets, y = load_twitter_data()
98
- freqs = build_freqs(tweets, y.reshape(-1, 1))
99
-
100
- # trích features
 
 
 
 
 
 
101
  X2 = vectorize(tweets, freqs, mode="2f")
102
  X6 = vectorize(tweets, freqs, mode="6f")
103
-
104
- print("\n===== Train với 2-feature =====")
105
- models_2f = train_models(X2, y)
106
-
107
- print("\n===== Train với 6-feature =====")
108
- models_6f = train_models(X6, y)
109
-
110
- data_to_save = {
111
- "freqs": freqs,
112
- "2f": models_2f,
113
- "6f": models_6f,
114
- }
115
-
116
- # lưu file pickle
117
  with open(save_path, "wb") as f:
118
  pickle.dump(data_to_save, f)
119
-
120
- print(f"\nĐã train và lưu mô hình + freqs vào file: {save_path}")
121
  return data_to_save
122
 
123
- # -------------------- Load --------------------
124
  def load_demo_models(save_path: str = "demo_models.pkl"):
125
- """Load lại mô hình + freqs từ file pickle."""
126
  with open(save_path, "rb") as f:
127
  data = pickle.load(f)
128
  return data
129
 
130
- # -------------------- CLI --------------------
131
  if __name__ == "__main__":
132
- models = train_all_versions() # train & save
133
  print("Các mô hình 2f:", list(models["2f"].keys()))
134
  print("Các mô hình 6f:", list(models["6f"].keys()))
 
1
  # file: train_demo_models.py
2
  from __future__ import annotations
3
+ import os
4
  import pickle
5
  import numpy as np
6
  from typing import Dict, Tuple, List
 
7
  import nltk
8
  from nltk.corpus import twitter_samples, stopwords
 
9
  from sklearn.ensemble import RandomForestClassifier
10
  from xgboost import XGBClassifier
11
  from lightgbm import LGBMClassifier
12
  from sklearn.svm import SVC
13
  from sklearn.tree import DecisionTreeClassifier
14
  from sklearn.naive_bayes import GaussianNB
15
+ from sklearn.linear_model import LogisticRegression
16
  from sklearn.metrics import accuracy_score, log_loss
17
+ from feature_extract import build_freqs, extract_features_2, extract_features_6
18
 
 
 
 
19
  def _ensure_nltk():
20
  try:
21
  twitter_samples.fileids()
 
26
  except LookupError:
27
  nltk.download("stopwords", quiet=True)
28
 
 
29
  def load_twitter_data() -> Tuple[List[str], np.ndarray]:
30
  pos = twitter_samples.strings("positive_tweets.json")
31
  neg = twitter_samples.strings("negative_tweets.json")
 
33
  y = np.array([1] * len(pos) + [0] * len(neg))
34
  return tweets, y
35
 
36
+ def vectorize(tweets: List[str], freqs: Dict[Tuple[str, float], float], mode: str = "2f") -> np.ndarray:
 
 
 
37
  feat_fn = extract_features_2 if mode == "2f" else extract_features_6
38
  rows = [feat_fn(t, freqs) for t in tweets]
39
  return np.vstack(rows) if rows else np.zeros((0, 2 if mode == "2f" else 6))
40
 
41
+ ALL_MODEL_SPECS: Dict[str, object] = {
42
+ "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
43
+ "XGBoost": XGBClassifier(use_label_encoder=False, eval_metric="logloss"),
44
+ "LightGBM": LGBMClassifier(random_state=42),
45
+ "SVM": SVC(kernel="linear", probability=True, random_state=42),
46
+ "Decision Tree": DecisionTreeClassifier(random_state=42),
47
+ "Naive Bayes": GaussianNB(),
48
+ "Logistic Regression": LogisticRegression(solver="liblinear", random_state=42),
49
+ }
50
+
51
+ def make_models(include: List[str] | None = None) -> Dict[str, object]:
52
+ if include is None:
53
+ return {k: v for k, v in ALL_MODEL_SPECS.items()}
54
+ return {k: ALL_MODEL_SPECS[k] for k in include}
55
+
56
+ def _fit_and_log(name: str, clf, X: np.ndarray, y: np.ndarray):
57
+ clf.fit(X, y.ravel())
58
+ y_pred = clf.predict(X)
59
+ acc = accuracy_score(y, y_pred)
60
+ try:
61
+ y_proba = clf.predict_proba(X)
62
+ loss = log_loss(y, y_proba)
63
+ print(f"[{name}] Accuracy: {acc:.4f} | LogLoss: {loss:.4f}")
64
+ except Exception:
65
+ print(f"[{name}] Accuracy: {acc:.4f} | (no predict_proba)")
66
+ return clf
67
+
68
+ def train_models(X: np.ndarray, y: np.ndarray, include: List[str] | None = None) -> Dict[str, object]:
69
+ specs = make_models(include)
70
+ trained: Dict[str, object] = {}
71
+ for name, clf in specs.items():
72
+ trained[name] = _fit_and_log(name, clf, X, y)
 
 
 
73
  return trained
74
 
75
+ def ensure_logreg_only(save_path: str = "demo_models.pkl"):
 
 
 
 
 
 
 
 
 
76
  _ensure_nltk()
77
  tweets, y = load_twitter_data()
78
+ if os.path.exists(save_path):
79
+ with open(save_path, "rb") as f:
80
+ data = pickle.load(f)
81
+ freqs = data.get("freqs")
82
+ models_2f: Dict[str, object] = data.get("2f", {})
83
+ models_6f: Dict[str, object] = data.get("6f", {})
84
+ else:
85
+ freqs = build_freqs(tweets, y.reshape(-1, 1))
86
+ models_2f, models_6f = {}, {}
87
  X2 = vectorize(tweets, freqs, mode="2f")
88
  X6 = vectorize(tweets, freqs, mode="6f")
89
+ if "Logistic Regression" not in models_2f:
90
+ new_models_2f = train_models(X2, y, include=["Logistic Regression"])
91
+ models_2f.update(new_models_2f)
92
+ if "Logistic Regression" not in models_6f:
93
+ new_models_6f = train_models(X6, y, include=["Logistic Regression"])
94
+ models_6f.update(new_models_6f)
95
+ data_to_save = {"freqs": freqs, "2f": models_2f, "6f": models_6f}
 
 
 
 
 
 
 
96
  with open(save_path, "wb") as f:
97
  pickle.dump(data_to_save, f)
 
 
98
  return data_to_save
99
 
 
100
  def load_demo_models(save_path: str = "demo_models.pkl"):
 
101
  with open(save_path, "rb") as f:
102
  data = pickle.load(f)
103
  return data
104
 
 
105
  if __name__ == "__main__":
106
+ models = ensure_logreg_only()
107
  print("Các mô hình 2f:", list(models["2f"].keys()))
108
  print("Các mô hình 6f:", list(models["6f"].keys()))