vigneshwar234 commited on
Commit
933490a
·
verified ·
1 Parent(s): 18cff1d

Add source: tmt/model/embedding.py

Browse files
Files changed (1) hide show
  1. tmt/model/embedding.py +107 -0
tmt/model/embedding.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ embedding.py — TokenEmbedding and TemporalPositionEncoder.
3
+
4
+ Novel vs standard: RoPE positional encoding is extended with per-token learned
5
+ decay scalars so that semantically distant tokens are attenuated before they
6
+ reach the attention layer — no recurrence needed.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from einops import rearrange
16
+ from torch import Tensor
17
+
18
+ from .config import TMTConfig
19
+
20
+
21
+ class TokenEmbedding(nn.Module):
22
+ """Standard learned token embedding with output projection scale."""
23
+
24
+ def __init__(self, cfg: TMTConfig) -> None:
25
+ super().__init__()
26
+ self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
27
+ self.scale = math.sqrt(cfg.d_model)
28
+ nn.init.normal_(self.embed.weight, std=0.02)
29
+
30
+ def forward(self, token_ids: Tensor) -> Tensor:
31
+ # token_ids: (B, S) → (B, S, D)
32
+ return self.embed(token_ids) * self.scale
33
+
34
+ def __repr__(self) -> str:
35
+ p = sum(p.numel() for p in self.parameters())
36
+ return f"TokenEmbedding(params={p:,})"
37
+
38
+
39
+ class TemporalPositionEncoder(nn.Module):
40
+ """
41
+ RoPE base + learned temporal decay scalars per position.
42
+
43
+ Decay scalar: sigmoid(w_decay · t) where t is the absolute position index
44
+ normalised to [0, 1] over max_seq_len. The scalar multiplies the embedding
45
+ before it reaches MeshAttention so semantically distant tokens fade.
46
+ """
47
+
48
+ def __init__(self, cfg: TMTConfig) -> None:
49
+ super().__init__()
50
+ self.d_model = cfg.d_model
51
+ self.max_seq_len = cfg.max_seq_len
52
+ self.decay_rate = cfg.decay_rate
53
+
54
+ # Learned decay weights — one per position dimension pair
55
+ self.w_decay = nn.Parameter(
56
+ torch.full((cfg.d_model,), cfg.decay_rate)
57
+ )
58
+
59
+ # RoPE cos/sin cache (not a parameter — regenerated on device change)
60
+ self._build_rope_cache(cfg.max_seq_len, cfg.d_model)
61
+
62
+ def _build_rope_cache(self, max_len: int, d: int) -> None:
63
+ half = d // 2
64
+ theta = 1.0 / (10000 ** (torch.arange(0, half, dtype=torch.float32) / half))
65
+ pos = torch.arange(max_len, dtype=torch.float32)
66
+ freqs = torch.outer(pos, theta) # (max_len, half)
67
+ emb = torch.cat([freqs, freqs], dim=-1) # (max_len, d)
68
+ self.register_buffer("rope_cos", emb.cos(), persistent=False)
69
+ self.register_buffer("rope_sin", emb.sin(), persistent=False)
70
+
71
+ @staticmethod
72
+ def _rotate_half(x: Tensor) -> Tensor:
73
+ half = x.shape[-1] // 2
74
+ x1, x2 = x[..., :half], x[..., half:]
75
+ return torch.cat([-x2, x1], dim=-1)
76
+
77
+ def apply_rope(self, x: Tensor, seq_len: int) -> Tensor:
78
+ cos = self.rope_cos[:seq_len].unsqueeze(0) # (1, S, D)
79
+ sin = self.rope_sin[:seq_len].unsqueeze(0)
80
+ return x * cos + self._rotate_half(x) * sin
81
+
82
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
83
+ """
84
+ Args:
85
+ x: (B, S, D) token embeddings
86
+ Returns:
87
+ encoded: (B, S, D) with RoPE applied
88
+ decay_scalars: (B, S, D) sigmoid decay weights per token-dim
89
+ """
90
+ B, S, D = x.shape
91
+
92
+ # RoPE
93
+ x = self.apply_rope(x, S)
94
+
95
+ # Temporal decay: t ∈ [0, 1] normalised position
96
+ t = torch.arange(S, device=x.device, dtype=x.dtype) / max(S - 1, 1)
97
+ # w_decay broadcast: (S, D) → decay per token dimension
98
+ decay_scalars = torch.sigmoid(
99
+ -rearrange(t, "s -> s 1") * rearrange(self.w_decay, "d -> 1 d")
100
+ ) # (S, D)
101
+ decay_scalars = decay_scalars.unsqueeze(0).expand(B, -1, -1) # (B, S, D)
102
+
103
+ return x * decay_scalars, decay_scalars
104
+
105
+ def __repr__(self) -> str:
106
+ p = sum(p.numel() for p in self.parameters())
107
+ return f"TemporalPositionEncoder(d={self.d_model}, params={p:,})"