vigneshwar234 commited on
Commit
9d0034a
·
verified ·
1 Parent(s): c93f22b

Add source: tmt/model/config.py

Browse files
Files changed (1) hide show
  1. tmt/model/config.py +58 -0
tmt/model/config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TMTConfig — central configuration for the TemporalMesh Transformer.
3
+
4
+ Novel vs standard: a single config surface that governs dynamic graph topology
5
+ (graph_k), per-token adaptive depth (exit_threshold), temporal decay rate, and
6
+ the dual-stream FFN — none of which exist in vanilla transformer configs.
7
+ """
8
+ from dataclasses import dataclass, field
9
+
10
+
11
+ @dataclass
12
+ class TMTConfig:
13
+ # Vocabulary & sequence
14
+ vocab_size: int = 32000
15
+ max_seq_len: int = 1024
16
+
17
+ # Core dims
18
+ d_model: int = 512
19
+ n_heads: int = 8
20
+ n_layers: int = 12
21
+
22
+ # Innovation 1 — Mesh Attention
23
+ graph_k: int = 8 # each token connects to k nearest neighbours by cosine sim
24
+
25
+ # Innovation 2 — Temporal decay
26
+ decay_rate: float = 0.1 # base for learned temporal decay scalars
27
+
28
+ # Innovation 3 — Adaptive depth routing
29
+ exit_threshold: float = 0.85 # confidence above which a token exits early
30
+
31
+ # Dual-stream FFN
32
+ dual_stream: bool = True
33
+ ffn_stream_dim: int = 256 # each stream is d_model // 2
34
+
35
+ # Memory anchors
36
+ memory_anchors: int = 16 # number of persistent KV memory parameter vectors
37
+
38
+ # Training
39
+ dropout: float = 0.1
40
+ layer_norm_eps: float = 1e-5
41
+
42
+ def __repr__(self) -> str:
43
+ total_params_est = (
44
+ self.vocab_size * self.d_model # embedding
45
+ + self.n_layers * (
46
+ 4 * self.d_model * self.d_model # attention projections
47
+ + 2 * self.d_model * self.ffn_stream_dim # dual stream FFN
48
+ + self.d_model # exit gate + memory
49
+ )
50
+ )
51
+ return (
52
+ f"TMTConfig("
53
+ f"vocab={self.vocab_size}, d={self.d_model}, "
54
+ f"heads={self.n_heads}, layers={self.n_layers}, "
55
+ f"k={self.graph_k}, decay={self.decay_rate}, "
56
+ f"exit_thr={self.exit_threshold}, "
57
+ f"~params={total_params_est / 1e6:.1f}M)"
58
+ )