File size: 1,898 Bytes
c93f22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
ffn.py — DualStreamFFN: parallel syntax + semantic feed-forward network.

Novel vs standard: instead of a single FFN (d_model → 4*d_model → d_model),
DualStreamFFN runs two parallel streams of half-width (d_model → d_stream),
each specialising on syntax or semantic content, then fuses them with a learned
gate.  This gives the same parameter budget as a standard FFN while separating
representational concerns.
"""
from __future__ import annotations

import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor

from .config import TMTConfig


class DualStreamFFN(nn.Module):
    """Two parallel feed-forward streams fused by a learned scalar gate."""

    def __init__(self, cfg: TMTConfig) -> None:
        super().__init__()
        d = cfg.d_model
        s = cfg.ffn_stream_dim  # each stream width (default 256)

        # Syntax stream
        self.syntax_up = nn.Linear(d, s)
        self.syntax_down = nn.Linear(s, d)

        # Semantic stream
        self.semantic_up = nn.Linear(d, s)
        self.semantic_down = nn.Linear(s, d)

        # Learned fusion gate: sigmoid(linear) → scalar ∈ (0,1) per token-dim
        self.gate = nn.Linear(d, d)

        self.act = nn.GELU()
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: (B, S, D)
        Returns:
            out: (B, S, D)
        """
        syntax = self.dropout(self.syntax_down(self.act(self.syntax_up(x))))
        semantic = self.dropout(self.semantic_down(self.act(self.semantic_up(x))))

        # Learned fusion gate
        g = torch.sigmoid(self.gate(x))  # (B, S, D)
        return g * syntax + (1.0 - g) * semantic

    def __repr__(self) -> str:
        p = sum(p.numel() for p in self.parameters())
        return f"DualStreamFFN(streams=2x{self.syntax_up.out_features}, params={p:,})"