File size: 4,474 Bytes
4b661a2
 
 
 
 
 
 
 
 
 
 
 
 
d9f2f9f
4b661a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9f2f9f
 
 
 
4b661a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
attention.py — MeshAttention: multi-head attention over graph edges.

Novel vs standard: instead of every token attending to every other token
(O(S²) full attention), MeshAttention restricts attention to graph neighbours.
Temporal decay is multiplied into the attention weights — not added as bias —
so semantically close but temporally distant tokens are suppressed.

Formula: attn = softmax(QK^T / sqrt(d)) * sigmoid(W_decay * temporal_distance)
"""
from __future__ import annotations

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from .config import TMTConfig


class MeshAttention(nn.Module):
    """
    Multi-head attention constrained to dynamic graph edges with temporal decay.

    Falls back to a sparse neighbour-masked full attention when torch_geometric
    is unavailable, preserving identical semantics.
    """

    def __init__(self, cfg: TMTConfig) -> None:
        super().__init__()
        assert cfg.d_model % cfg.n_heads == 0
        self.d_model = cfg.d_model
        self.n_heads = cfg.n_heads
        self.d_head = cfg.d_model // cfg.n_heads
        self.scale = math.sqrt(self.d_head)

        self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)

        # Learned temporal decay weight (scalar applied per head)
        self.w_decay = nn.Parameter(torch.ones(cfg.n_heads) * cfg.decay_rate)

        self.dropout = nn.Dropout(cfg.dropout)

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: Tensor,
        decay_scalars: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Args:
            x:             (B, S, D)
            edge_index:    (2, E) global node indices
            edge_weight:   (E,) cosine similarity weights
            decay_scalars: (B, S, D) per-token temporal decay from encoder
        Returns:
            out: (B, S, D)
        """
        B, S, D = x.shape

        Q = self.q_proj(x)  # (B, S, D)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Reshape to multi-head
        Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
        K = rearrange(K, "b s (h d) -> b h s d", h=self.n_heads)
        V = rearrange(V, "b s (h d) -> b h s d", h=self.n_heads)

        # Full attention scores (B, H, S, S)
        scores = torch.einsum("bhid,bhjd->bhij", Q, K) / self.scale

        # Build sparse neighbour mask from edge_index
        # edge_index is over global indices (B*S); remap to per-batch local
        mask = torch.full((B, S, S), float("-inf"), device=x.device)
        if edge_index.numel() > 0:
            src_global = edge_index[0]  # (E,)
            dst_global = edge_index[1]  # (E,)
            b_idx = src_global // S
            src_local = src_global % S
            dst_local = dst_global % S
            mask[b_idx, src_local, dst_local] = edge_weight.float()

        # Allow self-attention on the diagonal so every token attends to itself.
        # Direct index-assignment instead of add so -inf diagonal becomes 0.
        diag_idx = torch.arange(S, device=x.device)
        mask[:, diag_idx, diag_idx] = 0.0

        # Apply graph mask
        scores = scores + mask.unsqueeze(1)  # broadcast over heads

        attn = F.softmax(scores, dim=-1)  # (B, H, S, S)

        # Temporal decay: multiply attention weights by sigmoid decay per token
        if decay_scalars is not None:
            # Average decay across D → (B, S) scalar per token
            token_decay = decay_scalars.mean(dim=-1)  # (B, S)
            # Per-head decay scaling: w_decay (H,) * token_decay (B, S)
            head_decay = torch.sigmoid(
                rearrange(self.w_decay, "h -> 1 h 1") *
                rearrange(token_decay, "b s -> b 1 s")
            )  # (B, H, S)
            attn = attn * head_decay.unsqueeze(-1)

        attn = self.dropout(attn)
        out = torch.einsum("bhij,bhjd->bhid", attn, V)
        out = rearrange(out, "b h s d -> b s (h d)")
        return self.out_proj(out)

    def __repr__(self) -> str:
        p = sum(p.numel() for p in self.parameters())
        return f"MeshAttention(heads={self.n_heads}, d={self.d_model}, params={p:,})"