| """Gemma 3n E2B boot via Unsloth FastModel (docs/modules/training.md §3.1). |
| |
| Contract: |
| - Base model: ``unsloth/gemma-3n-E2B-it`` (4-bit Dynamic |
| NF4 quantization). |
| - Precision: hardware-aware. |
| V100 (sm_70) — explicit FP16 (``dtype=torch.float16``); Gemma 3n is |
| BF16-native, so we force FP16 on V100 to avoid BF16 software-emulation |
| slowdown / numerical instability. |
| H100 (sm_90) — BF16 (``dtype=torch.bfloat16``); uses native tensor cores. |
| - LoRA: r=16, α=32, dropout=0.05, vision towers frozen, language + attention |
| + MLP trainable via Unsloth's multimodal API (``finetune_vision_layers=False, |
| finetune_language_layers=True, finetune_attention_modules=True, |
| finetune_mlp_modules=True``), Unsloth gradient checkpointing, |
| ``random_state=3407``. |
| - V100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.float16`` |
| after FP16 load; any BF16 parameter triggers :class:`BF16SlippageError` |
| before optimizer build. |
| - H100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.bfloat16`` |
| after BF16 load; any FP16 parameter triggers :class:`FP16SlippageError` |
| before optimizer build. |
| |
| Heavy imports (``unsloth``, ``torch``) are deferred inside functions so this |
| cell loads on CPU-only CI runners where Unsloth is not installed. Tests mock |
| ``FastModel.from_pretrained`` and ``FastModel.get_peft_model``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Literal |
|
|
| BASE_MODEL_ID: str = "unsloth/gemma-3n-E2B-it" |
| MAX_SEQ_LENGTH: int = 4096 |
| LORA_R: int = 16 |
| LORA_ALPHA: int = 32 |
| LORA_DROPOUT: float = 0.05 |
| LORA_RANDOM_STATE: int = 3407 |
|
|
| |
| |
| FINETUNE_VISION_LAYERS: bool = False |
| FINETUNE_LANGUAGE_LAYERS: bool = True |
| FINETUNE_ATTENTION_MODULES: bool = True |
| FINETUNE_MLP_MODULES: bool = True |
|
|
| HardwareT = Literal["v100", "h100"] |
| ALLOWED_HARDWARE: tuple[HardwareT, ...] = ("v100", "h100") |
|
|
|
|
| class BF16SlippageError(AssertionError): |
| """Raised when the loaded model has any BF16 parameter on V100. |
| |
| V100 (sm_70) lacks BF16 tensor cores. Silent BF16 via software emulation |
| causes ~10x slowdown plus numerical-instability patterns in |
| ``docs/modules/training.md §7a``. Halt before the optimizer is built. |
| """ |
|
|
|
|
| class FP16SlippageError(AssertionError): |
| """Raised when the loaded model has any FP16 parameter on H100. |
| |
| H100 (sm_90) has native BF16 tensor cores. Running FP16 on H100 means |
| leaving native hardware capability unused and may cause gradient underflow |
| at large batch sizes. Halt before the optimizer is built. |
| """ |
|
|
|
|
| @dataclass(frozen=True) |
| class BootConfig: |
| """Arguments to :func:`boot_gemma`. Frozen per DriftCall immutability rule.""" |
|
|
| base_model_id: str = BASE_MODEL_ID |
| max_seq_length: int = MAX_SEQ_LENGTH |
| load_in_4bit: bool = True |
| lora_r: int = LORA_R |
| lora_alpha: int = LORA_ALPHA |
| lora_dropout: float = LORA_DROPOUT |
| lora_random_state: int = LORA_RANDOM_STATE |
| finetune_vision_layers: bool = FINETUNE_VISION_LAYERS |
| finetune_language_layers: bool = FINETUNE_LANGUAGE_LAYERS |
| finetune_attention_modules: bool = FINETUNE_ATTENTION_MODULES |
| finetune_mlp_modules: bool = FINETUNE_MLP_MODULES |
| use_gradient_checkpointing: str = "unsloth" |
| hardware: HardwareT = "v100" |
|
|
|
|
| def assert_dtype_for_hardware(model: Any, hardware: HardwareT) -> None: |
| """Assert the first parameter dtype matches the expected precision for hardware. |
| |
| V100 must be ``torch.float16``; raises :class:`BF16SlippageError` otherwise. |
| H100 must be ``torch.bfloat16``; raises :class:`FP16SlippageError` otherwise. |
| Called once at ``boot_gemma`` entry, before any LoRA attach or optimizer build. |
| """ |
| import torch |
|
|
| params_iter = model.parameters() |
| try: |
| first_param = next(params_iter) |
| except StopIteration as exc: |
| raise BF16SlippageError( |
| "Model has no parameters; cannot verify dtype." |
| ) from exc |
|
|
| dtype = first_param.dtype |
| if hardware == "v100": |
| if dtype != torch.float16: |
| raise BF16SlippageError( |
| f"BF16 slipped through: V100 unsafe. " |
| f"next(model.parameters()).dtype == {dtype}, expected torch.float16. " |
| f"Root cause: Unsloth auto-picked BF16 despite dtype=torch.float16 kwarg. " |
| f"Halt training; do NOT proceed on V100." |
| ) |
| else: |
| if dtype != torch.bfloat16: |
| raise FP16SlippageError( |
| f"FP16 slipped through: H100 should use BF16. " |
| f"next(model.parameters()).dtype == {dtype}, expected torch.bfloat16. " |
| f"Root cause: dtype kwarg may have forced FP16 on H100. " |
| f"Halt training; do NOT proceed on H100 with FP16." |
| ) |
|
|
|
|
| def assert_fp16_dtype(model: Any) -> None: |
| """Assert the first trainable parameter is torch.float16 (V100 safety). |
| |
| Thin wrapper around :func:`assert_dtype_for_hardware` for backwards |
| compatibility with call sites that predate the hardware-aware API. |
| Raises :class:`BF16SlippageError` with the halt message from |
| ``docs/modules/training.md §3.1``. |
| """ |
| assert_dtype_for_hardware(model, "v100") |
|
|
|
|
| def boot_gemma(config: BootConfig | None = None) -> tuple[Any, Any]: |
| """Load Gemma 3n E2B in 4-bit + attach LoRA; return (model, tokenizer). |
| |
| Steps (training.md §3.1): |
| 1. ``FastModel.from_pretrained(base_model_id, max_seq_length=..., |
| load_in_4bit=True, dtype=torch.float16)`` on V100 |
| or ``dtype=torch.bfloat16`` on H100. |
| 2. ``assert_dtype_for_hardware(model, hardware)`` — raises |
| :class:`BF16SlippageError` or :class:`FP16SlippageError` if the dtype |
| does not match the hardware. |
| 3. ``FastModel.get_peft_model(model, r=16, lora_alpha=32, |
| finetune_vision_layers=False, finetune_language_layers=True, |
| finetune_attention_modules=True, finetune_mlp_modules=True, |
| use_gradient_checkpointing="unsloth", random_state=3407)``. |
| 4. Return ``(peft_model, tokenizer)``. |
| |
| All heavy imports are lazy so the module is importable on CPU-only CI. |
| """ |
| cfg = config if config is not None else BootConfig() |
|
|
| import torch |
| from unsloth import FastModel |
|
|
| dtype = torch.float16 if cfg.hardware == "v100" else torch.bfloat16 |
|
|
| model, tokenizer = FastModel.from_pretrained( |
| cfg.base_model_id, |
| max_seq_length=cfg.max_seq_length, |
| load_in_4bit=cfg.load_in_4bit, |
| dtype=dtype, |
| ) |
|
|
| assert_dtype_for_hardware(model, cfg.hardware) |
|
|
| peft_model = FastModel.get_peft_model( |
| model, |
| r=cfg.lora_r, |
| lora_alpha=cfg.lora_alpha, |
| lora_dropout=cfg.lora_dropout, |
| finetune_vision_layers=cfg.finetune_vision_layers, |
| finetune_language_layers=cfg.finetune_language_layers, |
| finetune_attention_modules=cfg.finetune_attention_modules, |
| finetune_mlp_modules=cfg.finetune_mlp_modules, |
| use_gradient_checkpointing=cfg.use_gradient_checkpointing, |
| random_state=cfg.lora_random_state, |
| ) |
|
|
| return peft_model, tokenizer |
|
|
|
|
| __all__ = [ |
| "ALLOWED_HARDWARE", |
| "BASE_MODEL_ID", |
| "BF16SlippageError", |
| "BootConfig", |
| "FINETUNE_ATTENTION_MODULES", |
| "FINETUNE_LANGUAGE_LAYERS", |
| "FINETUNE_MLP_MODULES", |
| "FINETUNE_VISION_LAYERS", |
| "FP16SlippageError", |
| "HardwareT", |
| "LORA_ALPHA", |
| "LORA_DROPOUT", |
| "LORA_R", |
| "LORA_RANDOM_STATE", |
| "MAX_SEQ_LENGTH", |
| "assert_dtype_for_hardware", |
| "assert_fp16_dtype", |
| "boot_gemma", |
| ] |
|
|