""" Model loading and wrapping. Provides: - load_checkpoint(ckpt_path, config, device, dtype, use_ema, strict) -> ModelWrapper - ModelWrapper.__call__(x [1,L], t [1]) -> logits [1,L,V] with autocast handled internally """ from __future__ import annotations import re from contextlib import nullcontext from pathlib import Path from typing import Optional import torch from .model import LangDiT, create_model # noqa: F401 STEP_CHECKPOINT_RE = re.compile(r"step_(\d+)(?:\.pt|\.safetensors)$") IGNORED_KEY_SUFFIXES = ("._extra_state",) IGNORED_EXACT_KEYS = {"rope.rope.inv_freq"} # ── checkpoint helpers ──────────────────────────────────────────────────────── def resolve_checkpoint(path: str) -> str: """If *path* is a directory, find a supported checkpoint file inside it.""" p = Path(path) if p.is_file(): return str(p) if p.is_dir(): candidates = sorted( p.glob("step_*.pt"), key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) if STEP_CHECKPOINT_RE.match(f.name) else -1, ) if not candidates: candidates = sorted( p.glob("step_*.safetensors"), key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) if STEP_CHECKPOINT_RE.match(f.name) else -1, ) if candidates: return str(candidates[-1]) named = [p / "model.safetensors", p / "checkpoint.safetensors"] for candidate in named: if candidate.is_file(): return str(candidate) safetensors_files = sorted(p.glob("*.safetensors")) if len(safetensors_files) == 1: return str(safetensors_files[0]) if (p / "model.safetensors.index.json").is_file(): raise FileNotFoundError( "Sharded safetensors are not supported by whale4b yet. " "Pass a single .safetensors file instead." ) raise FileNotFoundError(f"No checkpoint found at: {path}") def load_state_dict(ckpt_path: str, use_ema: bool = True): """Load raw state dict from ``.pt`` or ``.safetensors``, preferring EMA.""" if ckpt_path.endswith(".safetensors"): from safetensors.torch import load_file return load_file(ckpt_path, device="cpu"), "safetensors" load_kwargs = {"map_location": "cpu", "weights_only": False} try: ckpt = torch.load(ckpt_path, mmap=True, **load_kwargs) except TypeError: ckpt = torch.load(ckpt_path, **load_kwargs) if not isinstance(ckpt, dict): return ckpt, "raw" if use_ema and isinstance(ckpt.get("ema"), dict): return ckpt["ema"], "ema" if isinstance(ckpt.get("model"), dict): return ckpt["model"], "model" if isinstance(ckpt.get("state_dict"), dict): return ckpt["state_dict"], "state_dict" return ckpt, "root" def _strip_prefix(sd: dict, prefix: str) -> dict: if not any(k.startswith(prefix) for k in sd): return sd out = {} for key, value in sd.items(): out[key[len(prefix):] if key.startswith(prefix) else key] = value return out def sanitize_state_dict(state_dict: dict) -> tuple[dict, list[str]]: """Strip wrapper prefixes and drop non-inference metadata keys.""" for prefix in ("module.", "model.", "_orig_mod."): state_dict = _strip_prefix(state_dict, prefix) dropped: list[str] = [] cleaned: dict = {} for key, value in state_dict.items(): if key in IGNORED_EXACT_KEYS or any( key.endswith(suffix) for suffix in IGNORED_KEY_SUFFIXES ): dropped.append(key) continue cleaned[key] = value return cleaned, dropped def resolve_dtype(dtype_name: str, device: torch.device): """Returns ``(amp_dtype, use_amp, model_dtype)``.""" dtype_map = { "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, } amp_dtype = dtype_map.get(dtype_name, torch.bfloat16) if dtype_name == "fp32": return amp_dtype, False, torch.float32 if device.type == "cuda": return amp_dtype, True, amp_dtype if device.type == "mps" and dtype_name == "fp16": return amp_dtype, False, torch.float16 return amp_dtype, False, torch.float32 # ── ModelWrapper ────────────────────────────────────────────────────────────── class ModelWrapper: """ Wraps LangDiT into a standard ``(x [1,L], t [1]) -> logits [1,L,V]`` callable. Handles autocast internally — callers never deal with AMP. """ def __init__( self, model: LangDiT, vocab_size: int, mask_token_id: int, device: torch.device, use_amp: bool, amp_dtype: torch.dtype, ): self.model = model self.vocab_size = vocab_size self.mask_token_id = mask_token_id self.device = device self.use_amp = use_amp self.amp_dtype = amp_dtype @torch.no_grad() def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """ x: [1, L] int64 t: [1] float Returns: [1, L, V] float32 logits (raw — no softmax) """ x = x.to(self.device) t = t.to(self.device) amp_ctx = ( torch.autocast(device_type="cuda", dtype=self.amp_dtype) if self.use_amp and self.device.type == "cuda" else nullcontext() ) with amp_ctx: logits = self.model(x, t) return logits def load_checkpoint( ckpt_path: str, config: dict, device: Optional[torch.device] = None, dtype: str = "bf16", use_ema: bool = True, strict: bool = False, ) -> ModelWrapper: """ Full pipeline: resolve path -> load state dict -> build model -> wrap. Returns a ready-to-call ModelWrapper. """ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") amp_dtype, use_amp, model_dtype = resolve_dtype(dtype, device) resolved = resolve_checkpoint(ckpt_path) state_dict, source = load_state_dict(resolved, use_ema=use_ema) state_dict, dropped = sanitize_state_dict(state_dict) model = create_model(config).to(device=device, dtype=model_dtype) model.eval() missing, unexpected = model.load_state_dict(state_dict, strict=strict) del state_dict if missing: print(f"[loader] missing keys: {len(missing)} — sample: {missing[:3]}") if unexpected: print(f"[loader] unexpected keys: {len(unexpected)} — sample: {unexpected[:3]}") if dropped: print(f"[loader] dropped non-inference keys: {len(dropped)} — sample: {dropped[:3]}") print(f"[loader] loaded {resolved!r} (source={source}, dtype={model_dtype})") diff_cfg = config.get("diffusion", {}) vocab_size = int(config["model"]["vocab_size"]) mask_token_id = int(diff_cfg.get("mask_token_id", 14)) return ModelWrapper( model=model, vocab_size=vocab_size, mask_token_id=mask_token_id, device=device, use_amp=use_amp, amp_dtype=amp_dtype, )