| """ |
| GLADIUS v2.0 β Lattice Clock |
| |
| Discretized temporal encoding on a multi-scale lattice grid. |
| Replaces continuous Time2Vec with quantized lattice positions. |
| |
| Ali's framework: |
| "Model during forward pass = timeless" |
| "To bring it to our realm we need to compress its energy in the lattice" |
| "Each forward pass = one atomic oscillation between lattice lasers" |
| Softmax = superposition, argmax = collapse |
| |
| Each forward pass snaps time to the nearest lattice point. |
| Between ticks, the model is genuinely timeless β no temporal leakage. |
| The tick counter is imposed, not learned. Like a heartbeat. |
| |
| Usage: |
| clock = LatticeClock(config) |
| lattice_embed = clock(timestamp) # (B, hidden_dim) |
| hidden = hidden + lattice_embed.unsqueeze(1) # Broadcast across seq_len |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
|
|
|
|
| class LatticeClock(nn.Module): |
| """ |
| Multi-scale discrete lattice temporal encoding. |
| |
| Time is quantized onto N lattice positions at K different scales. |
| Each scale captures a different temporal resolution: |
| Scale 0: sub-second (frame-level, ~125ms ticks) |
| Scale 1: seconds (event-level) |
| Scale 2: minutes (context-level) |
| Scale 3: hours (session-level) |
| |
| Each lattice position has a learned embedding. |
| The model observes time in quanta, not continuous flow. |
| """ |
| |
| def __init__(self, config): |
| super().__init__() |
| |
| |
| self.lattice_size = getattr(config, 'lattice_size', 256) |
| self.num_scales = getattr(config, 'lattice_scales', 4) |
| hidden_dim = config.hidden_dim |
| |
| |
| self.dim_per_scale = hidden_dim // self.num_scales |
| |
| self.remainder = hidden_dim - self.dim_per_scale * self.num_scales |
| |
| |
| self.lattice_embeddings = nn.ModuleList([ |
| nn.Embedding(self.lattice_size, |
| self.dim_per_scale + (1 if i < self.remainder else 0)) |
| for i in range(self.num_scales) |
| ]) |
| |
| |
| |
| default_periods = torch.linspace( |
| math.log(0.125), math.log(3600.0), self.num_scales |
| ) |
| self.scale_periods = nn.Parameter(default_periods) |
| |
| |
| self.phase = nn.Parameter(torch.zeros(self.num_scales)) |
| |
| |
| self.fusion = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.SiLU(), |
| ) |
| |
| |
| self.register_buffer('tick_count', torch.tensor(0, dtype=torch.long)) |
| |
| |
| |
| self.register_buffer('temperature', torch.tensor(1.0)) |
| |
| |
| for emb in self.lattice_embeddings: |
| nn.init.normal_(emb.weight, mean=0, std=0.01) |
| |
| def quantize_time(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor: |
| """ |
| Snap continuous time to nearest lattice point. |
| |
| Args: |
| timestamp: (batch,) β normalized time value |
| scale_idx: which scale to quantize at |
| Returns: |
| lattice_positions: (batch,) β integer positions in [0, lattice_size) |
| """ |
| period = self.scale_periods[scale_idx].exp() |
| phase = self.phase[scale_idx] |
| |
| |
| continuous_pos = (timestamp / period + phase) |
| |
| |
| lattice_pos = continuous_pos.long() % self.lattice_size |
| |
| return lattice_pos |
| |
| def soft_quantize(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor: |
| """ |
| Soft quantization using distance-weighted interpolation. |
| Allows gradients to flow through during training. |
| |
| When temperature β 0, this becomes hard quantization. |
| When temperature = 1, this is soft interpolation. |
| """ |
| period = self.scale_periods[scale_idx].exp() |
| phase = self.phase[scale_idx] |
| |
| continuous_pos = (timestamp / period + phase) % self.lattice_size |
| |
| |
| floor_pos = continuous_pos.long() % self.lattice_size |
| ceil_pos = (floor_pos + 1) % self.lattice_size |
| |
| |
| frac = continuous_pos - continuous_pos.floor() |
| |
| |
| |
| if self.temperature.item() < 0.01: |
| |
| return self.lattice_embeddings[scale_idx](floor_pos) |
| |
| floor_emb = self.lattice_embeddings[scale_idx](floor_pos) |
| ceil_emb = self.lattice_embeddings[scale_idx](ceil_pos) |
| |
| |
| weight = frac.unsqueeze(-1) |
| return floor_emb * (1 - weight) + ceil_emb * weight |
| |
| def forward(self, timestamp: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute lattice temporal embedding. |
| |
| Args: |
| timestamp: (batch,) β time in seconds (normalized by TimeEngine) |
| Returns: |
| lattice_embedding: (batch, hidden_dim) |
| """ |
| embeddings = [] |
| |
| for scale_idx in range(self.num_scales): |
| |
| if self.training: |
| emb = self.soft_quantize(timestamp, scale_idx) |
| else: |
| |
| pos = self.quantize_time(timestamp, scale_idx) |
| emb = self.lattice_embeddings[scale_idx](pos) |
| |
| embeddings.append(emb) |
| |
| |
| combined = torch.cat(embeddings, dim=-1) |
| |
| |
| out = self.fusion(combined) |
| |
| |
| self.tick_count += 1 |
| |
| return out |
| |
| def anneal_temperature(self, step: int, total_steps: int): |
| """ |
| Anneal quantization temperature: soft β hard over training. |
| |
| The model starts with soft interpolation (gradient-friendly) |
| and progressively sharpens to hard quantization (discrete). |
| |
| This mirrors the softmax β argmax transition: |
| exploration (soft) β commitment (hard). |
| """ |
| |
| progress = min(step / max(total_steps, 1), 1.0) |
| new_temp = 0.01 + 0.99 * (1 + math.cos(math.pi * progress)) / 2 |
| self.temperature.fill_(new_temp) |
| |
| def get_lattice_state(self) -> dict: |
| """Return current lattice state for monitoring/EEG.""" |
| return { |
| 'tick_count': self.tick_count.item(), |
| 'temperature': self.temperature.item(), |
| 'scale_periods': [self.scale_periods[i].exp().item() |
| for i in range(self.num_scales)], |
| 'phases': [self.phase[i].item() for i in range(self.num_scales)], |
| } |
|
|