File size: 4,157 Bytes
a229747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
config.py
─────────
Updated for Mac M4 / Apple Silicon MPS.

Key changes vs Windows version:
  - Auto device detection: MPS β†’ CUDA β†’ CPU
  - Sample caps removed (full 120 K dataset now feasible)
  - batch_size 16  (grad_accum_steps=2 β†’ effective batch 32)
  - max_length 128 (AG News headlines fit comfortably; saves ~60% VRAM vs 256)
  - Default model upgraded to 'roberta-base'
  - label_smoothing added for better calibration
  - gradient_checkpointing enabled by default (MPS OOM safeguard)
"""
import os
import torch
from dataclasses import dataclass, field
from typing import List, Optional


def _auto_device() -> str:
    """Detect the best available compute device at import time."""
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"


@dataclass
class Config:
    # ── Dataset ──────────────────────────────────────────────────────────────
    dataset_name: str       = "ag_news"
    num_labels:   int       = 4
    label_names:  List[str] = field(
        default_factory=lambda: ["World", "Sports", "Business", "Sci/Tech"]
    )

    # Full dataset β€” no caps needed on M4 MPS
    max_train_samples: Optional[int] = None   # 120,000
    max_eval_samples:  Optional[int] = None   # ~12,000
    max_test_samples:  Optional[int] = None   # 7,600

    # ── Model ─────────────────────────────────────────────────────────────────
    # Supported checkpoints (swap as needed):
    #   "distilbert-base-uncased"  β€” 66M params  β€” fastest  (~45–70 min on M4)
    #   "bert-base-uncased"        β€” 110M params  β€” balanced (~90–120 min on M4)
    #   "roberta-base"             β€” 125M params  β€” best acc (~90–150 min on M4)
    model_checkpoint: str = "roberta-base"
    max_length:       int = 128    # 128 is ample for AG News; saves ~60% VRAM vs 256

    # ── Training Hyper-parameters ─────────────────────────────────────────────
    batch_size:                 int   = 16    # 16 Γ— grad_accum_steps=2 β†’ effective batch 32
    num_epochs:                 int   = 3     # Safe training epochs
    learning_rate:              float = 2e-5
    warmup_ratio:               float = 0.06
    weight_decay:               float = 0.01
    grad_accum_steps:           int   = 2     # Accumulate 2 steps β†’ effective batch 32
    label_smoothing:            float = 0.1   # Regularisation: prevents over-confidence
    use_gradient_checkpointing: bool  = True  # ON by default β€” critical MPS OOM safeguard

    # ── Hardware (auto-detected) ───────────────────────────────────────────────
    device:      str = field(default_factory=_auto_device)
    # num_workers=0 is safest with HuggingFace datasets in torch format on Mac
    num_workers: int = 0
    seed:        int = 42
    low_confidence_threshold: float = 0.60

    # ── Paths ─────────────────────────────────────────────────────────────────
    data_dir:    str = "data"
    models_dir:  str = "saved_models"
    outputs_dir: str = "outputs"
    logs_dir:    str = os.path.join("outputs", "logs")

    def __post_init__(self) -> None:
        for d in [self.data_dir, self.models_dir, self.outputs_dir, self.logs_dir]:
            os.makedirs(d, exist_ok=True)
        device_label = (
            "MPS β€” Apple Metal (M4)" if self.device == "mps" else self.device.upper()
        )
        print(f"[Config] Device: {device_label}  |  Model: {self.model_checkpoint}")


# Module-level singleton β€” imported by all modules
CFG = Config()