nexa-classify-api / train_traditional.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
2.44 kB
"""
train_traditional.py
────────────────────
Entry-point: trains TF-IDF + Logistic Regression or Linear SVM.
Usage
─────
python train_traditional.py # Logistic Regression (default)
python train_traditional.py --model svm # Linear SVM
python train_traditional.py --full # Use all 120 K training samples
python train_traditional.py --model svm --full
"""
import argparse
import logging
from config import CFG
from data_loader import load_ag_news, get_raw_splits
import traditional_model as tm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Train traditional ML document classifier (TF-IDF + LR or SVM)"
)
p.add_argument(
"--model",
default="lr",
choices=["lr", "svm"],
help="'lr' = Logistic Regression | 'svm' = Linear SVM (default: lr)",
)
p.add_argument(
"--full",
action="store_true",
help="Disable sample cap; use all 120 K training examples (slower)",
)
return p.parse_args()
def main() -> None:
args = parse_args()
max_train = None if args.full else CFG.max_train_samples
max_eval = None if args.full else CFG.max_eval_samples
sample_label = "Full dataset" if (args.full or max_train is None) else f"{max_train:,} samples (subset)"
print(f"\n{'-' * 60}")
print(f" Document Classifier -- Traditional ML Training")
print(f" Model : {args.model.upper()}")
print(f" Samples : {sample_label}")
print(f"{'-' * 60}\n")
# Load data
dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None)
X_train, y_train, X_val, y_val, X_test, y_test = get_raw_splits(dataset)
# Train
pipeline, val_acc = tm.train(
X_train, y_train,
X_val, y_val,
model_type=args.model,
)
# Evaluate on test set
save_dir = f"outputs/{args.model}"
results = tm.evaluate(pipeline, X_test, y_test, save_dir=save_dir)
# Save the model
tm.save_model(pipeline, name=args.model)
print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%")
print(f" Model saved to : saved_models/traditional_{args.model}.joblib\n")
if __name__ == "__main__":
main()