| """ |
| Model Loading Utilities for CDD |
| ================================ |
| |
| Provides unified loading for MDLM and UDLM discrete diffusion models |
| from HuggingFace Hub. |
| |
| Supported models: |
| - kuleshov-group/mdlm-owt: MDLM trained on OpenWebText (110M params) |
| - kuleshov-group/udlm-qm9: UDLM trained on QM9 molecules (92M params) |
| - kuleshov-group/udlm-lm1b: UDLM trained on LM1B text (139M params) |
| """ |
|
|
| import torch |
| from typing import Optional, Tuple |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class DiffusionModelInfo: |
| """Metadata about a loaded diffusion model.""" |
| model_name: str |
| diffusion_type: str |
| vocab_size: int |
| model_length: int |
| hidden_dim: int |
| n_params: int |
| mask_token_id: Optional[int] = None |
|
|
|
|
| def load_mdlm( |
| model_name: str = "kuleshov-group/mdlm-owt", |
| device: str = "cuda", |
| torch_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]: |
| """Load a pretrained MDLM model and tokenizer. |
| |
| MDLM (Masked Diffusion Language Model) uses absorbing noise: |
| - Forward: tokens are replaced with [MASK] |
| - Reverse: [MASK] tokens are unmasked to clean tokens |
| - time_conditioning=False: sigma is zeroed (enables KV-cache) |
| |
| Args: |
| model_name: HuggingFace model id. |
| device: Device to load model on. |
| torch_dtype: Model precision. |
| |
| Returns: |
| Tuple of (model, tokenizer, info). |
| """ |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
| |
| model = AutoModelForMaskedLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch_dtype, |
| ).to(device) |
| model.eval() |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| |
| |
| mask_token_id = model.config.vocab_size - 1 |
| |
| n_params = sum(p.numel() for p in model.parameters()) |
| |
| info = DiffusionModelInfo( |
| model_name=model_name, |
| diffusion_type="mdlm", |
| vocab_size=model.config.vocab_size, |
| model_length=model.config.model_length, |
| hidden_dim=model.config.hidden_dim, |
| n_params=n_params, |
| mask_token_id=mask_token_id, |
| ) |
| |
| return model, tokenizer, info |
|
|
|
|
| def load_udlm( |
| model_name: str = "kuleshov-group/udlm-qm9", |
| device: str = "cuda", |
| torch_dtype: torch.dtype = torch.float32, |
| ) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]: |
| """Load a pretrained UDLM model and tokenizer. |
| |
| UDLM (Uniform Diffusion Language Model) uses uniform noise: |
| - Forward: tokens transition to any random token |
| - Reverse: all tokens are re-sampled at every step |
| - time_conditioning=True: sigma is used for conditioning |
| |
| Args: |
| model_name: HuggingFace model id. |
| device: Device to load model on. |
| torch_dtype: Model precision. |
| |
| Returns: |
| Tuple of (model, tokenizer, info). |
| """ |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
| |
| model = AutoModelForMaskedLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch_dtype, |
| ).to(device) |
| model.eval() |
| |
| |
| if "qm9" in model_name.lower(): |
| tokenizer = AutoTokenizer.from_pretrained("yairschiff/qm9-tokenizer") |
| elif "lm1b" in model_name.lower(): |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| else: |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| |
| n_params = sum(p.numel() for p in model.parameters()) |
| |
| info = DiffusionModelInfo( |
| model_name=model_name, |
| diffusion_type="udlm", |
| vocab_size=model.config.vocab_size, |
| model_length=model.config.model_length, |
| hidden_dim=model.config.hidden_dim, |
| n_params=n_params, |
| ) |
| |
| return model, tokenizer, info |
|
|
|
|
| def load_model( |
| model_name: str, |
| device: str = "cuda", |
| torch_dtype: torch.dtype = torch.float32, |
| ): |
| """Auto-detect and load a discrete diffusion model. |
| |
| Args: |
| model_name: HuggingFace model id. |
| device: Device. |
| torch_dtype: Precision. |
| |
| Returns: |
| Tuple of (model, tokenizer, info). |
| """ |
| if "mdlm" in model_name.lower(): |
| return load_mdlm(model_name, device, torch_dtype) |
| elif "udlm" in model_name.lower(): |
| return load_udlm(model_name, device, torch_dtype) |
| else: |
| raise ValueError( |
| f"Cannot auto-detect model type for '{model_name}'. " |
| "Use load_mdlm() or load_udlm() directly." |
| ) |
|
|