klemenk commited on
Commit
c2cc9ee
·
verified ·
1 Parent(s): 28d0dca

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +64 -43
modeling_auristream.py CHANGED
@@ -111,30 +111,35 @@ class AuriStream(PreTrainedModel):
111
  x = self.transformer.ln_f(x)
112
  logits = self.coch_head(x)
113
 
114
- if tgt is not None:
115
 
116
- if output_logits:
117
- all_logits = [logits]
118
-
 
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
122
 
123
- # If we have more than one future head, compute the loss for each head
124
- if self.future_heads is not None:
125
- for i, head in enumerate(self.future_heads):
126
- future_logits = head(x[:, :-(i+1)])
 
 
127
  loss += F.cross_entropy(
128
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
129
  )
130
- if output_logits:
131
- all_logits.append(future_logits)
 
 
132
  # divide loss by number of future heads
133
  loss = loss / (len(self.future_heads) + 1)
134
 
135
- if return_dict:
136
- if output_logits:
137
- if output_hidden_states:
 
138
  model_output = CausalLMOutput(
139
  loss=loss,
140
  logits=all_logits,
@@ -142,23 +147,45 @@ class AuriStream(PreTrainedModel):
142
  )
143
  else:
144
  model_output = CausalLMOutput(
145
- loss=loss,
146
  logits=all_logits,
 
147
  )
148
  else:
149
- if output_hidden_states:
 
 
 
 
 
 
 
 
 
 
 
150
  model_output = CausalLMOutput(
151
  loss=loss,
152
  logits=logits,
153
  hidden_states=all_hidden_states,
154
  )
155
  else:
 
 
 
 
 
 
156
  model_output = CausalLMOutput(
157
  loss=loss,
158
  logits=logits,
159
  )
160
- return model_output
161
-
 
 
 
 
 
162
  return logits, loss
163
 
164
  return logits, None
@@ -215,26 +242,35 @@ class AuriStream(PreTrainedModel):
215
  return sampled
216
 
217
  @torch.no_grad()
218
- def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
219
- top_k=500, top_p=0.5, seed=None):
 
 
 
 
 
 
 
220
  """
221
  Parameters:
222
- seq: torch.Tensor of shape (b, t, n_freq_bins)
223
- Input cochleagram to use for generation
224
  n_tokens: int
225
- Number of time bins to predict
226
  temp: float
227
  Temperature for sampling logits
 
 
 
 
228
  seed: int
229
  Random seed for sampling
230
 
231
  Returns:
232
- pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
233
- The predicted cochleagram
234
- all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
235
- The logits for each time step
236
- all_embs: (optional if return_embs is not None) list of torch.Tensor
237
- The embeddings for each transformer block
238
  """
239
 
240
  # Set seed if provided
@@ -250,14 +286,6 @@ class AuriStream(PreTrainedModel):
250
  # grab shape of the cochleagram
251
  b, t = seq.size()
252
 
253
- # TODO: double check this works then delete the block bellow:
254
- # pass the given input through the model to get the predictions and cache
255
- # the k and v values for each transformer block in the process
256
- # pos = torch.arange(0, t, dtype=torch.long, device=device)
257
- # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
258
- # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
259
- # x = self.transformer.drop(tok_emb + pos_emb)
260
-
261
  #### Embed conditioning sequence into KV cache
262
 
263
  tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
@@ -295,13 +323,6 @@ class AuriStream(PreTrainedModel):
295
  # using the last embedding of the input
296
  for i in range(n_tokens-1):
297
 
298
- # TODO: double check this works then delete the block bellow:
299
- # # Get the emb and pos embedding of just the last token
300
- # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
301
- # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
302
- # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
303
- # x = self.transformer.drop(tok_emb + pos_emb)
304
-
305
  # Get the emb and pos embedding of just the last token
306
  tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
307
  # if wpe exists in self.transformer apply leanred positional embedding
 
111
  x = self.transformer.ln_f(x)
112
  logits = self.coch_head(x)
113
 
 
114
 
115
+ if output_logits:
116
+ all_logits = [logits]
117
+
118
+ if tgt is not None:
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
122
 
123
+ # If we have more than one future head, compute the loss for each head
124
+ if self.future_heads is not None:
125
+ for i, head in enumerate(self.future_heads):
126
+ future_logits = head(x[:, :-(i+1)])
127
+
128
+ if tgt is not None:
129
  loss += F.cross_entropy(
130
  future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
131
  )
132
+ if output_logits:
133
+ all_logits.append(future_logits)
134
+
135
+ if tgt is not None:
136
  # divide loss by number of future heads
137
  loss = loss / (len(self.future_heads) + 1)
138
 
139
+ if return_dict:
140
+ if output_logits:
141
+ if output_hidden_states:
142
+ if tgt is not None:
143
  model_output = CausalLMOutput(
144
  loss=loss,
145
  logits=all_logits,
 
147
  )
148
  else:
149
  model_output = CausalLMOutput(
 
150
  logits=all_logits,
151
+ hidden_states=all_hidden_states,
152
  )
153
  else:
154
+ if tgt is not None:
155
+ model_output = CausalLMOutput(
156
+ loss=loss,
157
+ logits=all_logits,
158
+ )
159
+ else:
160
+ model_output = CausalLMOutput(
161
+ logits=all_logits,
162
+ )
163
+ else:
164
+ if output_hidden_states:
165
+ if tgt is not None:
166
  model_output = CausalLMOutput(
167
  loss=loss,
168
  logits=logits,
169
  hidden_states=all_hidden_states,
170
  )
171
  else:
172
+ model_output = CausalLMOutput(
173
+ logits=logits,
174
+ hidden_states=all_hidden_states,
175
+ )
176
+ else:
177
+ if tgt is not None:
178
  model_output = CausalLMOutput(
179
  loss=loss,
180
  logits=logits,
181
  )
182
+ else:
183
+ model_output = CausalLMOutput(
184
+ logits=logits,
185
+ )
186
+ return model_output
187
+
188
+ if tgt is not None:
189
  return logits, loss
190
 
191
  return logits, None
 
242
  return sampled
243
 
244
  @torch.no_grad()
245
+ def generate(
246
+ self,
247
+ seq: torch.Tensor,
248
+ n_tokens: int = 1,
249
+ temp: float = 1.0,
250
+ top_k: int = None,
251
+ top_p: float = None,
252
+ seed: int = None,
253
+ ):
254
  """
255
  Parameters:
256
+ seq: torch.Tensor of shape (b, t)
257
+ Input cochlear tokens to condition the generation
258
  n_tokens: int
259
+ Number of future tokens (5ms time bins) to predict
260
  temp: float
261
  Temperature for sampling logits
262
+ top_k: int
263
+ Restrict sampling to k tokens with highest probability (sample from all tokens if None)
264
+ top_p: float
265
+ Restrict sampling to most probable tokens with cumulative probability of p (sample form all tokens if None)
266
  seed: int
267
  Random seed for sampling
268
 
269
  Returns:
270
+ pred_coch: torch.Tensor of shape (b, t)
271
+ The generated cochlear tokens
272
+ all_logits: torch.Tensor of shape (b, n_tokens, vocab_size)
273
+ The logits at each time step
 
 
274
  """
275
 
276
  # Set seed if provided
 
286
  # grab shape of the cochleagram
287
  b, t = seq.size()
288
 
 
 
 
 
 
 
 
 
289
  #### Embed conditioning sequence into KV cache
290
 
291
  tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
 
323
  # using the last embedding of the input
324
  for i in range(n_tokens-1):
325
 
 
 
 
 
 
 
 
326
  # Get the emb and pos embedding of just the last token
327
  tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
328
  # if wpe exists in self.transformer apply leanred positional embedding