""" traditional_model.py ──────────────────── Approach A: TF-IDF vectorisation + scikit-learn classifiers. Provides two classifier options: 'lr' → Logistic Regression (supports probability scores) 'svm' → Linear SVM (slightly faster, no probability output) """ import logging import os import time from typing import Dict, Literal, Tuple import joblib import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import seaborn as sns from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, ) from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC from config import CFG logger = logging.getLogger(__name__) # ── Pipeline factory ────────────────────────────────────────────────────────── def build_pipeline(model_type: Literal["lr", "svm"] = "lr") -> Pipeline: """ Build a TF-IDF → classifier sklearn Pipeline. TF-IDF settings: - max_features=60_000 : vocabulary cap; covers ~99 % of AG News tokens - ngram_range=(1, 2) : unigrams + bigrams capture short phrases - sublinear_tf=True : apply log(TF) to dampen very frequent terms - min_df=2 : discard hapax legomena (appear only once) """ tfidf = TfidfVectorizer( max_features=60_000, ngram_range=(1, 2), sublinear_tf=True, min_df=2, strip_accents="unicode", analyzer="word", token_pattern=r"\w{1,}", ) if model_type == "lr": clf = LogisticRegression( C=5.0, max_iter=1_000, solver="saga", n_jobs=-1, random_state=CFG.seed, ) elif model_type == "svm": clf = LinearSVC( C=1.0, max_iter=2_000, random_state=CFG.seed, ) else: raise ValueError( f"Unknown model_type '{model_type}'. Valid choices: 'lr', 'svm'." ) pipeline = Pipeline([("tfidf", tfidf), (model_type, clf)]) logger.info(f"Pipeline: TF-IDF -> {clf.__class__.__name__}") return pipeline # ── Training ────────────────────────────────────────────────────────────────── def train( X_train, y_train, X_val, y_val, model_type: str = "lr", ) -> Tuple[Pipeline, float]: """ Fit the pipeline and report validation accuracy. Returns ------- (fitted_pipeline, validation_accuracy) """ pipeline = build_pipeline(model_type) logger.info(f"Training {model_type.upper()} on {len(X_train):,} samples ...") t0 = time.perf_counter() pipeline.fit(X_train, y_train) elapsed = time.perf_counter() - t0 logger.info(f"Training complete in {elapsed:.1f}s") val_preds = pipeline.predict(X_val) val_acc = accuracy_score(y_val, val_preds) logger.info(f"Validation accuracy: {val_acc * 100:.2f}%") return pipeline, val_acc # ── Evaluation ──────────────────────────────────────────────────────────────── def evaluate( pipeline: Pipeline, X_test, y_test, save_dir: str = None, ) -> Dict: """ Run the pipeline on the test set, print a full report and save the confusion matrix. Returns ------- dict with keys: accuracy, report, confusion_matrix """ preds = pipeline.predict(X_test) acc = accuracy_score(y_test, preds) cm = confusion_matrix(y_test, preds) report = classification_report( y_test, preds, target_names=CFG.label_names, digits=4, ) print("\n" + "=" * 60) print(" TRADITIONAL MODEL -- TEST SET RESULTS") print("=" * 60) print(f" Accuracy : {acc * 100:.2f}%\n") print(report) _plot_confusion_matrix( cm, title="Traditional Model -- Confusion Matrix", save_dir=save_dir, ) return {"accuracy": acc, "report": report, "confusion_matrix": cm} # ── Persistence ─────────────────────────────────────────────────────────────── def save_model(pipeline: Pipeline, name: str = "lr") -> str: """Serialise the pipeline with joblib.""" path = os.path.join(CFG.models_dir, f"traditional_{name}.joblib") joblib.dump(pipeline, path) logger.info(f"Model saved -> {path}") return path def load_model(name: str = "lr") -> Pipeline: """Deserialise a saved pipeline.""" path = os.path.join(CFG.models_dir, f"traditional_{name}.joblib") if not os.path.exists(path): raise FileNotFoundError( f"No saved model at '{path}'. " f"Run: python train_traditional.py --model {name}" ) pipeline = joblib.load(path) logger.info(f"Model loaded <- {path}") return pipeline # ── Internal helpers ────────────────────────────────────────────────────────── def _plot_confusion_matrix( cm: np.ndarray, title: str, save_dir: str = None, ) -> None: fig, ax = plt.subplots(figsize=(7, 6)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=CFG.label_names, yticklabels=CFG.label_names, linewidths=0.5, ax=ax, ) ax.set_xlabel("Predicted Label", fontsize=11) ax.set_ylabel("True Label", fontsize=11) ax.set_title(title, fontsize=13, fontweight="bold") plt.tight_layout() if save_dir: os.makedirs(save_dir, exist_ok=True) fig_path = os.path.join(save_dir, "confusion_matrix.png") plt.savefig(fig_path, dpi=150) logger.info(f"Confusion matrix -> {fig_path}") plt.show() plt.close(fig)