klemenk commited on
Commit
92b5b06
·
verified ·
1 Parent(s): 9c3e596

Upload AuriStream base model code

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +23 -25
modeling_auristream.py CHANGED
@@ -251,13 +251,11 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
251
  super().__init__(config)
252
  self.config = config
253
 
254
- # Transformer components
255
- self.transformer = nn.ModuleDict(dict(
256
- wte=nn.Embedding(config.vocab_size, config.n_embd),
257
- drop=nn.Dropout(config.dropout),
258
- h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
259
- ln_f=RMSNorm(config.n_embd, bias=config.bias),
260
- ))
261
 
262
  # Multi-token prediction heads
263
  if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
@@ -269,7 +267,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
269
  self.future_heads = None
270
 
271
  # Output head
272
- self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
273
 
274
  # Initialize weights
275
  self.apply(self._init_weights)
@@ -279,10 +277,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
279
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
280
 
281
  def get_input_embeddings(self):
282
- return self.transformer.wte
283
 
284
  def set_input_embeddings(self, value):
285
- self.transformer.wte = value
286
 
287
  def get_num_params(self, non_embedding=True):
288
  """Return the number of parameters in the model."""
@@ -319,14 +317,14 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
319
  labels = tgt
320
 
321
  # Get embeddings
322
- tok_emb = self.transformer.wte(input_ids)
323
- x = self.transformer.drop(tok_emb)
324
 
325
  # Collect hidden states if requested
326
  all_hidden_states = []
327
 
328
  # Forward through transformer blocks
329
- for block in self.transformer.h:
330
  if output_hidden_states:
331
  all_hidden_states.append(x)
332
  x = block(x)
@@ -335,8 +333,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
335
  all_hidden_states.append(x)
336
 
337
  # Final layer norm and output head
338
- x = self.transformer.ln_f(x)
339
- logits = self.coch_head(x)
340
 
341
  # Compute loss if labels provided
342
  loss = None
@@ -438,42 +436,42 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
438
  b, t = seq.size()
439
 
440
  # Encode conditioning sequence into KV cache
441
- tok_emb = self.transformer.wte(seq)
442
- x = self.transformer.drop(tok_emb)
443
 
444
  k_list = []
445
  v_list = []
446
- for block in self.transformer.h:
447
  x, k, v = block(x, return_kv=True)
448
  k_list.append(k)
449
  v_list.append(v)
450
 
451
  k_cache = torch.stack(k_list, dim=0)
452
  v_cache = torch.stack(v_list, dim=0)
453
- x = self.transformer.ln_f(x)
454
 
455
  # First prediction
456
- logits = self.coch_head(x[:, [-1]])
457
  predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
458
  all_logits.append(logits)
459
 
460
  # Generate remaining tokens
461
  for i in range(n_tokens - 1):
462
- tok_emb = self.transformer.wte(predictions[-1])
463
- x = self.transformer.drop(tok_emb)
464
 
465
  k_list = []
466
  v_list = []
467
- for block_idx, block in enumerate(self.transformer.h):
468
  x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
469
  k_list.append(k)
470
  v_list.append(v)
471
 
472
- x = self.transformer.ln_f(x)
473
  k_cache = torch.stack(k_list, dim=0)
474
  v_cache = torch.stack(v_list, dim=0)
475
 
476
- logits = self.coch_head(x)
477
  predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
478
  all_logits.append(logits)
479
 
 
251
  super().__init__(config)
252
  self.config = config
253
 
254
+ # Transformer components (no wrapper to match weight keys)
255
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
256
+ self.drop = nn.Dropout(config.dropout)
257
+ self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
258
+ self.ln_f = RMSNorm(config.n_embd, bias=config.bias)
 
 
259
 
260
  # Multi-token prediction heads
261
  if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
 
267
  self.future_heads = None
268
 
269
  # Output head
270
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
271
 
272
  # Initialize weights
273
  self.apply(self._init_weights)
 
277
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
278
 
279
  def get_input_embeddings(self):
280
+ return self.wte
281
 
282
  def set_input_embeddings(self, value):
283
+ self.wte = value
284
 
285
  def get_num_params(self, non_embedding=True):
286
  """Return the number of parameters in the model."""
 
317
  labels = tgt
318
 
319
  # Get embeddings
320
+ tok_emb = self.wte(input_ids)
321
+ x = self.drop(tok_emb)
322
 
323
  # Collect hidden states if requested
324
  all_hidden_states = []
325
 
326
  # Forward through transformer blocks
327
+ for block in self.h:
328
  if output_hidden_states:
329
  all_hidden_states.append(x)
330
  x = block(x)
 
333
  all_hidden_states.append(x)
334
 
335
  # Final layer norm and output head
336
+ x = self.ln_f(x)
337
+ logits = self.lm_head(x)
338
 
339
  # Compute loss if labels provided
340
  loss = None
 
436
  b, t = seq.size()
437
 
438
  # Encode conditioning sequence into KV cache
439
+ tok_emb = self.wte(seq)
440
+ x = self.drop(tok_emb)
441
 
442
  k_list = []
443
  v_list = []
444
+ for block in self.h:
445
  x, k, v = block(x, return_kv=True)
446
  k_list.append(k)
447
  v_list.append(v)
448
 
449
  k_cache = torch.stack(k_list, dim=0)
450
  v_cache = torch.stack(v_list, dim=0)
451
+ x = self.ln_f(x)
452
 
453
  # First prediction
454
+ logits = self.lm_head(x[:, [-1]])
455
  predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
456
  all_logits.append(logits)
457
 
458
  # Generate remaining tokens
459
  for i in range(n_tokens - 1):
460
+ tok_emb = self.wte(predictions[-1])
461
+ x = self.drop(tok_emb)
462
 
463
  k_list = []
464
  v_list = []
465
+ for block_idx, block in enumerate(self.h):
466
  x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
467
  k_list.append(k)
468
  v_list.append(v)
469
 
470
+ x = self.ln_f(x)
471
  k_cache = torch.stack(k_list, dim=0)
472
  v_cache = torch.stack(v_list, dim=0)
473
 
474
+ logits = self.lm_head(x)
475
  predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
476
  all_logits.append(logits)
477