Upload AuriStream Parallel base model code
Browse files
modeling_auristream_parallel.py
CHANGED
|
@@ -198,10 +198,6 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 198 |
bsz, tg, gsz, vsz = logits.shape
|
| 199 |
return logits.reshape(bsz, tg * gsz, vsz)
|
| 200 |
|
| 201 |
-
def _expand_group_hidden(self, x: torch.Tensor, target_len: int) -> torch.Tensor:
|
| 202 |
-
expanded = x.repeat_interleave(self.group_size, dim=1)
|
| 203 |
-
return expanded[:, :target_len, :]
|
| 204 |
-
|
| 205 |
def forward(
|
| 206 |
self,
|
| 207 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -232,14 +228,12 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 232 |
|
| 233 |
all_hidden_states = ()
|
| 234 |
if output_hidden_states:
|
| 235 |
-
all_hidden_states = (
|
| 236 |
|
| 237 |
for block in self.h:
|
| 238 |
x = block(x)
|
| 239 |
if output_hidden_states:
|
| 240 |
-
all_hidden_states = all_hidden_states + (
|
| 241 |
-
self._expand_group_hidden(x, target_len=usable_len),
|
| 242 |
-
)
|
| 243 |
|
| 244 |
x = self.ln_f(x)
|
| 245 |
logits = self._decode_parallel_logits(x)
|
|
|
|
| 198 |
bsz, tg, gsz, vsz = logits.shape
|
| 199 |
return logits.reshape(bsz, tg * gsz, vsz)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
def forward(
|
| 202 |
self,
|
| 203 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 228 |
|
| 229 |
all_hidden_states = ()
|
| 230 |
if output_hidden_states:
|
| 231 |
+
all_hidden_states = (x,)
|
| 232 |
|
| 233 |
for block in self.h:
|
| 234 |
x = block(x)
|
| 235 |
if output_hidden_states:
|
| 236 |
+
all_hidden_states = all_hidden_states + (x,)
|
|
|
|
|
|
|
| 237 |
|
| 238 |
x = self.ln_f(x)
|
| 239 |
logits = self._decode_parallel_logits(x)
|