vigneshwar234 commited on
Commit
8f6eed4
·
verified ·
1 Parent(s): 244a709

Add source: tmt/model/layers.py

Browse files
Files changed (1) hide show
  1. tmt/model/layers.py +87 -0
tmt/model/layers.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ layers.py — TMTLayer: one full layer of the TemporalMesh Transformer.
3
+
4
+ Combines MeshAttention → DualStreamFFN → ExitGate → MemoryAnchorCross.
5
+ Tokens that have already exited (exit_mask=True) are frozen — their
6
+ representation from the exiting layer is carried forward unchanged.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+ from .attention import MeshAttention
17
+ from .config import TMTConfig
18
+ from .exit_gate import ExitGate
19
+ from .ffn import DualStreamFFN
20
+ from .memory import MemoryAnchorCross
21
+
22
+
23
+ class TMTLayer(nn.Module):
24
+ def __init__(self, cfg: TMTConfig, layer_idx: int) -> None:
25
+ super().__init__()
26
+ self.layer_idx = layer_idx
27
+
28
+ self.norm1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
29
+ self.attn = MeshAttention(cfg)
30
+
31
+ self.norm2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
32
+ self.ffn = DualStreamFFN(cfg)
33
+
34
+ self.exit_gate = ExitGate(cfg)
35
+
36
+ self.norm3 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
37
+ self.memory_cross = MemoryAnchorCross(cfg)
38
+
39
+ def forward(
40
+ self,
41
+ x: Tensor,
42
+ edge_index: Tensor,
43
+ edge_weight: Tensor,
44
+ exit_mask: Tensor,
45
+ decay_scalars: Optional[Tensor] = None,
46
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
47
+ """
48
+ Args:
49
+ x: (B, S, D)
50
+ edge_index: (2, E)
51
+ edge_weight: (E,)
52
+ exit_mask: (B, S) bool — True where token has exited
53
+ decay_scalars: (B, S, D) optional temporal decay
54
+ Returns:
55
+ x: (B, S, D) updated representations
56
+ exit_mask: (B, S) updated exit mask
57
+ confidence: (B, S) gate confidence scores
58
+ memory_state: (M, D) updated memory anchors
59
+ """
60
+ # Save exited-token representations so we can restore after layer ops
61
+ x_frozen = x.clone()
62
+
63
+ # MeshAttention + residual
64
+ attn_out = self.attn(self.norm1(x), edge_index, edge_weight, decay_scalars)
65
+ x = x + attn_out
66
+
67
+ # DualStreamFFN + residual
68
+ ffn_out = self.ffn(self.norm2(x))
69
+ x = x + ffn_out
70
+
71
+ # ExitGate
72
+ x, exit_mask, confidence = self.exit_gate(x, exit_mask)
73
+
74
+ # Memory cross-attention + residual
75
+ mem_out, memory_state = self.memory_cross(self.norm3(x))
76
+ x = x + mem_out
77
+
78
+ # Freeze exited tokens: replace with their pre-layer values
79
+ if exit_mask.any():
80
+ frozen = exit_mask.unsqueeze(-1).expand_as(x)
81
+ x = torch.where(frozen, x_frozen, x)
82
+
83
+ return x, exit_mask, confidence, memory_state
84
+
85
+ def __repr__(self) -> str:
86
+ p = sum(p.numel() for p in self.parameters())
87
+ return f"TMTLayer(idx={self.layer_idx}, params={p:,})"