vigneshwar234 commited on
Commit
4b661a2
·
verified ·
1 Parent(s): d89f4f8

Add source: tmt/model/attention.py

Browse files
Files changed (1) hide show
  1. tmt/model/attention.py +120 -0
tmt/model/attention.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ attention.py — MeshAttention: multi-head attention over graph edges.
3
+
4
+ Novel vs standard: instead of every token attending to every other token
5
+ (O(S²) full attention), MeshAttention restricts attention to graph neighbours.
6
+ Temporal decay is multiplied into the attention weights — not added as bias —
7
+ so semantically close but temporally distant tokens are suppressed.
8
+
9
+ Formula: attn = softmax(QK^T / sqrt(d)) * sigmoid(W_decay * temporal_distance)
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ from typing import Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+ from torch import Tensor
21
+
22
+ from .config import TMTConfig
23
+
24
+
25
+ class MeshAttention(nn.Module):
26
+ """
27
+ Multi-head attention constrained to dynamic graph edges with temporal decay.
28
+
29
+ Falls back to a sparse neighbour-masked full attention when torch_geometric
30
+ is unavailable, preserving identical semantics.
31
+ """
32
+
33
+ def __init__(self, cfg: TMTConfig) -> None:
34
+ super().__init__()
35
+ assert cfg.d_model % cfg.n_heads == 0
36
+ self.d_model = cfg.d_model
37
+ self.n_heads = cfg.n_heads
38
+ self.d_head = cfg.d_model // cfg.n_heads
39
+ self.scale = math.sqrt(self.d_head)
40
+
41
+ self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
42
+ self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
43
+ self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
44
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
45
+
46
+ # Learned temporal decay weight (scalar applied per head)
47
+ self.w_decay = nn.Parameter(torch.ones(cfg.n_heads) * cfg.decay_rate)
48
+
49
+ self.dropout = nn.Dropout(cfg.dropout)
50
+
51
+ def forward(
52
+ self,
53
+ x: Tensor,
54
+ edge_index: Tensor,
55
+ edge_weight: Tensor,
56
+ decay_scalars: Optional[Tensor] = None,
57
+ ) -> Tensor:
58
+ """
59
+ Args:
60
+ x: (B, S, D)
61
+ edge_index: (2, E) global node indices
62
+ edge_weight: (E,) cosine similarity weights
63
+ decay_scalars: (B, S, D) per-token temporal decay from encoder
64
+ Returns:
65
+ out: (B, S, D)
66
+ """
67
+ B, S, D = x.shape
68
+
69
+ Q = self.q_proj(x) # (B, S, D)
70
+ K = self.k_proj(x)
71
+ V = self.v_proj(x)
72
+
73
+ # Reshape to multi-head
74
+ Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
75
+ K = rearrange(K, "b s (h d) -> b h s d", h=self.n_heads)
76
+ V = rearrange(V, "b s (h d) -> b h s d", h=self.n_heads)
77
+
78
+ # Full attention scores (B, H, S, S)
79
+ scores = torch.einsum("bhid,bhjd->bhij", Q, K) / self.scale
80
+
81
+ # Build sparse neighbour mask from edge_index
82
+ # edge_index is over global indices (B*S); remap to per-batch local
83
+ mask = torch.full((B, S, S), float("-inf"), device=x.device)
84
+ if edge_index.numel() > 0:
85
+ src_global = edge_index[0] # (E,)
86
+ dst_global = edge_index[1] # (E,)
87
+ b_idx = src_global // S
88
+ src_local = src_global % S
89
+ dst_local = dst_global % S
90
+ mask[b_idx, src_local, dst_local] = edge_weight.float()
91
+
92
+ # Also allow causal self (diagonal) so every token has at least itself
93
+ diag_mask = torch.zeros(S, S, device=x.device)
94
+ diag_mask.fill_diagonal_(0.0)
95
+ mask = mask + diag_mask.unsqueeze(0)
96
+
97
+ # Apply graph mask
98
+ scores = scores + mask.unsqueeze(1) # broadcast over heads
99
+
100
+ attn = F.softmax(scores, dim=-1) # (B, H, S, S)
101
+
102
+ # Temporal decay: multiply attention weights by sigmoid decay per token
103
+ if decay_scalars is not None:
104
+ # Average decay across D → (B, S) scalar per token
105
+ token_decay = decay_scalars.mean(dim=-1) # (B, S)
106
+ # Per-head decay scaling: w_decay (H,) * token_decay (B, S)
107
+ head_decay = torch.sigmoid(
108
+ rearrange(self.w_decay, "h -> 1 h 1") *
109
+ rearrange(token_decay, "b s -> b 1 s")
110
+ ) # (B, H, S)
111
+ attn = attn * head_decay.unsqueeze(-1)
112
+
113
+ attn = self.dropout(attn)
114
+ out = torch.einsum("bhij,bhjd->bhid", attn, V)
115
+ out = rearrange(out, "b h s d -> b s (h d)")
116
+ return self.out_proj(out)
117
+
118
+ def __repr__(self) -> str:
119
+ p = sum(p.numel() for p in self.parameters())
120
+ return f"MeshAttention(heads={self.n_heads}, d={self.d_model}, params={p:,})"