klemenk commited on
Commit
6211130
·
verified ·
1 Parent(s): 89f5db6

Upload AuriStream Parallel base model code

Browse files
Files changed (1) hide show
  1. modeling_auristream_parallel.py +2 -8
modeling_auristream_parallel.py CHANGED
@@ -198,10 +198,6 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
198
  bsz, tg, gsz, vsz = logits.shape
199
  return logits.reshape(bsz, tg * gsz, vsz)
200
 
201
- def _expand_group_hidden(self, x: torch.Tensor, target_len: int) -> torch.Tensor:
202
- expanded = x.repeat_interleave(self.group_size, dim=1)
203
- return expanded[:, :target_len, :]
204
-
205
  def forward(
206
  self,
207
  input_ids: Optional[torch.LongTensor] = None,
@@ -232,14 +228,12 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
232
 
233
  all_hidden_states = ()
234
  if output_hidden_states:
235
- all_hidden_states = (self._expand_group_hidden(x, target_len=usable_len),)
236
 
237
  for block in self.h:
238
  x = block(x)
239
  if output_hidden_states:
240
- all_hidden_states = all_hidden_states + (
241
- self._expand_group_hidden(x, target_len=usable_len),
242
- )
243
 
244
  x = self.ln_f(x)
245
  logits = self._decode_parallel_logits(x)
 
198
  bsz, tg, gsz, vsz = logits.shape
199
  return logits.reshape(bsz, tg * gsz, vsz)
200
 
 
 
 
 
201
  def forward(
202
  self,
203
  input_ids: Optional[torch.LongTensor] = None,
 
228
 
229
  all_hidden_states = ()
230
  if output_hidden_states:
231
+ all_hidden_states = (x,)
232
 
233
  for block in self.h:
234
  x = block(x)
235
  if output_hidden_states:
236
+ all_hidden_states = all_hidden_states + (x,)
 
 
237
 
238
  x = self.ln_f(x)
239
  logits = self._decode_parallel_logits(x)