klemenk commited on
Commit
980097b
·
verified ·
1 Parent(s): c219ef7

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +6 -0
modeling_auristream.py CHANGED
@@ -147,6 +147,12 @@ class AuriStream(PreTrainedModel):
147
 
148
  return logits, loss
149
 
 
 
 
 
 
 
150
  return logits, None
151
 
152
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
 
147
 
148
  return logits, loss
149
 
150
+ if output_logits and return_dict:
151
+ model_output = CausalLMOutput(
152
+ logits=all_logits,
153
+ )
154
+ return model_output
155
+
156
  return logits, None
157
 
158
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,