Spaces:
Running
Running
| """ | |
| train_multi.py | |
| ββββββββββββββ | |
| Train multiple transformer architectures back-to-back. | |
| Ideal for an overnight run on Mac M4 β each model saves independently. | |
| Results are compared at the end in a clean table. | |
| Usage | |
| βββββ | |
| # Train DistilBERT + BERT + RoBERTa (default) | |
| python train_multi.py | |
| # Train only DistilBERT and RoBERTa | |
| python train_multi.py --models distilbert-base-uncased roberta-base | |
| # Single model | |
| python train_multi.py --models roberta-base | |
| """ | |
| import argparse | |
| import logging | |
| import time | |
| from config import CFG | |
| from data_loader import load_ag_news, get_tokenizer, tokenise_dataset | |
| import transformer_model as trm | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| DEFAULT_MODELS = [ | |
| "distilbert-base-uncased", | |
| "bert-base-uncased", | |
| "roberta-base", | |
| ] | |
| def train_single_architecture(checkpoint: str) -> dict: | |
| """ | |
| End-to-end train and evaluate one transformer checkpoint. | |
| Mutates CFG.model_checkpoint for the duration of the run. | |
| """ | |
| print(f"\n{'β' * 60}") | |
| print(f" Model: {checkpoint}") | |
| print(f"{'β' * 60}") | |
| CFG.model_checkpoint = checkpoint # Override for this run | |
| dataset = load_ag_news() # Full 120K (no cap in updated config) | |
| tokenizer = get_tokenizer() | |
| tokenised = tokenise_dataset(dataset, tokenizer) | |
| t0 = time.perf_counter() | |
| trainer = trm.train(tokenised, tokenizer, checkpoint=checkpoint) | |
| elapsed = time.perf_counter() - t0 | |
| save_dir = f"outputs/{trm._checkpoint_to_dir(checkpoint)}" | |
| results = trm.evaluate(trainer, tokenised, | |
| checkpoint=checkpoint, save_dir=save_dir) | |
| trm.save_model(trainer, tokenizer, checkpoint=checkpoint) | |
| h, rem = divmod(int(elapsed), 3600) | |
| m, s = divmod(rem, 60) | |
| return { | |
| "checkpoint": checkpoint, | |
| "accuracy": results["accuracy"], | |
| "f1_macro": results["metrics"].get("test_f1_macro", 0.0), | |
| "time": f"{h}h {m}m {s}s", | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Train multiple transformer architectures sequentially" | |
| ) | |
| parser.add_argument( | |
| "--models", nargs="+", default=DEFAULT_MODELS, | |
| help="Space-separated list of HuggingFace checkpoint names", | |
| ) | |
| args = parser.parse_args() | |
| device_label = "MPS (Apple Metal)" if CFG.device == "mps" else CFG.device.upper() | |
| print(f"\n Multi-Architecture Training Session") | |
| print(f" Device : {device_label}") | |
| print(f" Models : {', '.join(args.models)}") | |
| print(f" Dataset : AG News β full 120,000 training examples\n") | |
| all_results = [] | |
| session_t0 = time.perf_counter() | |
| for checkpoint in args.models: | |
| result = train_single_architecture(checkpoint) | |
| all_results.append(result) | |
| print(f"\n β {checkpoint} β acc={result['accuracy']*100:.2f}% time={result['time']}\n") | |
| session_elapsed = time.perf_counter() - session_t0 | |
| h, rem = divmod(int(session_elapsed), 3600) | |
| m, s = divmod(rem, 60) | |
| print(f"\n{'β' * 66}") | |
| print(f" {'Architecture':<28} {'Accuracy':>10} {'F1-Macro':>10} {'Time':>10}") | |
| print(f"{'β' * 66}") | |
| for r in sorted(all_results, key=lambda x: x["accuracy"], reverse=True): | |
| name = r["checkpoint"].split("/")[-1] | |
| star = " β best" if r == max(all_results, key=lambda x: x["accuracy"]) else "" | |
| print( | |
| f" {name:<28} {r['accuracy']*100:>9.2f}% " | |
| f"{r['f1_macro']:>10.4f} {r['time']:>10}{star}" | |
| ) | |
| print(f"{'β' * 66}") | |
| print(f"\n Total session time: {h}h {m}m {s}s\n") | |
| if __name__ == "__main__": | |
| main() | |