TaoNet-pico-T1 / model.py
Lobakkang's picture
Upload TaoNet model to HuggingFace Hub
2981407 verified
"""
SimpleLLM - Mamba-style State-Space Model with ternary quantization.
"""
import torch
import torch.nn as nn
import torch.nn. functional as F
from .ssm import SSMBlock
from .bitlinear import BitLinear, RMSNorm, ActivationQuantize
from .factorized_embedding import FactorizedEmbedding
from .mla import MemoryOptimizedMLA
class SSMBlockWrapper(nn.Module):
"""
Pre-Norm SSM Block (Mamba-style) with nn.Sequential structure.
Structure:
x → Norm → SSM → Add → Norm → FFN → Add → output
"""
def __init__(self, config):
super().__init__()
self.ssm = SSMBlock(config)
self.feed_forward = nn.Sequential(
BitLinear(config.d_model, config.d_ff, bias=False),
nn.ReLU(),
BitLinear(config.d_ff, config.d_model, bias=False),
)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, mask=None):
# Pre-norm SSM with residual
x = x + self.dropout(self.ssm(x, mask)) # Normalize before SSM
# Pre-norm FFN with residual
x = x + self.dropout(self.feed_forward(x)) # Normalize before FFN
return x
class MLABlockWrapper(nn.Module):
"""
MLA Block with residual connection and FFN.
Structure:
x → Norm → MLA → Add → Norm → FFN → Add → output
Pre-norm structure stabilizes training and prevents gradient explosion.
"""
def __init__(self, config):
super().__init__()
self.mla = MemoryOptimizedMLA(config)
self.ffn = nn.Sequential(
nn.Linear(config.d_model, config.d_ff, bias=False),
nn.ReLU(),
nn.Linear(config.d_ff, config.d_model, bias=False),
nn.ReLU(),
nn.Linear(config.d_ff, config.d_model, bias=False),
)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, mask=None):
# Pre-norm MLA with residual
x = x + self.dropout(self.mla(x, mask=mask))
# Pre-norm FFN with residual
x = x + self.dropout(self.ffn(x))
return x
class SimpleLLM(nn.Module):
"""
Language Model with Hybrid Mamba-style SSM + MLA blocks.
Architecture: Token Embedding → (SSM Blocks + MLA Blocks) → Output Head
Hybrid structure controlled by config.ssm_per_mla:
- ssm_per_mla = 2: SSM, SSM, MLA, SSM, SSM, MLA, ...
- ssm_per_mla = 3: SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, ...
"""
def __init__(self, config):
super().__init__()
self.config = config
# Factorized embeddings
self.token_embedding = FactorizedEmbedding(
vocab_size=config.vocab_size,
d_model=config.d_model,
d_embed_rank=config.d_embed_rank
)
self.dropout = nn.Dropout(config.dropout)
# Build block architecture based on arrangement strategy
self.blocks = nn.ModuleList()
if config.block_arrangement == "interleaving":
self._build_interleaving_blocks(config)
elif config.block_arrangement == "layered":
self._build_layered_blocks(config)
else:
raise ValueError(f"Unknown block_arrangement: {config.block_arrangement}")
# =================================================================
# Two-stage output projection (mirrors factorized embedding)
# =================================================================
# Stage 1: d_model → d_embed_rank (reverse of embedding projection)
self.output_proj = nn.Linear(config.d_model, config.d_embed_rank, bias=False)
# Stage 2: d_embed_rank → vocab_size (tied to embedding table)
self.lm_head = nn.Linear(config.d_embed_rank, config.vocab_size, bias=False)
# Tie lm_head weights to embedding table
self.lm_head.weight = self.token_embedding.embed.weight
# =================================================================
# Final layer norm before output head to stabilize predictions
self.pre_final_norm = nn.LayerNorm(config.d_model)
self.final_norm = nn.LayerNorm(config.d_embed_rank)
self.apply(self._init_weights)
self.register_buffer("causal_mask_cache", None, persistent=False)
self._print_architecture()
def _build_interleaving_blocks(self, config):
"""
Build interleaving block arrangement: SSM blocks followed by MLA blocks in a pattern.
Example with ssm_per_mla=3 and n_layers=16:
SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA
"""
ssm_per_mla = config.ssm_per_mla
num_mla_blocks = max(1, config.n_layers // (ssm_per_mla + 1))
block_idx = 0
for mla_idx in range(num_mla_blocks):
# Add SSM blocks before each MLA block
for _ in range(ssm_per_mla):
if block_idx < config.n_layers:
self.blocks.append(SSMBlockWrapper(config))
block_idx += 1
# Add MLA block
if block_idx < config.n_layers:
self.blocks.append(MLABlockWrapper(config))
block_idx += 1
# Add remaining SSM blocks (if n_layers is not evenly divisible)
while block_idx < config.n_layers:
self.blocks.append(SSMBlockWrapper(config))
block_idx += 1
def _build_layered_blocks(self, config):
"""
Build layered block arrangement: MLA blocks followed by SSM blocks.
Example with layered_mla_num=4 and n_layers=16:
MLA, MLA, MLA, MLA, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM
"""
num_mla = config.layered_mla_num
# Add MLA blocks first
for _ in range(min(num_mla, config.n_layers)):
self.blocks.append(MLABlockWrapper(config))
# Add remaining SSM blocks
num_ssm = config.n_layers - len(self.blocks)
for _ in range(num_ssm):
self.blocks.append(SSMBlockWrapper(config))
def _init_weights(self, module):
if isinstance(module, nn.Linear) and not isinstance(module, BitLinear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module. bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _print_architecture(self):
total_params = self.count_parameters()
embed_params = self.token_embedding.get_num_params()
output_proj_params = self.config.d_model * self.config.d_embed_rank
ssm_params = total_params - embed_params - output_proj_params
# Count SSM and MLA blocks
num_ssm = sum(1 for b in self.blocks if isinstance(b, SSMBlockWrapper))
num_mla = sum(1 for b in self.blocks if isinstance(b, MLABlockWrapper))
print(f"\n{'='*60}")
print("MODEL ARCHITECTURE - HYBRID SSM + MLA")
print(f"{'='*60}")
print(f"Embedding: {embed_params/1e6:>6.2f}M params")
print(f"Hybrid Blocks: {num_ssm} SSM + {num_mla} MLA = {num_ssm + num_mla} total")
print(f"Output Proj: {output_proj_params/1e6:>6.2f}M params")
print(f"Output Head: tied to embedding (0 extra params)")
print(f"{'─'*60}")
print(f"Total: {total_params/1e6:>6.2f}M params")
print(f"{'='*60}")
print(f"Config: {self.config.n_layers} layers, {self.config.d_model} dim")
print(f"SSM: d_state={self.config.d_state}")
print(f"MLA: n_heads={self.config.n_heads}, d_kv_comp={self.config.d_kv_comp}")
# Print arrangement-specific info
if self.config.block_arrangement == "interleaving":
print(f"Arrangement: INTERLEAVING (ssm_per_mla={self.config.ssm_per_mla})")
elif self.config.block_arrangement == "layered":
print(f"Arrangement: LAYERED (mla_blocks={self.config.layered_mla_num}, ssm_blocks={num_ssm})")
print(f"{'='*60}\n")
def _get_causal_mask(self, seq_len, device):
if self.causal_mask_cache is None or self.causal_mask_cache. size(-1) < seq_len:
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
mask = mask.unsqueeze(0).unsqueeze(0)
self.causal_mask_cache = mask
return self.causal_mask_cache[: , :, :seq_len, :seq_len]
def forward(self, input_ids, attention_mask=None):
batch_size, seq_len = input_ids.shape
# Causal mask
causal_mask = self._get_causal_mask(seq_len, input_ids.device)
if attention_mask is not None:
padding_mask = attention_mask.unsqueeze(1).unsqueeze(1)
causal_mask = causal_mask * padding_mask
# Token embedding
x = self.token_embedding(input_ids)
x = self.dropout(x)
x = ActivationQuantize.apply(x)
# Hybrid SSM + MLA blocks
for block in self.blocks:
x = block(x, causal_mask)
# Two-stage output projection
x = self.pre_final_norm(x)
x = self.output_proj(x) # d_model → d_embed_rank
x = self.final_norm(x) # Normalize before output head
logits = self.lm_head(x) # d_embed_rank → vocab_size
return logits
def init_ssm_states(self, batch_size, device, dtype):
"""
Initialize SSM states for all SSM blocks (MLA blocks are stateless).
Returns:
states: List of [batch, d_state] tensors for each SSM block
"""
states = []
for block in self.blocks:
if isinstance(block, SSMBlockWrapper):
state = block.ssm.init_state(batch_size, device, dtype)
states.append(state)
return states
def inference_step(self, input_id, states, return_hidden_states=False):
"""
Single inference step for autoregressive generation (RNN-like).
Args:
input_id: [batch, 1] or scalar token id
states: List of SSM states from previous step
return_hidden_states: If True, also return SSM hidden states for visualization
Returns:
logits: [batch, vocab_size] - output logits for next token
new_states: List of updated SSM states for SSM blocks
hidden_states: (Optional) List of SSM hidden state values for each SSM layer
"""
if isinstance(input_id, int):
input_id = torch.tensor([[input_id]], dtype=torch.long, device=next(self.parameters()).device)
elif input_id.dim() == 1:
input_id = input_id.unsqueeze(0)
# Embed the token
x = self.token_embedding(input_id) # [batch, 1, d_model]
x = x.squeeze(1) # [batch, d_model]
x = ActivationQuantize.apply(x)
# Pass through hybrid blocks
new_states = []
hidden_states = [] if return_hidden_states else None
state_idx = 0 # Track position in states list (only for SSM blocks)
for block in self.blocks:
if isinstance(block, SSMBlockWrapper):
# SSM block with state management
residual = x
ssm_out, new_state = block.ssm.step(x, states[state_idx])
# Collect hidden state if requested
if return_hidden_states:
hidden_states.append(new_state.clone().detach())
x = residual + block.dropout(ssm_out)
# FFN + residual
residual = x
ffn_out = block.feed_forward(x)
x = residual + block.dropout(ffn_out)
new_states.append(new_state)
state_idx += 1
else:
# MLA block (stateless)
x = block(x.unsqueeze(1), mask=None).squeeze(1)
# Output projection
x = self.pre_final_norm(x)
x = self.output_proj(x)
x = self.final_norm(x)
logits = self.lm_head(x)
if return_hidden_states:
return logits, new_states, hidden_states
else:
return logits, new_states
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def count_non_embedding_parameters(self):
total = self.count_parameters()
embedding_params = self.token_embedding.get_num_params()
return total - embedding_params
@torch.no_grad()
def generate(
self,
input_ids,
max_new_tokens=50,
temperature=1.0,
top_k=50,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True
):
"""Generate tokens autoregressively."""
self.eval()
for _ in range(max_new_tokens):
# Crop to max_seq_len
idx_cond = input_ids[:, -self.config.max_seq_len:]
# Forward
logits = self(idx_cond)
logits = logits[:, -1, : ] / max(temperature, 1e-5)
# Repetition penalty
if repetition_penalty != 1.0:
for i in range(input_ids.shape[0]):
for token_id in set(input_ids[i].tolist()):
if logits[i, token_id] > 0:
logits[i, token_id] /= repetition_penalty
else:
logits[i, token_id] *= repetition_penalty
# Top-k filtering
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits. size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Top-p filtering
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[: , 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
for i in range(logits.shape[0]):
indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]]
logits[i, indices_to_remove] = float('-inf')
# Sample or greedy
probs = F.softmax(logits, dim=-1)
if do_sample:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(probs, dim=-1, keepdim=True)
input_ids = torch. cat([input_ids, next_token], dim=1)
# Stop on EOS
if self.config.eos_token_id is not None:
if (next_token == self.config. eos_token_id).all():
break
return input_ids
def get_num_params(self, non_embedding=True):
if non_embedding:
return self.count_non_embedding_parameters()
return self.count_parameters()