gladius-v2-kernel / kernel /temporal_lattice.py
amuzetnoM's picture
WYRM kernel source (v27 FINAL)
9463e5c verified
"""
GLADIUS v2.0 β€” Lattice Clock
Discretized temporal encoding on a multi-scale lattice grid.
Replaces continuous Time2Vec with quantized lattice positions.
Ali's framework:
"Model during forward pass = timeless"
"To bring it to our realm we need to compress its energy in the lattice"
"Each forward pass = one atomic oscillation between lattice lasers"
Softmax = superposition, argmax = collapse
Each forward pass snaps time to the nearest lattice point.
Between ticks, the model is genuinely timeless β€” no temporal leakage.
The tick counter is imposed, not learned. Like a heartbeat.
Usage:
clock = LatticeClock(config)
lattice_embed = clock(timestamp) # (B, hidden_dim)
hidden = hidden + lattice_embed.unsqueeze(1) # Broadcast across seq_len
"""
import torch
import torch.nn as nn
import math
class LatticeClock(nn.Module):
"""
Multi-scale discrete lattice temporal encoding.
Time is quantized onto N lattice positions at K different scales.
Each scale captures a different temporal resolution:
Scale 0: sub-second (frame-level, ~125ms ticks)
Scale 1: seconds (event-level)
Scale 2: minutes (context-level)
Scale 3: hours (session-level)
Each lattice position has a learned embedding.
The model observes time in quanta, not continuous flow.
"""
def __init__(self, config):
super().__init__()
# Lattice parameters
self.lattice_size = getattr(config, 'lattice_size', 256)
self.num_scales = getattr(config, 'lattice_scales', 4)
hidden_dim = config.hidden_dim
# Embedding dimension per scale
self.dim_per_scale = hidden_dim // self.num_scales
# Handle remainder
self.remainder = hidden_dim - self.dim_per_scale * self.num_scales
# Learned lattice embeddings at each scale
self.lattice_embeddings = nn.ModuleList([
nn.Embedding(self.lattice_size,
self.dim_per_scale + (1 if i < self.remainder else 0))
for i in range(self.num_scales)
])
# Learned scale periods (in log-space for stability)
# Default: 125ms, 1s, 60s, 3600s
default_periods = torch.linspace(
math.log(0.125), math.log(3600.0), self.num_scales
)
self.scale_periods = nn.Parameter(default_periods)
# Phase offsets per scale
self.phase = nn.Parameter(torch.zeros(self.num_scales))
# Fusion: project concatenated embeddings back to hidden_dim
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
)
# Tick counter β€” imposed, involuntary, never learned
self.register_buffer('tick_count', torch.tensor(0, dtype=torch.long))
# Temperature for soft quantization (anneals hard over training)
# Start soft (Ο„=1.0), can anneal to hard (Ο„β†’0)
self.register_buffer('temperature', torch.tensor(1.0))
# Initialize embeddings with small values
for emb in self.lattice_embeddings:
nn.init.normal_(emb.weight, mean=0, std=0.01)
def quantize_time(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor:
"""
Snap continuous time to nearest lattice point.
Args:
timestamp: (batch,) β€” normalized time value
scale_idx: which scale to quantize at
Returns:
lattice_positions: (batch,) β€” integer positions in [0, lattice_size)
"""
period = self.scale_periods[scale_idx].exp()
phase = self.phase[scale_idx]
# Continuous position on this scale's lattice
continuous_pos = (timestamp / period + phase)
# Hard quantization: floor to nearest integer, wrap around
lattice_pos = continuous_pos.long() % self.lattice_size
return lattice_pos
def soft_quantize(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor:
"""
Soft quantization using distance-weighted interpolation.
Allows gradients to flow through during training.
When temperature β†’ 0, this becomes hard quantization.
When temperature = 1, this is soft interpolation.
"""
period = self.scale_periods[scale_idx].exp()
phase = self.phase[scale_idx]
continuous_pos = (timestamp / period + phase) % self.lattice_size
# Get floor and ceil positions
floor_pos = continuous_pos.long() % self.lattice_size
ceil_pos = (floor_pos + 1) % self.lattice_size
# Fractional distance
frac = continuous_pos - continuous_pos.floor()
# Temperature-scaled interpolation
# At Ο„=0: hard floor. At Ο„=1: linear interpolation.
if self.temperature.item() < 0.01:
# Hard mode β€” no interpolation
return self.lattice_embeddings[scale_idx](floor_pos)
floor_emb = self.lattice_embeddings[scale_idx](floor_pos)
ceil_emb = self.lattice_embeddings[scale_idx](ceil_pos)
# Weighted blend
weight = frac.unsqueeze(-1) # (B, 1)
return floor_emb * (1 - weight) + ceil_emb * weight
def forward(self, timestamp: torch.Tensor) -> torch.Tensor:
"""
Compute lattice temporal embedding.
Args:
timestamp: (batch,) β€” time in seconds (normalized by TimeEngine)
Returns:
lattice_embedding: (batch, hidden_dim)
"""
embeddings = []
for scale_idx in range(self.num_scales):
# Use soft quantization for gradient flow during training
if self.training:
emb = self.soft_quantize(timestamp, scale_idx)
else:
# Hard quantization at inference
pos = self.quantize_time(timestamp, scale_idx)
emb = self.lattice_embeddings[scale_idx](pos)
embeddings.append(emb)
# Concatenate multi-scale lattice positions
combined = torch.cat(embeddings, dim=-1) # (batch, hidden_dim)
# Fuse
out = self.fusion(combined)
# Involuntary tick
self.tick_count += 1
return out
def anneal_temperature(self, step: int, total_steps: int):
"""
Anneal quantization temperature: soft β†’ hard over training.
The model starts with soft interpolation (gradient-friendly)
and progressively sharpens to hard quantization (discrete).
This mirrors the softmax β†’ argmax transition:
exploration (soft) β†’ commitment (hard).
"""
# Cosine annealing from 1.0 β†’ 0.01
progress = min(step / max(total_steps, 1), 1.0)
new_temp = 0.01 + 0.99 * (1 + math.cos(math.pi * progress)) / 2
self.temperature.fill_(new_temp)
def get_lattice_state(self) -> dict:
"""Return current lattice state for monitoring/EEG."""
return {
'tick_count': self.tick_count.item(),
'temperature': self.temperature.item(),
'scale_periods': [self.scale_periods[i].exp().item()
for i in range(self.num_scales)],
'phases': [self.phase[i].item() for i in range(self.num_scales)],
}