klemenk commited on
Commit
89f5db6
·
verified ·
1 Parent(s): c07a579

Upload AuriStream Parallel base model code

Browse files
Files changed (1) hide show
  1. modeling_auristream_parallel.py +18 -2
modeling_auristream_parallel.py CHANGED
@@ -198,6 +198,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
198
  bsz, tg, gsz, vsz = logits.shape
199
  return logits.reshape(bsz, tg * gsz, vsz)
200
 
 
 
 
 
201
  def forward(
202
  self,
203
  input_ids: Optional[torch.LongTensor] = None,
@@ -214,16 +218,28 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
214
  if input_ids is None:
215
  raise ValueError("input_ids (or seq) must be provided")
216
 
 
 
 
 
 
 
 
 
 
 
217
  x = self._group_embed(input_ids)
218
 
219
  all_hidden_states = ()
220
  if output_hidden_states:
221
- all_hidden_states = (x,)
222
 
223
  for block in self.h:
224
  x = block(x)
225
  if output_hidden_states:
226
- all_hidden_states = all_hidden_states + (x,)
 
 
227
 
228
  x = self.ln_f(x)
229
  logits = self._decode_parallel_logits(x)
 
198
  bsz, tg, gsz, vsz = logits.shape
199
  return logits.reshape(bsz, tg * gsz, vsz)
200
 
201
+ def _expand_group_hidden(self, x: torch.Tensor, target_len: int) -> torch.Tensor:
202
+ expanded = x.repeat_interleave(self.group_size, dim=1)
203
+ return expanded[:, :target_len, :]
204
+
205
  def forward(
206
  self,
207
  input_ids: Optional[torch.LongTensor] = None,
 
218
  if input_ids is None:
219
  raise ValueError("input_ids (or seq) must be provided")
220
 
221
+ usable_len = (input_ids.shape[1] // self.group_size) * self.group_size
222
+ if usable_len <= 0:
223
+ raise ValueError(
224
+ f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}"
225
+ )
226
+ if usable_len != input_ids.shape[1]:
227
+ input_ids = input_ids[:, :usable_len]
228
+ if labels is not None:
229
+ labels = labels[:, :usable_len]
230
+
231
  x = self._group_embed(input_ids)
232
 
233
  all_hidden_states = ()
234
  if output_hidden_states:
235
+ all_hidden_states = (self._expand_group_hidden(x, target_len=usable_len),)
236
 
237
  for block in self.h:
238
  x = block(x)
239
  if output_hidden_states:
240
+ all_hidden_states = all_hidden_states + (
241
+ self._expand_group_hidden(x, target_len=usable_len),
242
+ )
243
 
244
  x = self.ln_f(x)
245
  logits = self._decode_parallel_logits(x)