Update modeling_auristream.py
Browse files- modeling_auristream.py +51 -13
modeling_auristream.py
CHANGED
|
@@ -290,10 +290,12 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 290 |
self,
|
| 291 |
input_ids: Optional[torch.LongTensor] = None,
|
| 292 |
labels: Optional[torch.LongTensor] = None,
|
|
|
|
| 293 |
output_hidden_states: Optional[bool] = False,
|
| 294 |
return_dict: Optional[bool] = True,
|
|
|
|
|
|
|
| 295 |
# Legacy arguments for compatibility
|
| 296 |
-
return_logits: Optional[bool] = True,
|
| 297 |
seq: Optional[torch.LongTensor] = None,
|
| 298 |
tgt: Optional[torch.LongTensor] = None,
|
| 299 |
):
|
|
@@ -303,13 +305,16 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 303 |
Args:
|
| 304 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 305 |
labels: Target token IDs for computing loss
|
|
|
|
| 306 |
output_hidden_states: Whether to return all hidden states
|
| 307 |
return_dict: Whether to return a dict or tuple
|
|
|
|
|
|
|
| 308 |
seq: Legacy argument (alias for input_ids)
|
| 309 |
tgt: Legacy argument (alias for labels)
|
| 310 |
|
| 311 |
Returns:
|
| 312 |
-
CausalLMOutput with logits and optional loss
|
| 313 |
"""
|
| 314 |
# Handle legacy arguments
|
| 315 |
if seq is not None:
|
|
@@ -321,22 +326,55 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 321 |
tok_emb = self.wte(input_ids)
|
| 322 |
x = self.drop(tok_emb)
|
| 323 |
|
| 324 |
-
# Collect hidden states
|
| 325 |
all_hidden_states = []
|
| 326 |
|
| 327 |
# Forward through transformer blocks
|
| 328 |
-
for block in self.h:
|
| 329 |
-
|
| 330 |
-
|
|
|
|
| 331 |
x = block(x)
|
| 332 |
|
| 333 |
-
if
|
|
|
|
| 334 |
all_hidden_states.append(x)
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
# Final layer norm and output head
|
| 337 |
x = self.ln_f(x)
|
| 338 |
logits = self.lm_head(x)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# Compute loss if labels provided
|
| 341 |
loss = None
|
| 342 |
if labels is not None:
|
|
@@ -348,21 +386,21 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 348 |
# Multi-token prediction loss
|
| 349 |
if self.future_heads is not None:
|
| 350 |
for i, head in enumerate(self.future_heads):
|
| 351 |
-
future_logits = head(x[:, :-(i+1)])
|
| 352 |
loss = loss + F.cross_entropy(
|
| 353 |
future_logits.reshape(-1, self.config.vocab_size),
|
| 354 |
-
labels[:, (i+1):].reshape(-1),
|
| 355 |
)
|
| 356 |
|
| 357 |
if not return_dict:
|
| 358 |
if labels is not None:
|
| 359 |
-
return logits, loss
|
| 360 |
-
return logits
|
| 361 |
|
| 362 |
return CausalLMOutput(
|
| 363 |
loss=loss,
|
| 364 |
-
logits=logits
|
| 365 |
-
hidden_states=
|
| 366 |
)
|
| 367 |
|
| 368 |
def sample_logits(
|
|
|
|
| 290 |
self,
|
| 291 |
input_ids: Optional[torch.LongTensor] = None,
|
| 292 |
labels: Optional[torch.LongTensor] = None,
|
| 293 |
+
output_logits: Optional[bool] = False,
|
| 294 |
output_hidden_states: Optional[bool] = False,
|
| 295 |
return_dict: Optional[bool] = True,
|
| 296 |
+
up_until_layer: Optional[int] = None,
|
| 297 |
+
normalize_embeddings: Optional[str] = None,
|
| 298 |
# Legacy arguments for compatibility
|
|
|
|
| 299 |
seq: Optional[torch.LongTensor] = None,
|
| 300 |
tgt: Optional[torch.LongTensor] = None,
|
| 301 |
):
|
|
|
|
| 305 |
Args:
|
| 306 |
input_ids: Input token IDs of shape (batch_size, seq_len)
|
| 307 |
labels: Target token IDs for computing loss
|
| 308 |
+
output_logits: Whether to return all logits (including from future heads)
|
| 309 |
output_hidden_states: Whether to return all hidden states
|
| 310 |
return_dict: Whether to return a dict or tuple
|
| 311 |
+
up_until_layer: Stop forward pass at this layer index
|
| 312 |
+
normalize_embeddings: 'l2' or 'learned' to normalize hidden states
|
| 313 |
seq: Legacy argument (alias for input_ids)
|
| 314 |
tgt: Legacy argument (alias for labels)
|
| 315 |
|
| 316 |
Returns:
|
| 317 |
+
CausalLMOutput with logits and optional loss, or tuple
|
| 318 |
"""
|
| 319 |
# Handle legacy arguments
|
| 320 |
if seq is not None:
|
|
|
|
| 326 |
tok_emb = self.wte(input_ids)
|
| 327 |
x = self.drop(tok_emb)
|
| 328 |
|
| 329 |
+
# Collect hidden states
|
| 330 |
all_hidden_states = []
|
| 331 |
|
| 332 |
# Forward through transformer blocks
|
| 333 |
+
for block_idx, block in enumerate(self.h):
|
| 334 |
+
all_hidden_states.append(x)
|
| 335 |
+
if up_until_layer is not None and block_idx == up_until_layer:
|
| 336 |
+
break
|
| 337 |
x = block(x)
|
| 338 |
|
| 339 |
+
# Append final pre-ln_f state if we didn't exit early
|
| 340 |
+
if up_until_layer is None or block_idx == len(self.h) - 1:
|
| 341 |
all_hidden_states.append(x)
|
| 342 |
|
| 343 |
+
# Normalize hidden states if requested
|
| 344 |
+
hs_to_return = all_hidden_states
|
| 345 |
+
if output_hidden_states and normalize_embeddings is not None:
|
| 346 |
+
if normalize_embeddings == 'l2':
|
| 347 |
+
hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states]
|
| 348 |
+
elif normalize_embeddings == 'learned':
|
| 349 |
+
hs_to_return = []
|
| 350 |
+
L = len(self.h)
|
| 351 |
+
for i, h in enumerate(all_hidden_states):
|
| 352 |
+
if i < L:
|
| 353 |
+
hs_to_return.append(self.h[i].norm1(h))
|
| 354 |
+
else:
|
| 355 |
+
hs_to_return.append(self.ln_f(h))
|
| 356 |
+
|
| 357 |
+
# If only hidden states requested (not logits), return early
|
| 358 |
+
if output_hidden_states and not output_logits and labels is None:
|
| 359 |
+
return BaseModelOutput(
|
| 360 |
+
last_hidden_state=x,
|
| 361 |
+
hidden_states=hs_to_return,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
# Final layer norm and output head
|
| 365 |
x = self.ln_f(x)
|
| 366 |
logits = self.lm_head(x)
|
| 367 |
|
| 368 |
+
# Collect all logits if requested
|
| 369 |
+
all_logits = [logits] if output_logits else None
|
| 370 |
+
|
| 371 |
+
# Compute future head logits
|
| 372 |
+
if self.future_heads is not None:
|
| 373 |
+
for i, head in enumerate(self.future_heads):
|
| 374 |
+
future_logits = head(x[:, :-(i + 1)])
|
| 375 |
+
if output_logits:
|
| 376 |
+
all_logits.append(future_logits)
|
| 377 |
+
|
| 378 |
# Compute loss if labels provided
|
| 379 |
loss = None
|
| 380 |
if labels is not None:
|
|
|
|
| 386 |
# Multi-token prediction loss
|
| 387 |
if self.future_heads is not None:
|
| 388 |
for i, head in enumerate(self.future_heads):
|
| 389 |
+
future_logits = head(x[:, :-(i + 1)])
|
| 390 |
loss = loss + F.cross_entropy(
|
| 391 |
future_logits.reshape(-1, self.config.vocab_size),
|
| 392 |
+
labels[:, (i + 1):].reshape(-1),
|
| 393 |
)
|
| 394 |
|
| 395 |
if not return_dict:
|
| 396 |
if labels is not None:
|
| 397 |
+
return (all_logits if output_logits else logits), loss
|
| 398 |
+
return (all_logits if output_logits else logits), None
|
| 399 |
|
| 400 |
return CausalLMOutput(
|
| 401 |
loss=loss,
|
| 402 |
+
logits=all_logits if output_logits else logits,
|
| 403 |
+
hidden_states=hs_to_return if output_hidden_states else None,
|
| 404 |
)
|
| 405 |
|
| 406 |
def sample_logits(
|