klemenk commited on
Commit
da6f56f
·
verified ·
1 Parent(s): deef453

Sync modeling_auristream.py from TuKoResearch/AuriStream100M_40Pred_BigAudioDataset_500k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +23 -9
modeling_auristream.py CHANGED
@@ -134,15 +134,29 @@ class AuriStream(PreTrainedModel):
134
 
135
  if return_dict:
136
  if output_logits:
137
- model_output = CausalLMOutput(
138
- loss=loss,
139
- logits=all_logits,
140
- )
 
 
 
 
 
 
 
141
  else:
142
- model_output = CausalLMOutput(
143
- loss=loss,
144
- logits=logits,
145
- )
 
 
 
 
 
 
 
146
  return model_output
147
 
148
  return logits, loss
@@ -577,4 +591,4 @@ class RMSNorm(nn.Module):
577
  output = self._norm(x.float()).type_as(x)
578
  if self.weight is not None:
579
  return output * self.weight
580
- return output
 
134
 
135
  if return_dict:
136
  if output_logits:
137
+ if output_hidden_states:
138
+ model_output = CausalLMOutput(
139
+ loss=loss,
140
+ logits=all_logits,
141
+ hidden_states=all_hidden_states,
142
+ )
143
+ else:
144
+ model_output = CausalLMOutput(
145
+ loss=loss,
146
+ logits=all_logits,
147
+ )
148
  else:
149
+ if output_hidden_states:
150
+ model_output = CausalLMOutput(
151
+ loss=loss,
152
+ logits=logits,
153
+ hidden_states=all_hidden_states,
154
+ )
155
+ else:
156
+ model_output = CausalLMOutput(
157
+ loss=loss,
158
+ logits=logits,
159
+ )
160
  return model_output
161
 
162
  return logits, loss
 
591
  output = self._norm(x.float()).type_as(x)
592
  if self.weight is not None:
593
  return output * self.weight
594
+ return output