""" SimpleLLM - Mamba-style State-Space Model with ternary quantization. """ import torch import torch.nn as nn import torch.nn. functional as F from .ssm import SSMBlock from .bitlinear import BitLinear, RMSNorm, ActivationQuantize from .factorized_embedding import FactorizedEmbedding from .mla import MemoryOptimizedMLA class SSMBlockWrapper(nn.Module): """ Pre-Norm SSM Block (Mamba-style) with nn.Sequential structure. Structure: x → Norm → SSM → Add → Norm → FFN → Add → output """ def __init__(self, config): super().__init__() self.ssm = SSMBlock(config) self.feed_forward = nn.Sequential( BitLinear(config.d_model, config.d_ff, bias=False), nn.ReLU(), BitLinear(config.d_ff, config.d_model, bias=False), ) self.dropout = nn.Dropout(config.dropout) def forward(self, x, mask=None): # Pre-norm SSM with residual x = x + self.dropout(self.ssm(x, mask)) # Normalize before SSM # Pre-norm FFN with residual x = x + self.dropout(self.feed_forward(x)) # Normalize before FFN return x class MLABlockWrapper(nn.Module): """ MLA Block with residual connection and FFN. Structure: x → Norm → MLA → Add → Norm → FFN → Add → output Pre-norm structure stabilizes training and prevents gradient explosion. """ def __init__(self, config): super().__init__() self.mla = MemoryOptimizedMLA(config) self.ffn = nn.Sequential( nn.Linear(config.d_model, config.d_ff, bias=False), nn.ReLU(), nn.Linear(config.d_ff, config.d_model, bias=False), nn.ReLU(), nn.Linear(config.d_ff, config.d_model, bias=False), ) self.dropout = nn.Dropout(config.dropout) def forward(self, x, mask=None): # Pre-norm MLA with residual x = x + self.dropout(self.mla(x, mask=mask)) # Pre-norm FFN with residual x = x + self.dropout(self.ffn(x)) return x class SimpleLLM(nn.Module): """ Language Model with Hybrid Mamba-style SSM + MLA blocks. Architecture: Token Embedding → (SSM Blocks + MLA Blocks) → Output Head Hybrid structure controlled by config.ssm_per_mla: - ssm_per_mla = 2: SSM, SSM, MLA, SSM, SSM, MLA, ... - ssm_per_mla = 3: SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, ... """ def __init__(self, config): super().__init__() self.config = config # Factorized embeddings self.token_embedding = FactorizedEmbedding( vocab_size=config.vocab_size, d_model=config.d_model, d_embed_rank=config.d_embed_rank ) self.dropout = nn.Dropout(config.dropout) # Build block architecture based on arrangement strategy self.blocks = nn.ModuleList() if config.block_arrangement == "interleaving": self._build_interleaving_blocks(config) elif config.block_arrangement == "layered": self._build_layered_blocks(config) else: raise ValueError(f"Unknown block_arrangement: {config.block_arrangement}") # ================================================================= # Two-stage output projection (mirrors factorized embedding) # ================================================================= # Stage 1: d_model → d_embed_rank (reverse of embedding projection) self.output_proj = nn.Linear(config.d_model, config.d_embed_rank, bias=False) # Stage 2: d_embed_rank → vocab_size (tied to embedding table) self.lm_head = nn.Linear(config.d_embed_rank, config.vocab_size, bias=False) # Tie lm_head weights to embedding table self.lm_head.weight = self.token_embedding.embed.weight # ================================================================= # Final layer norm before output head to stabilize predictions self.pre_final_norm = nn.LayerNorm(config.d_model) self.final_norm = nn.LayerNorm(config.d_embed_rank) self.apply(self._init_weights) self.register_buffer("causal_mask_cache", None, persistent=False) self._print_architecture() def _build_interleaving_blocks(self, config): """ Build interleaving block arrangement: SSM blocks followed by MLA blocks in a pattern. Example with ssm_per_mla=3 and n_layers=16: SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA, SSM, SSM, SSM, MLA """ ssm_per_mla = config.ssm_per_mla num_mla_blocks = max(1, config.n_layers // (ssm_per_mla + 1)) block_idx = 0 for mla_idx in range(num_mla_blocks): # Add SSM blocks before each MLA block for _ in range(ssm_per_mla): if block_idx < config.n_layers: self.blocks.append(SSMBlockWrapper(config)) block_idx += 1 # Add MLA block if block_idx < config.n_layers: self.blocks.append(MLABlockWrapper(config)) block_idx += 1 # Add remaining SSM blocks (if n_layers is not evenly divisible) while block_idx < config.n_layers: self.blocks.append(SSMBlockWrapper(config)) block_idx += 1 def _build_layered_blocks(self, config): """ Build layered block arrangement: MLA blocks followed by SSM blocks. Example with layered_mla_num=4 and n_layers=16: MLA, MLA, MLA, MLA, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM, SSM """ num_mla = config.layered_mla_num # Add MLA blocks first for _ in range(min(num_mla, config.n_layers)): self.blocks.append(MLABlockWrapper(config)) # Add remaining SSM blocks num_ssm = config.n_layers - len(self.blocks) for _ in range(num_ssm): self.blocks.append(SSMBlockWrapper(config)) def _init_weights(self, module): if isinstance(module, nn.Linear) and not isinstance(module, BitLinear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module. bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def _print_architecture(self): total_params = self.count_parameters() embed_params = self.token_embedding.get_num_params() output_proj_params = self.config.d_model * self.config.d_embed_rank ssm_params = total_params - embed_params - output_proj_params # Count SSM and MLA blocks num_ssm = sum(1 for b in self.blocks if isinstance(b, SSMBlockWrapper)) num_mla = sum(1 for b in self.blocks if isinstance(b, MLABlockWrapper)) print(f"\n{'='*60}") print("MODEL ARCHITECTURE - HYBRID SSM + MLA") print(f"{'='*60}") print(f"Embedding: {embed_params/1e6:>6.2f}M params") print(f"Hybrid Blocks: {num_ssm} SSM + {num_mla} MLA = {num_ssm + num_mla} total") print(f"Output Proj: {output_proj_params/1e6:>6.2f}M params") print(f"Output Head: tied to embedding (0 extra params)") print(f"{'─'*60}") print(f"Total: {total_params/1e6:>6.2f}M params") print(f"{'='*60}") print(f"Config: {self.config.n_layers} layers, {self.config.d_model} dim") print(f"SSM: d_state={self.config.d_state}") print(f"MLA: n_heads={self.config.n_heads}, d_kv_comp={self.config.d_kv_comp}") # Print arrangement-specific info if self.config.block_arrangement == "interleaving": print(f"Arrangement: INTERLEAVING (ssm_per_mla={self.config.ssm_per_mla})") elif self.config.block_arrangement == "layered": print(f"Arrangement: LAYERED (mla_blocks={self.config.layered_mla_num}, ssm_blocks={num_ssm})") print(f"{'='*60}\n") def _get_causal_mask(self, seq_len, device): if self.causal_mask_cache is None or self.causal_mask_cache. size(-1) < seq_len: mask = torch.tril(torch.ones(seq_len, seq_len, device=device)) mask = mask.unsqueeze(0).unsqueeze(0) self.causal_mask_cache = mask return self.causal_mask_cache[: , :, :seq_len, :seq_len] def forward(self, input_ids, attention_mask=None): batch_size, seq_len = input_ids.shape # Causal mask causal_mask = self._get_causal_mask(seq_len, input_ids.device) if attention_mask is not None: padding_mask = attention_mask.unsqueeze(1).unsqueeze(1) causal_mask = causal_mask * padding_mask # Token embedding x = self.token_embedding(input_ids) x = self.dropout(x) x = ActivationQuantize.apply(x) # Hybrid SSM + MLA blocks for block in self.blocks: x = block(x, causal_mask) # Two-stage output projection x = self.pre_final_norm(x) x = self.output_proj(x) # d_model → d_embed_rank x = self.final_norm(x) # Normalize before output head logits = self.lm_head(x) # d_embed_rank → vocab_size return logits def init_ssm_states(self, batch_size, device, dtype): """ Initialize SSM states for all SSM blocks (MLA blocks are stateless). Returns: states: List of [batch, d_state] tensors for each SSM block """ states = [] for block in self.blocks: if isinstance(block, SSMBlockWrapper): state = block.ssm.init_state(batch_size, device, dtype) states.append(state) return states def inference_step(self, input_id, states, return_hidden_states=False): """ Single inference step for autoregressive generation (RNN-like). Args: input_id: [batch, 1] or scalar token id states: List of SSM states from previous step return_hidden_states: If True, also return SSM hidden states for visualization Returns: logits: [batch, vocab_size] - output logits for next token new_states: List of updated SSM states for SSM blocks hidden_states: (Optional) List of SSM hidden state values for each SSM layer """ if isinstance(input_id, int): input_id = torch.tensor([[input_id]], dtype=torch.long, device=next(self.parameters()).device) elif input_id.dim() == 1: input_id = input_id.unsqueeze(0) # Embed the token x = self.token_embedding(input_id) # [batch, 1, d_model] x = x.squeeze(1) # [batch, d_model] x = ActivationQuantize.apply(x) # Pass through hybrid blocks new_states = [] hidden_states = [] if return_hidden_states else None state_idx = 0 # Track position in states list (only for SSM blocks) for block in self.blocks: if isinstance(block, SSMBlockWrapper): # SSM block with state management residual = x ssm_out, new_state = block.ssm.step(x, states[state_idx]) # Collect hidden state if requested if return_hidden_states: hidden_states.append(new_state.clone().detach()) x = residual + block.dropout(ssm_out) # FFN + residual residual = x ffn_out = block.feed_forward(x) x = residual + block.dropout(ffn_out) new_states.append(new_state) state_idx += 1 else: # MLA block (stateless) x = block(x.unsqueeze(1), mask=None).squeeze(1) # Output projection x = self.pre_final_norm(x) x = self.output_proj(x) x = self.final_norm(x) logits = self.lm_head(x) if return_hidden_states: return logits, new_states, hidden_states else: return logits, new_states def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def count_non_embedding_parameters(self): total = self.count_parameters() embedding_params = self.token_embedding.get_num_params() return total - embedding_params @torch.no_grad() def generate( self, input_ids, max_new_tokens=50, temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.1, do_sample=True ): """Generate tokens autoregressively.""" self.eval() for _ in range(max_new_tokens): # Crop to max_seq_len idx_cond = input_ids[:, -self.config.max_seq_len:] # Forward logits = self(idx_cond) logits = logits[:, -1, : ] / max(temperature, 1e-5) # Repetition penalty if repetition_penalty != 1.0: for i in range(input_ids.shape[0]): for token_id in set(input_ids[i].tolist()): if logits[i, token_id] > 0: logits[i, token_id] /= repetition_penalty else: logits[i, token_id] *= repetition_penalty # Top-k filtering if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits. size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Top-p filtering if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[: , 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 for i in range(logits.shape[0]): indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]] logits[i, indices_to_remove] = float('-inf') # Sample or greedy probs = F.softmax(logits, dim=-1) if do_sample: next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(probs, dim=-1, keepdim=True) input_ids = torch. cat([input_ids, next_token], dim=1) # Stop on EOS if self.config.eos_token_id is not None: if (next_token == self.config. eos_token_id).all(): break return input_ids def get_num_params(self, non_embedding=True): if non_embedding: return self.count_non_embedding_parameters() return self.count_parameters()