nexa-classify-api / train_multi.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
3.78 kB
"""
train_multi.py
──────────────
Train multiple transformer architectures back-to-back.
Ideal for an overnight run on Mac M4 β€” each model saves independently.
Results are compared at the end in a clean table.
Usage
─────
# Train DistilBERT + BERT + RoBERTa (default)
python train_multi.py
# Train only DistilBERT and RoBERTa
python train_multi.py --models distilbert-base-uncased roberta-base
# Single model
python train_multi.py --models roberta-base
"""
import argparse
import logging
import time
from config import CFG
from data_loader import load_ag_news, get_tokenizer, tokenise_dataset
import transformer_model as trm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
DEFAULT_MODELS = [
"distilbert-base-uncased",
"bert-base-uncased",
"roberta-base",
]
def train_single_architecture(checkpoint: str) -> dict:
"""
End-to-end train and evaluate one transformer checkpoint.
Mutates CFG.model_checkpoint for the duration of the run.
"""
print(f"\n{'━' * 60}")
print(f" Model: {checkpoint}")
print(f"{'━' * 60}")
CFG.model_checkpoint = checkpoint # Override for this run
dataset = load_ag_news() # Full 120K (no cap in updated config)
tokenizer = get_tokenizer()
tokenised = tokenise_dataset(dataset, tokenizer)
t0 = time.perf_counter()
trainer = trm.train(tokenised, tokenizer, checkpoint=checkpoint)
elapsed = time.perf_counter() - t0
save_dir = f"outputs/{trm._checkpoint_to_dir(checkpoint)}"
results = trm.evaluate(trainer, tokenised,
checkpoint=checkpoint, save_dir=save_dir)
trm.save_model(trainer, tokenizer, checkpoint=checkpoint)
h, rem = divmod(int(elapsed), 3600)
m, s = divmod(rem, 60)
return {
"checkpoint": checkpoint,
"accuracy": results["accuracy"],
"f1_macro": results["metrics"].get("test_f1_macro", 0.0),
"time": f"{h}h {m}m {s}s",
}
def main() -> None:
parser = argparse.ArgumentParser(
description="Train multiple transformer architectures sequentially"
)
parser.add_argument(
"--models", nargs="+", default=DEFAULT_MODELS,
help="Space-separated list of HuggingFace checkpoint names",
)
args = parser.parse_args()
device_label = "MPS (Apple Metal)" if CFG.device == "mps" else CFG.device.upper()
print(f"\n Multi-Architecture Training Session")
print(f" Device : {device_label}")
print(f" Models : {', '.join(args.models)}")
print(f" Dataset : AG News β€” full 120,000 training examples\n")
all_results = []
session_t0 = time.perf_counter()
for checkpoint in args.models:
result = train_single_architecture(checkpoint)
all_results.append(result)
print(f"\n βœ“ {checkpoint} β€” acc={result['accuracy']*100:.2f}% time={result['time']}\n")
session_elapsed = time.perf_counter() - session_t0
h, rem = divmod(int(session_elapsed), 3600)
m, s = divmod(rem, 60)
print(f"\n{'═' * 66}")
print(f" {'Architecture':<28} {'Accuracy':>10} {'F1-Macro':>10} {'Time':>10}")
print(f"{'─' * 66}")
for r in sorted(all_results, key=lambda x: x["accuracy"], reverse=True):
name = r["checkpoint"].split("/")[-1]
star = " β—€ best" if r == max(all_results, key=lambda x: x["accuracy"]) else ""
print(
f" {name:<28} {r['accuracy']*100:>9.2f}% "
f"{r['f1_macro']:>10.4f} {r['time']:>10}{star}"
)
print(f"{'═' * 66}")
print(f"\n Total session time: {h}h {m}m {s}s\n")
if __name__ == "__main__":
main()