""" train_traditional.py ──────────────────── Entry-point: trains TF-IDF + Logistic Regression or Linear SVM. Usage ───── python train_traditional.py # Logistic Regression (default) python train_traditional.py --model svm # Linear SVM python train_traditional.py --full # Use all 120 K training samples python train_traditional.py --model svm --full """ import argparse import logging from config import CFG from data_loader import load_ag_news, get_raw_splits import traditional_model as tm logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Train traditional ML document classifier (TF-IDF + LR or SVM)" ) p.add_argument( "--model", default="lr", choices=["lr", "svm"], help="'lr' = Logistic Regression | 'svm' = Linear SVM (default: lr)", ) p.add_argument( "--full", action="store_true", help="Disable sample cap; use all 120 K training examples (slower)", ) return p.parse_args() def main() -> None: args = parse_args() max_train = None if args.full else CFG.max_train_samples max_eval = None if args.full else CFG.max_eval_samples sample_label = "Full dataset" if (args.full or max_train is None) else f"{max_train:,} samples (subset)" print(f"\n{'-' * 60}") print(f" Document Classifier -- Traditional ML Training") print(f" Model : {args.model.upper()}") print(f" Samples : {sample_label}") print(f"{'-' * 60}\n") # Load data dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None) X_train, y_train, X_val, y_val, X_test, y_test = get_raw_splits(dataset) # Train pipeline, val_acc = tm.train( X_train, y_train, X_val, y_val, model_type=args.model, ) # Evaluate on test set save_dir = f"outputs/{args.model}" results = tm.evaluate(pipeline, X_test, y_test, save_dir=save_dir) # Save the model tm.save_model(pipeline, name=args.model) print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%") print(f" Model saved to : saved_models/traditional_{args.model}.joblib\n") if __name__ == "__main__": main()