W1-4B-dLLM-Base / core /loader.py
Cynthiawhaletech's picture
Initial release: W1-4B dLLM Base
267f903
"""
Model loading and wrapping.
Provides:
- load_checkpoint(ckpt_path, config, device, dtype, use_ema, strict)
-> ModelWrapper
- ModelWrapper.__call__(x [1,L], t [1]) -> logits [1,L,V]
with autocast handled internally
"""
from __future__ import annotations
import re
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
import torch
from .model import LangDiT, create_model # noqa: F401
STEP_CHECKPOINT_RE = re.compile(r"step_(\d+)(?:\.pt|\.safetensors)$")
IGNORED_KEY_SUFFIXES = ("._extra_state",)
IGNORED_EXACT_KEYS = {"rope.rope.inv_freq"}
# ── checkpoint helpers ────────────────────────────────────────────────────────
def resolve_checkpoint(path: str) -> str:
"""If *path* is a directory, find a supported checkpoint file inside it."""
p = Path(path)
if p.is_file():
return str(p)
if p.is_dir():
candidates = sorted(
p.glob("step_*.pt"),
key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1))
if STEP_CHECKPOINT_RE.match(f.name) else -1,
)
if not candidates:
candidates = sorted(
p.glob("step_*.safetensors"),
key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1))
if STEP_CHECKPOINT_RE.match(f.name) else -1,
)
if candidates:
return str(candidates[-1])
named = [p / "model.safetensors", p / "checkpoint.safetensors"]
for candidate in named:
if candidate.is_file():
return str(candidate)
safetensors_files = sorted(p.glob("*.safetensors"))
if len(safetensors_files) == 1:
return str(safetensors_files[0])
if (p / "model.safetensors.index.json").is_file():
raise FileNotFoundError(
"Sharded safetensors are not supported by whale4b yet. "
"Pass a single .safetensors file instead."
)
raise FileNotFoundError(f"No checkpoint found at: {path}")
def load_state_dict(ckpt_path: str, use_ema: bool = True):
"""Load raw state dict from ``.pt`` or ``.safetensors``, preferring EMA."""
if ckpt_path.endswith(".safetensors"):
from safetensors.torch import load_file
return load_file(ckpt_path, device="cpu"), "safetensors"
load_kwargs = {"map_location": "cpu", "weights_only": False}
try:
ckpt = torch.load(ckpt_path, mmap=True, **load_kwargs)
except TypeError:
ckpt = torch.load(ckpt_path, **load_kwargs)
if not isinstance(ckpt, dict):
return ckpt, "raw"
if use_ema and isinstance(ckpt.get("ema"), dict):
return ckpt["ema"], "ema"
if isinstance(ckpt.get("model"), dict):
return ckpt["model"], "model"
if isinstance(ckpt.get("state_dict"), dict):
return ckpt["state_dict"], "state_dict"
return ckpt, "root"
def _strip_prefix(sd: dict, prefix: str) -> dict:
if not any(k.startswith(prefix) for k in sd):
return sd
out = {}
for key, value in sd.items():
out[key[len(prefix):] if key.startswith(prefix) else key] = value
return out
def sanitize_state_dict(state_dict: dict) -> tuple[dict, list[str]]:
"""Strip wrapper prefixes and drop non-inference metadata keys."""
for prefix in ("module.", "model.", "_orig_mod."):
state_dict = _strip_prefix(state_dict, prefix)
dropped: list[str] = []
cleaned: dict = {}
for key, value in state_dict.items():
if key in IGNORED_EXACT_KEYS or any(
key.endswith(suffix) for suffix in IGNORED_KEY_SUFFIXES
):
dropped.append(key)
continue
cleaned[key] = value
return cleaned, dropped
def resolve_dtype(dtype_name: str, device: torch.device):
"""Returns ``(amp_dtype, use_amp, model_dtype)``."""
dtype_map = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
amp_dtype = dtype_map.get(dtype_name, torch.bfloat16)
if dtype_name == "fp32":
return amp_dtype, False, torch.float32
if device.type == "cuda":
return amp_dtype, True, amp_dtype
if device.type == "mps" and dtype_name == "fp16":
return amp_dtype, False, torch.float16
return amp_dtype, False, torch.float32
# ── ModelWrapper ──────────────────────────────────────────────────────────────
class ModelWrapper:
"""
Wraps LangDiT into a standard ``(x [1,L], t [1]) -> logits [1,L,V]``
callable. Handles autocast internally β€” callers never deal with AMP.
"""
def __init__(
self,
model: LangDiT,
vocab_size: int,
mask_token_id: int,
device: torch.device,
use_amp: bool,
amp_dtype: torch.dtype,
):
self.model = model
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
self.device = device
self.use_amp = use_amp
self.amp_dtype = amp_dtype
@torch.no_grad()
def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: [1, L] int64
t: [1] float
Returns: [1, L, V] float32 logits (raw β€” no softmax)
"""
x = x.to(self.device)
t = t.to(self.device)
amp_ctx = (
torch.autocast(device_type="cuda", dtype=self.amp_dtype)
if self.use_amp and self.device.type == "cuda"
else nullcontext()
)
with amp_ctx:
logits = self.model(x, t)
return logits
def load_checkpoint(
ckpt_path: str,
config: dict,
device: Optional[torch.device] = None,
dtype: str = "bf16",
use_ema: bool = True,
strict: bool = False,
) -> ModelWrapper:
"""
Full pipeline: resolve path -> load state dict -> build model -> wrap.
Returns a ready-to-call ModelWrapper.
"""
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
amp_dtype, use_amp, model_dtype = resolve_dtype(dtype, device)
resolved = resolve_checkpoint(ckpt_path)
state_dict, source = load_state_dict(resolved, use_ema=use_ema)
state_dict, dropped = sanitize_state_dict(state_dict)
model = create_model(config).to(device=device, dtype=model_dtype)
model.eval()
missing, unexpected = model.load_state_dict(state_dict, strict=strict)
del state_dict
if missing:
print(f"[loader] missing keys: {len(missing)} β€” sample: {missing[:3]}")
if unexpected:
print(f"[loader] unexpected keys: {len(unexpected)} β€” sample: {unexpected[:3]}")
if dropped:
print(f"[loader] dropped non-inference keys: {len(dropped)} β€” sample: {dropped[:3]}")
print(f"[loader] loaded {resolved!r} (source={source}, dtype={model_dtype})")
diff_cfg = config.get("diffusion", {})
vocab_size = int(config["model"]["vocab_size"])
mask_token_id = int(diff_cfg.get("mask_token_id", 14))
return ModelWrapper(
model=model,
vocab_size=vocab_size,
mask_token_id=mask_token_id,
device=device,
use_amp=use_amp,
amp_dtype=amp_dtype,
)