| """ |
| Model loading and wrapping. |
| |
| Provides: |
| - load_checkpoint(ckpt_path, config, device, dtype, use_ema, strict) |
| -> ModelWrapper |
| - ModelWrapper.__call__(x [1,L], t [1]) -> logits [1,L,V] |
| with autocast handled internally |
| """ |
| from __future__ import annotations |
|
|
| import re |
| from contextlib import nullcontext |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
|
|
| from .model import LangDiT, create_model |
|
|
| STEP_CHECKPOINT_RE = re.compile(r"step_(\d+)(?:\.pt|\.safetensors)$") |
| IGNORED_KEY_SUFFIXES = ("._extra_state",) |
| IGNORED_EXACT_KEYS = {"rope.rope.inv_freq"} |
|
|
|
|
| |
|
|
| def resolve_checkpoint(path: str) -> str: |
| """If *path* is a directory, find a supported checkpoint file inside it.""" |
| p = Path(path) |
| if p.is_file(): |
| return str(p) |
| if p.is_dir(): |
| candidates = sorted( |
| p.glob("step_*.pt"), |
| key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) |
| if STEP_CHECKPOINT_RE.match(f.name) else -1, |
| ) |
| if not candidates: |
| candidates = sorted( |
| p.glob("step_*.safetensors"), |
| key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) |
| if STEP_CHECKPOINT_RE.match(f.name) else -1, |
| ) |
| if candidates: |
| return str(candidates[-1]) |
| named = [p / "model.safetensors", p / "checkpoint.safetensors"] |
| for candidate in named: |
| if candidate.is_file(): |
| return str(candidate) |
| safetensors_files = sorted(p.glob("*.safetensors")) |
| if len(safetensors_files) == 1: |
| return str(safetensors_files[0]) |
| if (p / "model.safetensors.index.json").is_file(): |
| raise FileNotFoundError( |
| "Sharded safetensors are not supported by whale4b yet. " |
| "Pass a single .safetensors file instead." |
| ) |
| raise FileNotFoundError(f"No checkpoint found at: {path}") |
|
|
|
|
| def load_state_dict(ckpt_path: str, use_ema: bool = True): |
| """Load raw state dict from ``.pt`` or ``.safetensors``, preferring EMA.""" |
| if ckpt_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| return load_file(ckpt_path, device="cpu"), "safetensors" |
|
|
| load_kwargs = {"map_location": "cpu", "weights_only": False} |
| try: |
| ckpt = torch.load(ckpt_path, mmap=True, **load_kwargs) |
| except TypeError: |
| ckpt = torch.load(ckpt_path, **load_kwargs) |
|
|
| if not isinstance(ckpt, dict): |
| return ckpt, "raw" |
| if use_ema and isinstance(ckpt.get("ema"), dict): |
| return ckpt["ema"], "ema" |
| if isinstance(ckpt.get("model"), dict): |
| return ckpt["model"], "model" |
| if isinstance(ckpt.get("state_dict"), dict): |
| return ckpt["state_dict"], "state_dict" |
| return ckpt, "root" |
|
|
|
|
| def _strip_prefix(sd: dict, prefix: str) -> dict: |
| if not any(k.startswith(prefix) for k in sd): |
| return sd |
| out = {} |
| for key, value in sd.items(): |
| out[key[len(prefix):] if key.startswith(prefix) else key] = value |
| return out |
|
|
|
|
| def sanitize_state_dict(state_dict: dict) -> tuple[dict, list[str]]: |
| """Strip wrapper prefixes and drop non-inference metadata keys.""" |
| for prefix in ("module.", "model.", "_orig_mod."): |
| state_dict = _strip_prefix(state_dict, prefix) |
|
|
| dropped: list[str] = [] |
| cleaned: dict = {} |
| for key, value in state_dict.items(): |
| if key in IGNORED_EXACT_KEYS or any( |
| key.endswith(suffix) for suffix in IGNORED_KEY_SUFFIXES |
| ): |
| dropped.append(key) |
| continue |
| cleaned[key] = value |
| return cleaned, dropped |
|
|
|
|
| def resolve_dtype(dtype_name: str, device: torch.device): |
| """Returns ``(amp_dtype, use_amp, model_dtype)``.""" |
| dtype_map = { |
| "bf16": torch.bfloat16, |
| "fp16": torch.float16, |
| "fp32": torch.float32, |
| } |
| amp_dtype = dtype_map.get(dtype_name, torch.bfloat16) |
| if dtype_name == "fp32": |
| return amp_dtype, False, torch.float32 |
| if device.type == "cuda": |
| return amp_dtype, True, amp_dtype |
| if device.type == "mps" and dtype_name == "fp16": |
| return amp_dtype, False, torch.float16 |
| return amp_dtype, False, torch.float32 |
|
|
|
|
| |
|
|
| class ModelWrapper: |
| """ |
| Wraps LangDiT into a standard ``(x [1,L], t [1]) -> logits [1,L,V]`` |
| callable. Handles autocast internally β callers never deal with AMP. |
| """ |
|
|
| def __init__( |
| self, |
| model: LangDiT, |
| vocab_size: int, |
| mask_token_id: int, |
| device: torch.device, |
| use_amp: bool, |
| amp_dtype: torch.dtype, |
| ): |
| self.model = model |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
| self.device = device |
| self.use_amp = use_amp |
| self.amp_dtype = amp_dtype |
|
|
| @torch.no_grad() |
| def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [1, L] int64 |
| t: [1] float |
| Returns: [1, L, V] float32 logits (raw β no softmax) |
| """ |
| x = x.to(self.device) |
| t = t.to(self.device) |
| amp_ctx = ( |
| torch.autocast(device_type="cuda", dtype=self.amp_dtype) |
| if self.use_amp and self.device.type == "cuda" |
| else nullcontext() |
| ) |
| with amp_ctx: |
| logits = self.model(x, t) |
| return logits |
|
|
|
|
| def load_checkpoint( |
| ckpt_path: str, |
| config: dict, |
| device: Optional[torch.device] = None, |
| dtype: str = "bf16", |
| use_ema: bool = True, |
| strict: bool = False, |
| ) -> ModelWrapper: |
| """ |
| Full pipeline: resolve path -> load state dict -> build model -> wrap. |
| Returns a ready-to-call ModelWrapper. |
| """ |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| amp_dtype, use_amp, model_dtype = resolve_dtype(dtype, device) |
|
|
| resolved = resolve_checkpoint(ckpt_path) |
| state_dict, source = load_state_dict(resolved, use_ema=use_ema) |
| state_dict, dropped = sanitize_state_dict(state_dict) |
|
|
| model = create_model(config).to(device=device, dtype=model_dtype) |
| model.eval() |
| missing, unexpected = model.load_state_dict(state_dict, strict=strict) |
| del state_dict |
|
|
| if missing: |
| print(f"[loader] missing keys: {len(missing)} β sample: {missing[:3]}") |
| if unexpected: |
| print(f"[loader] unexpected keys: {len(unexpected)} β sample: {unexpected[:3]}") |
| if dropped: |
| print(f"[loader] dropped non-inference keys: {len(dropped)} β sample: {dropped[:3]}") |
| print(f"[loader] loaded {resolved!r} (source={source}, dtype={model_dtype})") |
|
|
| diff_cfg = config.get("diffusion", {}) |
| vocab_size = int(config["model"]["vocab_size"]) |
| mask_token_id = int(diff_cfg.get("mask_token_id", 14)) |
|
|
| return ModelWrapper( |
| model=model, |
| vocab_size=vocab_size, |
| mask_token_id=mask_token_id, |
| device=device, |
| use_amp=use_amp, |
| amp_dtype=amp_dtype, |
| ) |
|
|