Update modeling_auristream.py
Browse files- modeling_auristream.py +6 -0
modeling_auristream.py
CHANGED
|
@@ -147,6 +147,12 @@ class AuriStream(PreTrainedModel):
|
|
| 147 |
|
| 148 |
return logits, loss
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
return logits, None
|
| 151 |
|
| 152 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|
|
|
|
| 147 |
|
| 148 |
return logits, loss
|
| 149 |
|
| 150 |
+
if output_logits and return_dict:
|
| 151 |
+
model_output = CausalLMOutput(
|
| 152 |
+
logits=all_logits,
|
| 153 |
+
)
|
| 154 |
+
return model_output
|
| 155 |
+
|
| 156 |
return logits, None
|
| 157 |
|
| 158 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|