nexa-classify-api / train_transformer.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
2.13 kB
"""
train_transformer.py
────────────────────
Entry-point: fine-tunes DistilBERT on AG News.
Usage
─────
python train_transformer.py # 20 K subset β‰ˆ 1.5–2.5 hrs on i3
python train_transformer.py --full # 120 K full β‰ˆ 8–10 hrs on i3
Tip: Start with the subset to verify everything works, then run --full overnight.
"""
import argparse
import logging
from config import CFG
from data_loader import load_ag_news, get_tokenizer, tokenise_dataset
import transformer_model as trm
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Fine-tune DistilBERT document classifier"
)
p.add_argument(
"--full",
action="store_true",
help="Use the full 120 K training set instead of the 20 K subset",
)
return p.parse_args()
def main() -> None:
args = parse_args()
max_train = None if args.full else CFG.max_train_samples
max_eval = None if args.full else CFG.max_eval_samples
sample_label = "Full dataset" if args.full else f"{max_train:,} samples (subset)"
print(f"\n{'─' * 60}")
print(f" Document Classifier β€” DistilBERT Fine-Tuning")
print(f" Base model : {CFG.model_checkpoint}")
print(f" Samples : {sample_label}")
print(f"{'─' * 60}\n")
# Load and tokenise dataset
dataset = load_ag_news(max_train=max_train, max_eval=max_eval, max_test=None)
tokenizer = get_tokenizer()
tokenised = tokenise_dataset(dataset, tokenizer)
# Train
trainer = trm.train(tokenised, tokenizer)
# Evaluate on test set
save_dir = "outputs/transformer"
results = trm.evaluate(trainer, tokenised, save_dir=save_dir)
# Save best checkpoint + tokeniser
trm.save_model(trainer, tokenizer)
print(f"\n Final test accuracy : {results['accuracy'] * 100:.2f}%")
print(f" Model saved to : saved_models/transformer/\n")
if __name__ == "__main__":
main()