""" transformer_model.py ──────────────────── Updated for Mac M4 / Apple Silicon MPS. Key changes vs Windows version: - Removed use_cpu=True → Trainer auto-detects MPS on Mac - Added label_smoothing_factor - Model-specific output directories (supports multiple architectures) - Gradient checkpointing toggle - Cleaned up device handling for inference """ import logging import os import time from typing import Dict, Optional, Tuple import numpy as np import torch import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, f1_score, ) from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, EarlyStoppingCallback, PreTrainedTokenizerBase, Trainer, TrainingArguments, ) from config import CFG logger = logging.getLogger(__name__) # ── Helper ──────────────────────────────────────────────────────────────────── def _checkpoint_to_dir(checkpoint: str) -> str: """Convert a HuggingFace checkpoint name to a safe directory name. Examples: 'roberta-base' → 'roberta_base' 'distilbert-base-uncased' → 'distilbert_base_uncased' """ return checkpoint.replace("/", "_").replace("-", "_") # ── Model factory ───────────────────────────────────────────────────────────── def build_model(checkpoint: str = None) -> AutoModelForSequenceClassification: """Load a pre-trained encoder with a randomly-initialised classification head.""" if checkpoint is None: checkpoint = CFG.model_checkpoint model = AutoModelForSequenceClassification.from_pretrained( checkpoint, num_labels=CFG.num_labels, id2label={i: n for i, n in enumerate(CFG.label_names)}, label2id={n: i for i, n in enumerate(CFG.label_names)}, ) if CFG.use_gradient_checkpointing: model.gradient_checkpointing_enable() logger.info("Gradient checkpointing: ON") total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Model: {checkpoint} | total={total:,} trainable={trainable:,}") return model # ── Training arguments ──────────────────────────────────────────────────────── def get_training_args(checkpoint: str = None, output_dir: str = None) -> TrainingArguments: """ Build MPS-safe TrainingArguments for the HuggingFace Trainer. Critical Mac M4 notes ───────────────────── • Do NOT set use_cpu=True — the Trainer auto-detects MPS on Mac • fp16=False — MPS lacks full float16 operator coverage • bf16=False — Keep False for reliability (can try True on M2+) • dataloader_pin_memory=False — pin_memory only benefits CUDA • dataloader_num_workers=0 — HuggingFace torch datasets + multiprocessing can be unstable on Mac; 0 is safest Transformers 5.x deprecations handled here ────────────────────────────────────────── • warmup_ratio → warmup_steps (computed manually below) • logging_dir → TENSORBOARD_LOGGING_DIR env var """ if checkpoint is None: checkpoint = CFG.model_checkpoint if output_dir is None: output_dir = os.path.join(CFG.outputs_dir, _checkpoint_to_dir(checkpoint)) # Compute warmup_steps from ratio (replaces deprecated warmup_ratio arg) total_steps = (108_000 // CFG.batch_size) * CFG.num_epochs // CFG.grad_accum_steps warmup_steps = int(total_steps * CFG.warmup_ratio) # Set TensorBoard log dir via env var (replaces deprecated logging_dir arg) os.environ["TENSORBOARD_LOGGING_DIR"] = CFG.logs_dir return TrainingArguments( output_dir=output_dir, # ── Schedule ──────────────────────────────────────────────────────── num_train_epochs=CFG.num_epochs, per_device_train_batch_size=CFG.batch_size, per_device_eval_batch_size=CFG.batch_size * 2, gradient_accumulation_steps=CFG.grad_accum_steps, # ── Optimiser ──────────────────────────────────────────────────────── learning_rate=CFG.learning_rate, weight_decay=CFG.weight_decay, warmup_steps=warmup_steps, # replaces deprecated warmup_ratio lr_scheduler_type="cosine", label_smoothing_factor=CFG.label_smoothing, # ── Evaluation & checkpointing ─────────────────────────────────────── eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, save_total_limit=2, # ── Logging ────────────────────────────────────────────────────────── # logging_dir is deprecated in transformers 5.x; use TENSORBOARD_LOGGING_DIR env var instead logging_steps=100, report_to="none", # ── MPS / Mac-specific ──────────────────────────────────────────────── # NOTE: No use_cpu=True here — MPS is auto-detected on Mac M4 fp16=False, bf16=False, dataloader_num_workers=CFG.num_workers, dataloader_pin_memory=False, # ── Reproducibility ────────────────────────────────────────────────── seed=CFG.seed, data_seed=CFG.seed, push_to_hub=False, ) # ── Metrics ─────────────────────────────────────────────────────────────────── def compute_metrics(eval_pred) -> Dict[str, float]: """Called by Trainer after every validation epoch.""" logits, labels = eval_pred preds = np.argmax(logits, axis=-1) return { "accuracy": float(accuracy_score(labels, preds)), "f1_macro": float(f1_score(labels, preds, average="macro")), } # ── Training pipeline ───────────────────────────────────────────────────────── def train(tokenised_dataset, tokenizer: PreTrainedTokenizerBase, checkpoint: str = None) -> Trainer: """ Fine-tune a transformer encoder and return the Trainer with the best checkpoint loaded. """ if checkpoint is None: checkpoint = CFG.model_checkpoint model = build_model(checkpoint) training_args = get_training_args(checkpoint) data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt") trainer = Trainer( model=model, args=training_args, train_dataset=tokenised_dataset["train"], eval_dataset=tokenised_dataset["validation"], processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], ) steps_per_epoch = len(tokenised_dataset["train"]) // CFG.batch_size device_label = "MPS (Metal)" if CFG.device == "mps" else CFG.device.upper() logger.info("═" * 60) logger.info(f" Fine-Tuning: {checkpoint}") logger.info(f" Device : {device_label}") logger.info(f" train : {len(tokenised_dataset['train']):,}") logger.info(f" val : {len(tokenised_dataset['validation']):,}") logger.info(f" epochs : {CFG.num_epochs}") logger.info(f" batch : {CFG.batch_size}") logger.info(f" steps/ep : {steps_per_epoch:,}") logger.info(f" max_length : {CFG.max_length}") logger.info("═" * 60) t0 = time.perf_counter() trainer.train() elapsed = time.perf_counter() - t0 h, rem = divmod(int(elapsed), 3600) m, s = divmod(rem, 60) logger.info(f"Training complete: {h}h {m}m {s}s") return trainer # ── Evaluation ──────────────────────────────────────────────────────────────── def evaluate(trainer: Trainer, tokenised_dataset, checkpoint: str = None, save_dir: str = None) -> Dict: """Run predictions on the test split and print full report.""" if checkpoint is None: checkpoint = CFG.model_checkpoint logger.info(f"Evaluating {checkpoint} on test set …") predictions = trainer.predict(tokenised_dataset["test"]) preds = np.argmax(predictions.predictions, axis=-1) labels = predictions.label_ids acc = accuracy_score(labels, preds) report = classification_report(labels, preds, target_names=CFG.label_names, digits=4) cm = confusion_matrix(labels, preds) print("\n" + "═" * 60) print(f" {checkpoint.upper()} — TEST SET RESULTS") print("═" * 60) print(f" Accuracy : {acc * 100:.2f}%") print(f" Metrics : {predictions.metrics}\n") print(report) _plot_cm(cm, f"{checkpoint} — Confusion Matrix", save_dir=save_dir, cmap="Greens") return { "accuracy": acc, "report": report, "confusion_matrix": cm, "metrics": predictions.metrics, } # ── Persistence ─────────────────────────────────────────────────────────────── def save_model(trainer: Trainer, tokenizer: PreTrainedTokenizerBase, checkpoint: str = None) -> str: """Save best checkpoint + tokeniser to saved_models//.""" if checkpoint is None: checkpoint = CFG.model_checkpoint path = os.path.join(CFG.models_dir, _checkpoint_to_dir(checkpoint)) trainer.save_model(path) tokenizer.save_pretrained(path) logger.info(f"Model saved → {path}") return path def load_model(checkpoint: str = None) -> Tuple: """ Load a saved fine-tuned model and its tokeniser. Parameters ---------- checkpoint : HuggingFace checkpoint name, e.g. 'roberta-base'. If None, uses CFG.model_checkpoint. Returns ------- (model, tokenizer) — model is in eval mode """ if checkpoint is None: checkpoint = CFG.model_checkpoint path = os.path.join(CFG.models_dir, _checkpoint_to_dir(checkpoint)) if not os.path.isdir(path): raise FileNotFoundError( f"No saved model at '{path}'.\n" f"Run: python train_transformer.py (or python train_multi.py)" ) model = AutoModelForSequenceClassification.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path) model.eval() logger.info(f"Model loaded ← {path}") return model, tokenizer def load_quantized_model(checkpoint: str = "distilbert-base-uncased") -> Tuple: """ Load the INT8 dynamically quantized version of a model. Falls back to the FP32 model if the INT8 version is not found. Returns ------- (model, tokenizer, is_quantized) """ dir_name = _checkpoint_to_dir(checkpoint) int8_path = os.path.join(CFG.models_dir, f"{dir_name}_int8") fp32_path = os.path.join(CFG.models_dir, dir_name) model_file = os.path.join(int8_path, "model_int8.pt") if os.path.exists(model_file): # Apple Silicon/ARM requires the qengine (qnnpack) to be set before deserialising try: torch.backends.quantized.engine = "qnnpack" except Exception: pass try: model = torch.load(model_file, map_location="cpu", weights_only=False) except TypeError: model = torch.load(model_file, map_location="cpu") tokenizer = AutoTokenizer.from_pretrained(int8_path) model.eval() logger.info(f"INT8 quantized model loaded ← {int8_path}") return model, tokenizer, True logger.warning(f"INT8 model not found at {int8_path}. Falling back to FP32.") model, tokenizer = load_model(checkpoint) return model, tokenizer, False # ── Helpers ─────────────────────────────────────────────────────────────────── def _plot_cm(cm: np.ndarray, title: str, save_dir: str = None, cmap: str = "Blues") -> None: fig, ax = plt.subplots(figsize=(7, 6)) sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, xticklabels=CFG.label_names, yticklabels=CFG.label_names, linewidths=0.5, ax=ax) ax.set_xlabel("Predicted Label", fontsize=11) ax.set_ylabel("True Label", fontsize=11) ax.set_title(title, fontsize=13, fontweight="bold") plt.tight_layout() if save_dir: os.makedirs(save_dir, exist_ok=True) path = os.path.join(save_dir, "confusion_matrix.png") plt.savefig(path, dpi=150) logger.info(f"Confusion matrix → {path}") plt.show() plt.close(fig)