""" GLADIUS v2.0 — Lattice Clock Discretized temporal encoding on a multi-scale lattice grid. Replaces continuous Time2Vec with quantized lattice positions. Ali's framework: "Model during forward pass = timeless" "To bring it to our realm we need to compress its energy in the lattice" "Each forward pass = one atomic oscillation between lattice lasers" Softmax = superposition, argmax = collapse Each forward pass snaps time to the nearest lattice point. Between ticks, the model is genuinely timeless — no temporal leakage. The tick counter is imposed, not learned. Like a heartbeat. Usage: clock = LatticeClock(config) lattice_embed = clock(timestamp) # (B, hidden_dim) hidden = hidden + lattice_embed.unsqueeze(1) # Broadcast across seq_len """ import torch import torch.nn as nn import math class LatticeClock(nn.Module): """ Multi-scale discrete lattice temporal encoding. Time is quantized onto N lattice positions at K different scales. Each scale captures a different temporal resolution: Scale 0: sub-second (frame-level, ~125ms ticks) Scale 1: seconds (event-level) Scale 2: minutes (context-level) Scale 3: hours (session-level) Each lattice position has a learned embedding. The model observes time in quanta, not continuous flow. """ def __init__(self, config): super().__init__() # Lattice parameters self.lattice_size = getattr(config, 'lattice_size', 256) self.num_scales = getattr(config, 'lattice_scales', 4) hidden_dim = config.hidden_dim # Embedding dimension per scale self.dim_per_scale = hidden_dim // self.num_scales # Handle remainder self.remainder = hidden_dim - self.dim_per_scale * self.num_scales # Learned lattice embeddings at each scale self.lattice_embeddings = nn.ModuleList([ nn.Embedding(self.lattice_size, self.dim_per_scale + (1 if i < self.remainder else 0)) for i in range(self.num_scales) ]) # Learned scale periods (in log-space for stability) # Default: 125ms, 1s, 60s, 3600s default_periods = torch.linspace( math.log(0.125), math.log(3600.0), self.num_scales ) self.scale_periods = nn.Parameter(default_periods) # Phase offsets per scale self.phase = nn.Parameter(torch.zeros(self.num_scales)) # Fusion: project concatenated embeddings back to hidden_dim self.fusion = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), ) # Tick counter — imposed, involuntary, never learned self.register_buffer('tick_count', torch.tensor(0, dtype=torch.long)) # Temperature for soft quantization (anneals hard over training) # Start soft (τ=1.0), can anneal to hard (τ→0) self.register_buffer('temperature', torch.tensor(1.0)) # Initialize embeddings with small values for emb in self.lattice_embeddings: nn.init.normal_(emb.weight, mean=0, std=0.01) def quantize_time(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor: """ Snap continuous time to nearest lattice point. Args: timestamp: (batch,) — normalized time value scale_idx: which scale to quantize at Returns: lattice_positions: (batch,) — integer positions in [0, lattice_size) """ period = self.scale_periods[scale_idx].exp() phase = self.phase[scale_idx] # Continuous position on this scale's lattice continuous_pos = (timestamp / period + phase) # Hard quantization: floor to nearest integer, wrap around lattice_pos = continuous_pos.long() % self.lattice_size return lattice_pos def soft_quantize(self, timestamp: torch.Tensor, scale_idx: int) -> torch.Tensor: """ Soft quantization using distance-weighted interpolation. Allows gradients to flow through during training. When temperature → 0, this becomes hard quantization. When temperature = 1, this is soft interpolation. """ period = self.scale_periods[scale_idx].exp() phase = self.phase[scale_idx] continuous_pos = (timestamp / period + phase) % self.lattice_size # Get floor and ceil positions floor_pos = continuous_pos.long() % self.lattice_size ceil_pos = (floor_pos + 1) % self.lattice_size # Fractional distance frac = continuous_pos - continuous_pos.floor() # Temperature-scaled interpolation # At τ=0: hard floor. At τ=1: linear interpolation. if self.temperature.item() < 0.01: # Hard mode — no interpolation return self.lattice_embeddings[scale_idx](floor_pos) floor_emb = self.lattice_embeddings[scale_idx](floor_pos) ceil_emb = self.lattice_embeddings[scale_idx](ceil_pos) # Weighted blend weight = frac.unsqueeze(-1) # (B, 1) return floor_emb * (1 - weight) + ceil_emb * weight def forward(self, timestamp: torch.Tensor) -> torch.Tensor: """ Compute lattice temporal embedding. Args: timestamp: (batch,) — time in seconds (normalized by TimeEngine) Returns: lattice_embedding: (batch, hidden_dim) """ embeddings = [] for scale_idx in range(self.num_scales): # Use soft quantization for gradient flow during training if self.training: emb = self.soft_quantize(timestamp, scale_idx) else: # Hard quantization at inference pos = self.quantize_time(timestamp, scale_idx) emb = self.lattice_embeddings[scale_idx](pos) embeddings.append(emb) # Concatenate multi-scale lattice positions combined = torch.cat(embeddings, dim=-1) # (batch, hidden_dim) # Fuse out = self.fusion(combined) # Involuntary tick self.tick_count += 1 return out def anneal_temperature(self, step: int, total_steps: int): """ Anneal quantization temperature: soft → hard over training. The model starts with soft interpolation (gradient-friendly) and progressively sharpens to hard quantization (discrete). This mirrors the softmax → argmax transition: exploration (soft) → commitment (hard). """ # Cosine annealing from 1.0 → 0.01 progress = min(step / max(total_steps, 1), 1.0) new_temp = 0.01 + 0.99 * (1 + math.cos(math.pi * progress)) / 2 self.temperature.fill_(new_temp) def get_lattice_state(self) -> dict: """Return current lattice state for monitoring/EEG.""" return { 'tick_count': self.tick_count.item(), 'temperature': self.temperature.item(), 'scale_periods': [self.scale_periods[i].exp().item() for i in range(self.num_scales)], 'phases': [self.phase[i].item() for i in range(self.num_scales)], }