| |
| """SymbolicLight V1 model implementation.""" |
| import math |
| from dataclasses import dataclass, field |
| from typing import Optional, Tuple, List |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
| @dataclass |
| class SymbolicLightConfig: |
| """Default configuration for SymbolicLight V1.""" |
| vocab_size: int = 57344 |
| embed_dim: int = 1536 |
| n_layers: int = 22 |
| n_heads: int = 24 |
| head_dim: int = 64 |
| intermediate_dim: int = 6144 |
| max_seq_len: int = 512 |
| spike_chunk_size: int = 64 |
| dropout: float = 0.1 |
| |
| spike_threshold: float = 1.0 |
| leak_factor: float = 0.95 |
| stdp_lr: float = 0.01 |
| enable_stdp: bool = False |
| |
| rope_theta: float = 10000.0 |
| frontend_mode: str = "text" |
| |
| sparse_attn_window: int = 512 |
| n_global_anchors: int = 4 |
| enable_sparse_attn: bool = True |
| enable_dynamic_prior: bool = True |
| use_topk_mask: bool = False |
| topk_sparsity: float = 0.89 |
|
|
|
|
|
|
|
|
|
|
| class ATanSurrogate(torch.autograd.Function): |
| """ATan surrogate-gradient spike function.""" |
| @staticmethod |
| def forward(ctx, membrane_potential, threshold): |
| ctx.save_for_backward(membrane_potential, torch.tensor(threshold, |
| device=membrane_potential.device, |
| dtype=membrane_potential.dtype)) |
| return (membrane_potential >= threshold).float() |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| membrane_potential, threshold = ctx.saved_tensors |
| alpha = 2.0 |
| grad_v = 1.0 / (1.0 + (alpha * (membrane_potential - threshold)) ** 2) |
| return grad_output * grad_v, None |
|
|
|
|
| def surrogate_spike(membrane_potential: torch.Tensor, threshold: float = 1.0) -> torch.Tensor: |
| """Apply the surrogate spike function.""" |
| return ATanSurrogate.apply(membrane_potential, threshold) |
|
|
|
|
|
|
|
|
|
|
| class RotaryPositionEncoding(nn.Module): |
| """Rotary position embedding.""" |
| def __init__(self, dim: int, theta: float = 10000.0): |
| super().__init__() |
| |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor: |
| """ |
| Apply RoPE rotation to the input tensor. |
| Args: |
| x: [B, S, D] input continuous representation |
| offset: position offset for incremental decoding |
| Returns: |
| rotated: [B, S, D] rotated representation |
| """ |
| B, S, D = x.shape |
| |
| t = torch.arange(offset, offset + S, device=x.device, dtype=torch.float32) |
| |
| freqs = torch.outer(t, self.inv_freq.to(x.device)) |
| |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos_emb = emb.cos().unsqueeze(0) |
| sin_emb = emb.sin().unsqueeze(0) |
|
|
| |
| x_rotated = torch.cat([ |
| -x[..., D // 2:], |
| x[..., :D // 2], |
| ], dim=-1) |
|
|
| return x * cos_emb + x_rotated * sin_emb |
|
|
|
|
|
|
|
|
|
|
| class FrontendRouter(nn.Module): |
| """Text embedding frontend.""" |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.text_embedding = nn.Embedding(config.vocab_size, config.embed_dim) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| def forward(self, token_ids: torch.Tensor, modality: str = "text") -> torch.Tensor: |
| if modality == "text": |
| return self.text_embedding(token_ids) |
| elif modality == "vision": |
| raise NotImplementedError("Vision frontend is not included in this release.") |
| elif modality == "audio": |
| raise NotImplementedError("Audio frontend is not included in this release.") |
| else: |
| raise ValueError(f"Unknown modality: {modality}") |
|
|
|
|
|
|
|
|
|
|
|
|
| def _lif_scan_forward(x: torch.Tensor, v_mem: torch.Tensor, |
| leak: float, threshold: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| JIT forward pass for temporal LIF neuron scanning. |
| Returns: (spikes [B,S,D], final_v_mem [B,D], all_v_mem [B,S,D] for backward) |
| """ |
| B, S, D = x.shape |
| spikes = torch.empty_like(x) |
| all_v = torch.empty_like(x) |
| for t in range(S): |
| v_mem = v_mem * leak + x[:, t, :] |
| v_mem = torch.clamp(v_mem, -3.0, 3.0) |
| all_v[:, t, :] = v_mem |
| spike = (v_mem >= threshold).float() |
| v_mem = v_mem * (1.0 - spike) |
| spikes[:, t, :] = spike |
| return spikes, v_mem, all_v |
|
|
|
|
| class LIFScan(torch.autograd.Function): |
| """LIF scan with ATan surrogate gradient for backward.""" |
| @staticmethod |
| def forward(ctx, x, v_mem, leak, threshold): |
| spikes, final_v, all_v = _lif_scan_forward(x, v_mem, leak, threshold) |
| ctx.save_for_backward(all_v) |
| ctx.threshold = threshold |
| return spikes, final_v |
|
|
| @staticmethod |
| def backward(ctx, grad_spikes, grad_v_mem): |
| all_v, = ctx.saved_tensors |
| |
| alpha = 2.0 |
| surrogate_grad = 1.0 / (1.0 + (alpha * (all_v - ctx.threshold)) ** 2) |
| grad_x = grad_spikes * surrogate_grad |
| return grad_x, None, None, None |
|
|
|
|
|
|
|
|
|
|
| class SpikeEncoder(nn.Module): |
| """ |
| Convert discrete token IDs into spatiotemporal spike tensors. |
| |
| Main design updates: |
| - remove learned positional embeddings and use RoPE in SparseTCAM |
| - use chunk-parallel LIF spike conversion to reduce Python loops |
| - route token embeddings through FrontendRouter instead of a hard-coded embedding |
| |
| Flow: token_id -> FrontendRouter -> LayerNorm -> parallel LIF spike conversion |
| """ |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.config = config |
| |
| self.frontend = FrontendRouter(config) |
| self.norm = nn.LayerNorm(config.embed_dim) |
| self.threshold = config.spike_threshold |
| self.leak = config.leak_factor |
|
|
| |
| |
|
|
| |
| self.v_mem = None |
|
|
| def _init_membrane(self, shape: torch.Size, device: torch.device): |
| """Initialize or reset the membrane potential.""" |
| self.v_mem = torch.zeros(shape, device=device) |
|
|
| def forward(self, token_ids: torch.Tensor, use_cache: bool = False, |
| cache: dict = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| token_ids: [batch, seq_len] |
| use_cache: if True, run O(1) incremental decoding |
| cache: cache dictionary |
| Returns: |
| spikes: [batch, seq_len, embed_dim] sparse 0/1 spikes |
| continuous: [batch, seq_len, embed_dim] continuous residual stream |
| """ |
| B, S = token_ids.shape |
|
|
| |
| if use_cache and cache is not None: |
| if 'v_mem' not in cache: |
| cache['v_mem'] = torch.zeros(B, self.config.embed_dim, device=token_ids.device) |
| if 'seq_len' not in cache: |
| cache['seq_len'] = 0 |
| self.v_mem = cache['v_mem'] |
| cache['seq_len'] += S |
| else: |
| self._init_membrane((B, self.config.embed_dim), token_ids.device) |
|
|
| |
| x = self.frontend(token_ids) |
| x = self.norm(x) |
|
|
| if getattr(self.config, 'use_topk_mask', False): |
| |
| k = max(1, int((1.0 - self.config.topk_sparsity) * self.config.embed_dim)) |
| _, topk_indices = torch.topk(x.abs(), k, dim=-1) |
| spikes = torch.zeros_like(x) |
| spikes.scatter_(-1, topk_indices, 1.0) |
| if self.training: |
| spikes = spikes + (surrogate_spike(x, self.threshold) - spikes).detach() |
| else: |
| |
| chunk_size = self.config.spike_chunk_size |
| spikes_list = [] |
|
|
| for chunk_start in range(0, S, chunk_size): |
| chunk_end = min(chunk_start + chunk_size, S) |
| x_chunk = x[:, chunk_start:chunk_end, :] |
| chunk_spikes, self.v_mem = LIFScan.apply( |
| x_chunk, self.v_mem, self.leak, self.threshold |
| ) |
| spikes_list.append(chunk_spikes) |
| |
| if self.training: |
| self.v_mem = self.v_mem.detach() |
|
|
| spikes = torch.cat(spikes_list, dim=1) |
|
|
| |
| if use_cache and cache is not None: |
| cache['v_mem'] = self.v_mem.detach() |
|
|
| return spikes, x |
|
|
|
|
|
|
|
|
|
|
| class SparseLocalAttention(nn.Module): |
| """ |
| Compute attention only among active spike positions with a local window and global anchors. |
| |
| Key idea: |
| - dense attention attends across all S positions -> O(S^2) |
| - this path only attends over active positions inside a local window -> O(S * k * w) |
| where k is the active fraction and w is the window size |
| - global anchors let the first few tokens interact broadly and stabilize global context |
| |
| Relation to the decay path: |
| - the decay path compresses history into a fixed-size hidden state for coarse long-range memory |
| - the attention path focuses precisely on recent informative positions for local reasoning |
| - a learned gate blends both paths |
| """ |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.window_size = max(1, int(config.sparse_attn_window)) |
| self.n_global_anchors = config.n_global_anchors |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| self.scale = config.head_dim ** -0.5 |
| self._use_sdpa = hasattr(F, "scaled_dot_product_attention") |
|
|
| |
| self.q_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| self.k_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| self.v_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
|
|
| |
| self.rope = RotaryPositionEncoding(config.head_dim, theta=config.rope_theta) |
|
|
| |
| S = config.max_seq_len |
| q_pos = torch.arange(S).unsqueeze(1) |
| k_pos = torch.arange(S).unsqueeze(0) |
| distance = q_pos - k_pos |
| causal = distance >= 0 |
| window = (q_pos - k_pos) <= self.window_size |
| anchors = k_pos < self.n_global_anchors |
| self.register_buffer('_cached_mask', causal & (window | anchors)) |
|
|
| def forward(self, x: torch.Tensor, spike_mask: torch.Tensor, |
| offset: int = 0, use_cache: bool = False, cache: dict = None) -> torch.Tensor: |
| """ |
| Args: |
| x: [B, S_q, D] continuous representation; RoPE is applied internally to Q/K |
| spike_mask: [B, S_q] boolean mask, True means the position fired a spike |
| offset: RoPE position offset for incremental decoding |
| use_cache: whether to use the KV cache for incremental decoding |
| cache: KV cache dictionary |
| Returns: |
| attn_out: [B, S_q, D] sparse attention output with zeros on inactive positions |
| """ |
| B, S_q, D = x.shape |
|
|
| |
| |
| |
| Q = self.q_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2) |
| K = self.k_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2) |
| V = self.v_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| Q = self.rope(Q.contiguous().view(B * self.n_heads, S_q, self.head_dim), offset=offset) |
| Q = Q.view(B, self.n_heads, S_q, self.head_dim).to(V.dtype) |
| K = self.rope(K.contiguous().view(B * self.n_heads, S_q, self.head_dim), offset=offset) |
| K = K.view(B, self.n_heads, S_q, self.head_dim).to(V.dtype) |
|
|
| |
| if use_cache and cache is not None: |
| if 'K' in cache: |
| K = torch.cat([cache['K'], K], dim=2) |
| V = torch.cat([cache['V'], V], dim=2) |
| spike_mask_kv = torch.cat([cache['spike_mask'], spike_mask], dim=1) |
| else: |
| spike_mask_kv = spike_mask |
| cache['K'] = K.detach() |
| cache['V'] = V.detach() |
| cache['spike_mask'] = spike_mask_kv.detach() |
| else: |
| spike_mask_kv = spike_mask |
|
|
| S_kv = K.size(2) |
|
|
| |
| |
| if offset == 0 and S_q == S_kv and S_q == self._cached_mask.size(0): |
| attn_mask = self._cached_mask |
| else: |
| q_pos = torch.arange(offset, offset + S_q, device=x.device).unsqueeze(1) |
| k_pos = torch.arange(0, S_kv, device=x.device).unsqueeze(0) |
| distance = q_pos - k_pos |
| causal = distance >= 0 |
| window = distance <= self.window_size |
| anchors = k_pos < self.n_global_anchors |
| attn_mask = causal & (window | anchors) |
|
|
| |
| spike_key_mask = spike_mask_kv.unsqueeze(1).unsqueeze(2) |
| full_mask = attn_mask.unsqueeze(0).unsqueeze(0) & spike_key_mask |
|
|
| |
| query_has_any_key = full_mask.any(dim=-1, keepdim=True) |
| if self._use_sdpa: |
| safe_mask = full_mask | ~query_has_any_key |
| attn_out = F.scaled_dot_product_attention(Q, K, V, attn_mask=safe_mask, dropout_p=0.0) |
| attn_out = attn_out.masked_fill(~query_has_any_key, 0.0) |
| else: |
| scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale |
| scores = scores.masked_fill(~full_mask, float('-inf')) |
| scores = scores.masked_fill(~query_has_any_key, 0.0) |
| attn_weights = F.softmax(scores, dim=-1).to(V.dtype) |
| attn_weights = attn_weights.masked_fill(~query_has_any_key, 0.0) |
| attn_out = torch.matmul(attn_weights, V) |
|
|
| |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, S_q, D) |
|
|
| |
| attn_out = attn_out * spike_mask.unsqueeze(-1).to(dtype=attn_out.dtype) |
|
|
| return attn_out |
|
|
|
|
|
|
|
|
|
|
| class SparseTCAM(nn.Module): |
| """Dual-path spike-gated sequence mixer.""" |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| self.embed_dim = config.embed_dim |
| self.threshold = config.spike_threshold |
| self.leak = config.leak_factor |
| self.enable_sparse_attn = config.enable_sparse_attn |
|
|
| |
| self.tcam_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| |
| self.out_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False) |
| self.norm = nn.LayerNorm(config.embed_dim) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| |
| self.decay_raw = nn.Parameter(torch.full((config.n_heads,), 3.0)) |
|
|
| |
| |
|
|
| |
| if self.enable_sparse_attn: |
| self.sparse_attn = SparseLocalAttention(config) |
| |
| |
| self.attn_gate = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, spikes: torch.Tensor, continuous: torch.Tensor, |
| use_cache: bool = False, cache: dict = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Main updates: |
| 1. apply RoPE inside SparseLocalAttention after Q/K projection |
| 2. allow training-time cache reads and writes for hidden state h across chunks |
| """ |
| B, S, D = spikes.shape |
| compute_dtype = continuous.dtype |
| if spikes.dtype != compute_dtype: |
| spikes = spikes.to(compute_dtype) |
|
|
| |
| offset = 0 |
| if use_cache and cache is not None: |
| offset = cache.get('rope_offset', 0) |
| cache['rope_offset'] = offset + S |
|
|
| |
| spike_energy = spikes.sum(dim=-1) |
| active_mask = (spike_energy > 0).unsqueeze(-1).to(dtype=compute_dtype) |
|
|
| |
| tcam_out = self.tcam_proj(spikes * active_mask) |
|
|
| |
| tcam_out = tcam_out.view(B, S, self.n_heads, self.head_dim) |
|
|
| |
| decay = torch.sigmoid(self.decay_raw) |
|
|
| |
| if cache is not None: |
| if 'h' not in cache: |
| cache['h'] = torch.zeros(B, self.n_heads, self.head_dim, device=spikes.device, dtype=compute_dtype) |
| h = cache['h'] |
| else: |
| h = torch.zeros(B, self.n_heads, self.head_dim, device=spikes.device, dtype=compute_dtype) |
|
|
| if use_cache and cache is not None and S == 1: |
| |
| h = decay.view(1, self.n_heads, 1) * h + (1 - decay.view(1, self.n_heads, 1)) * tcam_out[:, 0] |
| cache['h'] = h.detach() |
| context = h.unsqueeze(1) |
| else: |
| |
| powers = torch.arange(S - 1, -1, -1, dtype=compute_dtype, device=spikes.device) |
| kernel = ((decay.view(-1, 1) ** powers.view(1, -1)) * (1 - decay).view(-1, 1)).unsqueeze(1) |
| tcam_out_trans = tcam_out.permute(0, 3, 2, 1).reshape(-1, self.n_heads, S) |
| tcam_out_pad = F.pad(tcam_out_trans, (S - 1, 0)) |
| out = F.conv1d(tcam_out_pad, kernel, groups=self.n_heads) |
| context = out.view(-1, self.head_dim, self.n_heads, S).permute(0, 3, 2, 1) |
|
|
| |
| powers_fwd = torch.arange(1, S + 1, dtype=compute_dtype, device=spikes.device).view(1, S, 1, 1) |
| decay_t = decay.view(1, 1, self.n_heads, 1) ** powers_fwd |
| context = context + h.unsqueeze(1) * decay_t |
|
|
| |
| if cache is not None: |
| cache['h'] = context[:, -1, :, :].detach() |
|
|
| |
| decay_output = context.reshape(B, S, D) |
|
|
| |
| if self.enable_sparse_attn: |
| spike_mask = (spikes.sum(dim=-1) > 0) |
| |
| attn_cache = cache.setdefault('attn', {}) if cache is not None else None |
| attn_output = self.sparse_attn( |
| continuous, spike_mask, offset=offset, |
| use_cache=use_cache, cache=attn_cache |
| ) |
|
|
| |
| gate = torch.sigmoid(self.attn_gate) |
| output = gate * attn_output + (1 - gate) * decay_output |
| else: |
| output = decay_output |
|
|
| output = self.out_proj(self.dropout(output)) |
|
|
| |
| |
| out_continuous = self.norm(continuous + output) |
|
|
| |
| out_spikes = surrogate_spike(out_continuous, self.threshold).to(out_continuous.dtype) |
|
|
| return out_spikes, out_continuous |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class SpikingFeedForward(nn.Module): |
| """ |
| Two-layer feed-forward block used in place of the standard Transformer MLP. |
| The main difference is the LIF-style spike activation in the hidden layer. |
| """ |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.up = nn.Linear(config.embed_dim, config.intermediate_dim, bias=False) |
| self.down = nn.Linear(config.intermediate_dim, config.embed_dim, bias=False) |
| self.norm = nn.LayerNorm(config.embed_dim) |
| self.threshold = config.spike_threshold |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| h = self.up(x) |
| h = surrogate_spike(h, self.threshold).to(x.dtype) |
| h = self.down(self.dropout(h)) |
| return self.norm(residual + h) |
|
|
|
|
|
|
|
|
|
|
| class SymbolicLightBlock(nn.Module): |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.tcam = SparseTCAM(config) |
| self.ffn = SpikingFeedForward(config) |
|
|
| def forward(self, spikes, continuous, use_cache=False, cache=None): |
| spikes, continuous = self.tcam(spikes, continuous, use_cache=use_cache, cache=cache) |
| continuous = self.ffn(continuous) |
| spikes = surrogate_spike(continuous, self.tcam.threshold).to(continuous.dtype) |
| return spikes, continuous |
|
|
|
|
|
|
|
|
|
|
| class BayesianHead(nn.Module): |
| """ |
| Dynamic context-conditioned prior head. |
| |
| Earlier versions used a static learned log_prior vector. |
| This version predicts log_prior from the current context with a lightweight network. |
| |
| Bayesian form: |
| log P(word|context) = log P(context|word) + log P(word|context_summary) |
| likelihood term dynamic prior term |
| |
| Intuition: |
| - when the context is about cooking, the prior can upweight tokens such as salt or pan |
| - when the context is about programming, the prior can upweight tokens such as function or loop |
| - this is more targeted than a static frequency bias |
| """ |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.output_proj = nn.Linear(config.embed_dim, config.vocab_size, bias=False) |
| self.use_dynamic_prior = getattr(config, 'enable_dynamic_prior', True) |
| |
| self.prior_weight = nn.Parameter(torch.tensor(0.1)) |
|
|
| if self.use_dynamic_prior: |
| bottleneck_dim = config.embed_dim // 4 |
| self.prior_net = nn.Sequential( |
| nn.Linear(config.embed_dim, bottleneck_dim, bias=False), |
| nn.GELU(), |
| nn.Linear(bottleneck_dim, config.vocab_size, bias=False), |
| ) |
| else: |
| self.log_prior = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
| def forward(self, continuous: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| continuous: [B, S, D] |
| Returns: |
| logits: [B, S, vocab_size] |
| """ |
| log_likelihood = self.output_proj(continuous) |
|
|
| if self.use_dynamic_prior: |
| dynamic_prior = self.prior_net(continuous) |
| logits = log_likelihood + self.prior_weight * dynamic_prior |
| else: |
| logits = log_likelihood + self.prior_weight * self.log_prior |
|
|
| return logits |
|
|
|
|
|
|
|
|
|
|
| class STDPUpdater: |
| """Optional local spike-timing update rule.""" |
| def __init__(self, config: SymbolicLightConfig): |
| self.lr = config.stdp_lr |
| self.enabled = config.enable_stdp |
|
|
| @torch.no_grad() |
| def update(self, model: nn.Module, pre_spikes: torch.Tensor, post_spikes: torch.Tensor): |
| if not self.enabled: |
| return |
|
|
| causal = (pre_spikes.sum(dim=1, keepdim=True) > 0) & (post_spikes.sum(dim=1, keepdim=True) > 0) |
|
|
| if causal.any(): |
| for block in model.blocks: |
| w = block.tcam.tcam_proj.weight |
| pre_active = (pre_spikes > 0).float() |
| post_active = (post_spikes > 0).float() |
| co_firing = torch.einsum('bsd,bse->de', post_active, pre_active) |
| delta = self.lr * co_firing / (pre_spikes.size(0) * pre_spikes.size(1)) |
| mask = (co_firing > 0).float() |
| w.data += delta * mask * 0.05 |
| w.data.clamp_(-5, 5) |
|
|
|
|
|
|
|
|
|
|
| class SymbolicLightModel(nn.Module): |
| """SymbolicLight language model.""" |
| def __init__(self, config: SymbolicLightConfig): |
| super().__init__() |
| self.config = config |
| self.spike_encoder = SpikeEncoder(config) |
| self.blocks = nn.ModuleList([ |
| SymbolicLightBlock(config) for _ in range(config.n_layers) |
| ]) |
| self.output_head = BayesianHead(config) |
| self.stdp = STDPUpdater(config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.apply(self._init_weights) |
| |
| n_params = sum(p.numel() for p in self.parameters()) |
| print(f"[SymbolicLight V1] model initialized | parameters: {n_params/1e6:.1f}M ({n_params/1e9:.3f}B)") |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def compile_for_inference(self): |
| """Apply torch.compile to selected inference-critical submodules.""" |
| self.spike_encoder = torch.compile(self.spike_encoder, mode='reduce-overhead') |
| for block in self.blocks: |
| block.tcam = torch.compile(block.tcam, mode='reduce-overhead') |
| block.ffn = torch.compile(block.ffn, mode='reduce-overhead') |
| print("[SymbolicLight V1] torch.compile applied for inference acceleration") |
|
|
| def gradient_checkpointing_enable(self): |
| self.gradient_checkpointing = True |
|
|
| def gradient_checkpointing_disable(self): |
| self.gradient_checkpointing = False |
|
|
| def forward(self, token_ids: torch.Tensor, use_cache: bool = False, |
| past_key_values: list = None, |
| streaming_state: list = None): |
| """ |
| Shared forward pass for training and inference. |
| |
| Args: |
| token_ids: [B, S] input token IDs |
| use_cache: whether to use the KV cache for inference |
| past_key_values: list of inference caches |
| streaming_state: list of streaming caches used to carry hidden state across chunks |
| during training; format matches past_key_values: |
| [encoder_cache, block0_cache, block1_cache, ...] |
| Returns: |
| logits: [B, S, vocab_size] |
| """ |
| |
| if use_cache and past_key_values is None: |
| past_key_values = [{} for _ in range(len(self.blocks) + 1)] |
|
|
| |
| if not use_cache and streaming_state is not None: |
| caches = streaming_state |
| elif use_cache: |
| caches = past_key_values |
| else: |
| caches = [None] * (len(self.blocks) + 1) |
|
|
| |
| encoder_cache = caches[0] if caches[0] is not None else ( |
| past_key_values[0] if use_cache else None |
| ) |
| spikes, continuous = self.spike_encoder(token_ids, use_cache=use_cache, cache=encoder_cache) |
| model_dtype = self.output_head.output_proj.weight.dtype |
| if continuous.dtype != model_dtype: |
| continuous = continuous.to(model_dtype) |
| if spikes.dtype != model_dtype: |
| spikes = spikes.to(model_dtype) |
| initial_spikes = spikes |
|
|
| |
| for i, block in enumerate(self.blocks): |
| block_cache = caches[i + 1] if caches[i + 1] is not None else ( |
| past_key_values[i + 1] if use_cache else None |
| ) |
| if self.training and self.gradient_checkpointing and not use_cache and block_cache is None: |
| def _checkpointed_block(spk, cont, current_block=block): |
| out_spikes, out_continuous = current_block( |
| spk, cont, use_cache=False, cache=None, |
| ) |
| return out_spikes, out_continuous |
|
|
| spikes, continuous = torch.utils.checkpoint.checkpoint( |
| _checkpointed_block, spikes, continuous, |
| use_reentrant=False, |
| ) |
| else: |
| spikes, continuous = block( |
| spikes, continuous, |
| use_cache=use_cache, cache=block_cache, |
| ) |
|
|
| |
| logits = self.output_head(continuous) |
|
|
| |
| if not self.training and self.config.enable_stdp and initial_spikes.size(1) > 1: |
| self.stdp.update(self, initial_spikes, spikes) |
|
|
| return logits |
|
|
| @torch.no_grad() |
| def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 100, |
| temperature: float = 0.8, top_k: int = 50, |
| adaptive_temperature: bool = True) -> torch.Tensor: |
| """ |
| Autoregressive text generation with O(1) cached incremental decoding. |
| |
| Adaptive temperature: |
| - lower entropy -> lower temperature for more deterministic outputs |
| - higher entropy -> higher temperature for more exploratory outputs |
| - effective range is approximately [0.3, 1.5] |
| """ |
| self.eval() |
| generated = prompt_ids.clone() |
| past_key_values = [{} for _ in range(len(self.blocks) + 1)] |
|
|
| logits = self.forward(prompt_ids, use_cache=True, past_key_values=past_key_values) |
|
|
| |
| def _adaptive_temp(raw_logits, base_temp): |
| """Adjust temperature dynamically from the logits entropy.""" |
| if not adaptive_temperature: |
| return base_temp |
| probs = F.softmax(raw_logits, dim=-1) |
| p = probs.clamp(1e-7, 1.0) |
| entropy = -(p * p.log()).sum(dim=-1).mean() |
| |
| max_entropy = math.log(self.config.vocab_size) |
| norm_entropy = (entropy / max_entropy).clamp(0, 1) |
| |
| |
| temp = max(0.1, base_temp - norm_entropy.item() * (base_temp - 0.1)) |
| return temp |
|
|
| raw_logits = logits[:, -1, :] |
| temp = _adaptive_temp(raw_logits, temperature) |
| next_logits = raw_logits / temp |
| if top_k > 0: |
| top_k_vals, _ = torch.topk(next_logits, top_k) |
| min_top_k = top_k_vals[:, -1].unsqueeze(-1) |
| next_logits[next_logits < min_top_k] = float('-inf') |
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated = torch.cat([generated, next_token], dim=1) |
|
|
| for _ in range(1, max_new_tokens): |
| if next_token.item() == 2: |
| break |
| logits = self.forward(next_token, use_cache=True, past_key_values=past_key_values) |
|
|
| raw_logits = logits[:, -1, :] |
| temp = _adaptive_temp(raw_logits, temperature) |
| next_logits = raw_logits / temp |
|
|
| if top_k > 0: |
| top_k_vals, _ = torch.topk(next_logits, top_k) |
| min_top_k = top_k_vals[:, -1].unsqueeze(-1) |
| next_logits[next_logits < min_top_k] = float('-inf') |
|
|
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated = torch.cat([generated, next_token], dim=1) |
|
|
| return generated |
|
|
| def get_sparsity_stats(self) -> dict: |
| """Return sparsity statistics for debugging and reporting.""" |
| stats = {} |
| with torch.no_grad(): |
| dummy = torch.randint(0, 100, (1, 32)) |
| spikes, _ = self.spike_encoder(dummy) |
| stats['encoder_sparsity'] = 1.0 - spikes.mean().item() |
| for i, block in enumerate(self.blocks): |
| spikes, _ = block(spikes, spikes) |
| stats[f'block_{i}_sparsity'] = 1.0 - spikes.mean().item() |
| return stats |
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" SymbolicLight V1 model smoke test") |
| print("=" * 60) |
|
|
| config = SymbolicLightConfig( |
| vocab_size=57344, |
| embed_dim=768, |
| n_layers=12, |
| n_heads=12, |
| head_dim=64, |
| ) |
|
|
| model = SymbolicLightModel(config) |
|
|
| |
| dummy_input = torch.randint(0, 57344, (2, 128)) |
| print(f"\nInput: batch=2, seq_len=128") |
|
|
| |
| logits = model(dummy_input) |
| print(f"Output logits: {logits.shape}") |
|
|
| |
| print(f"\nStreaming context test (2 chunks x 128 tokens)...") |
| chunk1 = torch.randint(0, 57344, (2, 128)) |
| chunk2 = torch.randint(0, 57344, (2, 128)) |
|
|
| |
| streaming_state = [{} for _ in range(len(model.blocks) + 1)] |
| logits1 = model(chunk1, streaming_state=streaming_state) |
| print(f" Chunk 1 logits: {logits1.shape}, streaming state saved [OK]") |
|
|
| |
| logits2 = model(chunk2, streaming_state=streaming_state) |
| print(f" Chunk 2 logits: {logits2.shape}, cross-chunk memory passed [OK]") |
|
|
| |
| stats = model.get_sparsity_stats() |
| print(f"\nSparsity stats:") |
| for k, v in stats.items(): |
| print(f" {k}: {v*100:.1f}% silent") |
|
|
| |
| prompt = torch.randint(0, 57344, (1, 10)) |
| print(f"\nAutoregressive generation test (prompt=10, gen=20)...") |
| output = model.generate(prompt, max_new_tokens=20) |
| print(f"Generated sequence length: {output.shape[1]}") |
|
|
| print("\n[PASS] SymbolicLight V1 smoke checks completed.") |
| print(" [1] RoPE rotary position encoding [OK]") |
| print(" [2] Cross-chunk state passing [OK]") |
| print(" [3] BayesianHead dynamic prior [OK]") |
| print(" [4] SpikeEncoder parallel scan [OK]") |
| print(" [5] FrontendRouter multimodal stub [OK]") |
|
|