Spaces:
Sleeping
Sleeping
| """ | |
| transformer_model.py | |
| ββββββββββββββββββββ | |
| Updated for Mac M4 / Apple Silicon MPS. | |
| Key changes vs Windows version: | |
| - Removed use_cpu=True β Trainer auto-detects MPS on Mac | |
| - Added label_smoothing_factor | |
| - Model-specific output directories (supports multiple architectures) | |
| - Gradient checkpointing toggle | |
| - Cleaned up device handling for inference | |
| """ | |
| import logging | |
| import os | |
| import time | |
| from typing import Dict, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| classification_report, | |
| confusion_matrix, | |
| f1_score, | |
| ) | |
| from transformers import ( | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| DataCollatorWithPadding, | |
| EarlyStoppingCallback, | |
| PreTrainedTokenizerBase, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| from config import CFG | |
| logger = logging.getLogger(__name__) | |
| # ββ Helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _checkpoint_to_dir(checkpoint: str) -> str: | |
| """Convert a HuggingFace checkpoint name to a safe directory name. | |
| Examples: | |
| 'roberta-base' β 'roberta_base' | |
| 'distilbert-base-uncased' β 'distilbert_base_uncased' | |
| """ | |
| return checkpoint.replace("/", "_").replace("-", "_") | |
| # ββ Model factory βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_model(checkpoint: str = None) -> AutoModelForSequenceClassification: | |
| """Load a pre-trained encoder with a randomly-initialised classification head.""" | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| checkpoint, | |
| num_labels=CFG.num_labels, | |
| id2label={i: n for i, n in enumerate(CFG.label_names)}, | |
| label2id={n: i for i, n in enumerate(CFG.label_names)}, | |
| ) | |
| if CFG.use_gradient_checkpointing: | |
| model.gradient_checkpointing_enable() | |
| logger.info("Gradient checkpointing: ON") | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| logger.info(f"Model: {checkpoint} | total={total:,} trainable={trainable:,}") | |
| return model | |
| # ββ Training arguments ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_training_args(checkpoint: str = None, output_dir: str = None) -> TrainingArguments: | |
| """ | |
| Build MPS-safe TrainingArguments for the HuggingFace Trainer. | |
| Critical Mac M4 notes | |
| βββββββββββββββββββββ | |
| β’ Do NOT set use_cpu=True β the Trainer auto-detects MPS on Mac | |
| β’ fp16=False β MPS lacks full float16 operator coverage | |
| β’ bf16=False β Keep False for reliability (can try True on M2+) | |
| β’ dataloader_pin_memory=False β pin_memory only benefits CUDA | |
| β’ dataloader_num_workers=0 β HuggingFace torch datasets + multiprocessing | |
| can be unstable on Mac; 0 is safest | |
| Transformers 5.x deprecations handled here | |
| ββββββββββββββββββββββββββββββββββββββββββ | |
| β’ warmup_ratio β warmup_steps (computed manually below) | |
| β’ logging_dir β TENSORBOARD_LOGGING_DIR env var | |
| """ | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| if output_dir is None: | |
| output_dir = os.path.join(CFG.outputs_dir, _checkpoint_to_dir(checkpoint)) | |
| # Compute warmup_steps from ratio (replaces deprecated warmup_ratio arg) | |
| total_steps = (108_000 // CFG.batch_size) * CFG.num_epochs // CFG.grad_accum_steps | |
| warmup_steps = int(total_steps * CFG.warmup_ratio) | |
| # Set TensorBoard log dir via env var (replaces deprecated logging_dir arg) | |
| os.environ["TENSORBOARD_LOGGING_DIR"] = CFG.logs_dir | |
| return TrainingArguments( | |
| output_dir=output_dir, | |
| # ββ Schedule ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| num_train_epochs=CFG.num_epochs, | |
| per_device_train_batch_size=CFG.batch_size, | |
| per_device_eval_batch_size=CFG.batch_size * 2, | |
| gradient_accumulation_steps=CFG.grad_accum_steps, | |
| # ββ Optimiser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| learning_rate=CFG.learning_rate, | |
| weight_decay=CFG.weight_decay, | |
| warmup_steps=warmup_steps, # replaces deprecated warmup_ratio | |
| lr_scheduler_type="cosine", | |
| label_smoothing_factor=CFG.label_smoothing, | |
| # ββ Evaluation & checkpointing βββββββββββββββββββββββββββββββββββββββ | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| greater_is_better=True, | |
| save_total_limit=2, | |
| # ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # logging_dir is deprecated in transformers 5.x; use TENSORBOARD_LOGGING_DIR env var instead | |
| logging_steps=100, | |
| report_to="none", | |
| # ββ MPS / Mac-specific ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NOTE: No use_cpu=True here β MPS is auto-detected on Mac M4 | |
| fp16=False, | |
| bf16=False, | |
| dataloader_num_workers=CFG.num_workers, | |
| dataloader_pin_memory=False, | |
| # ββ Reproducibility ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| seed=CFG.seed, | |
| data_seed=CFG.seed, | |
| push_to_hub=False, | |
| ) | |
| # ββ Metrics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_metrics(eval_pred) -> Dict[str, float]: | |
| """Called by Trainer after every validation epoch.""" | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return { | |
| "accuracy": float(accuracy_score(labels, preds)), | |
| "f1_macro": float(f1_score(labels, preds, average="macro")), | |
| } | |
| # ββ Training pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(tokenised_dataset, tokenizer: PreTrainedTokenizerBase, | |
| checkpoint: str = None) -> Trainer: | |
| """ | |
| Fine-tune a transformer encoder and return the Trainer | |
| with the best checkpoint loaded. | |
| """ | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| model = build_model(checkpoint) | |
| training_args = get_training_args(checkpoint) | |
| data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenised_dataset["train"], | |
| eval_dataset=tokenised_dataset["validation"], | |
| processing_class=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| ) | |
| steps_per_epoch = len(tokenised_dataset["train"]) // CFG.batch_size | |
| device_label = "MPS (Metal)" if CFG.device == "mps" else CFG.device.upper() | |
| logger.info("β" * 60) | |
| logger.info(f" Fine-Tuning: {checkpoint}") | |
| logger.info(f" Device : {device_label}") | |
| logger.info(f" train : {len(tokenised_dataset['train']):,}") | |
| logger.info(f" val : {len(tokenised_dataset['validation']):,}") | |
| logger.info(f" epochs : {CFG.num_epochs}") | |
| logger.info(f" batch : {CFG.batch_size}") | |
| logger.info(f" steps/ep : {steps_per_epoch:,}") | |
| logger.info(f" max_length : {CFG.max_length}") | |
| logger.info("β" * 60) | |
| t0 = time.perf_counter() | |
| trainer.train() | |
| elapsed = time.perf_counter() - t0 | |
| h, rem = divmod(int(elapsed), 3600) | |
| m, s = divmod(rem, 60) | |
| logger.info(f"Training complete: {h}h {m}m {s}s") | |
| return trainer | |
| # ββ Evaluation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate(trainer: Trainer, tokenised_dataset, | |
| checkpoint: str = None, save_dir: str = None) -> Dict: | |
| """Run predictions on the test split and print full report.""" | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| logger.info(f"Evaluating {checkpoint} on test set β¦") | |
| predictions = trainer.predict(tokenised_dataset["test"]) | |
| preds = np.argmax(predictions.predictions, axis=-1) | |
| labels = predictions.label_ids | |
| acc = accuracy_score(labels, preds) | |
| report = classification_report(labels, preds, | |
| target_names=CFG.label_names, digits=4) | |
| cm = confusion_matrix(labels, preds) | |
| print("\n" + "β" * 60) | |
| print(f" {checkpoint.upper()} β TEST SET RESULTS") | |
| print("β" * 60) | |
| print(f" Accuracy : {acc * 100:.2f}%") | |
| print(f" Metrics : {predictions.metrics}\n") | |
| print(report) | |
| _plot_cm(cm, f"{checkpoint} β Confusion Matrix", | |
| save_dir=save_dir, cmap="Greens") | |
| return { | |
| "accuracy": acc, | |
| "report": report, | |
| "confusion_matrix": cm, | |
| "metrics": predictions.metrics, | |
| } | |
| # ββ Persistence βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save_model(trainer: Trainer, tokenizer: PreTrainedTokenizerBase, | |
| checkpoint: str = None) -> str: | |
| """Save best checkpoint + tokeniser to saved_models/<model_dir>/.""" | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| path = os.path.join(CFG.models_dir, _checkpoint_to_dir(checkpoint)) | |
| trainer.save_model(path) | |
| tokenizer.save_pretrained(path) | |
| logger.info(f"Model saved β {path}") | |
| return path | |
| def load_model(checkpoint: str = None) -> Tuple: | |
| """ | |
| Load a saved fine-tuned model and its tokeniser. | |
| Parameters | |
| ---------- | |
| checkpoint : HuggingFace checkpoint name, e.g. 'roberta-base'. | |
| If None, uses CFG.model_checkpoint. | |
| Returns | |
| ------- | |
| (model, tokenizer) β model is in eval mode | |
| """ | |
| if checkpoint is None: | |
| checkpoint = CFG.model_checkpoint | |
| path = os.path.join(CFG.models_dir, _checkpoint_to_dir(checkpoint)) | |
| if not os.path.isdir(path): | |
| raise FileNotFoundError( | |
| f"No saved model at '{path}'.\n" | |
| f"Run: python train_transformer.py (or python train_multi.py)" | |
| ) | |
| model = AutoModelForSequenceClassification.from_pretrained(path) | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| model.eval() | |
| logger.info(f"Model loaded β {path}") | |
| return model, tokenizer | |
| def load_quantized_model(checkpoint: str = "distilbert-base-uncased") -> Tuple: | |
| """ | |
| Load the INT8 dynamically quantized version of a model. | |
| Falls back to the FP32 model if the INT8 version is not found. | |
| Returns | |
| ------- | |
| (model, tokenizer, is_quantized) | |
| """ | |
| dir_name = _checkpoint_to_dir(checkpoint) | |
| int8_path = os.path.join(CFG.models_dir, f"{dir_name}_int8") | |
| fp32_path = os.path.join(CFG.models_dir, dir_name) | |
| model_file = os.path.join(int8_path, "model_int8.pt") | |
| if os.path.exists(model_file): | |
| # Apple Silicon/ARM requires the qengine (qnnpack) to be set before deserialising | |
| try: | |
| torch.backends.quantized.engine = "qnnpack" | |
| except Exception: | |
| pass | |
| try: | |
| model = torch.load(model_file, map_location="cpu", weights_only=False) | |
| except TypeError: | |
| model = torch.load(model_file, map_location="cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(int8_path) | |
| model.eval() | |
| logger.info(f"INT8 quantized model loaded β {int8_path}") | |
| return model, tokenizer, True | |
| logger.warning(f"INT8 model not found at {int8_path}. Falling back to FP32.") | |
| model, tokenizer = load_model(checkpoint) | |
| return model, tokenizer, False | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _plot_cm(cm: np.ndarray, title: str, | |
| save_dir: str = None, cmap: str = "Blues") -> None: | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| sns.heatmap(cm, annot=True, fmt="d", cmap=cmap, | |
| xticklabels=CFG.label_names, yticklabels=CFG.label_names, | |
| linewidths=0.5, ax=ax) | |
| ax.set_xlabel("Predicted Label", fontsize=11) | |
| ax.set_ylabel("True Label", fontsize=11) | |
| ax.set_title(title, fontsize=13, fontweight="bold") | |
| plt.tight_layout() | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| path = os.path.join(save_dir, "confusion_matrix.png") | |
| plt.savefig(path, dpi=150) | |
| logger.info(f"Confusion matrix β {path}") | |
| plt.show() | |
| plt.close(fig) | |