Spaces:
Sleeping
Sleeping
| import pickle | |
| import numpy as np | |
| from feature_extract import extract_features_2, extract_features_6 | |
| with open("demo_models.pkl", "rb") as f: | |
| data = pickle.load(f) | |
| freqs = data["freqs"] | |
| models_2f = data.get("2f", {}) | |
| models_6f = data.get("6f", {}) | |
| def _predict_with_model(sentence: str, model) -> int: | |
| n = getattr(model, "n_features_in_", None) | |
| if n == 6: | |
| x = extract_features_6(sentence, freqs) | |
| else: | |
| x = extract_features_2(sentence, freqs) | |
| return int(model.predict(x)[0]) | |
| def _smart_pick_model(model_name: str, prefer: str = "2f"): | |
| if prefer == "6f": | |
| model = models_6f.get(model_name) or models_2f.get(model_name) | |
| else: | |
| model = models_2f.get(model_name) or models_6f.get(model_name) | |
| if model is None: | |
| raise KeyError(f"Model '{model_name}' not found in saved models.") | |
| return model | |
| # --- 2-feature API (sẽ tự xử lý nếu model thực tế là 6f) --- | |
| def _predict_2f(sentence: str, model_name: str) -> int: | |
| model = _smart_pick_model(model_name, prefer="2f") | |
| return _predict_with_model(sentence, model) | |
| # --- 6-feature API (sẽ tự xử lý nếu model thực tế là 2f) --- | |
| def _predict_6f(sentence: str, model_name: str) -> int: | |
| model = _smart_pick_model(model_name, prefer="6f") | |
| return _predict_with_model(sentence, model) | |
| # 2-feature | |
| def predict_randomforest_2f(sentence): return _predict_2f(sentence, "Random Forest") | |
| def predict_xgboost_2f(sentence): return _predict_2f(sentence, "XGBoost") | |
| def predict_lightgbm_2f(sentence): return _predict_2f(sentence, "LightGBM") | |
| def predict_svm_2f(sentence): return _predict_2f(sentence, "SVM") | |
| def predict_decisiontree_2f(sentence): return _predict_2f(sentence, "Decision Tree") | |
| def predict_naivebayes_2f(sentence): return _predict_2f(sentence, "Naive Bayes") | |
| def predict_logisticregression_2f(sentence): return _predict_2f(sentence, "Logistic Regression") | |
| # 6-feature | |
| def predict_randomforest_6f(sentence): return _predict_6f(sentence, "Random Forest") | |
| def predict_xgboost_6f(sentence): return _predict_6f(sentence, "XGBoost") | |
| def predict_lightgbm_6f(sentence): return _predict_6f(sentence, "LightGBM") | |
| def predict_svm_6f(sentence): return _predict_6f(sentence, "SVM") | |
| def predict_decisiontree_6f(sentence): return _predict_6f(sentence, "Decision Tree") | |
| def predict_naivebayes_6f(sentence): return _predict_6f(sentence, "Naive Bayes") | |
| def predict_logisticregression_6f(sentence): return _predict_6f(sentence, "Logistic Regression") | |
| if __name__ == "__main__": | |
| s = "I love this new phone!" | |
| print("RF 2f:", predict_randomforest_2f(s)) | |
| print("RF 6f:", predict_randomforest_6f(s)) | |
| print("SVM 2f:", predict_svm_2f(s)) | |
| print("SVM 6f:", predict_svm_6f(s)) | |
| print("LogReg 2f:", predict_logisticregression_2f(s)) | |
| print("LogReg 6f:", predict_logisticregression_6f(s)) | |