vigneshwar234's picture
Add source: tmt/model/ffn.py
c93f22b verified
"""
ffn.py β€” DualStreamFFN: parallel syntax + semantic feed-forward network.
Novel vs standard: instead of a single FFN (d_model β†’ 4*d_model β†’ d_model),
DualStreamFFN runs two parallel streams of half-width (d_model β†’ d_stream),
each specialising on syntax or semantic content, then fuses them with a learned
gate. This gives the same parameter budget as a standard FFN while separating
representational concerns.
"""
from __future__ import annotations
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from .config import TMTConfig
class DualStreamFFN(nn.Module):
"""Two parallel feed-forward streams fused by a learned scalar gate."""
def __init__(self, cfg: TMTConfig) -> None:
super().__init__()
d = cfg.d_model
s = cfg.ffn_stream_dim # each stream width (default 256)
# Syntax stream
self.syntax_up = nn.Linear(d, s)
self.syntax_down = nn.Linear(s, d)
# Semantic stream
self.semantic_up = nn.Linear(d, s)
self.semantic_down = nn.Linear(s, d)
# Learned fusion gate: sigmoid(linear) β†’ scalar ∈ (0,1) per token-dim
self.gate = nn.Linear(d, d)
self.act = nn.GELU()
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, S, D)
Returns:
out: (B, S, D)
"""
syntax = self.dropout(self.syntax_down(self.act(self.syntax_up(x))))
semantic = self.dropout(self.semantic_down(self.act(self.semantic_up(x))))
# Learned fusion gate
g = torch.sigmoid(self.gate(x)) # (B, S, D)
return g * syntax + (1.0 - g) * semantic
def __repr__(self) -> str:
p = sum(p.numel() for p in self.parameters())
return f"DualStreamFFN(streams=2x{self.syntax_up.out_features}, params={p:,})"