Spaces:
Running
Running
| """ | |
| train_transformer.py | |
| ββββββββββββββββββββ | |
| Entry-point: fine-tunes DistilBERT on AG News. | |
| Usage | |
| βββββ | |
| python train_transformer.py # 20 K subset β 1.5β2.5 hrs on i3 | |
| python train_transformer.py --full # 120 K full β 8β10 hrs on i3 | |
| Tip: Start with the subset to verify everything works, then run --full overnight. | |
| """ | |
| import argparse | |
| import logging | |
| from config import CFG | |
| from data_loader import load_ag_news, get_tokenizer, tokenise_dataset | |
| import transformer_model as trm | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Fine-tune DistilBERT document classifier" | |
| ) | |
| p.add_argument( | |
| "--full", | |
| action="store_true", | |
| help="Use the full 120 K training set instead of the 20 K subset", | |
| ) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| max_train = None if args.full else CFG.max_train_samples | |
| max_eval = None if args.full else CFG.max_eval_samples | |
| sample_label = "Full dataset" if args.full else f"{max_train:,} samples (subset)" | |
| print(f"\n{'β' * 60}") | |
| print(f" Document Classifier β DistilBERT Fine-Tuning") | |
| print(f" Base model : {CFG.model_checkpoint}") | |
| print(f" Samples : {sample_label}") | |
| print(f"{'β' * 60}\n") | |
| # Load and tokenise dataset | |
| dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None) | |
| tokenizer = get_tokenizer() | |
| tokenised = tokenise_dataset(dataset, tokenizer) | |
| # Train | |
| trainer = trm.train(tokenised, tokenizer) | |
| # Evaluate on test set | |
| save_dir = "outputs/transformer" | |
| results = trm.evaluate(trainer, tokenised, save_dir=save_dir) | |
| # Save best checkpoint + tokeniser | |
| trm.save_model(trainer, tokenizer) | |
| print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%") | |
| print(f" Model saved to : saved_models/transformer/\n") | |
| if __name__ == "__main__": | |
| main() | |