Vortex-7b-V1 / models /scigate_ffn.py
Zandy-Wandy's picture
Upload Vortex model
bf64b03 verified
"""
SciGateFFN: Science-aware gated feed-forward network.
Learns to activate different FFN pathways based on science domain.
Uses hybrid routing: explicit domain tags preferred, fallback to learned classifier.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class SciGateFFN(nn.Module):
"""
Gated FFN with science domain routing.
Learns to activate different FFN pathways for different science domains.
Gate is conditioned on detected domain (math, chemistry, biology etc).
"""
def __init__(
self,
d_model: int,
expansion: int = 4,
num_domains: int = 7,
use_domain_tags: bool = True,
):
"""
Initialize SciGateFFN.
Args:
d_model: Model dimension
expansion: FFN expansion factor (default 4)
num_domains: Number of science domains (7)
use_domain_tags: Whether to use explicit domain tags for routing
"""
super().__init__()
self.d_model = d_model
self.expansion = expansion
self.num_domains = num_domains
self.use_domain_tags = use_domain_tags
hidden_dim = d_model * expansion
# Standard SwiGLU architecture: up_proj splits into two paths
self.up_proj = nn.Linear(d_model, hidden_dim * 2, bias=False)
self.down_proj = nn.Linear(hidden_dim, d_model, bias=False)
# Domain-specific scaling factors (learnable)
# Shape: (num_domains, hidden_dim)
self.domain_gate = nn.Linear(num_domains, hidden_dim, bias=True)
# Fallback domain classifier (when tags not present)
# Simple linear classifier based on sequence representation
self.fallback_classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.SiLU(),
nn.Linear(d_model // 2, num_domains),
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights."""
for module in [self.up_proj, self.down_proj, self.domain_gate, self.fallback_classifier]:
if hasattr(module, 'weight'):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.zeros_(module.bias)
def get_domain_one_hot(
self,
domain_ids: Optional[torch.Tensor] = None,
domain_tags: Optional[torch.Tensor] = None,
hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Get domain one-hot vector for routing.
Hybrid strategy:
1. If domain_tags provided (explicit [MATH], [CHEM] etc), use those
2. If domain_ids provided (from data loader), use those
3. Fallback: classify from hidden_states
Args:
domain_ids: Tensor of domain IDs (batch, seq_len) or (batch,)
domain_tags: Boolean mask for domain tags (batch, seq_len, num_domains)
hidden_states: Hidden states for fallback classification (batch, seq_len, d_model)
Returns:
domain_one_hot: (batch, seq_len, num_domains)
"""
batch, seq_len, _ = hidden_states.shape if hidden_states is not None else (0, 0, 0)
if domain_tags is not None and domain_tags.any():
# Use explicit domain tags (one-hot already)
return domain_tags.float()
elif domain_ids is not None:
# Convert domain IDs to one-hot
if domain_ids.dim() == 1:
# Same domain for entire sequence
domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
# Expand to sequence length
domain_one_hot = domain_one_hot.unsqueeze(1).expand(-1, seq_len, -1)
else:
# Per-token domain IDs
domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
return domain_one_hot.float()
elif hidden_states is not None:
# Fallback: classify domain from hidden states
# Use mean pooling over sequence
pooled = hidden_states.mean(dim=1) # (batch, d_model)
domain_logits = self.fallback_classifier(pooled) # (batch, num_domains)
domain_probs = F.softmax(domain_logits, dim=-1)
# Expand to sequence length
return domain_probs.unsqueeze(1).expand(-1, seq_len, -1)
else:
# Uniform distribution (no domain info)
uniform = torch.ones(batch, seq_len, self.num_domains, device=hidden_states.device if hidden_states is not None else 'cpu')
return uniform / self.num_domains
def forward(
self,
x: torch.Tensor,
domain_ids: Optional[torch.Tensor] = None,
domain_tags: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass with domain-aware gating.
Args:
x: Input tensor (batch, seq_len, d_model)
domain_ids: Optional domain IDs (batch,) or (batch, seq_len)
domain_tags: Optional domain tag mask (batch, seq_len, num_domains)
Returns:
Output tensor (batch, seq_len, d_model)
"""
batch, seq_len, d_model = x.shape
# Get domain routing weights
domain_weights = self.get_domain_one_hot(domain_ids, domain_tags, x)
# Shape: (batch, seq_len, num_domains)
# Project to hidden dimension
up = self.up_proj(x) # (batch, seq_len, hidden_dim * 2)
up1, up2 = up.chunk(2, dim=-1) # Each: (batch, seq_len, hidden_dim)
# Apply SwiGLU activation
hidden = up1 * F.silu(up2) # (batch, seq_len, hidden_dim)
# Apply domain-specific scaling
# domain_weights: (batch, seq_len, num_domains)
# self.domain_gate.weight: (hidden_dim, num_domains) - Linear weight shape
# einsum: (batch, seq_len, num_domains) * (hidden_dim, num_domains) -> (batch, seq_len, hidden_dim)
domain_scaling = torch.einsum(
"bsd,hd->bsh",
domain_weights,
self.domain_gate.weight # (hidden_dim, num_domains)
)
# domain_scaling: (batch, seq_len, hidden_dim)
# Apply domain scaling (multiplicative gating)
hidden = hidden * domain_scaling
# Project back to model dimension
output = self.down_proj(hidden)
return output
def test_scigate_ffn():
"""Test SciGateFFN."""
batch_size = 2
seq_len = 128
d_model = 4096
num_domains = 7
ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
# Test with no domain info (fallback)
x = torch.randn(batch_size, seq_len, d_model)
output = ffn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape
# Test with explicit domain IDs
domain_ids = torch.randint(0, num_domains, (batch_size,))
output2 = ffn(x, domain_ids=domain_ids)
assert output2.shape == x.shape
# Test with domain tags
domain_tags = torch.zeros(batch_size, seq_len, num_domains)
domain_tags[:, :, 0] = 1.0 # All math
output3 = ffn(x, domain_tags=domain_tags)
assert output3.shape == x.shape
print("SciGateFFN test passed!")
if __name__ == "__main__":
test_scigate_ffn()