vigneshwar234 commited on
Commit
244a709
·
verified ·
1 Parent(s): 86d7602

Add source: tmt/model/model.py

Browse files
Files changed (1) hide show
  1. tmt/model/model.py +119 -0
tmt/model/model.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py — TMTModel: full TemporalMesh Transformer.
3
+
4
+ Assembles: TokenEmbedding → TemporalPositionEncoder → MeshBuilder →
5
+ TMTLayer × n_layers → OutputProjection.
6
+
7
+ Every forward pass returns a TMTOutput dataclass containing logits plus all
8
+ intermediate diagnostic tensors (exit_masks, graph edges, memory state).
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import List, Optional, Tuple
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+
19
+ from .config import TMTConfig
20
+ from .embedding import TemporalPositionEncoder, TokenEmbedding
21
+ from .layers import TMTLayer
22
+ from .mesh import MeshBuilder
23
+
24
+
25
+ @dataclass
26
+ class TMTOutput:
27
+ logits: Tensor # (B, S, V)
28
+ exit_masks: List[Tensor] # per-layer (B, S) bool
29
+ confidences: List[Tensor] # per-layer (B, S) float
30
+ graph_edges: Tuple[Tensor, Tensor] # (edge_index, edge_weight)
31
+ memory_state: Tensor # (M, D) final memory anchors
32
+ decay_scalars: Tensor # (B, S, D) temporal decay weights
33
+
34
+
35
+ class TMTModel(nn.Module):
36
+ """Full TemporalMesh Transformer."""
37
+
38
+ def __init__(self, cfg: TMTConfig) -> None:
39
+ super().__init__()
40
+ self.cfg = cfg
41
+
42
+ self.embedding = TokenEmbedding(cfg)
43
+ self.pos_encoder = TemporalPositionEncoder(cfg)
44
+ self.mesh_builder = MeshBuilder(cfg.graph_k)
45
+ self.layers = nn.ModuleList(
46
+ [TMTLayer(cfg, i) for i in range(cfg.n_layers)]
47
+ )
48
+ self.norm = nn.LayerNorm(cfg.d_model, eps=cfg.layer_norm_eps)
49
+ self.output_proj = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
50
+
51
+ # Tie output projection weights to embedding for parameter efficiency
52
+ self.output_proj.weight = self.embedding.embed.weight
53
+
54
+ self._init_weights()
55
+
56
+ def _init_weights(self) -> None:
57
+ for module in self.modules():
58
+ if isinstance(module, nn.Linear):
59
+ nn.init.normal_(module.weight, std=0.02)
60
+ if module.bias is not None:
61
+ nn.init.zeros_(module.bias)
62
+
63
+ def forward(self, input_ids: Tensor) -> TMTOutput:
64
+ """
65
+ Args:
66
+ input_ids: (B, S) integer token ids
67
+ Returns:
68
+ TMTOutput with logits and all diagnostic fields
69
+ """
70
+ B, S = input_ids.shape
71
+
72
+ # Phase 1: embed + temporal position encode
73
+ x = self.embedding(input_ids) # (B, S, D)
74
+ x, decay_scalars = self.pos_encoder(x) # (B, S, D), (B, S, D)
75
+
76
+ # Phase 2: build dynamic mesh graph
77
+ x_flat = x.reshape(B * S, self.cfg.d_model)
78
+ edge_index, edge_weight = self.mesh_builder(x_flat, B, S)
79
+
80
+ # Phase 3: pass through TMT layers with adaptive depth routing
81
+ exit_mask = torch.zeros(B, S, dtype=torch.bool, device=input_ids.device)
82
+ exit_masks: List[Tensor] = []
83
+ confidences: List[Tensor] = []
84
+ memory_state: Optional[Tensor] = None
85
+
86
+ for layer in self.layers:
87
+ x, exit_mask, confidence, memory_state = layer(
88
+ x, edge_index, edge_weight, exit_mask, decay_scalars
89
+ )
90
+ exit_masks.append(exit_mask.clone())
91
+ confidences.append(confidence.clone())
92
+
93
+ # Rebuild graph after each layer using updated representations
94
+ x_flat = x.reshape(B * S, self.cfg.d_model)
95
+ edge_index, edge_weight = self.mesh_builder(x_flat, B, S)
96
+
97
+ # Phase 4: project to vocabulary
98
+ x = self.norm(x)
99
+ logits = self.output_proj(x) # (B, S, V)
100
+
101
+ return TMTOutput(
102
+ logits=logits,
103
+ exit_masks=exit_masks,
104
+ confidences=confidences,
105
+ graph_edges=(edge_index, edge_weight),
106
+ memory_state=memory_state,
107
+ decay_scalars=decay_scalars,
108
+ )
109
+
110
+ def param_count(self) -> int:
111
+ return sum(p.numel() for p in self.parameters())
112
+
113
+ def __repr__(self) -> str:
114
+ return (
115
+ f"TMTModel(\n"
116
+ f" cfg={self.cfg},\n"
117
+ f" total_params={self.param_count() / 1e6:.2f}M\n"
118
+ f")"
119
+ )