File size: 3,781 Bytes
a229747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
train_multi.py
──────────────
Train multiple transformer architectures back-to-back.
Ideal for an overnight run on Mac M4 β€” each model saves independently.
Results are compared at the end in a clean table.

Usage
─────
    # Train DistilBERT + BERT + RoBERTa (default)
    python train_multi.py

    # Train only DistilBERT and RoBERTa
    python train_multi.py --models distilbert-base-uncased roberta-base

    # Single model
    python train_multi.py --models roberta-base
"""
import argparse
import logging
import time

from config import CFG
from data_loader import load_ag_news, get_tokenizer, tokenise_dataset
import transformer_model as trm

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)

DEFAULT_MODELS = [
    "distilbert-base-uncased",
    "bert-base-uncased",
    "roberta-base",
]


def train_single_architecture(checkpoint: str) -> dict:
    """
    End-to-end train and evaluate one transformer checkpoint.
    Mutates CFG.model_checkpoint for the duration of the run.
    """
    print(f"\n{'━' * 60}")
    print(f"  Model: {checkpoint}")
    print(f"{'━' * 60}")

    CFG.model_checkpoint = checkpoint   # Override for this run

    dataset   = load_ag_news()          # Full 120K (no cap in updated config)
    tokenizer = get_tokenizer()
    tokenised = tokenise_dataset(dataset, tokenizer)

    t0      = time.perf_counter()
    trainer = trm.train(tokenised, tokenizer, checkpoint=checkpoint)
    elapsed = time.perf_counter() - t0

    save_dir = f"outputs/{trm._checkpoint_to_dir(checkpoint)}"
    results  = trm.evaluate(trainer, tokenised,
                            checkpoint=checkpoint, save_dir=save_dir)
    trm.save_model(trainer, tokenizer, checkpoint=checkpoint)

    h, rem = divmod(int(elapsed), 3600)
    m, s   = divmod(rem, 60)

    return {
        "checkpoint": checkpoint,
        "accuracy":   results["accuracy"],
        "f1_macro":   results["metrics"].get("test_f1_macro", 0.0),
        "time":       f"{h}h {m}m {s}s",
    }


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Train multiple transformer architectures sequentially"
    )
    parser.add_argument(
        "--models", nargs="+", default=DEFAULT_MODELS,
        help="Space-separated list of HuggingFace checkpoint names",
    )
    args = parser.parse_args()

    device_label = "MPS (Apple Metal)" if CFG.device == "mps" else CFG.device.upper()
    print(f"\n  Multi-Architecture Training Session")
    print(f"  Device  : {device_label}")
    print(f"  Models  : {', '.join(args.models)}")
    print(f"  Dataset : AG News β€” full 120,000 training examples\n")

    all_results = []
    session_t0  = time.perf_counter()

    for checkpoint in args.models:
        result = train_single_architecture(checkpoint)
        all_results.append(result)
        print(f"\n  βœ“  {checkpoint}  β€”  acc={result['accuracy']*100:.2f}%  time={result['time']}\n")

    session_elapsed = time.perf_counter() - session_t0
    h, rem = divmod(int(session_elapsed), 3600)
    m, s   = divmod(rem, 60)

    print(f"\n{'═' * 66}")
    print(f"  {'Architecture':<28}  {'Accuracy':>10}  {'F1-Macro':>10}  {'Time':>10}")
    print(f"{'─' * 66}")
    for r in sorted(all_results, key=lambda x: x["accuracy"], reverse=True):
        name = r["checkpoint"].split("/")[-1]
        star = "  β—€ best" if r == max(all_results, key=lambda x: x["accuracy"]) else ""
        print(
            f"  {name:<28}  {r['accuracy']*100:>9.2f}%  "
            f"{r['f1_macro']:>10.4f}  {r['time']:>10}{star}"
        )
    print(f"{'═' * 66}")
    print(f"\n  Total session time: {h}h {m}m {s}s\n")


if __name__ == "__main__":
    main()