Sync modeling_auristream.py from TuKoResearch/AuriStream100M_40Pred_BigAudioDataset_500k
Browse files- modeling_auristream.py +23 -9
modeling_auristream.py
CHANGED
|
@@ -134,15 +134,29 @@ class AuriStream(PreTrainedModel):
|
|
| 134 |
|
| 135 |
if return_dict:
|
| 136 |
if output_logits:
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
else:
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return model_output
|
| 147 |
|
| 148 |
return logits, loss
|
|
@@ -577,4 +591,4 @@ class RMSNorm(nn.Module):
|
|
| 577 |
output = self._norm(x.float()).type_as(x)
|
| 578 |
if self.weight is not None:
|
| 579 |
return output * self.weight
|
| 580 |
-
return output
|
|
|
|
| 134 |
|
| 135 |
if return_dict:
|
| 136 |
if output_logits:
|
| 137 |
+
if output_hidden_states:
|
| 138 |
+
model_output = CausalLMOutput(
|
| 139 |
+
loss=loss,
|
| 140 |
+
logits=all_logits,
|
| 141 |
+
hidden_states=all_hidden_states,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
model_output = CausalLMOutput(
|
| 145 |
+
loss=loss,
|
| 146 |
+
logits=all_logits,
|
| 147 |
+
)
|
| 148 |
else:
|
| 149 |
+
if output_hidden_states:
|
| 150 |
+
model_output = CausalLMOutput(
|
| 151 |
+
loss=loss,
|
| 152 |
+
logits=logits,
|
| 153 |
+
hidden_states=all_hidden_states,
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
model_output = CausalLMOutput(
|
| 157 |
+
loss=loss,
|
| 158 |
+
logits=logits,
|
| 159 |
+
)
|
| 160 |
return model_output
|
| 161 |
|
| 162 |
return logits, loss
|
|
|
|
| 591 |
output = self._norm(x.float()).type_as(x)
|
| 592 |
if self.weight is not None:
|
| 593 |
return output * self.weight
|
| 594 |
+
return output
|