zhihanyang commited on
Commit
7b526cf
·
verified ·
1 Parent(s): 9d4a36a

Upload EsoLM

Browse files
Files changed (4) hide show
  1. config.json +23 -0
  2. config.py +27 -0
  3. model.py +1074 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sahoo-diffusion/Eso-LM-B-alpha-1",
3
+ "architectures": [
4
+ "EsoLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.EsoLMConfig",
8
+ "AutoModelForMaskedLM": "model.EsoLM"
9
+ },
10
+ "cond_dim": 128,
11
+ "dropout": 0.1,
12
+ "hidden_dim": 768,
13
+ "hidden_size": 768,
14
+ "mask_index": 50257,
15
+ "model_length": 1024,
16
+ "model_type": "EsoLM",
17
+ "n_blocks": 12,
18
+ "n_heads": 12,
19
+ "return_dict": false,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.49.0",
22
+ "vocab_size": 50258
23
+ }
config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+
4
+ class EsoLMConfig(transformers.PretrainedConfig):
5
+ """Hugging Face configuration class for EsoLM."""
6
+ model_type = 'EsoLM'
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size: int = 50258,
11
+ mask_index: int = 50257,
12
+ model_length: int = 1024,
13
+ hidden_size: int = 768,
14
+ cond_dim: int = 128,
15
+ n_blocks: int = 12,
16
+ n_heads: int = 12,
17
+ dropout: float = 0.1,
18
+ ** kwargs):
19
+ super().__init__(**kwargs)
20
+ self.vocab_size = vocab_size
21
+ self.mask_index = mask_index
22
+ self.model_length = model_length
23
+ self.hidden_size = hidden_size
24
+ self.cond_dim = cond_dim
25
+ self.n_blocks = n_blocks
26
+ self.n_heads = n_heads
27
+ self.dropout = dropout
model.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import typing
3
+
4
+ import einops
5
+ from functools import partial
6
+ import huggingface_hub
7
+ import omegaconf
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
12
+ import transformers
13
+ from functools import lru_cache
14
+ from .config import EsoLMConfig
15
+
16
+ torch.backends.cuda.matmul.allow_tf32 = True
17
+ torch.set_float32_matmul_precision("high")
18
+ torch.backends.cudnn.benchmark = True
19
+ import torch._inductor.config as inductor_cfg
20
+ inductor_cfg.triton.cudagraphs = True
21
+ inductor_cfg.coordinate_descent_tuning = True
22
+
23
+ # Flags required to enable jit fusion kernels
24
+ torch._C._jit_set_profiling_mode(False)
25
+ torch._C._jit_set_profiling_executor(False)
26
+ torch._C._jit_override_can_fuse_on_cpu(True)
27
+ torch._C._jit_override_can_fuse_on_gpu(True)
28
+
29
+
30
+ @lru_cache
31
+ def _causal_mask(b, h, q_idx, kv_idx):
32
+ causal = q_idx >= kv_idx
33
+ return causal
34
+
35
+
36
+ @lru_cache
37
+ def _get_causal_mask(seq_len):
38
+ return create_block_mask(
39
+ _causal_mask,
40
+ B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)
41
+
42
+
43
+ @lru_cache
44
+ def _bidirectional_mask(b, h, q_idx, kv_idx):
45
+ bidirectional = q_idx == q_idx
46
+ return bidirectional
47
+
48
+
49
+ @lru_cache
50
+ def _get_bidirectional_mask(seq_len):
51
+ return create_block_mask(
52
+ _bidirectional_mask,
53
+ B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)
54
+
55
+
56
+ @lru_cache
57
+ def _mixed_mask(b, h, q_idx, kv_idx, cutoffs):
58
+ causal = q_idx >= kv_idx
59
+ block_identity = q_idx >= cutoffs[b]
60
+ return causal | block_identity
61
+
62
+
63
+ @lru_cache
64
+ def _get_mixed_mask(seq_len, cutoffs):
65
+ return create_block_mask(
66
+ partial(_mixed_mask, cutoffs=cutoffs),
67
+ B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)
68
+
69
+
70
+ @lru_cache
71
+ def _mixed2_mask(b, h, q_idx, kv_idx, cutoffs):
72
+ causal = q_idx >= kv_idx
73
+ block_identity = (q_idx < cutoffs[b]) & (kv_idx < cutoffs[b])
74
+ return causal | block_identity
75
+
76
+
77
+ @lru_cache
78
+ def _get_mixed2_mask(seq_len, cutoffs):
79
+ return create_block_mask(
80
+ partial(_mixed2_mask, cutoffs=cutoffs),
81
+ B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)
82
+
83
+
84
+ def _block_diff_mask(b, h, q_idx, kv_idx, block_size=1, n=None):
85
+ """
86
+ Copied directly from BD3LM's codebase: https://github.com/kuleshov-group/bd3lms
87
+
88
+ Constructs the specialized block diffusion attention mask for training
89
+ composed of three masks:
90
+ - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
91
+ - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
92
+ - **Block Causal Mask (M_BC)**: Attention to update x0
93
+
94
+ Args:
95
+ b, h: Batch and head indices (ignored for mask logic).
96
+ q_idx, kv_idx: Query and Key indices.
97
+ seq_len: Total sequence length.
98
+ block_size: Defines the block structure.
99
+
100
+ Returns:
101
+ A boolean attention mask.
102
+ """
103
+
104
+ # Indicate whether token belongs to xt or x0
105
+ x0_flag_q = (q_idx >= n)
106
+ x0_flag_kv = (kv_idx >= n)
107
+
108
+ # Compute block indices
109
+ block_q = torch.where(x0_flag_q == 1,
110
+ (q_idx - n) // block_size,
111
+ q_idx // block_size)
112
+ block_kv = torch.where(x0_flag_kv == 1,
113
+ (kv_idx - n) // block_size,
114
+ kv_idx // block_size)
115
+
116
+ # **1. Block Diagonal Mask (M_BD) **
117
+ block_diagonal = (
118
+ block_q == block_kv) & (x0_flag_q == x0_flag_kv)
119
+
120
+ # **2. Offset Block-Causal Mask (M_OBC) **
121
+ offset_block_causal = ((block_q > block_kv)
122
+ & (x0_flag_kv == 1)
123
+ & (x0_flag_q == 0))
124
+
125
+ # **3. Block-Causal Mask (M_BC) **
126
+ block_causal = (block_q >= block_kv) & (
127
+ x0_flag_kv == 1) & (x0_flag_q == 1)
128
+
129
+ # **4. Combine Masks **
130
+ return block_diagonal | offset_block_causal | block_causal
131
+
132
+
133
+ @lru_cache
134
+ def _get_seq_mask(seq_len):
135
+ # here, seq_len means the length of zt only
136
+ return create_block_mask(
137
+ partial(_block_diff_mask, block_size=1, n=seq_len),
138
+ B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2)
139
+
140
+
141
+ def _block_diff_mask_prefix_lm(b, h, q_idx, kv_idx, n, cutoffs):
142
+ block_diff_mask_output = _block_diff_mask(
143
+ b, h, q_idx, kv_idx, block_size=1, n=n)
144
+ block_prefix_lm = (
145
+ (n <= q_idx) & (q_idx < n + cutoffs[b])
146
+ & (n <= kv_idx) & (kv_idx < n + cutoffs[b]))
147
+ return block_diff_mask_output | block_prefix_lm
148
+
149
+
150
+ @lru_cache
151
+ def _get_seq_mask_prefix_lm(seq_len, cutoffs):
152
+ # here, seq_len means the length of zt only
153
+ return create_block_mask(
154
+ partial(_block_diff_mask_prefix_lm, n=seq_len, cutoffs=cutoffs),
155
+ B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2)
156
+
157
+
158
+ flex_attention_compiled = torch.compile(flex_attention, dynamic=False, fullgraph=True, mode='reduce-overhead')
159
+ # flex_attention_compiled = torch.compile(flex_attention, dynamic=False, fullgraph=True, mode='max-autotune-no-cudagraphs')
160
+ # flex_attention_compiled = flex_attention
161
+ # flex_attention_compiled = torch.compile(flex_attention, dynamic=True)
162
+
163
+
164
+ def fused_flex_attention(q, k, v, mask=None):
165
+ return flex_attention_compiled(q, k, v, block_mask=mask)
166
+
167
+
168
+ def bias_dropout_add_scale(
169
+ x: torch.Tensor,
170
+ bias: typing.Optional[torch.Tensor],
171
+ scale: torch.Tensor,
172
+ residual: typing.Optional[torch.Tensor],
173
+ prob: float,
174
+ training: bool) -> torch.Tensor:
175
+ if bias is not None:
176
+ out = scale * F.dropout(x + bias, p=prob, training=training)
177
+ else:
178
+ out = scale * F.dropout(x, p=prob, training=training)
179
+
180
+ if residual is not None:
181
+ out = residual + out
182
+ return out
183
+
184
+
185
+ def get_bias_dropout_add_scale(training):
186
+ def _bias_dropout_add(x, bias, scale, residual, prob):
187
+ return bias_dropout_add_scale(
188
+ x, bias, scale, residual, prob, training)
189
+
190
+ return _bias_dropout_add
191
+
192
+
193
+ # function overload
194
+ def modulate(x: torch.Tensor,
195
+ shift: torch.Tensor,
196
+ scale: torch.Tensor) -> torch.Tensor:
197
+ return x * (1 + scale) + shift
198
+
199
+
200
+ @torch.jit.script
201
+ def bias_dropout_add_scale_fused_train(
202
+ x: torch.Tensor,
203
+ bias: typing.Optional[torch.Tensor],
204
+ scale: torch.Tensor,
205
+ residual: typing.Optional[torch.Tensor],
206
+ prob: float) -> torch.Tensor:
207
+ return bias_dropout_add_scale(
208
+ x, bias, scale, residual, prob, True)
209
+
210
+
211
+ @torch.jit.script
212
+ def bias_dropout_add_scale_fused_inference(
213
+ x: torch.Tensor,
214
+ bias: typing.Optional[torch.Tensor],
215
+ scale: torch.Tensor,
216
+ residual: typing.Optional[torch.Tensor],
217
+ prob: float) -> torch.Tensor:
218
+ return bias_dropout_add_scale(
219
+ x, bias, scale, residual, prob, False)
220
+
221
+
222
+ @torch.jit.script
223
+ def modulate_fused(x: torch.Tensor,
224
+ shift: torch.Tensor,
225
+ scale: torch.Tensor) -> torch.Tensor:
226
+ return modulate(x, shift, scale)
227
+
228
+
229
+ class Rotary(torch.nn.Module):
230
+ def __init__(self, dim, base=10_000):
231
+ super().__init__()
232
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
233
+ self.register_buffer('inv_freq', inv_freq)
234
+ self.seq_len_cached = None
235
+ self.cos_cached = None
236
+ self.sin_cached = None
237
+
238
+ def forward(self, x, seq_dim=1):
239
+ seq_len = x.shape[seq_dim]
240
+ if seq_len != self.seq_len_cached:
241
+ self.seq_len_cached = seq_len
242
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
243
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
244
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
245
+ # dims are: batch, seq_len, qkv, head, dim
246
+ self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
247
+ self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
248
+ # This makes the transformation on v an identity.
249
+ self.cos_cached[:,:,2,:,:].fill_(1.)
250
+ self.sin_cached[:,:,2,:,:].fill_(0.)
251
+
252
+ return self.cos_cached, self.sin_cached
253
+
254
+
255
+ def rotate_half(x, interleaved=False):
256
+ """Copied and refactored from FlashAttention"""
257
+ if interleaved:
258
+ x1, x2 = x[..., ::2], x[..., 1::2]
259
+ return einops.rearrange(
260
+ torch.stack((-x2, x1), dim=-1),
261
+ "... d two -> ... (d two)",
262
+ two=2)
263
+ x1, x2 = x.chunk(2, dim=-1)
264
+ return torch.cat((-x2, x1), dim=-1)
265
+
266
+
267
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
268
+ """
269
+ Copied and refactored from FlashAttention
270
+ x: (batch_size, seq_len, nheads, headdim)
271
+ cos, sin: (seq_len, rotary_dim / 2) or (batch_size, seq_len, rotary_dim / 2)
272
+ """
273
+ ro_dim = cos.shape[-1] * 2
274
+ assert ro_dim <= x.shape[-1]
275
+ pattern = "... d -> ... 1 (2 d)"
276
+ if interleaved:
277
+ pattern = "... d -> ... 1 (d 2)"
278
+ cos = einops.repeat(cos, pattern)
279
+ sin = einops.repeat(sin, pattern)
280
+ return torch.cat(
281
+ [x[..., :ro_dim] * cos
282
+ + rotate_half(x[..., :ro_dim],
283
+ interleaved) * sin, x[..., ro_dim:]],
284
+ dim=-1)
285
+
286
+
287
+ def _split_rotary(rotary_cos_sin, dtype):
288
+ cos, sin = rotary_cos_sin
289
+ cos = cos.to(dtype)
290
+ sin = sin.to(dtype)
291
+ cos = cos[0,:,0,0,:cos.shape[-1]//2]
292
+ sin = sin[0,:,0,0,:sin.shape[-1]//2]
293
+ return cos, sin
294
+
295
+
296
+ def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
297
+ with torch.amp.autocast('cuda', enabled=False):
298
+ cos, sin = _split_rotary(rotary_cos_sin, dtype=qkv.dtype)
299
+ q, k, v = qkv.chunk(3, dim=2)
300
+ q = apply_rotary_emb_torch(
301
+ q.squeeze(dim=2), cos, sin)
302
+ k = apply_rotary_emb_torch(
303
+ k.squeeze(dim=2), cos, sin)
304
+ v = v.squeeze(dim=2)
305
+ return q, k, v
306
+
307
+
308
+ def split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin):
309
+ with torch.amp.autocast('cuda', enabled=False):
310
+ cos, sin = rotary_cos_sin
311
+ cos = cos.to(qkv.dtype)
312
+ sin = sin.to(qkv.dtype)
313
+ cos = cos[:,:,0,0,:cos.shape[-1]//2] # difference is here
314
+ sin = sin[:,:,0,0,:sin.shape[-1]//2] # difference is here
315
+ q, k, v = qkv.chunk(3, dim=2)
316
+ q = apply_rotary_emb_torch(
317
+ q.squeeze(dim=2), cos, sin)
318
+ k = apply_rotary_emb_torch(
319
+ k.squeeze(dim=2), cos, sin)
320
+ v = v.squeeze(dim=2)
321
+ return q, k, v
322
+
323
+
324
+ def flex_attention_multi_headed(q, k, v, mask):
325
+ q = q.transpose(1, 2).contiguous()
326
+ k = k.transpose(1, 2).contiguous()
327
+ v = v.transpose(1, 2).contiguous()
328
+ attention_output = fused_flex_attention(q, k, v, mask=mask)
329
+ attention_output = attention_output.transpose(1, 2).contiguous()
330
+ return einops.rearrange(attention_output, 'b s h d -> b s (h d)')
331
+
332
+ #################################################################################
333
+ # Layers #
334
+ #################################################################################
335
+ class LayerNorm(nn.Module):
336
+ def __init__(self, dim):
337
+ super().__init__()
338
+ self.weight = nn.Parameter(torch.ones([dim]))
339
+ self.dim = dim
340
+ def forward(self, x):
341
+ with torch.amp.autocast('cuda', enabled=False):
342
+ x = F.layer_norm(x.float(), [self.dim])
343
+ return x * self.weight[None, None, :]
344
+
345
+
346
+ def residual_linear(x, W, x_skip, residual_scale):
347
+ """x_skip + residual_scale * W @ x"""
348
+ dim_out, dim_in = W.shape[0], W.shape[1]
349
+ return torch.addmm(
350
+ x_skip.view(-1, dim_out),
351
+ x.view(-1, dim_in),
352
+ W.T,
353
+ alpha=residual_scale).view(*x.shape[:-1], dim_out)
354
+
355
+
356
+ #################################################################################
357
+ # Embedding Layers for Timesteps and Class Labels #
358
+ #################################################################################
359
+ class TimestepEmbedder(nn.Module):
360
+ """
361
+ Embeds scalar timesteps into vector representations.
362
+ """
363
+ def __init__(self, hidden_size, frequency_embedding_size=256):
364
+ super().__init__()
365
+ self.mlp = nn.Sequential(
366
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
367
+ nn.SiLU(),
368
+ nn.Linear(hidden_size, hidden_size, bias=True))
369
+ self.frequency_embedding_size = frequency_embedding_size
370
+
371
+ @staticmethod
372
+ def timestep_embedding(t, dim, max_period=10000):
373
+ """
374
+ Create sinusoidal timestep embeddings.
375
+ :param t: a 1-D Tensor of N indices, one per batch element.
376
+ These may be fractional.
377
+ :param dim: the dimension of the output.
378
+ :param max_period: controls the minimum frequency of the embeddings.
379
+ :return: an (N, D) Tensor of positional embeddings.
380
+ """
381
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
382
+ half = dim // 2
383
+ freqs = torch.exp(
384
+ - math.log(max_period)
385
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
386
+ / half)
387
+ args = t[:, None].float() * freqs[None]
388
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
389
+ if dim % 2:
390
+ embedding = torch.cat(
391
+ [embedding,
392
+ torch.zeros_like(embedding[:, :1])], dim=-1)
393
+ return embedding
394
+
395
+ def forward(self, t):
396
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
397
+ t_emb = self.mlp(t_freq)
398
+ return t_emb
399
+
400
+
401
+ class LabelEmbedder(nn.Module):
402
+ """Embeds class labels into vector representations.
403
+
404
+ Also handles label dropout for classifier-free guidance.
405
+ """
406
+ def __init__(self, num_classes, cond_size):
407
+ super().__init__()
408
+ self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
409
+ self.num_classes = num_classes
410
+
411
+ # TODO think of initializing with 0.02 std deviation like in original DiT paper
412
+
413
+ def forward(self, labels):
414
+ embeddings = self.embedding_table(labels)
415
+ return embeddings
416
+
417
+
418
+ #################################################################################
419
+ # Core Model #
420
+ #################################################################################
421
+
422
+ class DDiTBlockCausal(nn.Module):
423
+ def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):
424
+ super().__init__()
425
+ self.n_heads = n_heads
426
+
427
+ self.dim = dim
428
+ self.norm1 = LayerNorm(dim)
429
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
430
+ self.attn_out = nn.Linear(dim, dim, bias=False)
431
+ self.dropout1 = nn.Dropout(dropout)
432
+
433
+ self.norm2 = LayerNorm(dim)
434
+ self.mlp = nn.Sequential(
435
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
436
+ nn.GELU(approximate='tanh'),
437
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
438
+ self.dropout2 = nn.Dropout(dropout)
439
+ self.dropout = dropout
440
+
441
+ self.past_k = None
442
+ self.past_v = None
443
+
444
+ def _get_bias_dropout_scale(self):
445
+ if self.training:
446
+ return bias_dropout_add_scale_fused_train
447
+ else:
448
+ return bias_dropout_add_scale_fused_inference
449
+
450
+ def reset_kv_cache(self):
451
+ self.past_k = None
452
+ self.past_v = None
453
+
454
+ def _process_and_update_kv(self, k, v):
455
+ if (self.past_k is not None
456
+ and self.past_v is not None):
457
+ k = torch.cat([self.past_k, k], dim=1)
458
+ v = torch.cat([self.past_v, v], dim=1)
459
+ self.past_k = k
460
+ self.past_v = v
461
+ return k, v
462
+
463
+ @torch.no_grad()
464
+ def _attention_with_kv_cache(self, qkv, rotary_cos_sin):
465
+ assert qkv.shape[1] == 1
466
+ q, k, v = qkv.chunk(3, dim=2)
467
+ k, v = self._process_and_update_kv(k=k, v=v)
468
+ with torch.amp.autocast('cuda', enabled=False):
469
+ cos, sin = _split_rotary(rotary_cos_sin, q.dtype)
470
+ q = apply_rotary_emb_torch(
471
+ q.squeeze(dim=2), cos[-1:, :], sin[-1:, :])
472
+ k = apply_rotary_emb_torch(k.squeeze(dim=2), cos, sin)
473
+ v = v.squeeze(dim=2)
474
+ scale = q.shape[-1] ** 0.5
475
+ # swap seq_len and num_heads
476
+ q = q.transpose(1, 2)
477
+ k = k.transpose(1, 2)
478
+ v = v.transpose(1, 2)
479
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
480
+ attn_weights = F.softmax(attn_scores, dim=-1)
481
+ x = torch.matmul(attn_weights, v).transpose(1, 2)
482
+ return x.view(x.shape[0], 1, self.dim)
483
+
484
+ def forward(self, x, rotary_cos_sin, kv_cache=False, **kwargs):
485
+ del kwargs
486
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
487
+ x_skip = x
488
+ x = self.norm1(x)
489
+ qkv = einops.rearrange(
490
+ self.attn_qkv(x),
491
+ 'b s (three h d) -> b s three h d',
492
+ three=3,
493
+ h=self.n_heads)
494
+
495
+ if kv_cache:
496
+ x = self._attention_with_kv_cache(qkv.detach())
497
+ else:
498
+ q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
499
+ # recreate the mask every time (cheap) to fit different input length
500
+ # different input length can happen during generation
501
+ attn_mask = _get_causal_mask(x.shape[1])
502
+ x = flex_attention_multi_headed(q, k, v, attn_mask)
503
+
504
+ scale = torch.ones(1, device=x.device, dtype=x.dtype)
505
+ x = bias_dropout_scale_fn(
506
+ self.attn_out(x), None, scale, x_skip, self.dropout)
507
+
508
+ # mlp operation
509
+ x = bias_dropout_scale_fn(
510
+ self.mlp(self.norm2(x)), None, scale, x, self.dropout)
511
+ return x
512
+
513
+
514
+ class DDiTBlock(nn.Module):
515
+ def __init__(self, dim, n_heads, adaLN,
516
+ cond_dim=None, mlp_ratio=4,
517
+ dropout=0.1):
518
+ super().__init__()
519
+ self.n_heads = n_heads
520
+ self.dim = dim
521
+ self.adaLN = adaLN
522
+
523
+ self.norm1 = LayerNorm(dim)
524
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
525
+ self.attn_out = nn.Linear(dim, dim, bias=False)
526
+ self.dropout1 = nn.Dropout(dropout)
527
+
528
+ self.norm2 = LayerNorm(dim)
529
+ self.mlp = nn.Sequential(
530
+ nn.Linear(dim, mlp_ratio * dim, bias=True),
531
+ nn.GELU(approximate='tanh'),
532
+ nn.Linear(mlp_ratio * dim, dim, bias=True))
533
+ self.dropout2 = nn.Dropout(dropout)
534
+ self.dropout = dropout
535
+
536
+ if self.adaLN:
537
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
538
+ self.adaLN_modulation.weight.data.zero_()
539
+ self.adaLN_modulation.bias.data.zero_()
540
+
541
+ self.past_k = None
542
+ self.past_v = None
543
+ self.neg_infinity = -1000000.0
544
+
545
+ def _get_bias_dropout_scale(self):
546
+ if self.training:
547
+ return bias_dropout_add_scale_fused_train
548
+ else:
549
+ return bias_dropout_add_scale_fused_inference
550
+
551
+ def reset_kv_cache(self):
552
+ self.past_k = None
553
+ self.past_v = None
554
+
555
+ def _process_and_update_kv(self, k, v, num_clean):
556
+ if num_clean == 0:
557
+ # no caching if all we see if mask tokens
558
+ return k, v
559
+ else:
560
+ if (self.past_k is None
561
+ and self.past_v is None):
562
+ self.past_k = k[:, :num_clean, :, :]
563
+ self.past_v = v[:, :num_clean, :, :]
564
+ return k, v
565
+ else:
566
+ k_so_far = torch.cat([self.past_k, k], dim=1)
567
+ v_so_far = torch.cat([self.past_v, v], dim=1)
568
+ # only update the kv cache with kv values from
569
+ # clean tokens generated during the previous
570
+ # iteration
571
+ self.past_k = torch.cat(
572
+ [self.past_k, k[:, :num_clean, :, :]], dim=1)
573
+ self.past_v = torch.cat(
574
+ [self.past_v, v[:, :num_clean, :, :]], dim=1)
575
+ return k_so_far, v_so_far
576
+
577
+ @torch.no_grad()
578
+ def _attention_with_kv_cache(self, qkv, rotary_cos_sin,
579
+ num_clean, num_clean_and_mask):
580
+ # num_clean: num gen last
581
+ # num_clean_and_mask: num gen last + num to gen
582
+ assert qkv.shape[1] == num_clean_and_mask
583
+ # qkv shape:
584
+ # [bs, num gen last + num to gen, 3, h, d]
585
+ q, k, v = qkv.chunk(3, dim=2)
586
+ q = q.squeeze(dim=2)
587
+ k = k.squeeze(dim=2)
588
+ v = v.squeeze(dim=2)
589
+ k, v = self._process_and_update_kv(
590
+ k=k, v=v, num_clean=num_clean)
591
+ # new kv shape:
592
+ # [bs,
593
+ # num gen before last + num gen last + num to gen,
594
+ # h, d]
595
+ with torch.amp.autocast('cuda', enabled=False):
596
+ cos, sin = rotary_cos_sin
597
+ cos = cos.to(qkv.dtype)
598
+ sin = sin.to(qkv.dtype)
599
+ cos = cos[:,:,0,0,:cos.shape[-1]//2]
600
+ sin = sin[:,:,0,0,:sin.shape[-1]//2]
601
+ cos_part = cos[:, -num_clean_and_mask:]
602
+ sin_part = sin[:, -num_clean_and_mask:]
603
+ q = apply_rotary_emb_torch(q, cos_part, sin_part)
604
+ k = apply_rotary_emb_torch(k, cos, sin)
605
+ scale = q.shape[-1] ** 0.5
606
+ # shapes after transpose:
607
+ # q: [bs, h, num gen last + num to gen, d]
608
+ # k: [bs, h, num gen before last + num gen last + num to gen, d]
609
+ # v: [bs, h, num gen before last + num gen last + num to gen, d]
610
+ q = q.transpose(1, 2)
611
+ k = k.transpose(1, 2)
612
+ v = v.transpose(1, 2)
613
+ # attn_scores shape:
614
+ # [bs, h,
615
+ # num gen last + num to gen,
616
+ # num gen before last + num gen last + num to gen]
617
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
618
+ ones = torch.ones(
619
+ num_clean_and_mask, num_clean_and_mask).to(qkv.device)
620
+ # A contains very large negative values above the diagonal
621
+ # - q attends to all v values over "num gen before last"
622
+ # - q attends causally to v values within "num gen last
623
+ # + num to gen"
624
+ A = self.neg_infinity * torch.triu(ones, diagonal=1)
625
+ A = A.view(1, 1, num_clean_and_mask, num_clean_and_mask)
626
+ attn_scores[:, :, :, -num_clean_and_mask:] += A
627
+ attn_weights = F.softmax(attn_scores, dim=-1)
628
+ # matmul shape: [bs, h, num gen last + num to gen, d]
629
+ # shape after tranpose: [bs, num gen last + num to gen, h, d]
630
+ attn_output = torch.matmul(attn_weights, v).transpose(1, 2)
631
+ return einops.rearrange(attn_output, 'b s h d -> b s (h d)')
632
+
633
+ def forward(self, x, rotary_cos_sin, c=None, attn_mask=None,
634
+ kv_cache=False, num_clean=None, num_clean_and_mask=None):
635
+ bias_dropout_scale_fn = self._get_bias_dropout_scale()
636
+
637
+ x_skip = x
638
+ x = self.norm1(x)
639
+ if self.adaLN:
640
+ # self.adaLN_modulation(c): (128, 1536)
641
+ # self.adaLN_modulation(c)[:, None]: (128, 1, 1536)
642
+ # "" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256)
643
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
644
+ gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
645
+ x = modulate_fused(x, shift_msa, scale_msa)
646
+
647
+ qkv = einops.rearrange(
648
+ self.attn_qkv(x),
649
+ 'b s (three h d) -> b s three h d',
650
+ three=3,
651
+ h=self.n_heads).contiguous()
652
+ if kv_cache:
653
+ x = self._attention_with_kv_cache(
654
+ qkv.detach(), rotary_cos_sin,
655
+ num_clean=num_clean, num_clean_and_mask=num_clean_and_mask)
656
+ else:
657
+ if rotary_cos_sin[0].shape[0] > 1:
658
+ q, k, v = split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin)
659
+ else:
660
+ q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
661
+ x = flex_attention_multi_headed(q, k, v, attn_mask)
662
+
663
+ if self.adaLN:
664
+ x = bias_dropout_scale_fn(self.attn_out(x),
665
+ None,
666
+ gate_msa,
667
+ x_skip,
668
+ self.dropout)
669
+ x = bias_dropout_scale_fn(
670
+ self.mlp(modulate_fused(
671
+ self.norm2(x), shift_mlp, scale_mlp)),
672
+ None, gate_mlp, x, self.dropout)
673
+ else:
674
+ scale = torch.ones(1, device=x.device, dtype=x.dtype)
675
+ x = bias_dropout_scale_fn(
676
+ self.attn_out(x), None, scale, x_skip, self.dropout)
677
+ x = bias_dropout_scale_fn(
678
+ self.mlp(self.norm2(x)), None, scale, x, self.dropout)
679
+ return x
680
+
681
+
682
+ class EmbeddingLayer(nn.Module):
683
+ def __init__(self, dim, vocab_dim):
684
+ super().__init__()
685
+ self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
686
+ torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
687
+
688
+ def forward(self, x):
689
+ if x.ndim == 2:
690
+ return self.embedding[x]
691
+ assert x.ndim == 3
692
+ return torch.einsum(
693
+ "blv,ve->ble",
694
+ torch.nn.functional.softmax(x, dim=-1).float(),
695
+ self.embedding.float()).to(x.dtype)
696
+
697
+
698
+ class DDiTFinalLayer(nn.Module):
699
+ def __init__(self, hidden_size, out_channels, cond_dim,
700
+ adaLN):
701
+ super().__init__()
702
+ self.norm_final = LayerNorm(hidden_size)
703
+ self.linear = nn.Linear(hidden_size, out_channels)
704
+ self.linear.weight.data.zero_()
705
+ self.linear.bias.data.zero_()
706
+ self.adaLN = adaLN
707
+ if self.adaLN:
708
+ self.adaLN_modulation = nn.Linear(cond_dim,
709
+ 2 * hidden_size,
710
+ bias=True)
711
+ self.adaLN_modulation.weight.data.zero_()
712
+ self.adaLN_modulation.bias.data.zero_()
713
+
714
+
715
+ def forward(self, x, c):
716
+ x = self.norm_final(x)
717
+ if self.adaLN:
718
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
719
+ x = modulate_fused(x, shift, scale)
720
+ x = self.linear(x)
721
+ return x
722
+
723
+
724
+ class DiT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
725
+ def __init__(self, config, vocab_size: int):
726
+ super().__init__()
727
+ if type(config) == dict:
728
+ config = omegaconf.OmegaConf.create(config)
729
+ self.causal = config.algo.causal_attention
730
+ self.adaLN = not self.causal
731
+ self.config = config
732
+ self.vocab_size = vocab_size
733
+ dim = config.model.hidden_size
734
+ cond_dim = config.model.cond_dim
735
+ self.vocab_embed = EmbeddingLayer(dim, vocab_size)
736
+ if not self.causal:
737
+ self.sigma_map = TimestepEmbedder(cond_dim)
738
+ self.rotary_dim = dim // config.model.n_heads
739
+ self.rotary_emb = Rotary(self.rotary_dim)
740
+
741
+ blocks = []
742
+ for _ in range(config.model.n_blocks):
743
+ if self.causal:
744
+ block = DDiTBlockCausal(
745
+ dim=dim,
746
+ n_heads=config.model.n_heads,
747
+ dropout=config.model.dropout)
748
+ else:
749
+ block = DDiTBlock(
750
+ dim=dim,
751
+ n_heads=config.model.n_heads,
752
+ cond_dim=cond_dim,
753
+ adaLN=self.adaLN,
754
+ dropout=config.model.dropout)
755
+ blocks.append(block)
756
+ self.blocks = nn.ModuleList(blocks)
757
+
758
+ self.output_layer = DDiTFinalLayer(
759
+ hidden_size=dim,
760
+ out_channels=vocab_size,
761
+ cond_dim=cond_dim,
762
+ adaLN=self.adaLN)
763
+ self.scale_by_sigma = config.model.scale_by_sigma
764
+
765
+ def _get_bias_dropout_scale(self):
766
+ if self.training:
767
+ return bias_dropout_add_scale_fused_train
768
+ else:
769
+ return bias_dropout_add_scale_fused_inference
770
+
771
+ def reset_kv_cache(self):
772
+ for block in self.blocks:
773
+ block.reset_kv_cache()
774
+
775
+ def forward(self, x, sigma, x0=None, kv_cache=False):
776
+ assert x0 is None
777
+ x = self.vocab_embed(x)
778
+ if self.causal:
779
+ t_cond = None
780
+ else:
781
+ t_cond = F.silu(self.sigma_map(sigma))
782
+
783
+ rotary_cos_sin = self.rotary_emb(x)
784
+ if kv_cache:
785
+ x = x[:, -1:, :]
786
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
787
+ for i in range(len(self.blocks)):
788
+ x = self.blocks[i](
789
+ x, rotary_cos_sin, c=t_cond, kv_cache=kv_cache)
790
+ x = self.output_layer(x, c=t_cond)
791
+ return x
792
+
793
+
794
+ def _get_reverse_indices(indices):
795
+ """
796
+ indices: LongTensor of shape [B, N] representing permutations
797
+ returns: LongTensor of shape [B, N] representing the inverse permutations
798
+ """
799
+ B, N = indices.shape
800
+ reverse_indices = torch.empty_like(indices)
801
+ arange = torch.arange(N, device=indices.device).unsqueeze(0).expand(B, -1)
802
+ reverse_indices.scatter_(1, indices, arange)
803
+ return reverse_indices
804
+
805
+
806
+ class EsoLMDiT(DiT):
807
+ def __init__(self, config, vocab_size: int, mask_index: int):
808
+ super().__init__(config, vocab_size)
809
+ # sequential not causal
810
+ # this also makes sure that
811
+ # - sigma_map was created
812
+ # - DDiTBlock was used instead of DDiTBlockCausal
813
+ assert not self.causal and self.adaLN
814
+ self.mask_index = mask_index
815
+
816
+ self.diffusion_shuffle = config.algo.diffusion_shuffle
817
+ self.diffusion_attn_mode = config.algo.diffusion_attn_mode
818
+ self.sequential_shuffle = config.algo.sequential_shuffle
819
+ self.sequential_attn_mode = config.algo.sequential_attn_mode
820
+
821
+ self.mdlm_mask = None
822
+ self.seq_mask = None
823
+
824
+ def _sort_indices(
825
+ self, indices, shuffle, keep_masks_unshuffled=False):
826
+ masked = (indices == self.mask_index)
827
+ if shuffle:
828
+ offsets = torch.rand(
829
+ indices.shape).to(indices.device) * 0.9
830
+ if keep_masks_unshuffled:
831
+ # induce left-to-right order within masked tokens
832
+ # only for sequential part
833
+ offsets[masked] = torch.linspace(
834
+ 0, 1, torch.sum(masked)).to(indices.device)
835
+ else:
836
+ offsets = torch.linspace(
837
+ 0, 0.9, indices.shape[1]).to(indices.device)
838
+ sort_idx = (masked + offsets).argsort(descending=False)
839
+ indices = torch.gather(indices, dim=1, index=sort_idx)
840
+ return indices, sort_idx
841
+
842
+ def _sort_rotary_cos_sin(self, rotary_cos_sin, sort_idx):
843
+ # example cos shape: (1, 128, 3, 1, 32)
844
+ # 128 for seq_len, 3 for qkv, 32 for head dim
845
+ cos, sin = rotary_cos_sin
846
+ bs = sort_idx.shape[0]
847
+ cos = cos.expand(bs, -1, -1, -1, -1)
848
+ sin = sin.expand(bs, -1, -1, -1, -1)
849
+ cos = torch.gather(
850
+ cos, dim=1,
851
+ index=sort_idx[:, :, None, None, None].expand(
852
+ -1, -1, 3, -1, self.rotary_dim)).contiguous()
853
+ sin = torch.gather(
854
+ sin, dim=1,
855
+ index=sort_idx[:, :, None, None, None].expand(
856
+ -1, -1, 3, -1, self.rotary_dim)).contiguous()
857
+ return cos, sin
858
+
859
+ def _get_attention_mask(self, seq_len, attn_mode=None,
860
+ cutoffs=None):
861
+ if attn_mode == 'causal':
862
+ if self.mdlm_mask is None:
863
+ self.mdlm_mask = _get_causal_mask(seq_len)
864
+ return self.mdlm_mask
865
+ elif attn_mode == 'bidirectional':
866
+ if self.mdlm_mask is None:
867
+ self.mdlm_mask = _get_bidirectional_mask(seq_len)
868
+ return self.mdlm_mask
869
+ elif attn_mode == 'mixed':
870
+ # causal over clean tokens
871
+ # bidirectional over masked tokens
872
+ return _get_mixed_mask(seq_len=seq_len,
873
+ cutoffs=cutoffs)
874
+ elif attn_mode == 'mixed2':
875
+ # bidirectional over clean tokens
876
+ # causal over masked tokens
877
+ return _get_mixed2_mask(seq_len=seq_len,
878
+ cutoffs=cutoffs)
879
+
880
+ def _diffusion_features(self, zt, sort_idx=None,
881
+ attn_mode=None, cutoffs=None):
882
+ # masked diffusion:
883
+ # - move masked tokens to the left
884
+ # - move unmasked tokens to the right
885
+ if cutoffs is None:
886
+ cutoffs = torch.sum(zt != self.mask_index, dim=1)
887
+ if attn_mode is None:
888
+ attn_mode = self.diffusion_attn_mode
889
+ if sort_idx is None:
890
+ zt, sort_idx = self._sort_indices(
891
+ zt, self.diffusion_shuffle)
892
+ x = self.vocab_embed(zt)
893
+ rotary_cos_sin = self.rotary_emb(x)
894
+ rotary_cos_sin = self._sort_rotary_cos_sin(
895
+ rotary_cos_sin, sort_idx)
896
+ attention_mask = self._get_attention_mask(
897
+ seq_len=zt.shape[1],
898
+ attn_mode=attn_mode,
899
+ cutoffs=cutoffs)
900
+ return {'x': x,
901
+ 'rotary': rotary_cos_sin,
902
+ 'attention': attention_mask,
903
+ 'sorted_indices': sort_idx}
904
+
905
+ def _sequential_features(self, zt, x0):
906
+ # gap-filling AR with trick from BD3LM
907
+ # - also move masked tokens to the left
908
+ # - also move unmasked tokens to the right
909
+ seq_len = zt.shape[1]
910
+ zt, sort_idx = self._sort_indices(
911
+ zt, self.sequential_shuffle,
912
+ keep_masks_unshuffled=True)
913
+ x0 = torch.gather(x0, dim=1, index=sort_idx)
914
+ zt_and_x0 = torch.cat([zt, x0], dim=1)
915
+ cutoffs = torch.sum(zt != self.mask_index, dim=1)
916
+ x = self.vocab_embed(zt_and_x0)
917
+ rotary_cos_sin = self.rotary_emb(x[:, :seq_len])
918
+ rotary_cos_sin = self._sort_rotary_cos_sin(
919
+ rotary_cos_sin, sort_idx)
920
+ cos, sin = rotary_cos_sin
921
+ cos = torch.cat([cos, cos], dim=1)
922
+ sin = torch.cat([sin, sin], dim=1)
923
+ rotary_cos_sin = (cos, sin)
924
+
925
+ if self.sequential_attn_mode == 'causal':
926
+ if self.seq_mask is None:
927
+ self.seq_mask = _get_seq_mask(seq_len)
928
+ return {'x': x,
929
+ 'rotary': rotary_cos_sin,
930
+ 'attention': self.seq_mask,
931
+ 'sorted_indices': sort_idx}
932
+ elif self.sequential_attn_mode == 'mixed':
933
+ return {'x': x,
934
+ 'rotary': rotary_cos_sin,
935
+ 'attention': _get_seq_mask_prefix_lm(
936
+ seq_len, cutoffs=cutoffs),
937
+ 'sorted_indices': sort_idx}
938
+
939
+ def forward(self, zt, sigma, x0=None):
940
+ diffusion_mode = x0 is None
941
+ seq_len = zt.shape[1]
942
+
943
+ if diffusion_mode:
944
+ features = self._diffusion_features(zt)
945
+ else:
946
+ features = self._sequential_features(zt, x0)
947
+ x = features['x']
948
+ t_cond = F.silu(self.sigma_map(sigma))
949
+ with torch.amp.autocast('cuda', enabled=False):
950
+ for i in range(len(self.blocks)):
951
+ x = self.blocks[i](x, features['rotary'], c=t_cond,
952
+ attn_mask=features['attention'])
953
+ x = self.output_layer(x, c=t_cond)
954
+
955
+ if not diffusion_mode:
956
+ x = x[:, :seq_len]
957
+ sort_idx_reversed = _get_reverse_indices(features['sorted_indices'])
958
+ x = torch.gather(
959
+ x, dim=1,
960
+ index=sort_idx_reversed[:, :, None].expand(
961
+ -1, -1, self.vocab_size))
962
+ return x
963
+
964
+ @torch.no_grad()
965
+ def forward_sample(self, zt, sort_idx, attn_mode=None,
966
+ cutoffs=None, kv_cache=False,
967
+ last_k_start=None,
968
+ curr_k_start=None,
969
+ curr_k_end=None):
970
+ """
971
+ zt is expected to be sorted as per sort_idx.
972
+
973
+ When kv_cache is true:
974
+ - zt will have shape (num_samples, model.length); we need its shape to generate
975
+ all the rotary embeddings because any of them can be selected by
976
+ the random ordering
977
+ - sort_idx will have shape
978
+ (num_samples, model.length) for the same reason
979
+ - last_k_start_idx (starting index)
980
+ - curr_k_start_idx
981
+ - curr_k_end_idx (ending index)
982
+ - use these two to select features['x'] to pass into the blocks
983
+
984
+ Within self._diffusion_features, zt will be used
985
+ to generate the full rotary embeddings, and sort_idx
986
+ will be index the embedded zt into shape
987
+ (num_samples, num_tokens_generated_last_time (non-mask) + num_tokens_to_gen (mask), hidden)
988
+
989
+ We want to append the kv values for num_tokens_generated_last_time to the old kv cache
990
+ and not build up kv values for num_tokens_to_gen (because they are masks)
991
+ """
992
+ assert attn_mode is not None
993
+ ones = torch.ones(zt.shape[0], device=zt.device)
994
+ if cutoffs is not None:
995
+ cutoffs = cutoffs * ones
996
+ assert cutoffs.ndim == 1
997
+ features = self._diffusion_features(
998
+ zt=zt,
999
+ sort_idx=sort_idx,
1000
+ attn_mode=attn_mode,
1001
+ cutoffs=cutoffs)
1002
+ zeros = torch.zeros(zt.shape[0], device=zt.device)
1003
+ t_cond = F.silu(self.sigma_map(zeros))
1004
+
1005
+ x = features['x']
1006
+ rotary = features['rotary']
1007
+ if kv_cache:
1008
+ # expect x to be sorted
1009
+ x = x[:, last_k_start:curr_k_end, :]
1010
+ # rotary is already sorted here
1011
+ # looking ahead
1012
+ cos, sin = rotary
1013
+ rotary = (cos[:, :curr_k_end], sin[:, :curr_k_end])
1014
+ num_clean = curr_k_start - last_k_start
1015
+ num_clean_and_mask = curr_k_end - last_k_start
1016
+ else:
1017
+ num_clean = None
1018
+ num_clean_and_mask = None
1019
+
1020
+ with torch.amp.autocast('cuda', enabled=False):
1021
+ for i in range(len(self.blocks)):
1022
+ x = self.blocks[i](
1023
+ x, rotary, c=t_cond,
1024
+ attn_mask=features['attention'],
1025
+ kv_cache=kv_cache,
1026
+ num_clean=num_clean,
1027
+ num_clean_and_mask=num_clean_and_mask)
1028
+ x = self.output_layer(x, c=t_cond)
1029
+
1030
+ if kv_cache:
1031
+ x = x[:, num_clean:, :]
1032
+ return x
1033
+
1034
+
1035
+ class EsoLMHFDiT(nn.Module):
1036
+ def __init__(self, config):
1037
+ super().__init__()
1038
+ self.vocab_embed = EmbeddingLayer(
1039
+ config.hidden_size, config.vocab_size)
1040
+ self.sigma_map = TimestepEmbedder(config.cond_dim)
1041
+ self.rotary_dim = config.hidden_size // config.n_heads
1042
+ self.rotary_emb = Rotary(self.rotary_dim)
1043
+
1044
+ blocks = []
1045
+ for _ in range(config.n_blocks):
1046
+ block = DDiTBlock(
1047
+ dim=config.hidden_size,
1048
+ n_heads=config.n_heads,
1049
+ cond_dim=config.cond_dim,
1050
+ adaLN=True,
1051
+ dropout=config.dropout)
1052
+ blocks.append(block)
1053
+ self.blocks = nn.ModuleList(blocks)
1054
+
1055
+ self.output_layer = DDiTFinalLayer(
1056
+ hidden_size=config.hidden_size,
1057
+ out_channels=config.vocab_size,
1058
+ cond_dim=config.cond_dim,
1059
+ adaLN=True)
1060
+
1061
+ def reset_kv_cache(self):
1062
+ for block in self.blocks:
1063
+ block.reset_kv_cache()
1064
+
1065
+
1066
+ class EsoLM(transformers.PreTrainedModel):
1067
+ """HF-compatible model."""
1068
+ config_class = EsoLMConfig
1069
+ base_model_prefix = 'esolm'
1070
+
1071
+ def __init__(self, config: EsoLMConfig):
1072
+ super().__init__(config)
1073
+ self.config = config
1074
+ self.backbone = EsoLMHFDiT(config)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b75bc23568da985ade7f6d7cbfad0eb6f6ce066f094e8c72d3b13b5ca2e7ee0
3
+ size 678522728