""" SciGateFFN: Science-aware gated feed-forward network. Learns to activate different FFN pathways based on science domain. Uses hybrid routing: explicit domain tags preferred, fallback to learned classifier. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class SciGateFFN(nn.Module): """ Gated FFN with science domain routing. Learns to activate different FFN pathways for different science domains. Gate is conditioned on detected domain (math, chemistry, biology etc). """ def __init__( self, d_model: int, expansion: int = 4, num_domains: int = 7, use_domain_tags: bool = True, ): """ Initialize SciGateFFN. Args: d_model: Model dimension expansion: FFN expansion factor (default 4) num_domains: Number of science domains (7) use_domain_tags: Whether to use explicit domain tags for routing """ super().__init__() self.d_model = d_model self.expansion = expansion self.num_domains = num_domains self.use_domain_tags = use_domain_tags hidden_dim = d_model * expansion # Standard SwiGLU architecture: up_proj splits into two paths self.up_proj = nn.Linear(d_model, hidden_dim * 2, bias=False) self.down_proj = nn.Linear(hidden_dim, d_model, bias=False) # Domain-specific scaling factors (learnable) # Shape: (num_domains, hidden_dim) self.domain_gate = nn.Linear(num_domains, hidden_dim, bias=True) # Fallback domain classifier (when tags not present) # Simple linear classifier based on sequence representation self.fallback_classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.SiLU(), nn.Linear(d_model // 2, num_domains), ) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights.""" for module in [self.up_proj, self.down_proj, self.domain_gate, self.fallback_classifier]: if hasattr(module, 'weight'): nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: nn.init.zeros_(module.bias) def get_domain_one_hot( self, domain_ids: Optional[torch.Tensor] = None, domain_tags: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Get domain one-hot vector for routing. Hybrid strategy: 1. If domain_tags provided (explicit [MATH], [CHEM] etc), use those 2. If domain_ids provided (from data loader), use those 3. Fallback: classify from hidden_states Args: domain_ids: Tensor of domain IDs (batch, seq_len) or (batch,) domain_tags: Boolean mask for domain tags (batch, seq_len, num_domains) hidden_states: Hidden states for fallback classification (batch, seq_len, d_model) Returns: domain_one_hot: (batch, seq_len, num_domains) """ batch, seq_len, _ = hidden_states.shape if hidden_states is not None else (0, 0, 0) if domain_tags is not None and domain_tags.any(): # Use explicit domain tags (one-hot already) return domain_tags.float() elif domain_ids is not None: # Convert domain IDs to one-hot if domain_ids.dim() == 1: # Same domain for entire sequence domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains) # Expand to sequence length domain_one_hot = domain_one_hot.unsqueeze(1).expand(-1, seq_len, -1) else: # Per-token domain IDs domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains) return domain_one_hot.float() elif hidden_states is not None: # Fallback: classify domain from hidden states # Use mean pooling over sequence pooled = hidden_states.mean(dim=1) # (batch, d_model) domain_logits = self.fallback_classifier(pooled) # (batch, num_domains) domain_probs = F.softmax(domain_logits, dim=-1) # Expand to sequence length return domain_probs.unsqueeze(1).expand(-1, seq_len, -1) else: # Uniform distribution (no domain info) uniform = torch.ones(batch, seq_len, self.num_domains, device=hidden_states.device if hidden_states is not None else 'cpu') return uniform / self.num_domains def forward( self, x: torch.Tensor, domain_ids: Optional[torch.Tensor] = None, domain_tags: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass with domain-aware gating. Args: x: Input tensor (batch, seq_len, d_model) domain_ids: Optional domain IDs (batch,) or (batch, seq_len) domain_tags: Optional domain tag mask (batch, seq_len, num_domains) Returns: Output tensor (batch, seq_len, d_model) """ batch, seq_len, d_model = x.shape # Get domain routing weights domain_weights = self.get_domain_one_hot(domain_ids, domain_tags, x) # Shape: (batch, seq_len, num_domains) # Project to hidden dimension up = self.up_proj(x) # (batch, seq_len, hidden_dim * 2) up1, up2 = up.chunk(2, dim=-1) # Each: (batch, seq_len, hidden_dim) # Apply SwiGLU activation hidden = up1 * F.silu(up2) # (batch, seq_len, hidden_dim) # Apply domain-specific scaling # domain_weights: (batch, seq_len, num_domains) # self.domain_gate.weight: (hidden_dim, num_domains) - Linear weight shape # einsum: (batch, seq_len, num_domains) * (hidden_dim, num_domains) -> (batch, seq_len, hidden_dim) domain_scaling = torch.einsum( "bsd,hd->bsh", domain_weights, self.domain_gate.weight # (hidden_dim, num_domains) ) # domain_scaling: (batch, seq_len, hidden_dim) # Apply domain scaling (multiplicative gating) hidden = hidden * domain_scaling # Project back to model dimension output = self.down_proj(hidden) return output def test_scigate_ffn(): """Test SciGateFFN.""" batch_size = 2 seq_len = 128 d_model = 4096 num_domains = 7 ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains) # Test with no domain info (fallback) x = torch.randn(batch_size, seq_len, d_model) output = ffn(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") assert output.shape == x.shape # Test with explicit domain IDs domain_ids = torch.randint(0, num_domains, (batch_size,)) output2 = ffn(x, domain_ids=domain_ids) assert output2.shape == x.shape # Test with domain tags domain_tags = torch.zeros(batch_size, seq_len, num_domains) domain_tags[:, :, 0] = 1.0 # All math output3 = ffn(x, domain_tags=domain_tags) assert output3.shape == x.shape print("SciGateFFN test passed!") if __name__ == "__main__": test_scigate_ffn()