Subham Sekhar Sahoo commited on
Commit
4ad52be
·
verified ·
1 Parent(s): e98b448

Upload DUO

Browse files
Files changed (1) hide show
  1. model.py +10 -5
model.py CHANGED
@@ -530,8 +530,11 @@ class HFDIT(torch.nn.Module):
530
  else:
531
  return bias_dropout_add_scale_fused_inference
532
 
533
- def forward(self, x, sigma):
 
534
  x = self.vocab_embed(x)
 
 
535
  if self.causal:
536
  t_cond = None
537
  else:
@@ -541,8 +544,12 @@ class HFDIT(torch.nn.Module):
541
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
542
  for i in range(len(self.blocks)):
543
  x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
 
 
544
  x = self.output_layer(x, c=t_cond)
545
- return x
 
 
546
 
547
 
548
 
@@ -585,10 +592,8 @@ class DUO(transformers.PreTrainedModel):
585
  else self.config.use_return_dict
586
 
587
  logits, all_hidden_states = self.backbone(
588
- indices=input_ids,
589
  sigma=timesteps,
590
- sample_mode=sample_mode,
591
- store_kv=store_kv,
592
  output_hidden_states=output_hidden_states,
593
  )
594
  if return_dict:
 
530
  else:
531
  return bias_dropout_add_scale_fused_inference
532
 
533
+ def forward(self, x, sigma, output_hidden_states=False):
534
+ all_hidden_states = []
535
  x = self.vocab_embed(x)
536
+ if output_hidden_states:
537
+ all_hidden_states.append(x)
538
  if self.causal:
539
  t_cond = None
540
  else:
 
544
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
545
  for i in range(len(self.blocks)):
546
  x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
547
+ if output_hidden_states:
548
+ all_hidden_states.append(x)
549
  x = self.output_layer(x, c=t_cond)
550
+ if output_hidden_states:
551
+ all_hidden_states.append(x)
552
+ return x, all_hidden_states
553
 
554
 
555
 
 
592
  else self.config.use_return_dict
593
 
594
  logits, all_hidden_states = self.backbone(
595
+ x=input_ids,
596
  sigma=timesteps,
 
 
597
  output_hidden_states=output_hidden_states,
598
  )
599
  if return_dict: