| | """
|
| | VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN.
|
| | Implements two block types: SSM-only and attention+science+SciGate FFN.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple, List, Dict
|
| |
|
| | from .ssm_layer import VortexSSM
|
| | from .attention_layer import VortexLocalAttention
|
| | from .scigate_ffn import SciGateFFN
|
| | from .science_modules import (
|
| | EquationModule,
|
| | NumericalReasoningModule,
|
| | CitationModule,
|
| | MolecularModule,
|
| | )
|
| |
|
| |
|
| | class VortexBlock(nn.Module):
|
| | """
|
| | Two types of blocks:
|
| | 1. SSMBlock: only VortexSSM
|
| | 2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | config: Dict,
|
| | is_ssm_block: bool = True,
|
| | ):
|
| | """
|
| | Initialize a Vortex block.
|
| |
|
| | Args:
|
| | config: Model configuration
|
| | is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN
|
| | """
|
| | super().__init__()
|
| | self.config = config
|
| | self.is_ssm_block = is_ssm_block
|
| | self.d_model = config["d_model"]
|
| |
|
| | if is_ssm_block:
|
| |
|
| | self.ssm = VortexSSM(
|
| | d_model=config["d_model"],
|
| | d_state=config["d_state"],
|
| | d_conv=config["d_conv"],
|
| | )
|
| | self.norm = nn.LayerNorm(config["d_model"])
|
| | else:
|
| |
|
| | self.attn = VortexLocalAttention(
|
| | d_model=config["d_model"],
|
| | num_heads=config["num_heads"],
|
| | window_size=config["window_size"],
|
| | use_flash_attention=config.get("use_flash_attention", True),
|
| | )
|
| | self.attn_norm = nn.LayerNorm(config["d_model"])
|
| |
|
| |
|
| | self.equation_module = None
|
| | self.numerical_module = None
|
| | self.citation_module = None
|
| | self.molecular_module = None
|
| |
|
| | if config.get("enable_equation_module", True):
|
| | self.equation_module = EquationModule(config["d_model"])
|
| |
|
| | if config.get("enable_numerical_module", True):
|
| | self.numerical_module = NumericalReasoningModule(config["d_model"])
|
| |
|
| | if config.get("enable_citation_module", True):
|
| | self.citation_module = CitationModule(config["d_model"])
|
| |
|
| | if config.get("enable_molecular_module", True):
|
| | self.molecular_module = MolecularModule(config["d_model"])
|
| |
|
| |
|
| | self.ffn = SciGateFFN(
|
| | d_model=config["d_model"],
|
| | expansion=config["ffn_expansion"],
|
| | num_domains=config["num_domains"],
|
| | )
|
| | self.ffn_norm = nn.LayerNorm(config["d_model"])
|
| |
|
| |
|
| | self.final_norm = nn.LayerNorm(config["d_model"])
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | domain_ids: Optional[torch.Tensor] = None,
|
| | domain_tags: Optional[torch.Tensor] = None,
|
| | text: Optional[List[str]] = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through the block.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | domain_ids: Optional domain IDs for SciGate FFN
|
| | domain_tags: Optional domain tag masks
|
| | text: Optional original text for science module span detection
|
| | attention_mask: Optional attention mask
|
| |
|
| | Returns:
|
| | Output tensor (batch, seq_len, d_model)
|
| | """
|
| | residual = x
|
| |
|
| | if self.is_ssm_block:
|
| |
|
| | x = self.norm(x)
|
| | x = self.ssm(x)
|
| | x = residual + x
|
| | x = self.final_norm(x)
|
| | else:
|
| |
|
| |
|
| | residual_attn = x
|
| | x = self.attn_norm(x)
|
| | global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None
|
| | x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask)
|
| | x = residual_attn + x
|
| |
|
| |
|
| | if self.equation_module is not None:
|
| | x = x + self.equation_module(x, text=text)
|
| |
|
| | if self.numerical_module is not None:
|
| | x = x + self.numerical_module(x, text=text)
|
| |
|
| | if self.citation_module is not None:
|
| | x_cited, _ = self.citation_module(x, text=text)
|
| | x = x + x_cited
|
| |
|
| | if self.molecular_module is not None:
|
| | x = x + self.molecular_module(x, text=text)
|
| |
|
| |
|
| | residual_ffn = x
|
| | x = self.ffn_norm(x)
|
| | x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags)
|
| | x = residual_ffn + x
|
| |
|
| | x = self.final_norm(x)
|
| |
|
| | return x
|
| |
|
| | def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Detect global tokens that should attend across the entire sequence.
|
| | Global tokens are those with special domain tags or high norm.
|
| | """
|
| |
|
| | norms = torch.norm(x, dim=-1)
|
| | threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True)
|
| | global_mask = norms > threshold
|
| |
|
| | return global_mask
|
| |
|
| |
|
| | class VortexModel(nn.Module):
|
| | """
|
| | Main Vortex model combining SSM and attention blocks.
|
| | Supports both 7B and 13B configurations.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | config: Dict,
|
| | ):
|
| | """
|
| | Initialize VortexModel.
|
| |
|
| | Args:
|
| | config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py)
|
| | """
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| |
|
| | self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"])
|
| |
|
| |
|
| | self.blocks = nn.ModuleList()
|
| | self._build_blocks()
|
| |
|
| |
|
| | self.ln_f = nn.LayerNorm(config["d_model"])
|
| |
|
| |
|
| | self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _build_blocks(self):
|
| | """Build the sequence of SSM and attention blocks."""
|
| | num_layers = self.config["num_layers"]
|
| | ssm_ratio = self.config["ssm_ratio"]
|
| |
|
| |
|
| | num_ssm_blocks = int(num_layers * ssm_ratio)
|
| | num_attn_blocks = num_layers - num_ssm_blocks
|
| |
|
| |
|
| | if ssm_ratio == 0.6:
|
| | pattern = [0, 0, 1]
|
| |
|
| | blocks = []
|
| | while len(blocks) < num_layers:
|
| | blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
| | else:
|
| | pattern = [0, 1]
|
| | blocks = []
|
| | while len(blocks) < num_layers:
|
| | blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
|
| |
|
| |
|
| | blocks = blocks[:num_layers]
|
| | assert len(blocks) == num_layers
|
| |
|
| |
|
| | for is_attn in blocks:
|
| | block = VortexBlock(
|
| | config=self.config,
|
| | is_ssm_block=not is_attn,
|
| | )
|
| | self.blocks.append(block)
|
| |
|
| | print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention")
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
|
| | for block in self.blocks:
|
| | if hasattr(block, 'ssm'):
|
| | block.ssm._initialize_weights()
|
| | if hasattr(block, 'attn'):
|
| | block.attn._initialize_weights()
|
| | if hasattr(block, 'ffn'):
|
| | block.ffn._initialize_weights()
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.Tensor,
|
| | domain_ids: Optional[torch.Tensor] = None,
|
| | domain_tags: Optional[torch.Tensor] = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | text: Optional[List[str]] = None,
|
| | return_dict: bool = True,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass through the model.
|
| |
|
| | Args:
|
| | input_ids: Token IDs (batch, seq_len)
|
| | domain_ids: Optional domain IDs
|
| | domain_tags: Optional domain tag masks
|
| | attention_mask: Optional attention mask (batch, seq_len)
|
| | text: Optional original text for science modules
|
| | return_dict: Whether to return dict (always returns tensor for now)
|
| |
|
| | Returns:
|
| | Logits (batch, seq_len, vocab_size)
|
| | """
|
| |
|
| | x = self.embed_tokens(input_ids)
|
| |
|
| |
|
| | for block in self.blocks:
|
| | x = block(
|
| | x,
|
| | domain_ids=domain_ids,
|
| | domain_tags=domain_tags,
|
| | text=text,
|
| | attention_mask=attention_mask,
|
| | )
|
| |
|
| |
|
| | x = self.ln_f(x)
|
| |
|
| |
|
| | logits = self.lm_head(x)
|
| |
|
| | if return_dict:
|
| | return {"logits": logits, "last_hidden_state": x}
|
| | return logits
|
| |
|
| | def get_num_params(self) -> int:
|
| | """Get total number of parameters."""
|
| | return sum(p.numel() for p in self.parameters())
|
| |
|
| | def get_trainable_params(self) -> int:
|
| | """Get number of trainable parameters."""
|
| | return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| | def estimate_memory_usage(
|
| | self,
|
| | batch_size: int,
|
| | seq_len: int,
|
| | use_gradient_checkpointing: bool = False,
|
| | ) -> Dict[str, float]:
|
| | """
|
| | Estimate memory usage for a given batch size and sequence length.
|
| |
|
| | Returns:
|
| | Dictionary with memory estimates in GB
|
| | """
|
| | params = self.get_num_params()
|
| | param_bytes = params * 2
|
| |
|
| |
|
| |
|
| | activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2
|
| | total_activations = activations_per_layer * self.config["num_layers"]
|
| |
|
| |
|
| | gradients = param_bytes
|
| |
|
| |
|
| | optimizer_states = params * 2 * 2
|
| |
|
| | total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9
|
| |
|
| | return {
|
| | "parameters_gb": param_bytes / 1e9,
|
| | "activations_gb": total_activations / 1e9,
|
| | "gradients_gb": gradients / 1e9,
|
| | "optimizer_states_gb": optimizer_states / 1e9,
|
| | "total_gb": total_memory,
|
| | }
|
| |
|
| |
|
| | def test_vortex_model():
|
| | """Test the VortexModel."""
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| | config = VORTEX_7B_CONFIG.copy()
|
| |
|
| | config["d_model"] = 512
|
| | config["num_layers"] = 4
|
| | config["num_heads"] = 8
|
| | config["vocab_size"] = 1000
|
| |
|
| | model = VortexModel(config)
|
| |
|
| | batch_size = 2
|
| | seq_len = 128
|
| | input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
| |
|
| |
|
| | output = model(input_ids)
|
| | logits = output["logits"]
|
| |
|
| | print(f"Model parameters: {model.get_num_params():,}")
|
| | print(f"Input shape: {input_ids.shape}")
|
| | print(f"Logits shape: {logits.shape}")
|
| | assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
| |
|
| |
|
| | mem = model.estimate_memory_usage(batch_size, seq_len)
|
| | print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:")
|
| | for k, v in mem.items():
|
| | print(f" {k}: {v:.2f} GB")
|
| |
|
| | print("VortexModel test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_vortex_model()
|
| |
|