Spaces:
Sleeping
Sleeping
| """ | |
| config.py | |
| βββββββββ | |
| Updated for Mac M4 / Apple Silicon MPS. | |
| Key changes vs Windows version: | |
| - Auto device detection: MPS β CUDA β CPU | |
| - Sample caps removed (full 120 K dataset now feasible) | |
| - batch_size 16 (grad_accum_steps=2 β effective batch 32) | |
| - max_length 128 (AG News headlines fit comfortably; saves ~60% VRAM vs 256) | |
| - Default model upgraded to 'roberta-base' | |
| - label_smoothing added for better calibration | |
| - gradient_checkpointing enabled by default (MPS OOM safeguard) | |
| """ | |
| import os | |
| import torch | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| def _auto_device() -> str: | |
| """Detect the best available compute device at import time.""" | |
| if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
| return "mps" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| class Config: | |
| # ββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| dataset_name: str = "ag_news" | |
| num_labels: int = 4 | |
| label_names: List[str] = field( | |
| default_factory=lambda: ["World", "Sports", "Business", "Sci/Tech"] | |
| ) | |
| # Full dataset β no caps needed on M4 MPS | |
| max_train_samples: Optional[int] = None # 120,000 | |
| max_eval_samples: Optional[int] = None # ~12,000 | |
| max_test_samples: Optional[int] = None # 7,600 | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Supported checkpoints (swap as needed): | |
| # "distilbert-base-uncased" β 66M params β fastest (~45β70 min on M4) | |
| # "bert-base-uncased" β 110M params β balanced (~90β120 min on M4) | |
| # "roberta-base" β 125M params β best acc (~90β150 min on M4) | |
| model_checkpoint: str = "roberta-base" | |
| max_length: int = 128 # 128 is ample for AG News; saves ~60% VRAM vs 256 | |
| # ββ Training Hyper-parameters βββββββββββββββββββββββββββββββββββββββββββββ | |
| batch_size: int = 16 # 16 Γ grad_accum_steps=2 β effective batch 32 | |
| num_epochs: int = 3 # Safe training epochs | |
| learning_rate: float = 2e-5 | |
| warmup_ratio: float = 0.06 | |
| weight_decay: float = 0.01 | |
| grad_accum_steps: int = 2 # Accumulate 2 steps β effective batch 32 | |
| label_smoothing: float = 0.1 # Regularisation: prevents over-confidence | |
| use_gradient_checkpointing: bool = True # ON by default β critical MPS OOM safeguard | |
| # ββ Hardware (auto-detected) βββββββββββββββββββββββββββββββββββββββββββββββ | |
| device: str = field(default_factory=_auto_device) | |
| # num_workers=0 is safest with HuggingFace datasets in torch format on Mac | |
| num_workers: int = 0 | |
| seed: int = 42 | |
| low_confidence_threshold: float = 0.60 | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| data_dir: str = "data" | |
| models_dir: str = "saved_models" | |
| outputs_dir: str = "outputs" | |
| logs_dir: str = os.path.join("outputs", "logs") | |
| def __post_init__(self) -> None: | |
| for d in [self.data_dir, self.models_dir, self.outputs_dir, self.logs_dir]: | |
| os.makedirs(d, exist_ok=True) | |
| device_label = ( | |
| "MPS β Apple Metal (M4)" if self.device == "mps" else self.device.upper() | |
| ) | |
| print(f"[Config] Device: {device_label} | Model: {self.model_checkpoint}") | |
| # Module-level singleton β imported by all modules | |
| CFG = Config() | |