Upload AuriStream base model code
Browse files- modeling_auristream.py +23 -25
modeling_auristream.py
CHANGED
|
@@ -251,13 +251,11 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 251 |
super().__init__(config)
|
| 252 |
self.config = config
|
| 253 |
|
| 254 |
-
# Transformer components
|
| 255 |
-
self.
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
ln_f=RMSNorm(config.n_embd, bias=config.bias),
|
| 260 |
-
))
|
| 261 |
|
| 262 |
# Multi-token prediction heads
|
| 263 |
if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
|
|
@@ -269,7 +267,7 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 269 |
self.future_heads = None
|
| 270 |
|
| 271 |
# Output head
|
| 272 |
-
self.
|
| 273 |
|
| 274 |
# Initialize weights
|
| 275 |
self.apply(self._init_weights)
|
|
@@ -279,10 +277,10 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 279 |
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 280 |
|
| 281 |
def get_input_embeddings(self):
|
| 282 |
-
return self.
|
| 283 |
|
| 284 |
def set_input_embeddings(self, value):
|
| 285 |
-
self.
|
| 286 |
|
| 287 |
def get_num_params(self, non_embedding=True):
|
| 288 |
"""Return the number of parameters in the model."""
|
|
@@ -319,14 +317,14 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 319 |
labels = tgt
|
| 320 |
|
| 321 |
# Get embeddings
|
| 322 |
-
tok_emb = self.
|
| 323 |
-
x = self.
|
| 324 |
|
| 325 |
# Collect hidden states if requested
|
| 326 |
all_hidden_states = []
|
| 327 |
|
| 328 |
# Forward through transformer blocks
|
| 329 |
-
for block in self.
|
| 330 |
if output_hidden_states:
|
| 331 |
all_hidden_states.append(x)
|
| 332 |
x = block(x)
|
|
@@ -335,8 +333,8 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 335 |
all_hidden_states.append(x)
|
| 336 |
|
| 337 |
# Final layer norm and output head
|
| 338 |
-
x = self.
|
| 339 |
-
logits = self.
|
| 340 |
|
| 341 |
# Compute loss if labels provided
|
| 342 |
loss = None
|
|
@@ -438,42 +436,42 @@ class AuriStreamModel(AuriStreamPreTrainedModel):
|
|
| 438 |
b, t = seq.size()
|
| 439 |
|
| 440 |
# Encode conditioning sequence into KV cache
|
| 441 |
-
tok_emb = self.
|
| 442 |
-
x = self.
|
| 443 |
|
| 444 |
k_list = []
|
| 445 |
v_list = []
|
| 446 |
-
for block in self.
|
| 447 |
x, k, v = block(x, return_kv=True)
|
| 448 |
k_list.append(k)
|
| 449 |
v_list.append(v)
|
| 450 |
|
| 451 |
k_cache = torch.stack(k_list, dim=0)
|
| 452 |
v_cache = torch.stack(v_list, dim=0)
|
| 453 |
-
x = self.
|
| 454 |
|
| 455 |
# First prediction
|
| 456 |
-
logits = self.
|
| 457 |
predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
|
| 458 |
all_logits.append(logits)
|
| 459 |
|
| 460 |
# Generate remaining tokens
|
| 461 |
for i in range(n_tokens - 1):
|
| 462 |
-
tok_emb = self.
|
| 463 |
-
x = self.
|
| 464 |
|
| 465 |
k_list = []
|
| 466 |
v_list = []
|
| 467 |
-
for block_idx, block in enumerate(self.
|
| 468 |
x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
|
| 469 |
k_list.append(k)
|
| 470 |
v_list.append(v)
|
| 471 |
|
| 472 |
-
x = self.
|
| 473 |
k_cache = torch.stack(k_list, dim=0)
|
| 474 |
v_cache = torch.stack(v_list, dim=0)
|
| 475 |
|
| 476 |
-
logits = self.
|
| 477 |
predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
|
| 478 |
all_logits.append(logits)
|
| 479 |
|
|
|
|
| 251 |
super().__init__(config)
|
| 252 |
self.config = config
|
| 253 |
|
| 254 |
+
# Transformer components (no wrapper to match weight keys)
|
| 255 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 256 |
+
self.drop = nn.Dropout(config.dropout)
|
| 257 |
+
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
| 258 |
+
self.ln_f = RMSNorm(config.n_embd, bias=config.bias)
|
|
|
|
|
|
|
| 259 |
|
| 260 |
# Multi-token prediction heads
|
| 261 |
if hasattr(config, 'n_pred_steps') and config.n_pred_steps > 1:
|
|
|
|
| 267 |
self.future_heads = None
|
| 268 |
|
| 269 |
# Output head
|
| 270 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 271 |
|
| 272 |
# Initialize weights
|
| 273 |
self.apply(self._init_weights)
|
|
|
|
| 277 |
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 278 |
|
| 279 |
def get_input_embeddings(self):
|
| 280 |
+
return self.wte
|
| 281 |
|
| 282 |
def set_input_embeddings(self, value):
|
| 283 |
+
self.wte = value
|
| 284 |
|
| 285 |
def get_num_params(self, non_embedding=True):
|
| 286 |
"""Return the number of parameters in the model."""
|
|
|
|
| 317 |
labels = tgt
|
| 318 |
|
| 319 |
# Get embeddings
|
| 320 |
+
tok_emb = self.wte(input_ids)
|
| 321 |
+
x = self.drop(tok_emb)
|
| 322 |
|
| 323 |
# Collect hidden states if requested
|
| 324 |
all_hidden_states = []
|
| 325 |
|
| 326 |
# Forward through transformer blocks
|
| 327 |
+
for block in self.h:
|
| 328 |
if output_hidden_states:
|
| 329 |
all_hidden_states.append(x)
|
| 330 |
x = block(x)
|
|
|
|
| 333 |
all_hidden_states.append(x)
|
| 334 |
|
| 335 |
# Final layer norm and output head
|
| 336 |
+
x = self.ln_f(x)
|
| 337 |
+
logits = self.lm_head(x)
|
| 338 |
|
| 339 |
# Compute loss if labels provided
|
| 340 |
loss = None
|
|
|
|
| 436 |
b, t = seq.size()
|
| 437 |
|
| 438 |
# Encode conditioning sequence into KV cache
|
| 439 |
+
tok_emb = self.wte(seq)
|
| 440 |
+
x = self.drop(tok_emb)
|
| 441 |
|
| 442 |
k_list = []
|
| 443 |
v_list = []
|
| 444 |
+
for block in self.h:
|
| 445 |
x, k, v = block(x, return_kv=True)
|
| 446 |
k_list.append(k)
|
| 447 |
v_list.append(v)
|
| 448 |
|
| 449 |
k_cache = torch.stack(k_list, dim=0)
|
| 450 |
v_cache = torch.stack(v_list, dim=0)
|
| 451 |
+
x = self.ln_f(x)
|
| 452 |
|
| 453 |
# First prediction
|
| 454 |
+
logits = self.lm_head(x[:, [-1]])
|
| 455 |
predictions = [self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)]
|
| 456 |
all_logits.append(logits)
|
| 457 |
|
| 458 |
# Generate remaining tokens
|
| 459 |
for i in range(n_tokens - 1):
|
| 460 |
+
tok_emb = self.wte(predictions[-1])
|
| 461 |
+
x = self.drop(tok_emb)
|
| 462 |
|
| 463 |
k_list = []
|
| 464 |
v_list = []
|
| 465 |
+
for block_idx, block in enumerate(self.h):
|
| 466 |
x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx])
|
| 467 |
k_list.append(k)
|
| 468 |
v_list.append(v)
|
| 469 |
|
| 470 |
+
x = self.ln_f(x)
|
| 471 |
k_cache = torch.stack(k_list, dim=0)
|
| 472 |
v_cache = torch.stack(v_list, dim=0)
|
| 473 |
|
| 474 |
+
logits = self.lm_head(x)
|
| 475 |
predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p))
|
| 476 |
all_logits.append(logits)
|
| 477 |
|