| | """
|
| | SimpleLLM - Mamba-style State-Space Model with ternary quantization.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn. functional as F
|
| |
|
| | from .ssm import SSMBlock
|
| | from .bitlinear import BitLinear, RMSNorm, ActivationQuantize
|
| | from .factorized_embedding import FactorizedEmbedding
|
| | from .mla import MemoryOptimizedMLA
|
| |
|
| | class SSMBlockWrapper(nn.Module):
|
| | """
|
| | Pre-Norm SSM Block (Mamba-style) with nn.Sequential structure.
|
| |
|
| | Structure:
|
| | x → Norm → SSM → Add → Norm → FFN → Add → output
|
| | """
|
| |
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.ssm = SSMBlock(config)
|
| | self.feed_forward = nn.Sequential(
|
| | BitLinear(config.d_model, config.d_ff, bias=False),
|
| | nn.ReLU(),
|
| | BitLinear(config.d_ff, config.d_model, bias=False),
|
| | )
|
| | self.dropout = nn.Dropout(config.dropout)
|
| |
|
| | def forward(self, x, mask=None):
|
| |
|
| | x = x + self.dropout(self.ssm(x, mask))
|
| |
|
| | x = x + self.dropout(self.feed_forward(x))
|
| | return x
|
| |
|
| |
|
| | class MLABlockWrapper(nn.Module):
|
| | """
|
| | MLA Block with residual connection and FFN.
|
| |
|
| | Structure:
|
| | x → Norm → MLA → Add → Norm → FFN → Add → output
|
| |
|
| | Pre-norm structure stabilizes training and prevents gradient explosion.
|
| | """
|
| |
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.mla = MemoryOptimizedMLA(config)
|
| | self.ffn = nn.Sequential(
|
| | nn.Linear(config.d_model, config.d_ff, bias=False),
|
| | nn.ReLU(),
|
| | nn.Linear(config.d_ff, config.d_model, bias=False),
|
| | nn.ReLU(),
|
| | nn.Linear(config.d_ff, config.d_model, bias=False),
|
| | )
|
| | self.dropout = nn.Dropout(config.dropout)
|
| |
|
| | def forward(self, x, mask=None):
|
| |
|
| | x = x + self.dropout(self.mla(x, mask=mask))
|
| |
|
| | x = x + self.dropout(self.ffn(x))
|
| | return x
|
| |
|
| |
|
| | class SimpleLLM(nn.Module):
|
| | """
|
| | Language Model with Hybrid Mamba-style SSM + MLA blocks.
|
| |
|
| | Architecture: Token Embedding → (SSM Blocks + MLA Blocks) → Output Head
|
| |
|
| | Hybrid structure controlled by config.ssm_per_mla:
|
| | - ssm_per_mla = 2: SSM, SSM, MLA, SSM, SSM, MLA, ...
|
| | - ssm_per_mla = 3: SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, ...
|
| | """
|
| |
|
| | def __init__(self, config):
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| |
|
| | self.token_embedding = FactorizedEmbedding(
|
| | vocab_size=config.vocab_size,
|
| | d_model=config.d_model,
|
| | d_embed_rank=config.d_embed_rank
|
| | )
|
| |
|
| | self.dropout = nn.Dropout(config.dropout)
|
| |
|
| |
|
| | self.blocks = nn.ModuleList()
|
| |
|
| | if config.block_arrangement == "interleaving":
|
| | self._build_interleaving_blocks(config)
|
| | elif config.block_arrangement == "layered":
|
| | self._build_layered_blocks(config)
|
| | else:
|
| | raise ValueError(f"Unknown block_arrangement: {config.block_arrangement}")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.output_proj = nn.Linear(config.d_model, config.d_embed_rank, bias=False)
|
| |
|
| |
|
| | self.lm_head = nn.Linear(config.d_embed_rank, config.vocab_size, bias=False)
|
| |
|
| |
|
| | self.lm_head.weight = self.token_embedding.embed.weight
|
| |
|
| |
|
| |
|
| | self.pre_final_norm = nn.LayerNorm(config.d_model)
|
| | self.final_norm = nn.LayerNorm(config.d_embed_rank)
|
| |
|
| | self.apply(self._init_weights)
|
| | self.register_buffer("causal_mask_cache", None, persistent=False)
|
| | self._print_architecture()
|
| |
|
| | def _build_interleaving_blocks(self, config):
|
| | """
|
| | Build interleaving block arrangement: SSM blocks followed by MLA blocks in a pattern.
|
| |
|
| | Example with ssm_per_mla=3 and n_layers=16:
|
| | SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA
|
| | """
|
| | ssm_per_mla = config.ssm_per_mla
|
| | num_mla_blocks = max(1, config.n_layers // (ssm_per_mla + 1))
|
| |
|
| | block_idx = 0
|
| | for mla_idx in range(num_mla_blocks):
|
| |
|
| | for _ in range(ssm_per_mla):
|
| | if block_idx < config.n_layers:
|
| | self.blocks.append(SSMBlockWrapper(config))
|
| | block_idx += 1
|
| |
|
| |
|
| | if block_idx < config.n_layers:
|
| | self.blocks.append(MLABlockWrapper(config))
|
| | block_idx += 1
|
| |
|
| |
|
| | while block_idx < config.n_layers:
|
| | self.blocks.append(SSMBlockWrapper(config))
|
| | block_idx += 1
|
| |
|
| | def _build_layered_blocks(self, config):
|
| | """
|
| | Build layered block arrangement: MLA blocks followed by SSM blocks.
|
| |
|
| | Example with layered_mla_num=4 and n_layers=16:
|
| | MLA, MLA, MLA, MLA, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM
|
| | """
|
| | num_mla = config.layered_mla_num
|
| |
|
| |
|
| | for _ in range(min(num_mla, config.n_layers)):
|
| | self.blocks.append(MLABlockWrapper(config))
|
| |
|
| |
|
| | num_ssm = config.n_layers - len(self.blocks)
|
| | for _ in range(num_ssm):
|
| | self.blocks.append(SSMBlockWrapper(config))
|
| |
|
| | def _init_weights(self, module):
|
| | if isinstance(module, nn.Linear) and not isinstance(module, BitLinear):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if module. bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| | elif isinstance(module, nn.Embedding):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| |
|
| | def _print_architecture(self):
|
| | total_params = self.count_parameters()
|
| | embed_params = self.token_embedding.get_num_params()
|
| | output_proj_params = self.config.d_model * self.config.d_embed_rank
|
| | ssm_params = total_params - embed_params - output_proj_params
|
| |
|
| |
|
| | num_ssm = sum(1 for b in self.blocks if isinstance(b, SSMBlockWrapper))
|
| | num_mla = sum(1 for b in self.blocks if isinstance(b, MLABlockWrapper))
|
| |
|
| | print(f"\n{'='*60}")
|
| | print("MODEL ARCHITECTURE - HYBRID SSM + MLA")
|
| | print(f"{'='*60}")
|
| | print(f"Embedding: {embed_params/1e6:>6.2f}M params")
|
| | print(f"Hybrid Blocks: {num_ssm} SSM + {num_mla} MLA = {num_ssm + num_mla} total")
|
| | print(f"Output Proj: {output_proj_params/1e6:>6.2f}M params")
|
| | print(f"Output Head: tied to embedding (0 extra params)")
|
| | print(f"{'─'*60}")
|
| | print(f"Total: {total_params/1e6:>6.2f}M params")
|
| | print(f"{'='*60}")
|
| | print(f"Config: {self.config.n_layers} layers, {self.config.d_model} dim")
|
| | print(f"SSM: d_state={self.config.d_state}")
|
| | print(f"MLA: n_heads={self.config.n_heads}, d_kv_comp={self.config.d_kv_comp}")
|
| |
|
| |
|
| | if self.config.block_arrangement == "interleaving":
|
| | print(f"Arrangement: INTERLEAVING (ssm_per_mla={self.config.ssm_per_mla})")
|
| | elif self.config.block_arrangement == "layered":
|
| | print(f"Arrangement: LAYERED (mla_blocks={self.config.layered_mla_num}, ssm_blocks={num_ssm})")
|
| |
|
| | print(f"{'='*60}\n")
|
| |
|
| | def _get_causal_mask(self, seq_len, device):
|
| | if self.causal_mask_cache is None or self.causal_mask_cache. size(-1) < seq_len:
|
| | mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
| | mask = mask.unsqueeze(0).unsqueeze(0)
|
| | self.causal_mask_cache = mask
|
| | return self.causal_mask_cache[: , :, :seq_len, :seq_len]
|
| |
|
| | def forward(self, input_ids, attention_mask=None):
|
| | batch_size, seq_len = input_ids.shape
|
| |
|
| |
|
| | causal_mask = self._get_causal_mask(seq_len, input_ids.device)
|
| | if attention_mask is not None:
|
| | padding_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
| | causal_mask = causal_mask * padding_mask
|
| |
|
| |
|
| | x = self.token_embedding(input_ids)
|
| | x = self.dropout(x)
|
| | x = ActivationQuantize.apply(x)
|
| |
|
| |
|
| | for block in self.blocks:
|
| | x = block(x, causal_mask)
|
| |
|
| |
|
| | x = self.pre_final_norm(x)
|
| | x = self.output_proj(x)
|
| | x = self.final_norm(x)
|
| | logits = self.lm_head(x)
|
| |
|
| | return logits
|
| |
|
| | def init_ssm_states(self, batch_size, device, dtype):
|
| | """
|
| | Initialize SSM states for all SSM blocks (MLA blocks are stateless).
|
| |
|
| | Returns:
|
| | states: List of [batch, d_state] tensors for each SSM block
|
| | """
|
| | states = []
|
| | for block in self.blocks:
|
| | if isinstance(block, SSMBlockWrapper):
|
| | state = block.ssm.init_state(batch_size, device, dtype)
|
| | states.append(state)
|
| | return states
|
| |
|
| | def inference_step(self, input_id, states, return_hidden_states=False):
|
| | """
|
| | Single inference step for autoregressive generation (RNN-like).
|
| |
|
| | Args:
|
| | input_id: [batch, 1] or scalar token id
|
| | states: List of SSM states from previous step
|
| | return_hidden_states: If True, also return SSM hidden states for visualization
|
| |
|
| | Returns:
|
| | logits: [batch, vocab_size] - output logits for next token
|
| | new_states: List of updated SSM states for SSM blocks
|
| | hidden_states: (Optional) List of SSM hidden state values for each SSM layer
|
| | """
|
| | if isinstance(input_id, int):
|
| | input_id = torch.tensor([[input_id]], dtype=torch.long, device=next(self.parameters()).device)
|
| | elif input_id.dim() == 1:
|
| | input_id = input_id.unsqueeze(0)
|
| |
|
| |
|
| | x = self.token_embedding(input_id)
|
| | x = x.squeeze(1)
|
| | x = ActivationQuantize.apply(x)
|
| |
|
| |
|
| | new_states = []
|
| | hidden_states = [] if return_hidden_states else None
|
| | state_idx = 0
|
| |
|
| | for block in self.blocks:
|
| | if isinstance(block, SSMBlockWrapper):
|
| |
|
| | residual = x
|
| | ssm_out, new_state = block.ssm.step(x, states[state_idx])
|
| |
|
| |
|
| | if return_hidden_states:
|
| | hidden_states.append(new_state.clone().detach())
|
| |
|
| | x = residual + block.dropout(ssm_out)
|
| |
|
| |
|
| | residual = x
|
| | ffn_out = block.feed_forward(x)
|
| | x = residual + block.dropout(ffn_out)
|
| |
|
| | new_states.append(new_state)
|
| | state_idx += 1
|
| | else:
|
| |
|
| | x = block(x.unsqueeze(1), mask=None).squeeze(1)
|
| |
|
| |
|
| | x = self.pre_final_norm(x)
|
| | x = self.output_proj(x)
|
| | x = self.final_norm(x)
|
| | logits = self.lm_head(x)
|
| |
|
| | if return_hidden_states:
|
| | return logits, new_states, hidden_states
|
| | else:
|
| | return logits, new_states
|
| |
|
| | def count_parameters(self):
|
| | return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| | def count_non_embedding_parameters(self):
|
| | total = self.count_parameters()
|
| | embedding_params = self.token_embedding.get_num_params()
|
| | return total - embedding_params
|
| |
|
| | @torch.no_grad()
|
| | def generate(
|
| | self,
|
| | input_ids,
|
| | max_new_tokens=50,
|
| | temperature=1.0,
|
| | top_k=50,
|
| | top_p=0.9,
|
| | repetition_penalty=1.1,
|
| | do_sample=True
|
| | ):
|
| | """Generate tokens autoregressively."""
|
| | self.eval()
|
| |
|
| | for _ in range(max_new_tokens):
|
| |
|
| | idx_cond = input_ids[:, -self.config.max_seq_len:]
|
| |
|
| |
|
| | logits = self(idx_cond)
|
| | logits = logits[:, -1, : ] / max(temperature, 1e-5)
|
| |
|
| |
|
| | if repetition_penalty != 1.0:
|
| | for i in range(input_ids.shape[0]):
|
| | for token_id in set(input_ids[i].tolist()):
|
| | if logits[i, token_id] > 0:
|
| | logits[i, token_id] /= repetition_penalty
|
| | else:
|
| | logits[i, token_id] *= repetition_penalty
|
| |
|
| |
|
| | if top_k is not None and top_k > 0:
|
| | v, _ = torch.topk(logits, min(top_k, logits. size(-1)))
|
| | logits[logits < v[:, [-1]]] = float('-inf')
|
| |
|
| |
|
| | if top_p is not None and top_p < 1.0:
|
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| |
|
| | sorted_indices_to_remove = cumulative_probs > top_p
|
| | sorted_indices_to_remove[: , 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| | sorted_indices_to_remove[:, 0] = 0
|
| |
|
| | for i in range(logits.shape[0]):
|
| | indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]]
|
| | logits[i, indices_to_remove] = float('-inf')
|
| |
|
| |
|
| | probs = F.softmax(logits, dim=-1)
|
| | if do_sample:
|
| | next_token = torch.multinomial(probs, num_samples=1)
|
| | else:
|
| | next_token = torch.argmax(probs, dim=-1, keepdim=True)
|
| |
|
| | input_ids = torch. cat([input_ids, next_token], dim=1)
|
| |
|
| |
|
| | if self.config.eos_token_id is not None:
|
| | if (next_token == self.config. eos_token_id).all():
|
| | break
|
| |
|
| | return input_ids
|
| |
|
| | def get_num_params(self, non_embedding=True):
|
| | if non_embedding:
|
| | return self.count_non_embedding_parameters()
|
| | return self.count_parameters()
|
| |
|