klemenk commited on
Commit
248af8b
·
verified ·
1 Parent(s): 5abc44c

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +51 -13
modeling_auristream.py CHANGED
@@ -290,10 +290,12 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
290
  self,
291
  input_ids: Optional[torch.LongTensor] = None,
292
  labels: Optional[torch.LongTensor] = None,
 
293
  output_hidden_states: Optional[bool] = False,
294
  return_dict: Optional[bool] = True,
 
 
295
  # Legacy arguments for compatibility
296
- return_logits: Optional[bool] = True,
297
  seq: Optional[torch.LongTensor] = None,
298
  tgt: Optional[torch.LongTensor] = None,
299
  ):
@@ -303,13 +305,16 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
303
  Args:
304
  input_ids: Input token IDs of shape (batch_size, seq_len)
305
  labels: Target token IDs for computing loss
 
306
  output_hidden_states: Whether to return all hidden states
307
  return_dict: Whether to return a dict or tuple
 
 
308
  seq: Legacy argument (alias for input_ids)
309
  tgt: Legacy argument (alias for labels)
310
 
311
  Returns:
312
- CausalLMOutput with logits and optional loss
313
  """
314
  # Handle legacy arguments
315
  if seq is not None:
@@ -321,22 +326,55 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
321
  tok_emb = self.wte(input_ids)
322
  x = self.drop(tok_emb)
323
 
324
- # Collect hidden states if requested
325
  all_hidden_states = []
326
 
327
  # Forward through transformer blocks
328
- for block in self.h:
329
- if output_hidden_states:
330
- all_hidden_states.append(x)
 
331
  x = block(x)
332
 
333
- if output_hidden_states:
 
334
  all_hidden_states.append(x)
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  # Final layer norm and output head
337
  x = self.ln_f(x)
338
  logits = self.lm_head(x)
339
 
 
 
 
 
 
 
 
 
 
 
340
  # Compute loss if labels provided
341
  loss = None
342
  if labels is not None:
@@ -348,21 +386,21 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
348
  # Multi-token prediction loss
349
  if self.future_heads is not None:
350
  for i, head in enumerate(self.future_heads):
351
- future_logits = head(x[:, :-(i+1)])
352
  loss = loss + F.cross_entropy(
353
  future_logits.reshape(-1, self.config.vocab_size),
354
- labels[:, (i+1):].reshape(-1),
355
  )
356
 
357
  if not return_dict:
358
  if labels is not None:
359
- return logits, loss
360
- return logits.unsqueeze(0), None
361
 
362
  return CausalLMOutput(
363
  loss=loss,
364
- logits=logits.unsqueeze(0),
365
- hidden_states=all_hidden_states if output_hidden_states else None,
366
  )
367
 
368
  def sample_logits(
 
290
  self,
291
  input_ids: Optional[torch.LongTensor] = None,
292
  labels: Optional[torch.LongTensor] = None,
293
+ output_logits: Optional[bool] = False,
294
  output_hidden_states: Optional[bool] = False,
295
  return_dict: Optional[bool] = True,
296
+ up_until_layer: Optional[int] = None,
297
+ normalize_embeddings: Optional[str] = None,
298
  # Legacy arguments for compatibility
 
299
  seq: Optional[torch.LongTensor] = None,
300
  tgt: Optional[torch.LongTensor] = None,
301
  ):
 
305
  Args:
306
  input_ids: Input token IDs of shape (batch_size, seq_len)
307
  labels: Target token IDs for computing loss
308
+ output_logits: Whether to return all logits (including from future heads)
309
  output_hidden_states: Whether to return all hidden states
310
  return_dict: Whether to return a dict or tuple
311
+ up_until_layer: Stop forward pass at this layer index
312
+ normalize_embeddings: 'l2' or 'learned' to normalize hidden states
313
  seq: Legacy argument (alias for input_ids)
314
  tgt: Legacy argument (alias for labels)
315
 
316
  Returns:
317
+ CausalLMOutput with logits and optional loss, or tuple
318
  """
319
  # Handle legacy arguments
320
  if seq is not None:
 
326
  tok_emb = self.wte(input_ids)
327
  x = self.drop(tok_emb)
328
 
329
+ # Collect hidden states
330
  all_hidden_states = []
331
 
332
  # Forward through transformer blocks
333
+ for block_idx, block in enumerate(self.h):
334
+ all_hidden_states.append(x)
335
+ if up_until_layer is not None and block_idx == up_until_layer:
336
+ break
337
  x = block(x)
338
 
339
+ # Append final pre-ln_f state if we didn't exit early
340
+ if up_until_layer is None or block_idx == len(self.h) - 1:
341
  all_hidden_states.append(x)
342
 
343
+ # Normalize hidden states if requested
344
+ hs_to_return = all_hidden_states
345
+ if output_hidden_states and normalize_embeddings is not None:
346
+ if normalize_embeddings == 'l2':
347
+ hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states]
348
+ elif normalize_embeddings == 'learned':
349
+ hs_to_return = []
350
+ L = len(self.h)
351
+ for i, h in enumerate(all_hidden_states):
352
+ if i < L:
353
+ hs_to_return.append(self.h[i].norm1(h))
354
+ else:
355
+ hs_to_return.append(self.ln_f(h))
356
+
357
+ # If only hidden states requested (not logits), return early
358
+ if output_hidden_states and not output_logits and labels is None:
359
+ return BaseModelOutput(
360
+ last_hidden_state=x,
361
+ hidden_states=hs_to_return,
362
+ )
363
+
364
  # Final layer norm and output head
365
  x = self.ln_f(x)
366
  logits = self.lm_head(x)
367
 
368
+ # Collect all logits if requested
369
+ all_logits = [logits] if output_logits else None
370
+
371
+ # Compute future head logits
372
+ if self.future_heads is not None:
373
+ for i, head in enumerate(self.future_heads):
374
+ future_logits = head(x[:, :-(i + 1)])
375
+ if output_logits:
376
+ all_logits.append(future_logits)
377
+
378
  # Compute loss if labels provided
379
  loss = None
380
  if labels is not None:
 
386
  # Multi-token prediction loss
387
  if self.future_heads is not None:
388
  for i, head in enumerate(self.future_heads):
389
+ future_logits = head(x[:, :-(i + 1)])
390
  loss = loss + F.cross_entropy(
391
  future_logits.reshape(-1, self.config.vocab_size),
392
+ labels[:, (i + 1):].reshape(-1),
393
  )
394
 
395
  if not return_dict:
396
  if labels is not None:
397
+ return (all_logits if output_logits else logits), loss
398
+ return (all_logits if output_logits else logits), None
399
 
400
  return CausalLMOutput(
401
  loss=loss,
402
+ logits=all_logits if output_logits else logits,
403
+ hidden_states=hs_to_return if output_hidden_states else None,
404
  )
405
 
406
  def sample_logits(