Text Generation
PyTorch
Transformers
English
language-model
graph-neural-network
sparse-attention
adaptive-depth
temporal-decay
mesh-attention
efficient-transformer
novel-architecture
causal-lm
research
preprint
mesh-transformer
dynamic-graph
early-exit
per-token-routing
Eval Results (legacy)
Instructions to use vigneshwar234/TemporalMesh-Transformer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use vigneshwar234/TemporalMesh-Transformer with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="vigneshwar234/TemporalMesh-Transformer")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("vigneshwar234/TemporalMesh-Transformer", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use vigneshwar234/TemporalMesh-Transformer with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "vigneshwar234/TemporalMesh-Transformer" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/vigneshwar234/TemporalMesh-Transformer
- SGLang
How to use vigneshwar234/TemporalMesh-Transformer with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "vigneshwar234/TemporalMesh-Transformer" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "vigneshwar234/TemporalMesh-Transformer" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "vigneshwar234/TemporalMesh-Transformer", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use vigneshwar234/TemporalMesh-Transformer with Docker Model Runner:
docker model run hf.co/vigneshwar234/TemporalMesh-Transformer
Add source: tmt/model/memory.py
Browse files- tmt/model/memory.py +90 -0
tmt/model/memory.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
memory.py — MemoryAnchorCross: persistent cross-attention to global memory nodes.
|
| 3 |
+
|
| 4 |
+
Novel vs standard: vanilla transformers have no persistent state across forward
|
| 5 |
+
passes. MemoryAnchorCross maintains 16 learnable nn.Parameter vectors as
|
| 6 |
+
global "anchor" nodes that every token can attend to. After each forward pass
|
| 7 |
+
the anchors are updated via exponential moving average (EMA) of the current
|
| 8 |
+
token representations, giving the model a form of fast-weight memory without
|
| 9 |
+
recurrence.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from typing import Tuple
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
|
| 21 |
+
from .config import TMTConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MemoryAnchorCross(nn.Module):
|
| 25 |
+
"""Cross-attention from token stream to persistent memory anchor nodes."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, cfg: TMTConfig) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.d_model = cfg.d_model
|
| 30 |
+
self.n_heads = cfg.n_heads
|
| 31 |
+
self.d_head = cfg.d_model // cfg.n_heads
|
| 32 |
+
self.n_anchors = cfg.memory_anchors
|
| 33 |
+
self.ema_alpha = 0.9 # EMA decay for memory update
|
| 34 |
+
|
| 35 |
+
# Persistent memory parameters — shape (M, D)
|
| 36 |
+
self.memory = nn.Parameter(
|
| 37 |
+
torch.randn(cfg.memory_anchors, cfg.d_model) * 0.02
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 41 |
+
self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 42 |
+
self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 43 |
+
self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 44 |
+
|
| 45 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
x: (B, S, D) token representations
|
| 51 |
+
Returns:
|
| 52 |
+
out: (B, S, D) tokens enhanced by memory cross-attention
|
| 53 |
+
memory_state: (M, D) updated memory anchors (detached for logging)
|
| 54 |
+
"""
|
| 55 |
+
B, S, D = x.shape
|
| 56 |
+
M = self.n_anchors
|
| 57 |
+
scale = self.d_head ** -0.5
|
| 58 |
+
|
| 59 |
+
# Queries come from tokens, Keys/Values from memory anchors
|
| 60 |
+
Q = self.q_proj(x) # (B, S, D)
|
| 61 |
+
mem = self.memory.unsqueeze(0).expand(B, -1, -1) # (B, M, D)
|
| 62 |
+
K = self.k_proj(mem) # (B, M, D)
|
| 63 |
+
V = self.v_proj(mem) # (B, M, D)
|
| 64 |
+
|
| 65 |
+
Q = rearrange(Q, "b s (h d) -> b h s d", h=self.n_heads)
|
| 66 |
+
K = rearrange(K, "b m (h d) -> b h m d", h=self.n_heads)
|
| 67 |
+
V = rearrange(V, "b m (h d) -> b h m d", h=self.n_heads)
|
| 68 |
+
|
| 69 |
+
# Attention over memory anchors: (B, H, S, M)
|
| 70 |
+
attn = torch.einsum("bhsd,bhmd->bhsm", Q, K) * scale
|
| 71 |
+
attn = F.softmax(attn, dim=-1)
|
| 72 |
+
attn = self.dropout(attn)
|
| 73 |
+
|
| 74 |
+
out = torch.einsum("bhsm,bhmd->bhsd", attn, V) # (B, H, S, D//H)
|
| 75 |
+
out = rearrange(out, "b h s d -> b s (h d)")
|
| 76 |
+
out = self.out_proj(out)
|
| 77 |
+
|
| 78 |
+
# EMA update of memory anchors using mean token representation
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
token_mean = x.mean(dim=1).mean(dim=0) # (D,) across batch
|
| 81 |
+
self.memory.data = (
|
| 82 |
+
self.ema_alpha * self.memory.data
|
| 83 |
+
+ (1 - self.ema_alpha) * token_mean.unsqueeze(0)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return out, self.memory.detach()
|
| 87 |
+
|
| 88 |
+
def __repr__(self) -> str:
|
| 89 |
+
p = sum(p.numel() for p in self.parameters())
|
| 90 |
+
return f"MemoryAnchorCross(anchors={self.n_anchors}, params={p:,})"
|