File size: 7,507 Bytes
9463e5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """
GLADIUS v2.0 β Lattice Clock
Discretized temporal encoding on a multi-scale lattice grid.
Replaces continuous Time2Vec with quantized lattice positions.
Ali's framework:
"Model during forward pass = timeless"
"To bring it to our realm we need to compress its energy in the lattice"
"Each forward pass = one atomic oscillation between lattice lasers"
Softmax = superposition, argmax = collapse
Each forward pass snaps time to the nearest lattice point.
Between ticks, the model is genuinely timeless β no temporal leakage.
The tick counter is imposed, not learned. Like a heartbeat.
Usage:
clock = LatticeClock(config)
lattice_embed = clock(timestamp) # (B, hidden_dim)
hidden = hidden + lattice_embed.unsqueeze(1) # Broadcast across seq_len
"""
import torch
import torch.nn as nn
import math
class LatticeClock(nn.Module):
"""
Multi-scale discrete lattice temporal encoding.
Time is quantized onto N lattice positions at K different scales.
Each scale captures a different temporal resolution:
Scale 0: sub-second (frame-level, ~125ms ticks)
Scale 1: seconds (event-level)
Scale 2: minutes (context-level)
Scale 3: hours (session-level)
Each lattice position has a learned embedding.
The model observes time in quanta, not continuous flow.
"""
def __init__(self, config):
super().__init__()
# Lattice parameters
self.lattice_size = getattr(config, 'lattice_size', 256)
self.num_scales = getattr(config, 'lattice_scales', 4)
hidden_dim = config.hidden_dim
# Embedding dimension per scale
self.dim_per_scale = hidden_dim // self.num_scales
# Handle remainder
self.remainder = hidden_dim - self.dim_per_scale * self.num_scales
# Learned lattice embeddings at each scale
self.lattice_embeddings = nn.ModuleList([
nn.Embedding(self.lattice_size,
self.dim_per_scale + (1 if i < self.remainder else 0))
for i in range(self.num_scales)
])
# Learned scale periods (in log-space for stability)
# Default: 125ms, 1s, 60s, 3600s
default_periods = torch.linspace(
math.log(0.125), math.log(3600.0), self.num_scales
)
self.scale_periods = nn.Parameter(default_periods)
# Phase offsets per scale
self.phase = nn.Parameter(torch.zeros(self.num_scales))
# Fusion: project concatenated embeddings back to hidden_dim
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
)
# Tick counter β imposed, involuntary, never learned
self.register_buffer('tick_count', torch.tensor(0, dtype=torch.long))
# Temperature for soft quantization (anneals hard over training)
# Start soft (Ο=1.0), can anneal to hard (Οβ0)
self.register_buffer('temperature', torch.tensor(1.0))
# Initialize embeddings with small values
for emb in self.lattice_embeddings:
nn.init.normal_(emb.weight, mean=0, std=0.01)
def quantize_time(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor:
"""
Snap continuous time to nearest lattice point.
Args:
timestamp: (batch,) β normalized time value
scale_idx: which scale to quantize at
Returns:
lattice_positions: (batch,) β integer positions in [0, lattice_size)
"""
period = self.scale_periods[scale_idx].exp()
phase = self.phase[scale_idx]
# Continuous position on this scale's lattice
continuous_pos = (timestamp / period + phase)
# Hard quantization: floor to nearest integer, wrap around
lattice_pos = continuous_pos.long() % self.lattice_size
return lattice_pos
def soft_quantize(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor:
"""
Soft quantization using distance-weighted interpolation.
Allows gradients to flow through during training.
When temperature β 0, this becomes hard quantization.
When temperature = 1, this is soft interpolation.
"""
period = self.scale_periods[scale_idx].exp()
phase = self.phase[scale_idx]
continuous_pos = (timestamp / period + phase) % self.lattice_size
# Get floor and ceil positions
floor_pos = continuous_pos.long() % self.lattice_size
ceil_pos = (floor_pos + 1) % self.lattice_size
# Fractional distance
frac = continuous_pos - continuous_pos.floor()
# Temperature-scaled interpolation
# At Ο=0: hard floor. At Ο=1: linear interpolation.
if self.temperature.item() < 0.01:
# Hard mode β no interpolation
return self.lattice_embeddings[scale_idx](floor_pos)
floor_emb = self.lattice_embeddings[scale_idx](floor_pos)
ceil_emb = self.lattice_embeddings[scale_idx](ceil_pos)
# Weighted blend
weight = frac.unsqueeze(-1) # (B, 1)
return floor_emb * (1 - weight) + ceil_emb * weight
def forward(self, timestamp: torch.Tensor) -> torch.Tensor:
"""
Compute lattice temporal embedding.
Args:
timestamp: (batch,) β time in seconds (normalized by TimeEngine)
Returns:
lattice_embedding: (batch, hidden_dim)
"""
embeddings = []
for scale_idx in range(self.num_scales):
# Use soft quantization for gradient flow during training
if self.training:
emb = self.soft_quantize(timestamp, scale_idx)
else:
# Hard quantization at inference
pos = self.quantize_time(timestamp, scale_idx)
emb = self.lattice_embeddings[scale_idx](pos)
embeddings.append(emb)
# Concatenate multi-scale lattice positions
combined = torch.cat(embeddings, dim=-1) # (batch, hidden_dim)
# Fuse
out = self.fusion(combined)
# Involuntary tick
self.tick_count += 1
return out
def anneal_temperature(self, step: int, total_steps: int):
"""
Anneal quantization temperature: soft β hard over training.
The model starts with soft interpolation (gradient-friendly)
and progressively sharpens to hard quantization (discrete).
This mirrors the softmax β argmax transition:
exploration (soft) β commitment (hard).
"""
# Cosine annealing from 1.0 β 0.01
progress = min(step / max(total_steps, 1), 1.0)
new_temp = 0.01 + 0.99 * (1 + math.cos(math.pi * progress)) / 2
self.temperature.fill_(new_temp)
def get_lattice_state(self) -> dict:
"""Return current lattice state for monitoring/EEG."""
return {
'tick_count': self.tick_count.item(),
'temperature': self.temperature.item(),
'scale_periods': [self.scale_periods[i].exp().item()
for i in range(self.num_scales)],
'phases': [self.phase[i].item() for i in range(self.num_scales)],
}
|