""" Model Loading Utilities for CDD ================================ Provides unified loading for MDLM and UDLM discrete diffusion models from HuggingFace Hub. Supported models: - kuleshov-group/mdlm-owt: MDLM trained on OpenWebText (110M params) - kuleshov-group/udlm-qm9: UDLM trained on QM9 molecules (92M params) - kuleshov-group/udlm-lm1b: UDLM trained on LM1B text (139M params) """ import torch from typing import Optional, Tuple from dataclasses import dataclass @dataclass class DiffusionModelInfo: """Metadata about a loaded diffusion model.""" model_name: str diffusion_type: str # "mdlm" or "udlm" vocab_size: int model_length: int hidden_dim: int n_params: int mask_token_id: Optional[int] = None # Only for MDLM def load_mdlm( model_name: str = "kuleshov-group/mdlm-owt", device: str = "cuda", torch_dtype: torch.dtype = torch.float32, ) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]: """Load a pretrained MDLM model and tokenizer. MDLM (Masked Diffusion Language Model) uses absorbing noise: - Forward: tokens are replaced with [MASK] - Reverse: [MASK] tokens are unmasked to clean tokens - time_conditioning=False: sigma is zeroed (enables KV-cache) Args: model_name: HuggingFace model id. device: Device to load model on. torch_dtype: Model precision. Returns: Tuple of (model, tokenizer, info). """ from transformers import AutoModelForMaskedLM, AutoTokenizer model = AutoModelForMaskedLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch_dtype, ).to(device) model.eval() # MDLM uses GPT-2 tokenizer with an added [MASK] token tokenizer = AutoTokenizer.from_pretrained("gpt2") # The mask token is the last token in vocab (added during MDLM training) mask_token_id = model.config.vocab_size - 1 n_params = sum(p.numel() for p in model.parameters()) info = DiffusionModelInfo( model_name=model_name, diffusion_type="mdlm", vocab_size=model.config.vocab_size, model_length=model.config.model_length, hidden_dim=model.config.hidden_dim, n_params=n_params, mask_token_id=mask_token_id, ) return model, tokenizer, info def load_udlm( model_name: str = "kuleshov-group/udlm-qm9", device: str = "cuda", torch_dtype: torch.dtype = torch.float32, ) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]: """Load a pretrained UDLM model and tokenizer. UDLM (Uniform Diffusion Language Model) uses uniform noise: - Forward: tokens transition to any random token - Reverse: all tokens are re-sampled at every step - time_conditioning=True: sigma is used for conditioning Args: model_name: HuggingFace model id. device: Device to load model on. torch_dtype: Model precision. Returns: Tuple of (model, tokenizer, info). """ from transformers import AutoModelForMaskedLM, AutoTokenizer model = AutoModelForMaskedLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch_dtype, ).to(device) model.eval() # Determine tokenizer based on model if "qm9" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained("yairschiff/qm9-tokenizer") elif "lm1b" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") else: tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") n_params = sum(p.numel() for p in model.parameters()) info = DiffusionModelInfo( model_name=model_name, diffusion_type="udlm", vocab_size=model.config.vocab_size, model_length=model.config.model_length, hidden_dim=model.config.hidden_dim, n_params=n_params, ) return model, tokenizer, info def load_model( model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.float32, ): """Auto-detect and load a discrete diffusion model. Args: model_name: HuggingFace model id. device: Device. torch_dtype: Precision. Returns: Tuple of (model, tokenizer, info). """ if "mdlm" in model_name.lower(): return load_mdlm(model_name, device, torch_dtype) elif "udlm" in model_name.lower(): return load_udlm(model_name, device, torch_dtype) else: raise ValueError( f"Cannot auto-detect model type for '{model_name}'. " "Use load_mdlm() or load_udlm() directly." )