""" VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN. Implements two block types: SSM-only and attention+science+SciGate FFN. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Dict from .ssm_layer import VortexSSM from .attention_layer import VortexLocalAttention from .scigate_ffn import SciGateFFN from .science_modules import ( EquationModule, NumericalReasoningModule, CitationModule, MolecularModule, ) class VortexBlock(nn.Module): """ Two types of blocks: 1. SSMBlock: only VortexSSM 2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN """ def __init__( self, config: Dict, is_ssm_block: bool = True, ): """ Initialize a Vortex block. Args: config: Model configuration is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN """ super().__init__() self.config = config self.is_ssm_block = is_ssm_block self.d_model = config["d_model"] if is_ssm_block: # SSM-only block self.ssm = VortexSSM( d_model=config["d_model"], d_state=config["d_state"], d_conv=config["d_conv"], ) self.norm = nn.LayerNorm(config["d_model"]) else: # Attention + Science + FFN block self.attn = VortexLocalAttention( d_model=config["d_model"], num_heads=config["num_heads"], window_size=config["window_size"], use_flash_attention=config.get("use_flash_attention", True), ) self.attn_norm = nn.LayerNorm(config["d_model"]) # Science modules (enabled based on config flags) self.equation_module = None self.numerical_module = None self.citation_module = None self.molecular_module = None if config.get("enable_equation_module", True): self.equation_module = EquationModule(config["d_model"]) if config.get("enable_numerical_module", True): self.numerical_module = NumericalReasoningModule(config["d_model"]) if config.get("enable_citation_module", True): self.citation_module = CitationModule(config["d_model"]) if config.get("enable_molecular_module", True): self.molecular_module = MolecularModule(config["d_model"]) # SciGate FFN self.ffn = SciGateFFN( d_model=config["d_model"], expansion=config["ffn_expansion"], num_domains=config["num_domains"], ) self.ffn_norm = nn.LayerNorm(config["d_model"]) # Final layer norm for both block types self.final_norm = nn.LayerNorm(config["d_model"]) def forward( self, x: torch.Tensor, domain_ids: Optional[torch.Tensor] = None, domain_tags: Optional[torch.Tensor] = None, text: Optional[List[str]] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass through the block. Args: x: Input tensor (batch, seq_len, d_model) domain_ids: Optional domain IDs for SciGate FFN domain_tags: Optional domain tag masks text: Optional original text for science module span detection attention_mask: Optional attention mask Returns: Output tensor (batch, seq_len, d_model) """ residual = x if self.is_ssm_block: # SSM-only pathway x = self.norm(x) x = self.ssm(x) x = residual + x x = self.final_norm(x) else: # Attention + Science + FFN pathway # Attention residual_attn = x x = self.attn_norm(x) global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask) x = residual_attn + x # Science modules (applied sequentially) if self.equation_module is not None: x = x + self.equation_module(x, text=text) if self.numerical_module is not None: x = x + self.numerical_module(x, text=text) if self.citation_module is not None: x_cited, _ = self.citation_module(x, text=text) x = x + x_cited if self.molecular_module is not None: x = x + self.molecular_module(x, text=text) # SciGate FFN residual_ffn = x x = self.ffn_norm(x) x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags) x = residual_ffn + x x = self.final_norm(x) return x def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor: """ Detect global tokens that should attend across the entire sequence. Global tokens are those with special domain tags or high norm. """ # Simple heuristic: tokens with large L2 norm are likely special norms = torch.norm(x, dim=-1) # (batch, seq_len) threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True) global_mask = norms > threshold return global_mask class VortexModel(nn.Module): """ Main Vortex model combining SSM and attention blocks. Supports both 7B and 13B configurations. """ def __init__( self, config: Dict, ): """ Initialize VortexModel. Args: config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py) """ super().__init__() self.config = config # Token embedding self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"]) # Build blocks according to layer ratio self.blocks = nn.ModuleList() self._build_blocks() # Final layer norm self.ln_f = nn.LayerNorm(config["d_model"]) # Output projection (weights will be tied by HuggingFace if config.tie_word_embeddings=True) self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False) # Initialize weights self._initialize_weights() def _build_blocks(self): """Build the sequence of SSM and attention blocks.""" num_layers = self.config["num_layers"] ssm_ratio = self.config["ssm_ratio"] # Calculate number of each block type num_ssm_blocks = int(num_layers * ssm_ratio) num_attn_blocks = num_layers - num_ssm_blocks # Determine block pattern if ssm_ratio == 0.6: # 7B pattern: SSM, SSM, Attn, SSM, SSM, Attn... pattern = [0, 0, 1] # 0=SSM, 1=Attn # Repeat pattern and fill remaining blocks = [] while len(blocks) < num_layers: blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))]) else: # 13B pattern: SSM, Attn, SSM, Attn... pattern = [0, 1] blocks = [] while len(blocks) < num_layers: blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))]) # Ensure exact count blocks = blocks[:num_layers] assert len(blocks) == num_layers # Create blocks for is_attn in blocks: block = VortexBlock( config=self.config, is_ssm_block=not is_attn, ) self.blocks.append(block) print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention") def _initialize_weights(self): """Initialize weights.""" nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02) for block in self.blocks: if hasattr(block, 'ssm'): block.ssm._initialize_weights() if hasattr(block, 'attn'): block.attn._initialize_weights() if hasattr(block, 'ffn'): block.ffn._initialize_weights() def forward( self, input_ids: torch.Tensor, domain_ids: Optional[torch.Tensor] = None, domain_tags: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, text: Optional[List[str]] = None, return_dict: bool = True, ) -> torch.Tensor: """ Forward pass through the model. Args: input_ids: Token IDs (batch, seq_len) domain_ids: Optional domain IDs domain_tags: Optional domain tag masks attention_mask: Optional attention mask (batch, seq_len) text: Optional original text for science modules return_dict: Whether to return dict (always returns tensor for now) Returns: Logits (batch, seq_len, vocab_size) """ # Embed tokens x = self.embed_tokens(input_ids) # Pass through blocks for block in self.blocks: x = block( x, domain_ids=domain_ids, domain_tags=domain_tags, text=text, attention_mask=attention_mask, ) # Final norm x = self.ln_f(x) # Project to vocabulary logits = self.lm_head(x) if return_dict: return {"logits": logits, "last_hidden_state": x} return logits def get_num_params(self) -> int: """Get total number of parameters.""" return sum(p.numel() for p in self.parameters()) def get_trainable_params(self) -> int: """Get number of trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def estimate_memory_usage( self, batch_size: int, seq_len: int, use_gradient_checkpointing: bool = False, ) -> Dict[str, float]: """ Estimate memory usage for a given batch size and sequence length. Returns: Dictionary with memory estimates in GB """ params = self.get_num_params() param_bytes = params * 2 # Assuming bfloat16 # Activation memory (rough estimate) # Each layer: activations ~ batch * seq_len * d_model * 2 activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2 total_activations = activations_per_layer * self.config["num_layers"] # Gradients (same size as parameters) gradients = param_bytes # Optimizer states (AdamW: 2x parameters) optimizer_states = params * 2 * 2 total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9 return { "parameters_gb": param_bytes / 1e9, "activations_gb": total_activations / 1e9, "gradients_gb": gradients / 1e9, "optimizer_states_gb": optimizer_states / 1e9, "total_gb": total_memory, } def test_vortex_model(): """Test the VortexModel.""" from configs.vortex_7b_config import VORTEX_7B_CONFIG config = VORTEX_7B_CONFIG.copy() # Reduce size for testing config["d_model"] = 512 config["num_layers"] = 4 config["num_heads"] = 8 config["vocab_size"] = 1000 model = VortexModel(config) batch_size = 2 seq_len = 128 input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)) # Forward pass output = model(input_ids) logits = output["logits"] print(f"Model parameters: {model.get_num_params():,}") print(f"Input shape: {input_ids.shape}") print(f"Logits shape: {logits.shape}") assert logits.shape == (batch_size, seq_len, config["vocab_size"]) # Memory estimate mem = model.estimate_memory_usage(batch_size, seq_len) print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:") for k, v in mem.items(): print(f" {k}: {v:.2f} GB") print("VortexModel test passed!") if __name__ == "__main__": test_vortex_model()