File size: 7,738 Bytes
b43d8da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""Gemma 3n E2B boot via Unsloth FastModel (docs/modules/training.md §3.1).

Contract:
  - Base model: ``unsloth/gemma-3n-E2B-it`` (4-bit Dynamic
    NF4 quantization).
  - Precision: hardware-aware.
      V100 (sm_70) — explicit FP16 (``dtype=torch.float16``); Gemma 3n is
      BF16-native, so we force FP16 on V100 to avoid BF16 software-emulation
      slowdown / numerical instability.
      H100 (sm_90) — BF16 (``dtype=torch.bfloat16``); uses native tensor cores.
  - LoRA: r=16, α=32, dropout=0.05, vision towers frozen, language + attention
    + MLP trainable via Unsloth's multimodal API (``finetune_vision_layers=False,
    finetune_language_layers=True, finetune_attention_modules=True,
    finetune_mlp_modules=True``), Unsloth gradient checkpointing,
    ``random_state=3407``.
  - V100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.float16``
    after FP16 load; any BF16 parameter triggers :class:`BF16SlippageError`
    before optimizer build.
  - H100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.bfloat16``
    after BF16 load; any FP16 parameter triggers :class:`FP16SlippageError`
    before optimizer build.

Heavy imports (``unsloth``, ``torch``) are deferred inside functions so this
cell loads on CPU-only CI runners where Unsloth is not installed. Tests mock
``FastModel.from_pretrained`` and ``FastModel.get_peft_model``.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Literal

BASE_MODEL_ID: str = "unsloth/gemma-3n-E2B-it"
MAX_SEQ_LENGTH: int = 4096
LORA_R: int = 16
LORA_ALPHA: int = 32
LORA_DROPOUT: float = 0.05
LORA_RANDOM_STATE: int = 3407

# Gemma 3n multimodal LoRA flags — vision/audio towers stay frozen so GRPO
# trains only the language stack (Unsloth Gemma 3N notebook §fine-tune).
FINETUNE_VISION_LAYERS: bool = False
FINETUNE_LANGUAGE_LAYERS: bool = True
FINETUNE_ATTENTION_MODULES: bool = True
FINETUNE_MLP_MODULES: bool = True

HardwareT = Literal["v100", "h100"]
ALLOWED_HARDWARE: tuple[HardwareT, ...] = ("v100", "h100")


class BF16SlippageError(AssertionError):
    """Raised when the loaded model has any BF16 parameter on V100.

    V100 (sm_70) lacks BF16 tensor cores. Silent BF16 via software emulation
    causes ~10x slowdown plus numerical-instability patterns in
    ``docs/modules/training.md §7a``. Halt before the optimizer is built.
    """


class FP16SlippageError(AssertionError):
    """Raised when the loaded model has any FP16 parameter on H100.

    H100 (sm_90) has native BF16 tensor cores. Running FP16 on H100 means
    leaving native hardware capability unused and may cause gradient underflow
    at large batch sizes. Halt before the optimizer is built.
    """


@dataclass(frozen=True)
class BootConfig:
    """Arguments to :func:`boot_gemma`. Frozen per DriftCall immutability rule."""

    base_model_id: str = BASE_MODEL_ID
    max_seq_length: int = MAX_SEQ_LENGTH
    load_in_4bit: bool = True
    lora_r: int = LORA_R
    lora_alpha: int = LORA_ALPHA
    lora_dropout: float = LORA_DROPOUT
    lora_random_state: int = LORA_RANDOM_STATE
    finetune_vision_layers: bool = FINETUNE_VISION_LAYERS
    finetune_language_layers: bool = FINETUNE_LANGUAGE_LAYERS
    finetune_attention_modules: bool = FINETUNE_ATTENTION_MODULES
    finetune_mlp_modules: bool = FINETUNE_MLP_MODULES
    use_gradient_checkpointing: str = "unsloth"
    hardware: HardwareT = "v100"


def assert_dtype_for_hardware(model: Any, hardware: HardwareT) -> None:
    """Assert the first parameter dtype matches the expected precision for hardware.

    V100 must be ``torch.float16``; raises :class:`BF16SlippageError` otherwise.
    H100 must be ``torch.bfloat16``; raises :class:`FP16SlippageError` otherwise.
    Called once at ``boot_gemma`` entry, before any LoRA attach or optimizer build.
    """
    import torch

    params_iter = model.parameters()
    try:
        first_param = next(params_iter)
    except StopIteration as exc:  # pragma: no cover - defensive
        raise BF16SlippageError(
            "Model has no parameters; cannot verify dtype."
        ) from exc

    dtype = first_param.dtype
    if hardware == "v100":
        if dtype != torch.float16:
            raise BF16SlippageError(
                f"BF16 slipped through: V100 unsafe. "
                f"next(model.parameters()).dtype == {dtype}, expected torch.float16. "
                f"Root cause: Unsloth auto-picked BF16 despite dtype=torch.float16 kwarg. "
                f"Halt training; do NOT proceed on V100."
            )
    else:  # h100
        if dtype != torch.bfloat16:
            raise FP16SlippageError(
                f"FP16 slipped through: H100 should use BF16. "
                f"next(model.parameters()).dtype == {dtype}, expected torch.bfloat16. "
                f"Root cause: dtype kwarg may have forced FP16 on H100. "
                f"Halt training; do NOT proceed on H100 with FP16."
            )


def assert_fp16_dtype(model: Any) -> None:
    """Assert the first trainable parameter is torch.float16 (V100 safety).

    Thin wrapper around :func:`assert_dtype_for_hardware` for backwards
    compatibility with call sites that predate the hardware-aware API.
    Raises :class:`BF16SlippageError` with the halt message from
    ``docs/modules/training.md §3.1``.
    """
    assert_dtype_for_hardware(model, "v100")


def boot_gemma(config: BootConfig | None = None) -> tuple[Any, Any]:
    """Load Gemma 3n E2B in 4-bit + attach LoRA; return (model, tokenizer).

    Steps (training.md §3.1):
      1. ``FastModel.from_pretrained(base_model_id, max_seq_length=...,
         load_in_4bit=True, dtype=torch.float16)`` on V100
         or ``dtype=torch.bfloat16`` on H100.
      2. ``assert_dtype_for_hardware(model, hardware)`` — raises
         :class:`BF16SlippageError` or :class:`FP16SlippageError` if the dtype
         does not match the hardware.
      3. ``FastModel.get_peft_model(model, r=16, lora_alpha=32,
         finetune_vision_layers=False, finetune_language_layers=True,
         finetune_attention_modules=True, finetune_mlp_modules=True,
         use_gradient_checkpointing="unsloth", random_state=3407)``.
      4. Return ``(peft_model, tokenizer)``.

    All heavy imports are lazy so the module is importable on CPU-only CI.
    """
    cfg = config if config is not None else BootConfig()

    import torch
    from unsloth import FastModel

    dtype = torch.float16 if cfg.hardware == "v100" else torch.bfloat16

    model, tokenizer = FastModel.from_pretrained(
        cfg.base_model_id,
        max_seq_length=cfg.max_seq_length,
        load_in_4bit=cfg.load_in_4bit,
        dtype=dtype,
    )

    assert_dtype_for_hardware(model, cfg.hardware)

    peft_model = FastModel.get_peft_model(
        model,
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        lora_dropout=cfg.lora_dropout,
        finetune_vision_layers=cfg.finetune_vision_layers,
        finetune_language_layers=cfg.finetune_language_layers,
        finetune_attention_modules=cfg.finetune_attention_modules,
        finetune_mlp_modules=cfg.finetune_mlp_modules,
        use_gradient_checkpointing=cfg.use_gradient_checkpointing,
        random_state=cfg.lora_random_state,
    )

    return peft_model, tokenizer


__all__ = [
    "ALLOWED_HARDWARE",
    "BASE_MODEL_ID",
    "BF16SlippageError",
    "BootConfig",
    "FINETUNE_ATTENTION_MODULES",
    "FINETUNE_LANGUAGE_LAYERS",
    "FINETUNE_MLP_MODULES",
    "FINETUNE_VISION_LAYERS",
    "FP16SlippageError",
    "HardwareT",
    "LORA_ALPHA",
    "LORA_DROPOUT",
    "LORA_R",
    "LORA_RANDOM_STATE",
    "MAX_SEQ_LENGTH",
    "assert_dtype_for_hardware",
    "assert_fp16_dtype",
    "boot_gemma",
]