""" train_multi.py ────────────── Train multiple transformer architectures back-to-back. Ideal for an overnight run on Mac M4 — each model saves independently. Results are compared at the end in a clean table. Usage ───── # Train DistilBERT + BERT + RoBERTa (default) python train_multi.py # Train only DistilBERT and RoBERTa python train_multi.py --models distilbert-base-uncased roberta-base # Single model python train_multi.py --models roberta-base """ import argparse import logging import time from config import CFG from data_loader import load_ag_news, get_tokenizer, tokenise_dataset import transformer_model as trm logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) DEFAULT_MODELS = [ "distilbert-base-uncased", "bert-base-uncased", "roberta-base", ] def train_single_architecture(checkpoint: str) -> dict: """ End-to-end train and evaluate one transformer checkpoint. Mutates CFG.model_checkpoint for the duration of the run. """ print(f"\n{'━' * 60}") print(f" Model: {checkpoint}") print(f"{'━' * 60}") CFG.model_checkpoint = checkpoint # Override for this run dataset = load_ag_news() # Full 120K (no cap in updated config) tokenizer = get_tokenizer() tokenised = tokenise_dataset(dataset, tokenizer) t0 = time.perf_counter() trainer = trm.train(tokenised, tokenizer, checkpoint=checkpoint) elapsed = time.perf_counter() - t0 save_dir = f"outputs/{trm._checkpoint_to_dir(checkpoint)}" results = trm.evaluate(trainer, tokenised, checkpoint=checkpoint, save_dir=save_dir) trm.save_model(trainer, tokenizer, checkpoint=checkpoint) h, rem = divmod(int(elapsed), 3600) m, s = divmod(rem, 60) return { "checkpoint": checkpoint, "accuracy": results["accuracy"], "f1_macro": results["metrics"].get("test_f1_macro", 0.0), "time": f"{h}h {m}m {s}s", } def main() -> None: parser = argparse.ArgumentParser( description="Train multiple transformer architectures sequentially" ) parser.add_argument( "--models", nargs="+", default=DEFAULT_MODELS, help="Space-separated list of HuggingFace checkpoint names", ) args = parser.parse_args() device_label = "MPS (Apple Metal)" if CFG.device == "mps" else CFG.device.upper() print(f"\n Multi-Architecture Training Session") print(f" Device : {device_label}") print(f" Models : {', '.join(args.models)}") print(f" Dataset : AG News — full 120,000 training examples\n") all_results = [] session_t0 = time.perf_counter() for checkpoint in args.models: result = train_single_architecture(checkpoint) all_results.append(result) print(f"\n ✓ {checkpoint} — acc={result['accuracy']*100:.2f}% time={result['time']}\n") session_elapsed = time.perf_counter() - session_t0 h, rem = divmod(int(session_elapsed), 3600) m, s = divmod(rem, 60) print(f"\n{'═' * 66}") print(f" {'Architecture':<28} {'Accuracy':>10} {'F1-Macro':>10} {'Time':>10}") print(f"{'─' * 66}") for r in sorted(all_results, key=lambda x: x["accuracy"], reverse=True): name = r["checkpoint"].split("/")[-1] star = " ◀ best" if r == max(all_results, key=lambda x: x["accuracy"]) else "" print( f" {name:<28} {r['accuracy']*100:>9.2f}% " f"{r['f1_macro']:>10.4f} {r['time']:>10}{star}" ) print(f"{'═' * 66}") print(f"\n Total session time: {h}h {m}m {s}s\n") if __name__ == "__main__": main()