bbc-document-classifier / src /models /train_compare_models.py
pearlll's picture
Deploy document classifier app
492754f
Raw
History Blame Contribute Delete
2.16 kB
import os
import mlflow
import mlflow.sklearn
import pandas as pd
import joblib
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
train_df = pd.read_csv("data/splits/train.csv")
val_df = pd.read_csv("data/splits/val.csv")
X_train = train_df["clean_text"]
y_train = train_df["label_text"]
X_val = val_df["clean_text"]
y_val = val_df["label_text"]
os.makedirs("models", exist_ok=True)
models = {
"logistic_regression": LogisticRegression(max_iter=1000),
"linear_svm": LinearSVC(),
"random_forest": RandomForestClassifier(n_estimators=100, random_state=42),
}
mlflow.set_experiment("bbc-document-classification")
best_model = None
best_model_name = None
best_f1 = 0
for model_name, classifier in models.items():
with mlflow.start_run(run_name=model_name):
pipeline = Pipeline([
("tfidf", TfidfVectorizer(max_features=5000)),
("classifier", classifier)
])
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_val)
accuracy = accuracy_score(y_val, y_pred)
f1 = f1_score(y_val, y_pred, average="weighted")
print("\n==============================")
print(f"Model: {model_name}")
print("Accuracy:", accuracy)
print("F1 Score:", f1)
print(classification_report(y_val, y_pred))
mlflow.log_param("model_name", model_name)
mlflow.log_param("vectorizer", "TF-IDF")
mlflow.log_param("max_features", 5000)
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
mlflow.sklearn.log_model(pipeline, model_name)
if f1 > best_f1:
best_f1 = f1
best_model = pipeline
best_model_name = model_name
joblib.dump(best_model, "models/best_model.pkl")
print("\nBest model:", best_model_name)
print("Best F1:", best_f1)
print("Saved to models/best_model.pkl")