ml_exercise / inference_demo.py
trantuan1701's picture
damn
58e2d3b
import pickle
import numpy as np
from feature_extract import extract_features_2, extract_features_6
with open("demo_models.pkl", "rb") as f:
data = pickle.load(f)
freqs = data["freqs"]
models_2f = data.get("2f", {})
models_6f = data.get("6f", {})
def _predict_with_model(sentence: str, model) -> int:
n = getattr(model, "n_features_in_", None)
if n == 6:
x = extract_features_6(sentence, freqs)
else:
x = extract_features_2(sentence, freqs)
return int(model.predict(x)[0])
def _smart_pick_model(model_name: str, prefer: str = "2f"):
if prefer == "6f":
model = models_6f.get(model_name) or models_2f.get(model_name)
else:
model = models_2f.get(model_name) or models_6f.get(model_name)
if model is None:
raise KeyError(f"Model '{model_name}' not found in saved models.")
return model
# --- 2-feature API (sẽ tự xử lý nếu model thực tế là 6f) ---
def _predict_2f(sentence: str, model_name: str) -> int:
model = _smart_pick_model(model_name, prefer="2f")
return _predict_with_model(sentence, model)
# --- 6-feature API (sẽ tự xử lý nếu model thực tế là 2f) ---
def _predict_6f(sentence: str, model_name: str) -> int:
model = _smart_pick_model(model_name, prefer="6f")
return _predict_with_model(sentence, model)
# 2-feature
def predict_randomforest_2f(sentence): return _predict_2f(sentence, "Random Forest")
def predict_xgboost_2f(sentence): return _predict_2f(sentence, "XGBoost")
def predict_lightgbm_2f(sentence): return _predict_2f(sentence, "LightGBM")
def predict_svm_2f(sentence): return _predict_2f(sentence, "SVM")
def predict_decisiontree_2f(sentence): return _predict_2f(sentence, "Decision Tree")
def predict_naivebayes_2f(sentence): return _predict_2f(sentence, "Naive Bayes")
def predict_logisticregression_2f(sentence): return _predict_2f(sentence, "Logistic Regression")
# 6-feature
def predict_randomforest_6f(sentence): return _predict_6f(sentence, "Random Forest")
def predict_xgboost_6f(sentence): return _predict_6f(sentence, "XGBoost")
def predict_lightgbm_6f(sentence): return _predict_6f(sentence, "LightGBM")
def predict_svm_6f(sentence): return _predict_6f(sentence, "SVM")
def predict_decisiontree_6f(sentence): return _predict_6f(sentence, "Decision Tree")
def predict_naivebayes_6f(sentence): return _predict_6f(sentence, "Naive Bayes")
def predict_logisticregression_6f(sentence): return _predict_6f(sentence, "Logistic Regression")
if __name__ == "__main__":
s = "I love this new phone!"
print("RF 2f:", predict_randomforest_2f(s))
print("RF 6f:", predict_randomforest_6f(s))
print("SVM 2f:", predict_svm_2f(s))
print("SVM 6f:", predict_svm_6f(s))
print("LogReg 2f:", predict_logisticregression_2f(s))
print("LogReg 6f:", predict_logisticregression_6f(s))