klemenk commited on
Commit
1d06c27
·
verified ·
1 Parent(s): 39e5822

Sync modeling_auristream.py from TuKoResearch/AuriStream200M_100Pred_librilight_200k

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +44 -17
modeling_auristream.py CHANGED
@@ -111,30 +111,35 @@ class AuriStream(PreTrainedModel):
111
  x = self.transformer.ln_f(x)
112
  logits = self.coch_head(x)
113
 
114
- if tgt is not None:
115
 
116
- if output_logits:
117
- all_logits = [logits]
118
-
 
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
122
 
123
- # If we have more than one future head, compute the loss for each head
124
- if self.future_heads is not None:
125
- for i, head in enumerate(self.future_heads):
126
- future_logits = head(x[:, :-(i+1)])
 
 
127
  loss += F.cross_entropy(
128
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
129
  )
130
- if output_logits:
131
- all_logits.append(future_logits)
 
 
132
  # divide loss by number of future heads
133
  loss = loss / (len(self.future_heads) + 1)
134
 
135
- if return_dict:
136
- if output_logits:
137
- if output_hidden_states:
 
138
  model_output = CausalLMOutput(
139
  loss=loss,
140
  logits=all_logits,
@@ -142,23 +147,45 @@ class AuriStream(PreTrainedModel):
142
  )
143
  else:
144
  model_output = CausalLMOutput(
145
- loss=loss,
146
  logits=all_logits,
 
147
  )
148
  else:
149
- if output_hidden_states:
 
 
 
 
 
 
 
 
 
 
 
150
  model_output = CausalLMOutput(
151
  loss=loss,
152
  logits=logits,
153
  hidden_states=all_hidden_states,
154
  )
155
  else:
 
 
 
 
 
 
156
  model_output = CausalLMOutput(
157
  loss=loss,
158
  logits=logits,
159
  )
160
- return model_output
161
-
 
 
 
 
 
162
  return logits, loss
163
 
164
  return logits, None
 
111
  x = self.transformer.ln_f(x)
112
  logits = self.coch_head(x)
113
 
 
114
 
115
+ if output_logits:
116
+ all_logits = [logits]
117
+
118
+ if tgt is not None:
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
122
 
123
+ # If we have more than one future head, compute the loss for each head
124
+ if self.future_heads is not None:
125
+ for i, head in enumerate(self.future_heads):
126
+ future_logits = head(x[:, :-(i+1)])
127
+
128
+ if tgt is not None:
129
  loss += F.cross_entropy(
130
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
131
  )
132
+ if output_logits:
133
+ all_logits.append(future_logits)
134
+
135
+ if tgt is not None:
136
  # divide loss by number of future heads
137
  loss = loss / (len(self.future_heads) + 1)
138
 
139
+ if return_dict:
140
+ if output_logits:
141
+ if output_hidden_states:
142
+ if tgt is not None:
143
  model_output = CausalLMOutput(
144
  loss=loss,
145
  logits=all_logits,
 
147
  )
148
  else:
149
  model_output = CausalLMOutput(
 
150
  logits=all_logits,
151
+ hidden_states=all_hidden_states,
152
  )
153
  else:
154
+ if tgt is not None:
155
+ model_output = CausalLMOutput(
156
+ loss=loss,
157
+ logits=all_logits,
158
+ )
159
+ else:
160
+ model_output = CausalLMOutput(
161
+ logits=all_logits,
162
+ )
163
+ else:
164
+ if output_hidden_states:
165
+ if tgt is not None:
166
  model_output = CausalLMOutput(
167
  loss=loss,
168
  logits=logits,
169
  hidden_states=all_hidden_states,
170
  )
171
  else:
172
+ model_output = CausalLMOutput(
173
+ logits=logits,
174
+ hidden_states=all_hidden_states,
175
+ )
176
+ else:
177
+ if tgt is not None:
178
  model_output = CausalLMOutput(
179
  loss=loss,
180
  logits=logits,
181
  )
182
+ else:
183
+ model_output = CausalLMOutput(
184
+ logits=logits,
185
+ )
186
+ return model_output
187
+
188
+ if tgt is not None:
189
  return logits, loss
190
 
191
  return logits, None