Update modeling_auristream.py
Browse files- modeling_auristream.py +3 -2
modeling_auristream.py
CHANGED
|
@@ -293,6 +293,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 293 |
output_hidden_states: Optional[bool] = False,
|
| 294 |
return_dict: Optional[bool] = True,
|
| 295 |
# Legacy arguments for compatibility
|
|
|
|
| 296 |
seq: Optional[torch.LongTensor] = None,
|
| 297 |
tgt: Optional[torch.LongTensor] = None,
|
| 298 |
):
|
|
@@ -356,11 +357,11 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 356 |
if not return_dict:
|
| 357 |
if labels is not None:
|
| 358 |
return logits, loss
|
| 359 |
-
return logits, None
|
| 360 |
|
| 361 |
return CausalLMOutput(
|
| 362 |
loss=loss,
|
| 363 |
-
logits=logits,
|
| 364 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 365 |
)
|
| 366 |
|
|
|
|
| 293 |
output_hidden_states: Optional[bool] = False,
|
| 294 |
return_dict: Optional[bool] = True,
|
| 295 |
# Legacy arguments for compatibility
|
| 296 |
+
return_logits: Optional[bool] = True,
|
| 297 |
seq: Optional[torch.LongTensor] = None,
|
| 298 |
tgt: Optional[torch.LongTensor] = None,
|
| 299 |
):
|
|
|
|
| 357 |
if not return_dict:
|
| 358 |
if labels is not None:
|
| 359 |
return logits, loss
|
| 360 |
+
return logits.unsuqeeze(0), None
|
| 361 |
|
| 362 |
return CausalLMOutput(
|
| 363 |
loss=loss,
|
| 364 |
+
logits=logits.unsuqeeze(0),
|
| 365 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 366 |
)
|
| 367 |
|