""" train_transformer.py ──────────────────── Entry-point: fine-tunes DistilBERT on AG News. Usage ───── python train_transformer.py # 20 K subset ≈ 1.5–2.5 hrs on i3 python train_transformer.py --full # 120 K full ≈ 8–10 hrs on i3 Tip: Start with the subset to verify everything works, then run --full overnight. """ import argparse import logging from config import CFG from data_loader import load_ag_news, get_tokenizer, tokenise_dataset import transformer_model as trm logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Fine-tune DistilBERT document classifier" ) p.add_argument( "--full", action="store_true", help="Use the full 120 K training set instead of the 20 K subset", ) return p.parse_args() def main() -> None: args = parse_args() max_train = None if args.full else CFG.max_train_samples max_eval = None if args.full else CFG.max_eval_samples sample_label = "Full dataset" if args.full else f"{max_train:,} samples (subset)" print(f"\n{'─' * 60}") print(f" Document Classifier — DistilBERT Fine-Tuning") print(f" Base model : {CFG.model_checkpoint}") print(f" Samples : {sample_label}") print(f"{'─' * 60}\n") # Load and tokenise dataset dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None) tokenizer = get_tokenizer() tokenised = tokenise_dataset(dataset, tokenizer) # Train trainer = trm.train(tokenised, tokenizer) # Evaluate on test set save_dir = "outputs/transformer" results = trm.evaluate(trainer, tokenised, save_dir=save_dir) # Save best checkpoint + tokeniser trm.save_model(trainer, tokenizer) print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%") print(f" Model saved to : saved_models/transformer/\n") if __name__ == "__main__": main()