Update modeling_auristream.py
Browse files- modeling_auristream.py +2 -2
modeling_auristream.py
CHANGED
|
@@ -357,11 +357,11 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 357 |
if not return_dict:
|
| 358 |
if labels is not None:
|
| 359 |
return logits, loss
|
| 360 |
-
return logits.
|
| 361 |
|
| 362 |
return CausalLMOutput(
|
| 363 |
loss=loss,
|
| 364 |
-
logits=logits.
|
| 365 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 366 |
)
|
| 367 |
|
|
|
|
| 357 |
if not return_dict:
|
| 358 |
if labels is not None:
|
| 359 |
return logits, loss
|
| 360 |
+
return logits.unsqueeze(0), None
|
| 361 |
|
| 362 |
return CausalLMOutput(
|
| 363 |
loss=loss,
|
| 364 |
+
logits=logits.unsqueeze(0),
|
| 365 |
hidden_states=all_hidden_states if output_hidden_states else None,
|
| 366 |
)
|
| 367 |
|