Update modeling_auristream.py
Browse files- modeling_auristream.py +36 -1
modeling_auristream.py
CHANGED
|
@@ -165,7 +165,6 @@ class AuriStream(PreTrainedModel):
|
|
| 165 |
)
|
| 166 |
if output_logits:
|
| 167 |
all_logits.append(future_logits)
|
| 168 |
-
loss = loss / (len(self.future_heads) + 1)
|
| 169 |
|
| 170 |
if return_dict:
|
| 171 |
if output_logits:
|
|
@@ -195,7 +194,43 @@ class AuriStream(PreTrainedModel):
|
|
| 195 |
return model_output
|
| 196 |
|
| 197 |
return logits, loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
|
|
|
|
|
|
| 199 |
return logits, None
|
| 200 |
|
| 201 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|
|
|
|
| 165 |
)
|
| 166 |
if output_logits:
|
| 167 |
all_logits.append(future_logits)
|
|
|
|
| 168 |
|
| 169 |
if return_dict:
|
| 170 |
if output_logits:
|
|
|
|
| 194 |
return model_output
|
| 195 |
|
| 196 |
return logits, loss
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
if output_logits:
|
| 200 |
+
all_logits = [logits]
|
| 201 |
+
|
| 202 |
+
# future multi-step heads (unchanged)
|
| 203 |
+
if self.future_heads is not None:
|
| 204 |
+
for i, head in enumerate(self.future_heads):
|
| 205 |
+
future_logits = head(x[:, :-(i + 1)])
|
| 206 |
+
if output_logits:
|
| 207 |
+
all_logits.append(future_logits)
|
| 208 |
+
|
| 209 |
+
if return_dict:
|
| 210 |
+
if output_logits:
|
| 211 |
+
if output_hidden_states:
|
| 212 |
+
model_output = CausalLMOutput(
|
| 213 |
+
logits=all_logits,
|
| 214 |
+
hidden_states=hs_to_return,
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
model_output = CausalLMOutput(
|
| 218 |
+
logits=all_logits,
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
if output_hidden_states:
|
| 222 |
+
model_output = CausalLMOutput(
|
| 223 |
+
logits=logits,
|
| 224 |
+
hidden_states=hs_to_return,
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
model_output = CausalLMOutput(
|
| 228 |
+
logits=logits,
|
| 229 |
+
)
|
| 230 |
+
return model_output
|
| 231 |
|
| 232 |
+
return logits, loss
|
| 233 |
+
|
| 234 |
return logits, None
|
| 235 |
|
| 236 |
def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
|