klemenk commited on
Commit
f8a35c2
·
verified ·
1 Parent(s): c3250d7

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +36 -1
modeling_auristream.py CHANGED
@@ -165,7 +165,6 @@ class AuriStream(PreTrainedModel):
165
  )
166
  if output_logits:
167
  all_logits.append(future_logits)
168
- loss = loss / (len(self.future_heads) + 1)
169
 
170
  if return_dict:
171
  if output_logits:
@@ -195,7 +194,43 @@ class AuriStream(PreTrainedModel):
195
  return model_output
196
 
197
  return logits, loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
 
 
199
  return logits, None
200
 
201
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
 
165
  )
166
  if output_logits:
167
  all_logits.append(future_logits)
 
168
 
169
  if return_dict:
170
  if output_logits:
 
194
  return model_output
195
 
196
  return logits, loss
197
+
198
+ else:
199
+ if output_logits:
200
+ all_logits = [logits]
201
+
202
+ # future multi-step heads (unchanged)
203
+ if self.future_heads is not None:
204
+ for i, head in enumerate(self.future_heads):
205
+ future_logits = head(x[:, :-(i + 1)])
206
+ if output_logits:
207
+ all_logits.append(future_logits)
208
+
209
+ if return_dict:
210
+ if output_logits:
211
+ if output_hidden_states:
212
+ model_output = CausalLMOutput(
213
+ logits=all_logits,
214
+ hidden_states=hs_to_return,
215
+ )
216
+ else:
217
+ model_output = CausalLMOutput(
218
+ logits=all_logits,
219
+ )
220
+ else:
221
+ if output_hidden_states:
222
+ model_output = CausalLMOutput(
223
+ logits=logits,
224
+ hidden_states=hs_to_return,
225
+ )
226
+ else:
227
+ model_output = CausalLMOutput(
228
+ logits=logits,
229
+ )
230
+ return model_output
231
 
232
+ return logits, loss
233
+
234
  return logits, None
235
 
236
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,