Spaces:
Runtime error
Runtime error
| """Gemma 3n E2B boot via Unsloth FastModel (docs/modules/training.md §3.1). | |
| Contract: | |
| - Base model: ``unsloth/gemma-3n-E2B-it`` (4-bit Dynamic | |
| NF4 quantization). | |
| - Precision: hardware-aware. | |
| V100 (sm_70) — explicit FP16 (``dtype=torch.float16``); Gemma 3n is | |
| BF16-native, so we force FP16 on V100 to avoid BF16 software-emulation | |
| slowdown / numerical instability. | |
| H100 (sm_90) — BF16 (``dtype=torch.bfloat16``); uses native tensor cores. | |
| - LoRA: r=16, α=32, dropout=0.05, vision towers frozen, language + attention | |
| + MLP trainable via Unsloth's multimodal API (``finetune_vision_layers=False, | |
| finetune_language_layers=True, finetune_attention_modules=True, | |
| finetune_mlp_modules=True``), Unsloth gradient checkpointing, | |
| ``random_state=3407``. | |
| - V100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.float16`` | |
| after FP16 load; any BF16 parameter triggers :class:`BF16SlippageError` | |
| before optimizer build. | |
| - H100 halt: ``next(model.parameters()).dtype`` MUST be ``torch.bfloat16`` | |
| after BF16 load; any FP16 parameter triggers :class:`FP16SlippageError` | |
| before optimizer build. | |
| Heavy imports (``unsloth``, ``torch``) are deferred inside functions so this | |
| cell loads on CPU-only CI runners where Unsloth is not installed. Tests mock | |
| ``FastModel.from_pretrained`` and ``FastModel.get_peft_model``. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Literal | |
| BASE_MODEL_ID: str = "unsloth/gemma-3n-E2B-it" | |
| MAX_SEQ_LENGTH: int = 4096 | |
| LORA_R: int = 16 | |
| LORA_ALPHA: int = 32 | |
| LORA_DROPOUT: float = 0.05 | |
| LORA_RANDOM_STATE: int = 3407 | |
| # Gemma 3n multimodal LoRA flags — vision/audio towers stay frozen so GRPO | |
| # trains only the language stack (Unsloth Gemma 3N notebook §fine-tune). | |
| FINETUNE_VISION_LAYERS: bool = False | |
| FINETUNE_LANGUAGE_LAYERS: bool = True | |
| FINETUNE_ATTENTION_MODULES: bool = True | |
| FINETUNE_MLP_MODULES: bool = True | |
| HardwareT = Literal["v100", "h100"] | |
| ALLOWED_HARDWARE: tuple[HardwareT, ...] = ("v100", "h100") | |
| class BF16SlippageError(AssertionError): | |
| """Raised when the loaded model has any BF16 parameter on V100. | |
| V100 (sm_70) lacks BF16 tensor cores. Silent BF16 via software emulation | |
| causes ~10x slowdown plus numerical-instability patterns in | |
| ``docs/modules/training.md §7a``. Halt before the optimizer is built. | |
| """ | |
| class FP16SlippageError(AssertionError): | |
| """Raised when the loaded model has any FP16 parameter on H100. | |
| H100 (sm_90) has native BF16 tensor cores. Running FP16 on H100 means | |
| leaving native hardware capability unused and may cause gradient underflow | |
| at large batch sizes. Halt before the optimizer is built. | |
| """ | |
| class BootConfig: | |
| """Arguments to :func:`boot_gemma`. Frozen per DriftCall immutability rule.""" | |
| base_model_id: str = BASE_MODEL_ID | |
| max_seq_length: int = MAX_SEQ_LENGTH | |
| load_in_4bit: bool = True | |
| lora_r: int = LORA_R | |
| lora_alpha: int = LORA_ALPHA | |
| lora_dropout: float = LORA_DROPOUT | |
| lora_random_state: int = LORA_RANDOM_STATE | |
| finetune_vision_layers: bool = FINETUNE_VISION_LAYERS | |
| finetune_language_layers: bool = FINETUNE_LANGUAGE_LAYERS | |
| finetune_attention_modules: bool = FINETUNE_ATTENTION_MODULES | |
| finetune_mlp_modules: bool = FINETUNE_MLP_MODULES | |
| use_gradient_checkpointing: str = "unsloth" | |
| hardware: HardwareT = "v100" | |
| def assert_dtype_for_hardware(model: Any, hardware: HardwareT) -> None: | |
| """Assert the first parameter dtype matches the expected precision for hardware. | |
| V100 must be ``torch.float16``; raises :class:`BF16SlippageError` otherwise. | |
| H100 must be ``torch.bfloat16``; raises :class:`FP16SlippageError` otherwise. | |
| Called once at ``boot_gemma`` entry, before any LoRA attach or optimizer build. | |
| """ | |
| import torch | |
| params_iter = model.parameters() | |
| try: | |
| first_param = next(params_iter) | |
| except StopIteration as exc: # pragma: no cover - defensive | |
| raise BF16SlippageError( | |
| "Model has no parameters; cannot verify dtype." | |
| ) from exc | |
| dtype = first_param.dtype | |
| if hardware == "v100": | |
| if dtype != torch.float16: | |
| raise BF16SlippageError( | |
| f"BF16 slipped through: V100 unsafe. " | |
| f"next(model.parameters()).dtype == {dtype}, expected torch.float16. " | |
| f"Root cause: Unsloth auto-picked BF16 despite dtype=torch.float16 kwarg. " | |
| f"Halt training; do NOT proceed on V100." | |
| ) | |
| else: # h100 | |
| if dtype != torch.bfloat16: | |
| raise FP16SlippageError( | |
| f"FP16 slipped through: H100 should use BF16. " | |
| f"next(model.parameters()).dtype == {dtype}, expected torch.bfloat16. " | |
| f"Root cause: dtype kwarg may have forced FP16 on H100. " | |
| f"Halt training; do NOT proceed on H100 with FP16." | |
| ) | |
| def assert_fp16_dtype(model: Any) -> None: | |
| """Assert the first trainable parameter is torch.float16 (V100 safety). | |
| Thin wrapper around :func:`assert_dtype_for_hardware` for backwards | |
| compatibility with call sites that predate the hardware-aware API. | |
| Raises :class:`BF16SlippageError` with the halt message from | |
| ``docs/modules/training.md §3.1``. | |
| """ | |
| assert_dtype_for_hardware(model, "v100") | |
| def boot_gemma(config: BootConfig | None = None) -> tuple[Any, Any]: | |
| """Load Gemma 3n E2B in 4-bit + attach LoRA; return (model, tokenizer). | |
| Steps (training.md §3.1): | |
| 1. ``FastModel.from_pretrained(base_model_id, max_seq_length=..., | |
| load_in_4bit=True, dtype=torch.float16)`` on V100 | |
| or ``dtype=torch.bfloat16`` on H100. | |
| 2. ``assert_dtype_for_hardware(model, hardware)`` — raises | |
| :class:`BF16SlippageError` or :class:`FP16SlippageError` if the dtype | |
| does not match the hardware. | |
| 3. ``FastModel.get_peft_model(model, r=16, lora_alpha=32, | |
| finetune_vision_layers=False, finetune_language_layers=True, | |
| finetune_attention_modules=True, finetune_mlp_modules=True, | |
| use_gradient_checkpointing="unsloth", random_state=3407)``. | |
| 4. Return ``(peft_model, tokenizer)``. | |
| All heavy imports are lazy so the module is importable on CPU-only CI. | |
| """ | |
| cfg = config if config is not None else BootConfig() | |
| import torch | |
| from unsloth import FastModel | |
| dtype = torch.float16 if cfg.hardware == "v100" else torch.bfloat16 | |
| model, tokenizer = FastModel.from_pretrained( | |
| cfg.base_model_id, | |
| max_seq_length=cfg.max_seq_length, | |
| load_in_4bit=cfg.load_in_4bit, | |
| dtype=dtype, | |
| ) | |
| assert_dtype_for_hardware(model, cfg.hardware) | |
| peft_model = FastModel.get_peft_model( | |
| model, | |
| r=cfg.lora_r, | |
| lora_alpha=cfg.lora_alpha, | |
| lora_dropout=cfg.lora_dropout, | |
| finetune_vision_layers=cfg.finetune_vision_layers, | |
| finetune_language_layers=cfg.finetune_language_layers, | |
| finetune_attention_modules=cfg.finetune_attention_modules, | |
| finetune_mlp_modules=cfg.finetune_mlp_modules, | |
| use_gradient_checkpointing=cfg.use_gradient_checkpointing, | |
| random_state=cfg.lora_random_state, | |
| ) | |
| return peft_model, tokenizer | |
| __all__ = [ | |
| "ALLOWED_HARDWARE", | |
| "BASE_MODEL_ID", | |
| "BF16SlippageError", | |
| "BootConfig", | |
| "FINETUNE_ATTENTION_MODULES", | |
| "FINETUNE_LANGUAGE_LAYERS", | |
| "FINETUNE_MLP_MODULES", | |
| "FINETUNE_VISION_LAYERS", | |
| "FP16SlippageError", | |
| "HardwareT", | |
| "LORA_ALPHA", | |
| "LORA_DROPOUT", | |
| "LORA_R", | |
| "LORA_RANDOM_STATE", | |
| "MAX_SEQ_LENGTH", | |
| "assert_dtype_for_hardware", | |
| "assert_fp16_dtype", | |
| "boot_gemma", | |
| ] | |