File size: 4,789 Bytes
2d0a056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Model Loading Utilities for CDD
================================

Provides unified loading for MDLM and UDLM discrete diffusion models
from HuggingFace Hub.

Supported models:
    - kuleshov-group/mdlm-owt: MDLM trained on OpenWebText (110M params)
    - kuleshov-group/udlm-qm9: UDLM trained on QM9 molecules (92M params)
    - kuleshov-group/udlm-lm1b: UDLM trained on LM1B text (139M params)
"""

import torch
from typing import Optional, Tuple
from dataclasses import dataclass


@dataclass
class DiffusionModelInfo:
    """Metadata about a loaded diffusion model."""
    model_name: str
    diffusion_type: str  # "mdlm" or "udlm"
    vocab_size: int
    model_length: int
    hidden_dim: int
    n_params: int
    mask_token_id: Optional[int] = None  # Only for MDLM


def load_mdlm(
    model_name: str = "kuleshov-group/mdlm-owt",
    device: str = "cuda",
    torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
    """Load a pretrained MDLM model and tokenizer.
    
    MDLM (Masked Diffusion Language Model) uses absorbing noise:
    - Forward: tokens are replaced with [MASK]
    - Reverse: [MASK] tokens are unmasked to clean tokens
    - time_conditioning=False: sigma is zeroed (enables KV-cache)
    
    Args:
        model_name: HuggingFace model id.
        device: Device to load model on.
        torch_dtype: Model precision.
        
    Returns:
        Tuple of (model, tokenizer, info).
    """
    from transformers import AutoModelForMaskedLM, AutoTokenizer
    
    model = AutoModelForMaskedLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch_dtype,
    ).to(device)
    model.eval()
    
    # MDLM uses GPT-2 tokenizer with an added [MASK] token
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    
    # The mask token is the last token in vocab (added during MDLM training)
    mask_token_id = model.config.vocab_size - 1
    
    n_params = sum(p.numel() for p in model.parameters())
    
    info = DiffusionModelInfo(
        model_name=model_name,
        diffusion_type="mdlm",
        vocab_size=model.config.vocab_size,
        model_length=model.config.model_length,
        hidden_dim=model.config.hidden_dim,
        n_params=n_params,
        mask_token_id=mask_token_id,
    )
    
    return model, tokenizer, info


def load_udlm(
    model_name: str = "kuleshov-group/udlm-qm9",
    device: str = "cuda",
    torch_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.nn.Module, "transformers.PreTrainedTokenizer", DiffusionModelInfo]:
    """Load a pretrained UDLM model and tokenizer.
    
    UDLM (Uniform Diffusion Language Model) uses uniform noise:
    - Forward: tokens transition to any random token
    - Reverse: all tokens are re-sampled at every step
    - time_conditioning=True: sigma is used for conditioning
    
    Args:
        model_name: HuggingFace model id.
        device: Device to load model on.
        torch_dtype: Model precision.
        
    Returns:
        Tuple of (model, tokenizer, info).
    """
    from transformers import AutoModelForMaskedLM, AutoTokenizer
    
    model = AutoModelForMaskedLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch_dtype,
    ).to(device)
    model.eval()
    
    # Determine tokenizer based on model
    if "qm9" in model_name.lower():
        tokenizer = AutoTokenizer.from_pretrained("yairschiff/qm9-tokenizer")
    elif "lm1b" in model_name.lower():
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    else:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    n_params = sum(p.numel() for p in model.parameters())
    
    info = DiffusionModelInfo(
        model_name=model_name,
        diffusion_type="udlm",
        vocab_size=model.config.vocab_size,
        model_length=model.config.model_length,
        hidden_dim=model.config.hidden_dim,
        n_params=n_params,
    )
    
    return model, tokenizer, info


def load_model(
    model_name: str,
    device: str = "cuda",
    torch_dtype: torch.dtype = torch.float32,
):
    """Auto-detect and load a discrete diffusion model.
    
    Args:
        model_name: HuggingFace model id.
        device: Device.
        torch_dtype: Precision.
        
    Returns:
        Tuple of (model, tokenizer, info).
    """
    if "mdlm" in model_name.lower():
        return load_mdlm(model_name, device, torch_dtype)
    elif "udlm" in model_name.lower():
        return load_udlm(model_name, device, torch_dtype)
    else:
        raise ValueError(
            f"Cannot auto-detect model type for '{model_name}'. "
            "Use load_mdlm() or load_udlm() directly."
        )