Spaces:
Running
Running
File size: 4,157 Bytes
a229747 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """
config.py
βββββββββ
Updated for Mac M4 / Apple Silicon MPS.
Key changes vs Windows version:
- Auto device detection: MPS β CUDA β CPU
- Sample caps removed (full 120 K dataset now feasible)
- batch_size 16 (grad_accum_steps=2 β effective batch 32)
- max_length 128 (AG News headlines fit comfortably; saves ~60% VRAM vs 256)
- Default model upgraded to 'roberta-base'
- label_smoothing added for better calibration
- gradient_checkpointing enabled by default (MPS OOM safeguard)
"""
import os
import torch
from dataclasses import dataclass, field
from typing import List, Optional
def _auto_device() -> str:
"""Detect the best available compute device at import time."""
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
@dataclass
class Config:
# ββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
dataset_name: str = "ag_news"
num_labels: int = 4
label_names: List[str] = field(
default_factory=lambda: ["World", "Sports", "Business", "Sci/Tech"]
)
# Full dataset β no caps needed on M4 MPS
max_train_samples: Optional[int] = None # 120,000
max_eval_samples: Optional[int] = None # ~12,000
max_test_samples: Optional[int] = None # 7,600
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Supported checkpoints (swap as needed):
# "distilbert-base-uncased" β 66M params β fastest (~45β70 min on M4)
# "bert-base-uncased" β 110M params β balanced (~90β120 min on M4)
# "roberta-base" β 125M params β best acc (~90β150 min on M4)
model_checkpoint: str = "roberta-base"
max_length: int = 128 # 128 is ample for AG News; saves ~60% VRAM vs 256
# ββ Training Hyper-parameters βββββββββββββββββββββββββββββββββββββββββββββ
batch_size: int = 16 # 16 Γ grad_accum_steps=2 β effective batch 32
num_epochs: int = 3 # Safe training epochs
learning_rate: float = 2e-5
warmup_ratio: float = 0.06
weight_decay: float = 0.01
grad_accum_steps: int = 2 # Accumulate 2 steps β effective batch 32
label_smoothing: float = 0.1 # Regularisation: prevents over-confidence
use_gradient_checkpointing: bool = True # ON by default β critical MPS OOM safeguard
# ββ Hardware (auto-detected) βββββββββββββββββββββββββββββββββββββββββββββββ
device: str = field(default_factory=_auto_device)
# num_workers=0 is safest with HuggingFace datasets in torch format on Mac
num_workers: int = 0
seed: int = 42
low_confidence_threshold: float = 0.60
# ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
data_dir: str = "data"
models_dir: str = "saved_models"
outputs_dir: str = "outputs"
logs_dir: str = os.path.join("outputs", "logs")
def __post_init__(self) -> None:
for d in [self.data_dir, self.models_dir, self.outputs_dir, self.logs_dir]:
os.makedirs(d, exist_ok=True)
device_label = (
"MPS β Apple Metal (M4)" if self.device == "mps" else self.device.upper()
)
print(f"[Config] Device: {device_label} | Model: {self.model_checkpoint}")
# Module-level singleton β imported by all modules
CFG = Config()
|