Subham Sekhar Sahoo
commited on
Upload DUO
Browse files
model.py
CHANGED
|
@@ -530,8 +530,11 @@ class HFDIT(torch.nn.Module):
|
|
| 530 |
else:
|
| 531 |
return bias_dropout_add_scale_fused_inference
|
| 532 |
|
| 533 |
-
def forward(self, x, sigma):
|
|
|
|
| 534 |
x = self.vocab_embed(x)
|
|
|
|
|
|
|
| 535 |
if self.causal:
|
| 536 |
t_cond = None
|
| 537 |
else:
|
|
@@ -541,8 +544,12 @@ class HFDIT(torch.nn.Module):
|
|
| 541 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 542 |
for i in range(len(self.blocks)):
|
| 543 |
x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
|
|
|
|
|
|
|
| 544 |
x = self.output_layer(x, c=t_cond)
|
| 545 |
-
|
|
|
|
|
|
|
| 546 |
|
| 547 |
|
| 548 |
|
|
@@ -585,10 +592,8 @@ class DUO(transformers.PreTrainedModel):
|
|
| 585 |
else self.config.use_return_dict
|
| 586 |
|
| 587 |
logits, all_hidden_states = self.backbone(
|
| 588 |
-
|
| 589 |
sigma=timesteps,
|
| 590 |
-
sample_mode=sample_mode,
|
| 591 |
-
store_kv=store_kv,
|
| 592 |
output_hidden_states=output_hidden_states,
|
| 593 |
)
|
| 594 |
if return_dict:
|
|
|
|
| 530 |
else:
|
| 531 |
return bias_dropout_add_scale_fused_inference
|
| 532 |
|
| 533 |
+
def forward(self, x, sigma, output_hidden_states=False):
|
| 534 |
+
all_hidden_states = []
|
| 535 |
x = self.vocab_embed(x)
|
| 536 |
+
if output_hidden_states:
|
| 537 |
+
all_hidden_states.append(x)
|
| 538 |
if self.causal:
|
| 539 |
t_cond = None
|
| 540 |
else:
|
|
|
|
| 544 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 545 |
for i in range(len(self.blocks)):
|
| 546 |
x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
|
| 547 |
+
if output_hidden_states:
|
| 548 |
+
all_hidden_states.append(x)
|
| 549 |
x = self.output_layer(x, c=t_cond)
|
| 550 |
+
if output_hidden_states:
|
| 551 |
+
all_hidden_states.append(x)
|
| 552 |
+
return x, all_hidden_states
|
| 553 |
|
| 554 |
|
| 555 |
|
|
|
|
| 592 |
else self.config.use_return_dict
|
| 593 |
|
| 594 |
logits, all_hidden_states = self.backbone(
|
| 595 |
+
x=input_ids,
|
| 596 |
sigma=timesteps,
|
|
|
|
|
|
|
| 597 |
output_hidden_states=output_hidden_states,
|
| 598 |
)
|
| 599 |
if return_dict:
|