nexa-classify-api / config.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
4.16 kB
"""
config.py
─────────
Updated for Mac M4 / Apple Silicon MPS.
Key changes vs Windows version:
- Auto device detection: MPS β†’ CUDA β†’ CPU
- Sample caps removed (full 120 K dataset now feasible)
- batch_size 16 (grad_accum_steps=2 β†’ effective batch 32)
- max_length 128 (AG News headlines fit comfortably; saves ~60% VRAM vs 256)
- Default model upgraded to 'roberta-base'
- label_smoothing added for better calibration
- gradient_checkpointing enabled by default (MPS OOM safeguard)
"""
import os
import torch
from dataclasses import dataclass, field
from typing import List, Optional
def _auto_device() -> str:
"""Detect the best available compute device at import time."""
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return "mps"
if torch.cuda.is_available():
return "cuda"
return "cpu"
@dataclass
class Config:
# ── Dataset ──────────────────────────────────────────────────────────────
dataset_name: str = "ag_news"
num_labels: int = 4
label_names: List[str] = field(
default_factory=lambda: ["World", "Sports", "Business", "Sci/Tech"]
)
# Full dataset β€” no caps needed on M4 MPS
max_train_samples: Optional[int] = None # 120,000
max_eval_samples: Optional[int] = None # ~12,000
max_test_samples: Optional[int] = None # 7,600
# ── Model ─────────────────────────────────────────────────────────────────
# Supported checkpoints (swap as needed):
# "distilbert-base-uncased" β€” 66M params β€” fastest (~45–70 min on M4)
# "bert-base-uncased" β€” 110M params β€” balanced (~90–120 min on M4)
# "roberta-base" β€” 125M params β€” best acc (~90–150 min on M4)
model_checkpoint: str = "roberta-base"
max_length: int = 128 # 128 is ample for AG News; saves ~60% VRAM vs 256
# ── Training Hyper-parameters ─────────────────────────────────────────────
batch_size: int = 16 # 16 Γ— grad_accum_steps=2 β†’ effective batch 32
num_epochs: int = 3 # Safe training epochs
learning_rate: float = 2e-5
warmup_ratio: float = 0.06
weight_decay: float = 0.01
grad_accum_steps: int = 2 # Accumulate 2 steps β†’ effective batch 32
label_smoothing: float = 0.1 # Regularisation: prevents over-confidence
use_gradient_checkpointing: bool = True # ON by default β€” critical MPS OOM safeguard
# ── Hardware (auto-detected) ───────────────────────────────────────────────
device: str = field(default_factory=_auto_device)
# num_workers=0 is safest with HuggingFace datasets in torch format on Mac
num_workers: int = 0
seed: int = 42
low_confidence_threshold: float = 0.60
# ── Paths ─────────────────────────────────────────────────────────────────
data_dir: str = "data"
models_dir: str = "saved_models"
outputs_dir: str = "outputs"
logs_dir: str = os.path.join("outputs", "logs")
def __post_init__(self) -> None:
for d in [self.data_dir, self.models_dir, self.outputs_dir, self.logs_dir]:
os.makedirs(d, exist_ok=True)
device_label = (
"MPS β€” Apple Metal (M4)" if self.device == "mps" else self.device.upper()
)
print(f"[Config] Device: {device_label} | Model: {self.model_checkpoint}")
# Module-level singleton β€” imported by all modules
CFG = Config()