""" config.py ───────── Updated for Mac M4 / Apple Silicon MPS. Key changes vs Windows version: - Auto device detection: MPS → CUDA → CPU - Sample caps removed (full 120 K dataset now feasible) - batch_size 16 (grad_accum_steps=2 → effective batch 32) - max_length 128 (AG News headlines fit comfortably; saves ~60% VRAM vs 256) - Default model upgraded to 'roberta-base' - label_smoothing added for better calibration - gradient_checkpointing enabled by default (MPS OOM safeguard) """ import os import torch from dataclasses import dataclass, field from typing import List, Optional def _auto_device() -> str: """Detect the best available compute device at import time.""" if torch.backends.mps.is_available() and torch.backends.mps.is_built(): return "mps" if torch.cuda.is_available(): return "cuda" return "cpu" @dataclass class Config: # ── Dataset ────────────────────────────────────────────────────────────── dataset_name: str = "ag_news" num_labels: int = 4 label_names: List[str] = field( default_factory=lambda: ["World", "Sports", "Business", "Sci/Tech"] ) # Full dataset — no caps needed on M4 MPS max_train_samples: Optional[int] = None # 120,000 max_eval_samples: Optional[int] = None # ~12,000 max_test_samples: Optional[int] = None # 7,600 # ── Model ───────────────────────────────────────────────────────────────── # Supported checkpoints (swap as needed): # "distilbert-base-uncased" — 66M params — fastest (~45–70 min on M4) # "bert-base-uncased" — 110M params — balanced (~90–120 min on M4) # "roberta-base" — 125M params — best acc (~90–150 min on M4) model_checkpoint: str = "roberta-base" max_length: int = 128 # 128 is ample for AG News; saves ~60% VRAM vs 256 # ── Training Hyper-parameters ───────────────────────────────────────────── batch_size: int = 16 # 16 × grad_accum_steps=2 → effective batch 32 num_epochs: int = 3 # Safe training epochs learning_rate: float = 2e-5 warmup_ratio: float = 0.06 weight_decay: float = 0.01 grad_accum_steps: int = 2 # Accumulate 2 steps → effective batch 32 label_smoothing: float = 0.1 # Regularisation: prevents over-confidence use_gradient_checkpointing: bool = True # ON by default — critical MPS OOM safeguard # ── Hardware (auto-detected) ─────────────────────────────────────────────── device: str = field(default_factory=_auto_device) # num_workers=0 is safest with HuggingFace datasets in torch format on Mac num_workers: int = 0 seed: int = 42 low_confidence_threshold: float = 0.60 # ── Paths ───────────────────────────────────────────────────────────────── data_dir: str = "data" models_dir: str = "saved_models" outputs_dir: str = "outputs" logs_dir: str = os.path.join("outputs", "logs") def __post_init__(self) -> None: for d in [self.data_dir, self.models_dir, self.outputs_dir, self.logs_dir]: os.makedirs(d, exist_ok=True) device_label = ( "MPS — Apple Metal (M4)" if self.device == "mps" else self.device.upper() ) print(f"[Config] Device: {device_label} | Model: {self.model_checkpoint}") # Module-level singleton — imported by all modules CFG = Config()