Upload BD3LM
Browse files- modeling_bd3lm.py +23 -18
modeling_bd3lm.py
CHANGED
|
@@ -299,7 +299,7 @@ class DDiTBlock(nn.Module):
|
|
| 299 |
return bias_dropout_add_scale_fused_inference
|
| 300 |
|
| 301 |
|
| 302 |
-
def get_qkv(self, x, rotary_cos_sin,
|
| 303 |
# compute qkv (potentially use cache)
|
| 304 |
if self.kv_cache is not None:
|
| 305 |
block_len = x.shape[1] - self.kv_cache.shape[1]
|
|
@@ -308,8 +308,8 @@ class DDiTBlock(nn.Module):
|
|
| 308 |
else:
|
| 309 |
qkv = self.attn_qkv(x)
|
| 310 |
|
| 311 |
-
#
|
| 312 |
-
if
|
| 313 |
if self.kv_cache is not None:
|
| 314 |
cache_len = min(x.shape[1], self.n - block_len)
|
| 315 |
self.kv_cache = qkv[:, -cache_len:]
|
|
@@ -347,7 +347,8 @@ class DDiTBlock(nn.Module):
|
|
| 347 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
| 348 |
return x
|
| 349 |
|
| 350 |
-
def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
|
|
|
|
| 351 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
| 352 |
|
| 353 |
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
|
@@ -358,12 +359,12 @@ class DDiTBlock(nn.Module):
|
|
| 358 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
| 359 |
|
| 360 |
# get qkvs
|
| 361 |
-
if cross_attn_mask is not None and not
|
| 362 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
| 363 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
| 364 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
| 365 |
else:
|
| 366 |
-
qkv = self.get_qkv(x, rotary_cos_sin,
|
| 367 |
|
| 368 |
if cross_attn_mask is None and self.attn_backend == 'flash_attn':
|
| 369 |
x = regular_attention_multi_headed(qkv)
|
|
@@ -470,9 +471,8 @@ class DITBackbone(nn.Module):
|
|
| 470 |
x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
|
| 471 |
self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
|
| 472 |
|
| 473 |
-
def forward(self, indices, sigma,
|
| 474 |
-
|
| 475 |
-
cross_attn = self.cross_attn and not disable_cross_attn
|
| 476 |
if not self.config.time_conditioning:
|
| 477 |
sigma = torch.zeros_like(sigma)
|
| 478 |
all_hidden_states = []
|
|
@@ -480,11 +480,13 @@ class DITBackbone(nn.Module):
|
|
| 480 |
if output_hidden_states:
|
| 481 |
all_hidden_states.append(x)
|
| 482 |
c = F.silu(self.sigma_map(sigma))
|
| 483 |
-
if cross_attn:
|
| 484 |
-
cross_attn_mask = self.cross_attn_mask.to(x.device)
|
| 485 |
-
if save_kv:
|
| 486 |
-
cross_attn_mask = cross_attn_mask[:x.shape[1], :x.shape[1]]
|
| 487 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
else:
|
| 489 |
cross_attn_mask = None
|
| 490 |
rotary_cos_sin = self.rotary_emb(x)
|
|
@@ -495,11 +497,12 @@ class DITBackbone(nn.Module):
|
|
| 495 |
rotary_cos_sin,
|
| 496 |
c,
|
| 497 |
cross_attn_mask=cross_attn_mask,
|
| 498 |
-
|
|
|
|
| 499 |
if output_hidden_states:
|
| 500 |
all_hidden_states.append(x)
|
| 501 |
logits = self.output_layer(x, c)
|
| 502 |
-
if cross_attn and not
|
| 503 |
logits = logits[:, :self.n]
|
| 504 |
all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
|
| 505 |
return logits, all_hidden_states
|
|
@@ -526,7 +529,8 @@ class BD3LM(transformers.PreTrainedModel):
|
|
| 526 |
self,
|
| 527 |
input_ids: torch.LongTensor = None,
|
| 528 |
timesteps: torch.FloatTensor = None,
|
| 529 |
-
|
|
|
|
| 530 |
output_hidden_states: typing.Optional[bool] = None,
|
| 531 |
return_dict: typing.Optional[bool] = None,
|
| 532 |
) -> typing.Union[
|
|
@@ -545,8 +549,9 @@ class BD3LM(transformers.PreTrainedModel):
|
|
| 545 |
logits, all_hidden_states = self.backbone(
|
| 546 |
indices=input_ids,
|
| 547 |
sigma=timesteps,
|
| 548 |
-
|
| 549 |
-
|
|
|
|
| 550 |
)
|
| 551 |
if return_dict:
|
| 552 |
return modeling_outputs.MaskedLMOutput(
|
|
|
|
| 299 |
return bias_dropout_add_scale_fused_inference
|
| 300 |
|
| 301 |
|
| 302 |
+
def get_qkv(self, x, rotary_cos_sin, store_kv=False):
|
| 303 |
# compute qkv (potentially use cache)
|
| 304 |
if self.kv_cache is not None:
|
| 305 |
block_len = x.shape[1] - self.kv_cache.shape[1]
|
|
|
|
| 308 |
else:
|
| 309 |
qkv = self.attn_qkv(x)
|
| 310 |
|
| 311 |
+
# store kv cache in a sliding window (can't exceed context len)
|
| 312 |
+
if store_kv:
|
| 313 |
if self.kv_cache is not None:
|
| 314 |
cache_len = min(x.shape[1], self.n - block_len)
|
| 315 |
self.kv_cache = qkv[:, -cache_len:]
|
|
|
|
| 347 |
x = einops.rearrange(x, 'b s h d -> b s (h d)')
|
| 348 |
return x
|
| 349 |
|
| 350 |
+
def forward(self, x, rotary_cos_sin, c, cross_attn_mask=None,
|
| 351 |
+
sample_mode=False, store_kv=False):
|
| 352 |
bias_dropout_scale_fn = self._get_bias_dropout_scale()
|
| 353 |
|
| 354 |
(shift_msa, scale_msa, gate_msa, shift_mlp,
|
|
|
|
| 359 |
x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
|
| 360 |
|
| 361 |
# get qkvs
|
| 362 |
+
if cross_attn_mask is not None and not sample_mode:
|
| 363 |
qkv_x = self.get_qkv(x[:,:self.n], rotary_cos_sin)
|
| 364 |
qkv_x0 = self.get_qkv(x[:,self.n:], rotary_cos_sin)
|
| 365 |
qkv = torch.cat((qkv_x, qkv_x0), dim=1)
|
| 366 |
else:
|
| 367 |
+
qkv = self.get_qkv(x, rotary_cos_sin, store_kv=store_kv)
|
| 368 |
|
| 369 |
if cross_attn_mask is None and self.attn_backend == 'flash_attn':
|
| 370 |
x = regular_attention_multi_headed(qkv)
|
|
|
|
| 471 |
x0_attn_mask = torch.cat((torch.zeros_like(self_attn_mask), x0_attn_mask), dim=1)
|
| 472 |
self.cross_attn_mask = torch.cat((cross_attn_mask, x0_attn_mask), dim=0)
|
| 473 |
|
| 474 |
+
def forward(self, indices, sigma, sample_mode=False,
|
| 475 |
+
store_kv=False, output_hidden_states=False):
|
|
|
|
| 476 |
if not self.config.time_conditioning:
|
| 477 |
sigma = torch.zeros_like(sigma)
|
| 478 |
all_hidden_states = []
|
|
|
|
| 480 |
if output_hidden_states:
|
| 481 |
all_hidden_states.append(x)
|
| 482 |
c = F.silu(self.sigma_map(sigma))
|
| 483 |
+
if self.cross_attn:
|
|
|
|
|
|
|
|
|
|
| 484 |
rotary_cos_sin = self.rotary_emb(x[:, :self.n])
|
| 485 |
+
cross_attn_mask = self.cross_attn_mask.to(x.device)
|
| 486 |
+
# use block-causal mask only during sampling
|
| 487 |
+
if sample_mode:
|
| 488 |
+
cross_attn_mask = cross_attn_mask[
|
| 489 |
+
self.n:self.n+x.shape[1], self.n:self.n+x.shape[1]]
|
| 490 |
else:
|
| 491 |
cross_attn_mask = None
|
| 492 |
rotary_cos_sin = self.rotary_emb(x)
|
|
|
|
| 497 |
rotary_cos_sin,
|
| 498 |
c,
|
| 499 |
cross_attn_mask=cross_attn_mask,
|
| 500 |
+
sample_mode=sample_mode,
|
| 501 |
+
store_kv=store_kv)
|
| 502 |
if output_hidden_states:
|
| 503 |
all_hidden_states.append(x)
|
| 504 |
logits = self.output_layer(x, c)
|
| 505 |
+
if self.cross_attn and not sample_mode:
|
| 506 |
logits = logits[:, :self.n]
|
| 507 |
all_hidden_states = [hidden_states[:, :self.n] for hidden_states in all_hidden_states]
|
| 508 |
return logits, all_hidden_states
|
|
|
|
| 529 |
self,
|
| 530 |
input_ids: torch.LongTensor = None,
|
| 531 |
timesteps: torch.FloatTensor = None,
|
| 532 |
+
sample_mode: typing.Optional[bool] = None,
|
| 533 |
+
store_kv: typing.Optional[bool] = None,
|
| 534 |
output_hidden_states: typing.Optional[bool] = None,
|
| 535 |
return_dict: typing.Optional[bool] = None,
|
| 536 |
) -> typing.Union[
|
|
|
|
| 549 |
logits, all_hidden_states = self.backbone(
|
| 550 |
indices=input_ids,
|
| 551 |
sigma=timesteps,
|
| 552 |
+
sample_mode=sample_mode,
|
| 553 |
+
store_kv=store_kv,
|
| 554 |
+
output_hidden_states=output_hidden_states,
|
| 555 |
)
|
| 556 |
if return_dict:
|
| 557 |
return modeling_outputs.MaskedLMOutput(
|