vigneshwar234 commited on
Commit
86d7602
·
verified ·
1 Parent(s): 328b377

Add source: tmt/model/mesh.py

Browse files
Files changed (1) hide show
  1. tmt/model/mesh.py +89 -0
tmt/model/mesh.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mesh.py — MeshBuilder: constructs a dynamic token graph each forward pass.
3
+
4
+ Novel vs standard: unlike Graph Transformers that use fixed pre-defined graphs,
5
+ MeshBuilder recomputes the graph topology at every forward pass using cosine
6
+ similarity of the current token representations. Only the top-k nearest
7
+ neighbours are connected, giving a sparse O(S·k) edge set instead of O(S²).
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+
17
+
18
+ def build_mesh(
19
+ x: Tensor,
20
+ k: int,
21
+ batch_size: int,
22
+ seq_len: int,
23
+ ) -> Tuple[Tensor, Tensor]:
24
+ """
25
+ Build a dynamic kNN token graph from token embeddings.
26
+
27
+ Args:
28
+ x: (B*S, D) flattened token representations
29
+ k: number of nearest neighbours per token (graph_k)
30
+ batch_size: B
31
+ seq_len: S
32
+
33
+ Returns:
34
+ edge_index: (2, E) COO edge list in torch_geometric format.
35
+ Edges are within-batch — source and target indices are
36
+ global node indices (0 … B*S-1).
37
+ edge_weight:(E,) cosine similarity of each edge.
38
+ """
39
+ N = batch_size * seq_len # total nodes
40
+
41
+ # Normalise for cosine similarity
42
+ x_norm = F.normalize(x, p=2, dim=-1) # (N, D)
43
+
44
+ # Block-diagonal cosine similarity — only connect tokens within same sample
45
+ # so information never leaks across batch items
46
+ sim_rows, sim_cols, sim_vals = [], [], []
47
+
48
+ for b in range(batch_size):
49
+ start = b * seq_len
50
+ end = start + seq_len
51
+ x_b = x_norm[start:end] # (S, D)
52
+ sim = x_b @ x_b.T # (S, S) cosine sim matrix
53
+
54
+ # Zero out self-connections
55
+ sim.fill_diagonal_(float("-inf"))
56
+
57
+ # Top-k neighbours per token
58
+ actual_k = min(k, seq_len - 1)
59
+ topk_vals, topk_idx = sim.topk(actual_k, dim=-1) # (S, k)
60
+
61
+ src = torch.arange(seq_len, device=x.device).unsqueeze(1).expand(-1, actual_k)
62
+ src = src.reshape(-1) + start
63
+ dst = topk_idx.reshape(-1) + start
64
+ vals = topk_vals.reshape(-1)
65
+
66
+ sim_rows.append(src)
67
+ sim_cols.append(dst)
68
+ sim_vals.append(vals)
69
+
70
+ edge_index = torch.stack(
71
+ [torch.cat(sim_rows), torch.cat(sim_cols)], dim=0
72
+ ) # (2, E)
73
+ edge_weight = torch.cat(sim_vals) # (E,)
74
+
75
+ return edge_index, edge_weight
76
+
77
+
78
+ class MeshBuilder(torch.nn.Module):
79
+ """Thin nn.Module wrapper around build_mesh so it shows in model.repr."""
80
+
81
+ def __init__(self, k: int) -> None:
82
+ super().__init__()
83
+ self.k = k
84
+
85
+ def forward(self, x: Tensor, batch_size: int, seq_len: int) -> Tuple[Tensor, Tensor]:
86
+ return build_mesh(x, self.k, batch_size, seq_len)
87
+
88
+ def __repr__(self) -> str:
89
+ return f"MeshBuilder(k={self.k}, params=0)"