| | """
|
| | SciGateFFN: Science-aware gated feed-forward network.
|
| | Learns to activate different FFN pathways based on science domain.
|
| | Uses hybrid routing: explicit domain tags preferred, fallback to learned classifier.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple
|
| |
|
| |
|
| | class SciGateFFN(nn.Module):
|
| | """
|
| | Gated FFN with science domain routing.
|
| | Learns to activate different FFN pathways for different science domains.
|
| | Gate is conditioned on detected domain (math, chemistry, biology etc).
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | d_model: int,
|
| | expansion: int = 4,
|
| | num_domains: int = 7,
|
| | use_domain_tags: bool = True,
|
| | ):
|
| | """
|
| | Initialize SciGateFFN.
|
| |
|
| | Args:
|
| | d_model: Model dimension
|
| | expansion: FFN expansion factor (default 4)
|
| | num_domains: Number of science domains (7)
|
| | use_domain_tags: Whether to use explicit domain tags for routing
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.expansion = expansion
|
| | self.num_domains = num_domains
|
| | self.use_domain_tags = use_domain_tags
|
| |
|
| | hidden_dim = d_model * expansion
|
| |
|
| |
|
| | self.up_proj = nn.Linear(d_model, hidden_dim * 2, bias=False)
|
| | self.down_proj = nn.Linear(hidden_dim, d_model, bias=False)
|
| |
|
| |
|
| |
|
| | self.domain_gate = nn.Linear(num_domains, hidden_dim, bias=True)
|
| |
|
| |
|
| |
|
| | self.fallback_classifier = nn.Sequential(
|
| | nn.Linear(d_model, d_model // 2),
|
| | nn.SiLU(),
|
| | nn.Linear(d_model // 2, num_domains),
|
| | )
|
| |
|
| |
|
| | self._initialize_weights()
|
| |
|
| | def _initialize_weights(self):
|
| | """Initialize weights."""
|
| | for module in [self.up_proj, self.down_proj, self.domain_gate, self.fallback_classifier]:
|
| | if hasattr(module, 'weight'):
|
| | nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if hasattr(module, 'bias') and module.bias is not None:
|
| | nn.init.zeros_(module.bias)
|
| |
|
| | def get_domain_one_hot(
|
| | self,
|
| | domain_ids: Optional[torch.Tensor] = None,
|
| | domain_tags: Optional[torch.Tensor] = None,
|
| | hidden_states: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Get domain one-hot vector for routing.
|
| |
|
| | Hybrid strategy:
|
| | 1. If domain_tags provided (explicit [MATH], [CHEM] etc), use those
|
| | 2. If domain_ids provided (from data loader), use those
|
| | 3. Fallback: classify from hidden_states
|
| |
|
| | Args:
|
| | domain_ids: Tensor of domain IDs (batch, seq_len) or (batch,)
|
| | domain_tags: Boolean mask for domain tags (batch, seq_len, num_domains)
|
| | hidden_states: Hidden states for fallback classification (batch, seq_len, d_model)
|
| |
|
| | Returns:
|
| | domain_one_hot: (batch, seq_len, num_domains)
|
| | """
|
| | batch, seq_len, _ = hidden_states.shape if hidden_states is not None else (0, 0, 0)
|
| |
|
| | if domain_tags is not None and domain_tags.any():
|
| |
|
| | return domain_tags.float()
|
| | elif domain_ids is not None:
|
| |
|
| | if domain_ids.dim() == 1:
|
| |
|
| | domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
|
| |
|
| | domain_one_hot = domain_one_hot.unsqueeze(1).expand(-1, seq_len, -1)
|
| | else:
|
| |
|
| | domain_one_hot = F.one_hot(domain_ids, num_classes=self.num_domains)
|
| | return domain_one_hot.float()
|
| | elif hidden_states is not None:
|
| |
|
| |
|
| | pooled = hidden_states.mean(dim=1)
|
| | domain_logits = self.fallback_classifier(pooled)
|
| | domain_probs = F.softmax(domain_logits, dim=-1)
|
| |
|
| | return domain_probs.unsqueeze(1).expand(-1, seq_len, -1)
|
| | else:
|
| |
|
| | uniform = torch.ones(batch, seq_len, self.num_domains, device=hidden_states.device if hidden_states is not None else 'cpu')
|
| | return uniform / self.num_domains
|
| |
|
| | def forward(
|
| | self,
|
| | x: torch.Tensor,
|
| | domain_ids: Optional[torch.Tensor] = None,
|
| | domain_tags: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass with domain-aware gating.
|
| |
|
| | Args:
|
| | x: Input tensor (batch, seq_len, d_model)
|
| | domain_ids: Optional domain IDs (batch,) or (batch, seq_len)
|
| | domain_tags: Optional domain tag mask (batch, seq_len, num_domains)
|
| |
|
| | Returns:
|
| | Output tensor (batch, seq_len, d_model)
|
| | """
|
| | batch, seq_len, d_model = x.shape
|
| |
|
| |
|
| | domain_weights = self.get_domain_one_hot(domain_ids, domain_tags, x)
|
| |
|
| |
|
| |
|
| | up = self.up_proj(x)
|
| | up1, up2 = up.chunk(2, dim=-1)
|
| |
|
| |
|
| | hidden = up1 * F.silu(up2)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | domain_scaling = torch.einsum(
|
| | "bsd,hd->bsh",
|
| | domain_weights,
|
| | self.domain_gate.weight
|
| | )
|
| |
|
| |
|
| |
|
| | hidden = hidden * domain_scaling
|
| |
|
| |
|
| | output = self.down_proj(hidden)
|
| |
|
| | return output
|
| |
|
| |
|
| | def test_scigate_ffn():
|
| | """Test SciGateFFN."""
|
| | batch_size = 2
|
| | seq_len = 128
|
| | d_model = 4096
|
| | num_domains = 7
|
| |
|
| | ffn = SciGateFFN(d_model, expansion=4, num_domains=num_domains)
|
| |
|
| |
|
| | x = torch.randn(batch_size, seq_len, d_model)
|
| | output = ffn(x)
|
| | print(f"Input shape: {x.shape}")
|
| | print(f"Output shape: {output.shape}")
|
| | assert output.shape == x.shape
|
| |
|
| |
|
| | domain_ids = torch.randint(0, num_domains, (batch_size,))
|
| | output2 = ffn(x, domain_ids=domain_ids)
|
| | assert output2.shape == x.shape
|
| |
|
| |
|
| | domain_tags = torch.zeros(batch_size, seq_len, num_domains)
|
| | domain_tags[:, :, 0] = 1.0
|
| | output3 = ffn(x, domain_tags=domain_tags)
|
| | assert output3.shape == x.shape
|
| |
|
| | print("SciGateFFN test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_scigate_ffn()
|
| |
|