vigneshwar234 commited on
Commit
d9f2f9f
·
verified ·
1 Parent(s): fddc74a

Add TMT source code for direct installation from HuggingFace

Browse files
tmt/model/attention.py CHANGED
@@ -11,7 +11,7 @@ Formula: attn = softmax(QK^T / sqrt(d)) * sigmoid(W_decay * temporal_distance)
11
  from __future__ import annotations
12
 
13
  import math
14
- from typing import Optional, Tuple
15
 
16
  import torch
17
  import torch.nn as nn
@@ -89,10 +89,10 @@ class MeshAttention(nn.Module):
89
  dst_local = dst_global % S
90
  mask[b_idx, src_local, dst_local] = edge_weight.float()
91
 
92
- # Also allow causal self (diagonal) so every token has at least itself
93
- diag_mask = torch.zeros(S, S, device=x.device)
94
- diag_mask.fill_diagonal_(0.0)
95
- mask = mask + diag_mask.unsqueeze(0)
96
 
97
  # Apply graph mask
98
  scores = scores + mask.unsqueeze(1) # broadcast over heads
 
11
  from __future__ import annotations
12
 
13
  import math
14
+ from typing import Optional
15
 
16
  import torch
17
  import torch.nn as nn
 
89
  dst_local = dst_global % S
90
  mask[b_idx, src_local, dst_local] = edge_weight.float()
91
 
92
+ # Allow self-attention on the diagonal so every token attends to itself.
93
+ # Direct index-assignment instead of add so -inf diagonal becomes 0.
94
+ diag_idx = torch.arange(S, device=x.device)
95
+ mask[:, diag_idx, diag_idx] = 0.0
96
 
97
  # Apply graph mask
98
  scores = scores + mask.unsqueeze(1) # broadcast over heads
tmt/model/config.py CHANGED
@@ -5,7 +5,7 @@ Novel vs standard: a single config surface that governs dynamic graph topology
5
  (graph_k), per-token adaptive depth (exit_threshold), temporal decay rate, and
6
  the dual-stream FFN — none of which exist in vanilla transformer configs.
7
  """
8
- from dataclasses import dataclass, field
9
 
10
 
11
  @dataclass
 
5
  (graph_k), per-token adaptive depth (exit_threshold), temporal decay rate, and
6
  the dual-stream FFN — none of which exist in vanilla transformer configs.
7
  """
8
+ from dataclasses import dataclass
9
 
10
 
11
  @dataclass
tmt/model/ffn.py CHANGED
@@ -11,7 +11,6 @@ from __future__ import annotations
11
 
12
  import torch
13
  import torch.nn as nn
14
- from einops import rearrange
15
  from torch import Tensor
16
 
17
  from .config import TMTConfig
 
11
 
12
  import torch
13
  import torch.nn as nn
 
14
  from torch import Tensor
15
 
16
  from .config import TMTConfig
tmt/model/memory.py CHANGED
@@ -53,7 +53,6 @@ class MemoryAnchorCross(nn.Module):
53
  memory_state: (M, D) updated memory anchors (detached for logging)
54
  """
55
  B, S, D = x.shape
56
- M = self.n_anchors
57
  scale = self.d_head ** -0.5
58
 
59
  # Queries come from tokens, Keys/Values from memory anchors
@@ -75,13 +74,14 @@ class MemoryAnchorCross(nn.Module):
75
  out = rearrange(out, "b h s d -> b s (h d)")
76
  out = self.out_proj(out)
77
 
78
- # EMA update of memory anchors using mean token representation
79
- with torch.no_grad():
80
- token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
81
- self.memory.data = (
82
- self.ema_alpha * self.memory.data
83
- + (1 - self.ema_alpha) * token_mean.unsqueeze(0)
84
- )
 
85
 
86
  return out, self.memory.detach()
87
 
 
53
  memory_state: (M, D) updated memory anchors (detached for logging)
54
  """
55
  B, S, D = x.shape
 
56
  scale = self.d_head ** -0.5
57
 
58
  # Queries come from tokens, Keys/Values from memory anchors
 
74
  out = rearrange(out, "b h s d -> b s (h d)")
75
  out = self.out_proj(out)
76
 
77
+ # EMA update only during training eval must be deterministic
78
+ if self.training:
79
+ with torch.no_grad():
80
+ token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
81
+ self.memory.data = (
82
+ self.ema_alpha * self.memory.data
83
+ + (1 - self.ema_alpha) * token_mean.unsqueeze(0)
84
+ )
85
 
86
  return out, self.memory.detach()
87
 
tmt/model/mesh.py CHANGED
@@ -36,8 +36,6 @@ def build_mesh(
36
  global node indices (0 … B*S-1).
37
  edge_weight:(E,) cosine similarity of each edge.
38
  """
39
- N = batch_size * seq_len # total nodes
40
-
41
  # Normalise for cosine similarity
42
  x_norm = F.normalize(x, p=2, dim=-1) # (N, D)
43
 
 
36
  global node indices (0 … B*S-1).
37
  edge_weight:(E,) cosine similarity of each edge.
38
  """
 
 
39
  # Normalise for cosine similarity
40
  x_norm = F.normalize(x, p=2, dim=-1) # (N, D)
41
 
tmt/model/model.py CHANGED
@@ -9,7 +9,7 @@ intermediate diagnostic tensors (exit_masks, graph edges, memory state).
9
  """
10
  from __future__ import annotations
11
 
12
- from dataclasses import dataclass, field
13
  from typing import List, Optional, Tuple
14
 
15
  import torch
 
9
  """
10
  from __future__ import annotations
11
 
12
+ from dataclasses import dataclass
13
  from typing import List, Optional, Tuple
14
 
15
  import torch
tmt/training/loss.py CHANGED
@@ -48,7 +48,7 @@ def compute_loss(
48
 
49
  # Exit gate auxiliary: encourage decisiveness
50
  # Loss = -E[|conf - 0.5|] — penalise uncertainty
51
- gate_loss = torch.zeros(1, device=logits.device)
52
  for conf in confidences:
53
  gate_loss = gate_loss + -(conf - 0.5).abs().mean()
54
  gate_loss = gate_loss / max(len(confidences), 1)
 
48
 
49
  # Exit gate auxiliary: encourage decisiveness
50
  # Loss = -E[|conf - 0.5|] — penalise uncertainty
51
+ gate_loss = torch.zeros((), device=logits.device)
52
  for conf in confidences:
53
  gate_loss = gate_loss + -(conf - 0.5).abs().mean()
54
  gate_loss = gate_loss / max(len(confidences), 1)
tmt/training/trainer.py CHANGED
@@ -7,12 +7,11 @@ Logs: train loss, val perplexity, exit rate per layer, and memory anchor norms.
7
  from __future__ import annotations
8
 
9
  import os
10
- from dataclasses import dataclass, field
11
  from typing import Optional
12
 
13
  import torch
14
  import torch.nn as nn
15
- from torch import Tensor
16
  from torch.optim import AdamW
17
  from torch.utils.data import DataLoader
18
 
 
7
  from __future__ import annotations
8
 
9
  import os
10
+ from dataclasses import dataclass
11
  from typing import Optional
12
 
13
  import torch
14
  import torch.nn as nn
 
15
  from torch.optim import AdamW
16
  from torch.utils.data import DataLoader
17