Spaces:
Running
Running
| """ | |
| train_traditional.py | |
| ββββββββββββββββββββ | |
| Entry-point: trains TF-IDF + Logistic Regression or Linear SVM. | |
| Usage | |
| βββββ | |
| python train_traditional.py # Logistic Regression (default) | |
| python train_traditional.py --model svm # Linear SVM | |
| python train_traditional.py --full # Use all 120 K training samples | |
| python train_traditional.py --model svm --full | |
| """ | |
| import argparse | |
| import logging | |
| from config import CFG | |
| from data_loader import load_ag_news, get_raw_splits | |
| import traditional_model as tm | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Train traditional ML document classifier (TF-IDF + LR or SVM)" | |
| ) | |
| p.add_argument( | |
| "--model", | |
| default="lr", | |
| choices=["lr", "svm"], | |
| help="'lr' = Logistic Regression | 'svm' = Linear SVM (default: lr)", | |
| ) | |
| p.add_argument( | |
| "--full", | |
| action="store_true", | |
| help="Disable sample cap; use all 120 K training examples (slower)", | |
| ) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| max_train = None if args.full else CFG.max_train_samples | |
| max_eval = None if args.full else CFG.max_eval_samples | |
| sample_label = "Full dataset" if (args.full or max_train is None) else f"{max_train:,} samples (subset)" | |
| print(f"\n{'-' * 60}") | |
| print(f" Document Classifier -- Traditional ML Training") | |
| print(f" Model : {args.model.upper()}") | |
| print(f" Samples : {sample_label}") | |
| print(f"{'-' * 60}\n") | |
| # Load data | |
| dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None) | |
| X_train, y_train, X_val, y_val, X_test, y_test = get_raw_splits(dataset) | |
| # Train | |
| pipeline, val_acc = tm.train( | |
| X_train, y_train, | |
| X_val, y_val, | |
| model_type=args.model, | |
| ) | |
| # Evaluate on test set | |
| save_dir = f"outputs/{args.model}" | |
| results = tm.evaluate(pipeline, X_test, y_test, save_dir=save_dir) | |
| # Save the model | |
| tm.save_model(pipeline, name=args.model) | |
| print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%") | |
| print(f" Model saved to : saved_models/traditional_{args.model}.joblib\n") | |
| if __name__ == "__main__": | |
| main() | |