File size: 4,789 Bytes
2d0a056 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Model Loading Utilities for CDD
================================
Provides unified loading for MDLM and UDLM discrete diffusion models
from HuggingFace Hub.
Supported models:
- kuleshov-group/mdlm-owt: MDLM trained on OpenWebText (110M params)
- kuleshov-group/udlm-qm9: UDLM trained on QM9 molecules (92M params)
- kuleshov-group/udlm-lm1b: UDLM trained on LM1B text (139M params)
"""
import torch
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class DiffusionModelInfo:
"""Metadata about a loaded diffusion model."""
model_name: str
diffusion_type: str # "mdlm" or "udlm"
vocab_size: int
model_length: int
hidden_dim: int
n_params: int
mask_token_id: Optional[int] = None # Only for MDLM
def load_mdlm(
model_name: str = "kuleshov-group/mdlm-owt",
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
"""Load a pretrained MDLM model and tokenizer.
MDLM (Masked Diffusion Language Model) uses absorbing noise:
- Forward: tokens are replaced with [MASK]
- Reverse: [MASK] tokens are unmasked to clean tokens
- time_conditioning=False: sigma is zeroed (enables KV-cache)
Args:
model_name: HuggingFace model id.
device: Device to load model on.
torch_dtype: Model precision.
Returns:
Tuple of (model, tokenizer, info).
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch_dtype,
).to(device)
model.eval()
# MDLM uses GPT-2 tokenizer with an added [MASK] token
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# The mask token is the last token in vocab (added during MDLM training)
mask_token_id = model.config.vocab_size - 1
n_params = sum(p.numel() for p in model.parameters())
info = DiffusionModelInfo(
model_name=model_name,
diffusion_type="mdlm",
vocab_size=model.config.vocab_size,
model_length=model.config.model_length,
hidden_dim=model.config.hidden_dim,
n_params=n_params,
mask_token_id=mask_token_id,
)
return model, tokenizer, info
def load_udlm(
model_name: str = "kuleshov-group/udlm-qm9",
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
"""Load a pretrained UDLM model and tokenizer.
UDLM (Uniform Diffusion Language Model) uses uniform noise:
- Forward: tokens transition to any random token
- Reverse: all tokens are re-sampled at every step
- time_conditioning=True: sigma is used for conditioning
Args:
model_name: HuggingFace model id.
device: Device to load model on.
torch_dtype: Model precision.
Returns:
Tuple of (model, tokenizer, info).
"""
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch_dtype,
).to(device)
model.eval()
# Determine tokenizer based on model
if "qm9" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained("yairschiff/qm9-tokenizer")
elif "lm1b" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
else:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
n_params = sum(p.numel() for p in model.parameters())
info = DiffusionModelInfo(
model_name=model_name,
diffusion_type="udlm",
vocab_size=model.config.vocab_size,
model_length=model.config.model_length,
hidden_dim=model.config.hidden_dim,
n_params=n_params,
)
return model, tokenizer, info
def load_model(
model_name: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.float32,
):
"""Auto-detect and load a discrete diffusion model.
Args:
model_name: HuggingFace model id.
device: Device.
torch_dtype: Precision.
Returns:
Tuple of (model, tokenizer, info).
"""
if "mdlm" in model_name.lower():
return load_mdlm(model_name, device, torch_dtype)
elif "udlm" in model_name.lower():
return load_udlm(model_name, device, torch_dtype)
else:
raise ValueError(
f"Cannot auto-detect model type for '{model_name}'. "
"Use load_mdlm() or load_udlm() directly."
)
|