vigneshwar234 commited on
Commit
c93f22b
·
verified ·
1 Parent(s): 4b661a2

Add source: tmt/model/ffn.py

Browse files
Files changed (1) hide show
  1. tmt/model/ffn.py +58 -0
tmt/model/ffn.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ffn.py — DualStreamFFN: parallel syntax + semantic feed-forward network.
3
+
4
+ Novel vs standard: instead of a single FFN (d_model → 4*d_model → d_model),
5
+ DualStreamFFN runs two parallel streams of half-width (d_model → d_stream),
6
+ each specialising on syntax or semantic content, then fuses them with a learned
7
+ gate. This gives the same parameter budget as a standard FFN while separating
8
+ representational concerns.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+ from torch import Tensor
16
+
17
+ from .config import TMTConfig
18
+
19
+
20
+ class DualStreamFFN(nn.Module):
21
+ """Two parallel feed-forward streams fused by a learned scalar gate."""
22
+
23
+ def __init__(self, cfg: TMTConfig) -> None:
24
+ super().__init__()
25
+ d = cfg.d_model
26
+ s = cfg.ffn_stream_dim # each stream width (default 256)
27
+
28
+ # Syntax stream
29
+ self.syntax_up = nn.Linear(d, s)
30
+ self.syntax_down = nn.Linear(s, d)
31
+
32
+ # Semantic stream
33
+ self.semantic_up = nn.Linear(d, s)
34
+ self.semantic_down = nn.Linear(s, d)
35
+
36
+ # Learned fusion gate: sigmoid(linear) → scalar ∈ (0,1) per token-dim
37
+ self.gate = nn.Linear(d, d)
38
+
39
+ self.act = nn.GELU()
40
+ self.dropout = nn.Dropout(cfg.dropout)
41
+
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ """
44
+ Args:
45
+ x: (B, S, D)
46
+ Returns:
47
+ out: (B, S, D)
48
+ """
49
+ syntax = self.dropout(self.syntax_down(self.act(self.syntax_up(x))))
50
+ semantic = self.dropout(self.semantic_down(self.act(self.semantic_up(x))))
51
+
52
+ # Learned fusion gate
53
+ g = torch.sigmoid(self.gate(x)) # (B, S, D)
54
+ return g * syntax + (1.0 - g) * semantic
55
+
56
+ def __repr__(self) -> str:
57
+ p = sum(p.numel() for p in self.parameters())
58
+ return f"DualStreamFFN(streams=2x{self.syntax_up.out_features}, params={p:,})"