nexa-classify-api / traditional_model.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
6.34 kB
"""
traditional_model.py
────────────────────
Approach A: TF-IDF vectorisation + scikit-learn classifiers.
Provides two classifier options:
'lr' β†’ Logistic Regression (supports probability scores)
'svm' β†’ Linear SVM (slightly faster, no probability output)
"""
import logging
import os
import time
from typing import Dict, Literal, Tuple
import joblib
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
)
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from config import CFG
logger = logging.getLogger(__name__)
# ── Pipeline factory ──────────────────────────────────────────────────────────
def build_pipeline(model_type: Literal["lr", "svm"] = "lr") -> Pipeline:
"""
Build a TF-IDF β†’ classifier sklearn Pipeline.
TF-IDF settings:
- max_features=60_000 : vocabulary cap; covers ~99 % of AG News tokens
- ngram_range=(1, 2) : unigrams + bigrams capture short phrases
- sublinear_tf=True : apply log(TF) to dampen very frequent terms
- min_df=2 : discard hapax legomena (appear only once)
"""
tfidf = TfidfVectorizer(
max_features=60_000,
ngram_range=(1, 2),
sublinear_tf=True,
min_df=2,
strip_accents="unicode",
analyzer="word",
token_pattern=r"\w{1,}",
)
if model_type == "lr":
clf = LogisticRegression(
C=5.0,
max_iter=1_000,
solver="saga",
n_jobs=-1,
random_state=CFG.seed,
)
elif model_type == "svm":
clf = LinearSVC(
C=1.0,
max_iter=2_000,
random_state=CFG.seed,
)
else:
raise ValueError(
f"Unknown model_type '{model_type}'. Valid choices: 'lr', 'svm'."
)
pipeline = Pipeline([("tfidf", tfidf), (model_type, clf)])
logger.info(f"Pipeline: TF-IDF -> {clf.__class__.__name__}")
return pipeline
# ── Training ──────────────────────────────────────────────────────────────────
def train(
X_train, y_train,
X_val, y_val,
model_type: str = "lr",
) -> Tuple[Pipeline, float]:
"""
Fit the pipeline and report validation accuracy.
Returns
-------
(fitted_pipeline, validation_accuracy)
"""
pipeline = build_pipeline(model_type)
logger.info(f"Training {model_type.upper()} on {len(X_train):,} samples ...")
t0 = time.perf_counter()
pipeline.fit(X_train, y_train)
elapsed = time.perf_counter() - t0
logger.info(f"Training complete in {elapsed:.1f}s")
val_preds = pipeline.predict(X_val)
val_acc = accuracy_score(y_val, val_preds)
logger.info(f"Validation accuracy: {val_acc * 100:.2f}%")
return pipeline, val_acc
# ── Evaluation ────────────────────────────────────────────────────────────────
def evaluate(
pipeline: Pipeline,
X_test,
y_test,
save_dir: str = None,
) -> Dict:
"""
Run the pipeline on the test set, print a full report and save the
confusion matrix.
Returns
-------
dict with keys: accuracy, report, confusion_matrix
"""
preds = pipeline.predict(X_test)
acc = accuracy_score(y_test, preds)
cm = confusion_matrix(y_test, preds)
report = classification_report(
y_test, preds,
target_names=CFG.label_names,
digits=4,
)
print("\n" + "=" * 60)
print(" TRADITIONAL MODEL -- TEST SET RESULTS")
print("=" * 60)
print(f" Accuracy : {acc * 100:.2f}%\n")
print(report)
_plot_confusion_matrix(
cm,
title="Traditional Model -- Confusion Matrix",
save_dir=save_dir,
)
return {"accuracy": acc, "report": report, "confusion_matrix": cm}
# ── Persistence ───────────────────────────────────────────────────────────────
def save_model(pipeline: Pipeline, name: str = "lr") -> str:
"""Serialise the pipeline with joblib."""
path = os.path.join(CFG.models_dir, f"traditional_{name}.joblib")
joblib.dump(pipeline, path)
logger.info(f"Model saved -> {path}")
return path
def load_model(name: str = "lr") -> Pipeline:
"""Deserialise a saved pipeline."""
path = os.path.join(CFG.models_dir, f"traditional_{name}.joblib")
if not os.path.exists(path):
raise FileNotFoundError(
f"No saved model at '{path}'. "
f"Run: python train_traditional.py --model {name}"
)
pipeline = joblib.load(path)
logger.info(f"Model loaded <- {path}")
return pipeline
# ── Internal helpers ──────────────────────────────────────────────────────────
def _plot_confusion_matrix(
cm: np.ndarray,
title: str,
save_dir: str = None,
) -> None:
fig, ax = plt.subplots(figsize=(7, 6))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=CFG.label_names,
yticklabels=CFG.label_names,
linewidths=0.5,
ax=ax,
)
ax.set_xlabel("Predicted Label", fontsize=11)
ax.set_ylabel("True Label", fontsize=11)
ax.set_title(title, fontsize=13, fontweight="bold")
plt.tight_layout()
if save_dir:
os.makedirs(save_dir, exist_ok=True)
fig_path = os.path.join(save_dir, "confusion_matrix.png")
plt.savefig(fig_path, dpi=150)
logger.info(f"Confusion matrix -> {fig_path}")
plt.show()
plt.close(fig)