Vortex-7b-V1 / models /vortex_model.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
VortexModel: Main model class combining SSM, attention, science modules, and SciGate FFN.
Implements two block types: SSM-only and attention+science+SciGate FFN.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from .ssm_layer import VortexSSM
from .attention_layer import VortexLocalAttention
from .scigate_ffn import SciGateFFN
from .science_modules import (
EquationModule,
NumericalReasoningModule,
CitationModule,
MolecularModule,
)
class VortexBlock(nn.Module):
"""
Two types of blocks:
1. SSMBlock: only VortexSSM
2. AttentionBlock: VortexLocalAttention + ScienceModules + SciGateFFN
"""
def __init__(
self,
config: Dict,
is_ssm_block: bool = True,
):
"""
Initialize a Vortex block.
Args:
config: Model configuration
is_ssm_block: If True, this is an SSM-only block; else attention+science+FFN
"""
super().__init__()
self.config = config
self.is_ssm_block = is_ssm_block
self.d_model = config["d_model"]
if is_ssm_block:
# SSM-only block
self.ssm = VortexSSM(
d_model=config["d_model"],
d_state=config["d_state"],
d_conv=config["d_conv"],
)
self.norm = nn.LayerNorm(config["d_model"])
else:
# Attention + Science + FFN block
self.attn = VortexLocalAttention(
d_model=config["d_model"],
num_heads=config["num_heads"],
window_size=config["window_size"],
use_flash_attention=config.get("use_flash_attention", True),
)
self.attn_norm = nn.LayerNorm(config["d_model"])
# Science modules (enabled based on config flags)
self.equation_module = None
self.numerical_module = None
self.citation_module = None
self.molecular_module = None
if config.get("enable_equation_module", True):
self.equation_module = EquationModule(config["d_model"])
if config.get("enable_numerical_module", True):
self.numerical_module = NumericalReasoningModule(config["d_model"])
if config.get("enable_citation_module", True):
self.citation_module = CitationModule(config["d_model"])
if config.get("enable_molecular_module", True):
self.molecular_module = MolecularModule(config["d_model"])
# SciGate FFN
self.ffn = SciGateFFN(
d_model=config["d_model"],
expansion=config["ffn_expansion"],
num_domains=config["num_domains"],
)
self.ffn_norm = nn.LayerNorm(config["d_model"])
# Final layer norm for both block types
self.final_norm = nn.LayerNorm(config["d_model"])
def forward(
self,
x: torch.Tensor,
domain_ids: Optional[torch.Tensor] = None,
domain_tags: Optional[torch.Tensor] = None,
text: Optional[List[str]] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass through the block.
Args:
x: Input tensor (batch, seq_len, d_model)
domain_ids: Optional domain IDs for SciGate FFN
domain_tags: Optional domain tag masks
text: Optional original text for science module span detection
attention_mask: Optional attention mask
Returns:
Output tensor (batch, seq_len, d_model)
"""
residual = x
if self.is_ssm_block:
# SSM-only pathway
x = self.norm(x)
x = self.ssm(x)
x = residual + x
x = self.final_norm(x)
else:
# Attention + Science + FFN pathway
# Attention
residual_attn = x
x = self.attn_norm(x)
global_mask = self._detect_global_tokens(x) if hasattr(self, '_detect_global_tokens') else None
x = self.attn(x, global_mask=global_mask, attention_mask=attention_mask)
x = residual_attn + x
# Science modules (applied sequentially)
if self.equation_module is not None:
x = x + self.equation_module(x, text=text)
if self.numerical_module is not None:
x = x + self.numerical_module(x, text=text)
if self.citation_module is not None:
x_cited, _ = self.citation_module(x, text=text)
x = x + x_cited
if self.molecular_module is not None:
x = x + self.molecular_module(x, text=text)
# SciGate FFN
residual_ffn = x
x = self.ffn_norm(x)
x = self.ffn(x, domain_ids=domain_ids, domain_tags=domain_tags)
x = residual_ffn + x
x = self.final_norm(x)
return x
def _detect_global_tokens(self, x: torch.Tensor) -> torch.Tensor:
"""
Detect global tokens that should attend across the entire sequence.
Global tokens are those with special domain tags or high norm.
"""
# Simple heuristic: tokens with large L2 norm are likely special
norms = torch.norm(x, dim=-1) # (batch, seq_len)
threshold = torch.quantile(norms, 0.95, dim=-1, keepdim=True)
global_mask = norms > threshold
return global_mask
class VortexModel(nn.Module):
"""
Main Vortex model combining SSM and attention blocks.
Supports both 7B and 13B configurations.
"""
def __init__(
self,
config: Dict,
):
"""
Initialize VortexModel.
Args:
config: Model configuration (from vortex_7b_config.py or vortex_13b_config.py)
"""
super().__init__()
self.config = config
# Token embedding
self.embed_tokens = nn.Embedding(config["vocab_size"], config["d_model"])
# Build blocks according to layer ratio
self.blocks = nn.ModuleList()
self._build_blocks()
# Final layer norm
self.ln_f = nn.LayerNorm(config["d_model"])
# Output projection (weights will be tied by HuggingFace if config.tie_word_embeddings=True)
self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
# Initialize weights
self._initialize_weights()
def _build_blocks(self):
"""Build the sequence of SSM and attention blocks."""
num_layers = self.config["num_layers"]
ssm_ratio = self.config["ssm_ratio"]
# Calculate number of each block type
num_ssm_blocks = int(num_layers * ssm_ratio)
num_attn_blocks = num_layers - num_ssm_blocks
# Determine block pattern
if ssm_ratio == 0.6: # 7B pattern: SSM, SSM, Attn, SSM, SSM, Attn...
pattern = [0, 0, 1] # 0=SSM, 1=Attn
# Repeat pattern and fill remaining
blocks = []
while len(blocks) < num_layers:
blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
else: # 13B pattern: SSM, Attn, SSM, Attn...
pattern = [0, 1]
blocks = []
while len(blocks) < num_layers:
blocks.extend(pattern[:min(len(pattern), num_layers - len(blocks))])
# Ensure exact count
blocks = blocks[:num_layers]
assert len(blocks) == num_layers
# Create blocks
for is_attn in blocks:
block = VortexBlock(
config=self.config,
is_ssm_block=not is_attn,
)
self.blocks.append(block)
print(f"Built {num_layers} layers: {num_ssm_blocks} SSM, {num_attn_blocks} Attention")
def _initialize_weights(self):
"""Initialize weights."""
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=0.02)
for block in self.blocks:
if hasattr(block, 'ssm'):
block.ssm._initialize_weights()
if hasattr(block, 'attn'):
block.attn._initialize_weights()
if hasattr(block, 'ffn'):
block.ffn._initialize_weights()
def forward(
self,
input_ids: torch.Tensor,
domain_ids: Optional[torch.Tensor] = None,
domain_tags: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
text: Optional[List[str]] = None,
return_dict: bool = True,
) -> torch.Tensor:
"""
Forward pass through the model.
Args:
input_ids: Token IDs (batch, seq_len)
domain_ids: Optional domain IDs
domain_tags: Optional domain tag masks
attention_mask: Optional attention mask (batch, seq_len)
text: Optional original text for science modules
return_dict: Whether to return dict (always returns tensor for now)
Returns:
Logits (batch, seq_len, vocab_size)
"""
# Embed tokens
x = self.embed_tokens(input_ids)
# Pass through blocks
for block in self.blocks:
x = block(
x,
domain_ids=domain_ids,
domain_tags=domain_tags,
text=text,
attention_mask=attention_mask,
)
# Final norm
x = self.ln_f(x)
# Project to vocabulary
logits = self.lm_head(x)
if return_dict:
return {"logits": logits, "last_hidden_state": x}
return logits
def get_num_params(self) -> int:
"""Get total number of parameters."""
return sum(p.numel() for p in self.parameters())
def get_trainable_params(self) -> int:
"""Get number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def estimate_memory_usage(
self,
batch_size: int,
seq_len: int,
use_gradient_checkpointing: bool = False,
) -> Dict[str, float]:
"""
Estimate memory usage for a given batch size and sequence length.
Returns:
Dictionary with memory estimates in GB
"""
params = self.get_num_params()
param_bytes = params * 2 # Assuming bfloat16
# Activation memory (rough estimate)
# Each layer: activations ~ batch * seq_len * d_model * 2
activations_per_layer = batch_size * seq_len * self.config["d_model"] * 2
total_activations = activations_per_layer * self.config["num_layers"]
# Gradients (same size as parameters)
gradients = param_bytes
# Optimizer states (AdamW: 2x parameters)
optimizer_states = params * 2 * 2
total_memory = (param_bytes + total_activations + gradients + optimizer_states) / 1e9
return {
"parameters_gb": param_bytes / 1e9,
"activations_gb": total_activations / 1e9,
"gradients_gb": gradients / 1e9,
"optimizer_states_gb": optimizer_states / 1e9,
"total_gb": total_memory,
}
def test_vortex_model():
"""Test the VortexModel."""
from configs.vortex_7b_config import VORTEX_7B_CONFIG
config = VORTEX_7B_CONFIG.copy()
# Reduce size for testing
config["d_model"] = 512
config["num_layers"] = 4
config["num_heads"] = 8
config["vocab_size"] = 1000
model = VortexModel(config)
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
# Forward pass
output = model(input_ids)
logits = output["logits"]
print(f"Model parameters: {model.get_num_params():,}")
print(f"Input shape: {input_ids.shape}")
print(f"Logits shape: {logits.shape}")
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
# Memory estimate
mem = model.estimate_memory_usage(batch_size, seq_len)
print(f"Memory estimate for batch={batch_size}, seq_len={seq_len}:")
for k, v in mem.items():
print(f" {k}: {v:.2f} GB")
print("VortexModel test passed!")
if __name__ == "__main__":
test_vortex_model()