klemenk commited on
Commit
13eec23
·
verified ·
1 Parent(s): 92b5b06

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +3 -2
modeling_auristream.py CHANGED
@@ -293,6 +293,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
293
  output_hidden_states: Optional[bool] = False,
294
  return_dict: Optional[bool] = True,
295
  # Legacy arguments for compatibility
 
296
  seq: Optional[torch.LongTensor] = None,
297
  tgt: Optional[torch.LongTensor] = None,
298
  ):
@@ -356,11 +357,11 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
356
  if not return_dict:
357
  if labels is not None:
358
  return logits, loss
359
- return logits, None
360
 
361
  return CausalLMOutput(
362
  loss=loss,
363
- logits=logits,
364
  hidden_states=all_hidden_states if output_hidden_states else None,
365
  )
366
 
 
293
  output_hidden_states: Optional[bool] = False,
294
  return_dict: Optional[bool] = True,
295
  # Legacy arguments for compatibility
296
+ return_logits: Optional[bool] = True,
297
  seq: Optional[torch.LongTensor] = None,
298
  tgt: Optional[torch.LongTensor] = None,
299
  ):
 
357
  if not return_dict:
358
  if labels is not None:
359
  return logits, loss
360
+ return logits.unsuqeeze(0), None
361
 
362
  return CausalLMOutput(
363
  loss=loss,
364
+ logits=logits.unsuqeeze(0),
365
  hidden_states=all_hidden_states if output_hidden_states else None,
366
  )
367