Transformers
PyTorch
English
language-model
graph-attention
adaptive-depth
temporal-decay
efficient-llm
Eval Results (legacy)
Instructions to use vigneshwar234/TemporalMesh-Transformer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use vigneshwar234/TemporalMesh-Transformer with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("vigneshwar234/TemporalMesh-Transformer", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add TMT source code for direct installation from HuggingFace
Browse files- tmt/model/attention.py +5 -5
- tmt/model/config.py +1 -1
- tmt/model/ffn.py +0 -1
- tmt/model/memory.py +8 -8
- tmt/model/mesh.py +0 -2
- tmt/model/model.py +1 -1
- tmt/training/loss.py +1 -1
- tmt/training/trainer.py +1 -2
tmt/model/attention.py
CHANGED
|
@@ -11,7 +11,7 @@ Formula: attn = softmax(QK^T / sqrt(d)) * sigmoid(W_decay * temporal_distance)
|
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
import math
|
| 14 |
-
from typing import Optional
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
|
@@ -89,10 +89,10 @@ class MeshAttention(nn.Module):
|
|
| 89 |
dst_local = dst_global % S
|
| 90 |
mask[b_idx, src_local, dst_local] = edge_weight.float()
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
mask
|
| 96 |
|
| 97 |
# Apply graph mask
|
| 98 |
scores = scores + mask.unsqueeze(1) # broadcast over heads
|
|
|
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
import math
|
| 14 |
+
from typing import Optional
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
|
|
|
| 89 |
dst_local = dst_global % S
|
| 90 |
mask[b_idx, src_local, dst_local] = edge_weight.float()
|
| 91 |
|
| 92 |
+
# Allow self-attention on the diagonal so every token attends to itself.
|
| 93 |
+
# Direct index-assignment instead of add so -inf diagonal becomes 0.
|
| 94 |
+
diag_idx = torch.arange(S, device=x.device)
|
| 95 |
+
mask[:, diag_idx, diag_idx] = 0.0
|
| 96 |
|
| 97 |
# Apply graph mask
|
| 98 |
scores = scores + mask.unsqueeze(1) # broadcast over heads
|
tmt/model/config.py
CHANGED
|
@@ -5,7 +5,7 @@ Novel vs standard: a single config surface that governs dynamic graph topology
|
|
| 5 |
(graph_k), per-token adaptive depth (exit_threshold), temporal decay rate, and
|
| 6 |
the dual-stream FFN — none of which exist in vanilla transformer configs.
|
| 7 |
"""
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
|
|
| 5 |
(graph_k), per-token adaptive depth (exit_threshold), temporal decay rate, and
|
| 6 |
the dual-stream FFN — none of which exist in vanilla transformer configs.
|
| 7 |
"""
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
tmt/model/ffn.py
CHANGED
|
@@ -11,7 +11,6 @@ from __future__ import annotations
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
-
from einops import rearrange
|
| 15 |
from torch import Tensor
|
| 16 |
|
| 17 |
from .config import TMTConfig
|
|
|
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
|
|
|
| 14 |
from torch import Tensor
|
| 15 |
|
| 16 |
from .config import TMTConfig
|
tmt/model/memory.py
CHANGED
|
@@ -53,7 +53,6 @@ class MemoryAnchorCross(nn.Module):
|
|
| 53 |
memory_state: (M, D) updated memory anchors (detached for logging)
|
| 54 |
"""
|
| 55 |
B, S, D = x.shape
|
| 56 |
-
M = self.n_anchors
|
| 57 |
scale = self.d_head ** -0.5
|
| 58 |
|
| 59 |
# Queries come from tokens, Keys/Values from memory anchors
|
|
@@ -75,13 +74,14 @@ class MemoryAnchorCross(nn.Module):
|
|
| 75 |
out = rearrange(out, "b h s d -> b s (h d)")
|
| 76 |
out = self.out_proj(out)
|
| 77 |
|
| 78 |
-
# EMA update
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
self.
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
return out, self.memory.detach()
|
| 87 |
|
|
|
|
| 53 |
memory_state: (M, D) updated memory anchors (detached for logging)
|
| 54 |
"""
|
| 55 |
B, S, D = x.shape
|
|
|
|
| 56 |
scale = self.d_head ** -0.5
|
| 57 |
|
| 58 |
# Queries come from tokens, Keys/Values from memory anchors
|
|
|
|
| 74 |
out = rearrange(out, "b h s d -> b s (h d)")
|
| 75 |
out = self.out_proj(out)
|
| 76 |
|
| 77 |
+
# EMA update only during training — eval must be deterministic
|
| 78 |
+
if self.training:
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
|
| 81 |
+
self.memory.data = (
|
| 82 |
+
self.ema_alpha * self.memory.data
|
| 83 |
+
+ (1 - self.ema_alpha) * token_mean.unsqueeze(0)
|
| 84 |
+
)
|
| 85 |
|
| 86 |
return out, self.memory.detach()
|
| 87 |
|
tmt/model/mesh.py
CHANGED
|
@@ -36,8 +36,6 @@ def build_mesh(
|
|
| 36 |
global node indices (0 … B*S-1).
|
| 37 |
edge_weight:(E,) cosine similarity of each edge.
|
| 38 |
"""
|
| 39 |
-
N = batch_size * seq_len # total nodes
|
| 40 |
-
|
| 41 |
# Normalise for cosine similarity
|
| 42 |
x_norm = F.normalize(x, p=2, dim=-1) # (N, D)
|
| 43 |
|
|
|
|
| 36 |
global node indices (0 … B*S-1).
|
| 37 |
edge_weight:(E,) cosine similarity of each edge.
|
| 38 |
"""
|
|
|
|
|
|
|
| 39 |
# Normalise for cosine similarity
|
| 40 |
x_norm = F.normalize(x, p=2, dim=-1) # (N, D)
|
| 41 |
|
tmt/model/model.py
CHANGED
|
@@ -9,7 +9,7 @@ intermediate diagnostic tensors (exit_masks, graph edges, memory state).
|
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
from typing import List, Optional, Tuple
|
| 14 |
|
| 15 |
import torch
|
|
|
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
from typing import List, Optional, Tuple
|
| 14 |
|
| 15 |
import torch
|
tmt/training/loss.py
CHANGED
|
@@ -48,7 +48,7 @@ def compute_loss(
|
|
| 48 |
|
| 49 |
# Exit gate auxiliary: encourage decisiveness
|
| 50 |
# Loss = -E[|conf - 0.5|] — penalise uncertainty
|
| 51 |
-
gate_loss = torch.zeros(
|
| 52 |
for conf in confidences:
|
| 53 |
gate_loss = gate_loss + -(conf - 0.5).abs().mean()
|
| 54 |
gate_loss = gate_loss / max(len(confidences), 1)
|
|
|
|
| 48 |
|
| 49 |
# Exit gate auxiliary: encourage decisiveness
|
| 50 |
# Loss = -E[|conf - 0.5|] — penalise uncertainty
|
| 51 |
+
gate_loss = torch.zeros((), device=logits.device)
|
| 52 |
for conf in confidences:
|
| 53 |
gate_loss = gate_loss + -(conf - 0.5).abs().mean()
|
| 54 |
gate_loss = gate_loss / max(len(confidences), 1)
|
tmt/training/trainer.py
CHANGED
|
@@ -7,12 +7,11 @@ Logs: train loss, val perplexity, exit rate per layer, and memory anchor norms.
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import os
|
| 10 |
-
from dataclasses import dataclass
|
| 11 |
from typing import Optional
|
| 12 |
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
| 15 |
-
from torch import Tensor
|
| 16 |
from torch.optim import AdamW
|
| 17 |
from torch.utils.data import DataLoader
|
| 18 |
|
|
|
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
import os
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
from typing import Optional
|
| 12 |
|
| 13 |
import torch
|
| 14 |
import torch.nn as nn
|
|
|
|
| 15 |
from torch.optim import AdamW
|
| 16 |
from torch.utils.data import DataLoader
|
| 17 |
|