vigneshwar234's picture
Add source: tmt/model/layers.py
8f6eed4 verified
"""
layers.py β€” TMTLayer: one full layer of the TemporalMesh Transformer.
Combines MeshAttention β†’ DualStreamFFN β†’ ExitGate β†’ MemoryAnchorCross.
Tokens that have already exited (exit_mask=True) are frozen β€” their
representation from the exiting layer is carried forward unchanged.
"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from .attention import MeshAttention
from .config import TMTConfig
from .exit_gate import ExitGate
from .ffn import DualStreamFFN
from .memory import MemoryAnchorCross
class TMTLayer(nn.Module):
def __init__(self, cfg: TMTConfig, layer_idx: int) -> None:
super().__init__()
self.layer_idx = layer_idx
self.norm1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.attn = MeshAttention(cfg)
self.norm2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.ffn = DualStreamFFN(cfg)
self.exit_gate = ExitGate(cfg)
self.norm3 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
self.memory_cross = MemoryAnchorCross(cfg)
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: Tensor,
exit_mask: Tensor,
decay_scalars: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Args:
x: (B, S, D)
edge_index: (2, E)
edge_weight: (E,)
exit_mask: (B, S) bool β€” True where token has exited
decay_scalars: (B, S, D) optional temporal decay
Returns:
x: (B, S, D) updated representations
exit_mask: (B, S) updated exit mask
confidence: (B, S) gate confidence scores
memory_state: (M, D) updated memory anchors
"""
# Save exited-token representations so we can restore after layer ops
x_frozen = x.clone()
# MeshAttention + residual
attn_out = self.attn(self.norm1(x), edge_index, edge_weight, decay_scalars)
x = x + attn_out
# DualStreamFFN + residual
ffn_out = self.ffn(self.norm2(x))
x = x + ffn_out
# ExitGate
x, exit_mask, confidence = self.exit_gate(x, exit_mask)
# Memory cross-attention + residual
mem_out, memory_state = self.memory_cross(self.norm3(x))
x = x + mem_out
# Freeze exited tokens: replace with their pre-layer values
if exit_mask.any():
frozen = exit_mask.unsqueeze(-1).expand_as(x)
x = torch.where(frozen, x_frozen, x)
return x, exit_mask, confidence, memory_state
def __repr__(self) -> str:
p = sum(p.numel() for p in self.parameters())
return f"TMTLayer(idx={self.layer_idx}, params={p:,})"