syedmohaiminulhoque's picture
Complete CDD implementation: Constrained Discrete Diffusion (arXiv:2503.09790v3)
2d0a056 verified
"""
Model Loading Utilities for CDD
================================
Provides unified loading for MDLM and UDLM discrete diffusion models
from HuggingFace Hub.
Supported models:
- kuleshov-group/mdlm-owt: MDLM trained on OpenWebText (110M params)
- kuleshov-group/udlm-qm9: UDLM trained on QM9 molecules (92M params)
- kuleshov-group/udlm-lm1b: UDLM trained on LM1B text (139M params)
"""
import torch
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class DiffusionModelInfo:
"""Metadata about a loaded diffusion model."""
model_name: str
diffusion_type: str # "mdlm" or "udlm"
vocab_size: int
model_length: int
hidden_dim: int
n_params: int
mask_token_id: Optional[int] = None # Only for MDLM
def load_mdlm(
model_name: str = "kuleshov-group/mdlm-owt",
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
"""Load a pretrained MDLM model and tokenizer.
MDLM (Masked Diffusion Language Model) uses absorbing noise:
- Forward: tokens are replaced with [MASK]
- Reverse: [MASK] tokens are unmasked to clean tokens
- time_conditioning=False: sigma is zeroed (enables KV-cache)
Args:
model_name: HuggingFace model id.
device: Device to load model on.
torch_dtype: Model precision.
Returns:
Tuple of (model, tokenizer, info).
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch_dtype,
).to(device)
model.eval()
# MDLM uses GPT-2 tokenizer with an added [MASK] token
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# The mask token is the last token in vocab (added during MDLM training)
mask_token_id = model.config.vocab_size - 1
n_params = sum(p.numel() for p in model.parameters())
info = DiffusionModelInfo(
model_name=model_name,
diffusion_type="mdlm",
vocab_size=model.config.vocab_size,
model_length=model.config.model_length,
hidden_dim=model.config.hidden_dim,
n_params=n_params,
mask_token_id=mask_token_id,
)
return model, tokenizer, info
def load_udlm(
model_name: str = "kuleshov-group/udlm-qm9",
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
"""Load a pretrained UDLM model and tokenizer.
UDLM (Uniform Diffusion Language Model) uses uniform noise:
- Forward: tokens transition to any random token
- Reverse: all tokens are re-sampled at every step
- time_conditioning=True: sigma is used for conditioning
Args:
model_name: HuggingFace model id.
device: Device to load model on.
torch_dtype: Model precision.
Returns:
Tuple of (model, tokenizer, info).
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch_dtype,
).to(device)
model.eval()
# Determine tokenizer based on model
if "qm9" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained("yairschiff/qm9-tokenizer")
elif "lm1b" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
else:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
n_params = sum(p.numel() for p in model.parameters())
info = DiffusionModelInfo(
model_name=model_name,
diffusion_type="udlm",
vocab_size=model.config.vocab_size,
model_length=model.config.model_length,
hidden_dim=model.config.hidden_dim,
n_params=n_params,
)
return model, tokenizer, info
def load_model(
model_name: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
):
"""Auto-detect and load a discrete diffusion model.
Args:
model_name: HuggingFace model id.
device: Device.
torch_dtype: Precision.
Returns:
Tuple of (model, tokenizer, info).
"""
if "mdlm" in model_name.lower():
return load_mdlm(model_name, device, torch_dtype)
elif "udlm" in model_name.lower():
return load_udlm(model_name, device, torch_dtype)
else:
raise ValueError(
f"Cannot auto-detect model type for '{model_name}'. "
"Use load_mdlm() or load_udlm() directly."
)