Upload AuriStream Parallel base model code
Browse files
modeling_auristream_parallel.py
CHANGED
|
@@ -198,6 +198,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 198 |
bsz, tg, gsz, vsz = logits.shape
|
| 199 |
return logits.reshape(bsz, tg * gsz, vsz)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
def forward(
|
| 202 |
self,
|
| 203 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -214,16 +218,28 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 214 |
if input_ids is None:
|
| 215 |
raise ValueError("input_ids (or seq) must be provided")
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
x = self._group_embed(input_ids)
|
| 218 |
|
| 219 |
all_hidden_states = ()
|
| 220 |
if output_hidden_states:
|
| 221 |
-
all_hidden_states = (x,)
|
| 222 |
|
| 223 |
for block in self.h:
|
| 224 |
x = block(x)
|
| 225 |
if output_hidden_states:
|
| 226 |
-
all_hidden_states = all_hidden_states + (
|
|
|
|
|
|
|
| 227 |
|
| 228 |
x = self.ln_f(x)
|
| 229 |
logits = self._decode_parallel_logits(x)
|
|
|
|
| 198 |
bsz, tg, gsz, vsz = logits.shape
|
| 199 |
return logits.reshape(bsz, tg * gsz, vsz)
|
| 200 |
|
| 201 |
+
def _expand_group_hidden(self, x: torch.Tensor, target_len: int) -> torch.Tensor:
|
| 202 |
+
expanded = x.repeat_interleave(self.group_size, dim=1)
|
| 203 |
+
return expanded[:, :target_len, :]
|
| 204 |
+
|
| 205 |
def forward(
|
| 206 |
self,
|
| 207 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 218 |
if input_ids is None:
|
| 219 |
raise ValueError("input_ids (or seq) must be provided")
|
| 220 |
|
| 221 |
+
usable_len = (input_ids.shape[1] // self.group_size) * self.group_size
|
| 222 |
+
if usable_len <= 0:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"Input sequence length {input_ids.shape[1]} is too short for group_size={self.group_size}"
|
| 225 |
+
)
|
| 226 |
+
if usable_len != input_ids.shape[1]:
|
| 227 |
+
input_ids = input_ids[:, :usable_len]
|
| 228 |
+
if labels is not None:
|
| 229 |
+
labels = labels[:, :usable_len]
|
| 230 |
+
|
| 231 |
x = self._group_embed(input_ids)
|
| 232 |
|
| 233 |
all_hidden_states = ()
|
| 234 |
if output_hidden_states:
|
| 235 |
+
all_hidden_states = (self._expand_group_hidden(x, target_len=usable_len),)
|
| 236 |
|
| 237 |
for block in self.h:
|
| 238 |
x = block(x)
|
| 239 |
if output_hidden_states:
|
| 240 |
+
all_hidden_states = all_hidden_states + (
|
| 241 |
+
self._expand_group_hidden(x, target_len=usable_len),
|
| 242 |
+
)
|
| 243 |
|
| 244 |
x = self.ln_f(x)
|
| 245 |
logits = self._decode_parallel_logits(x)
|