File size: 2,868 Bytes
8f6eed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
layers.py — TMTLayer: one full layer of the TemporalMesh Transformer.

Combines MeshAttention → DualStreamFFN → ExitGate → MemoryAnchorCross.
Tokens that have already exited (exit_mask=True) are frozen — their
representation from the exiting layer is carried forward unchanged.
"""
from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from .attention import MeshAttention
from .config import TMTConfig
from .exit_gate import ExitGate
from .ffn import DualStreamFFN
from .memory import MemoryAnchorCross


class TMTLayer(nn.Module):
    def __init__(self, cfg: TMTConfig, layer_idx: int) -> None:
        super().__init__()
        self.layer_idx = layer_idx

        self.norm1 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.attn = MeshAttention(cfg)

        self.norm2 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.ffn = DualStreamFFN(cfg)

        self.exit_gate = ExitGate(cfg)

        self.norm3 = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
        self.memory_cross = MemoryAnchorCross(cfg)

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: Tensor,
        exit_mask: Tensor,
        decay_scalars: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Args:
            x:             (B, S, D)
            edge_index:    (2, E)
            edge_weight:   (E,)
            exit_mask:     (B, S) bool — True where token has exited
            decay_scalars: (B, S, D) optional temporal decay
        Returns:
            x:              (B, S, D) updated representations
            exit_mask:      (B, S) updated exit mask
            confidence:     (B, S) gate confidence scores
            memory_state:   (M, D) updated memory anchors
        """
        # Save exited-token representations so we can restore after layer ops
        x_frozen = x.clone()

        # MeshAttention + residual
        attn_out = self.attn(self.norm1(x), edge_index, edge_weight, decay_scalars)
        x = x + attn_out

        # DualStreamFFN + residual
        ffn_out = self.ffn(self.norm2(x))
        x = x + ffn_out

        # ExitGate
        x, exit_mask, confidence = self.exit_gate(x, exit_mask)

        # Memory cross-attention + residual
        mem_out, memory_state = self.memory_cross(self.norm3(x))
        x = x + mem_out

        # Freeze exited tokens: replace with their pre-layer values
        if exit_mask.any():
            frozen = exit_mask.unsqueeze(-1).expand_as(x)
            x = torch.where(frozen, x_frozen, x)

        return x, exit_mask, confidence, memory_state

    def __repr__(self) -> str:
        p = sum(p.numel() for p in self.parameters())
        return f"TMTLayer(idx={self.layer_idx}, params={p:,})"