|
|
| """
|
| SymbolicLight — Next-Generation Neuro-Symbolic Spiking Large Model Architecture
|
| ==================================================
|
| Key Innovations:
|
| 1. SparseTCAM: Spiking sparse routing replacing Self-Attention
|
| 2. LIF Neurons: Event-driven replacing dense activation
|
| 3. EntropyGate: On-demand compute depth (early exit for simple queries)
|
| 4. BayesianHead: Bayesian token selection replacing Softmax
|
| 5. STDP: Online learning during inference (no backward pass needed)
|
| """
|
| import math
|
| from dataclasses import dataclass
|
| from typing import Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class SymbolicLightConfig:
|
| """SymbolicLight 0.1B Default Configuration"""
|
| vocab_size: int = 32000
|
| embed_dim: int = 768
|
| n_layers: int = 12
|
| n_heads: int = 12
|
| head_dim: int = 64
|
| intermediate_dim: int = 2048
|
| max_seq_len: int = 2048
|
| dropout: float = 0.1
|
|
|
| spike_threshold: float = 1.0
|
| leak_factor: float = 0.95
|
| stdp_lr: float = 0.01
|
| entropy_exit_threshold: float = 0.3
|
| enable_entropy_exit: bool = False
|
| enable_stdp: bool = False
|
|
|
|
|
|
|
|
|
|
|
| class SurrogateSpike(torch.autograd.Function):
|
| """
|
| Forward: Hard threshold -> 0/1 spikes (non-differentiable)
|
| Backward: Use derivative of sigmoid as surrogate gradient (differentiable)
|
|
|
| This is the key mathematical trick enabling SNNs to be trained with backpropagation!
|
| """
|
| sigma = 10.0
|
|
|
| @staticmethod
|
| def forward(ctx, membrane_potential, threshold):
|
| ctx.save_for_backward(membrane_potential, torch.tensor(threshold))
|
| return (membrane_potential >= threshold).float()
|
|
|
| @staticmethod
|
| def backward(ctx, grad_output):
|
| membrane_potential, threshold = ctx.saved_tensors
|
|
|
| x = SurrogateSpike.sigma * (membrane_potential - threshold)
|
| sigmoid_x = torch.sigmoid(x)
|
| surrogate_grad = SurrogateSpike.sigma * sigmoid_x * (1.0 - sigmoid_x)
|
| return grad_output * surrogate_grad, None
|
|
|
|
|
| def surrogate_spike(membrane_potential: torch.Tensor, threshold: float = 1.0) -> torch.Tensor:
|
| """Exposed surrogate gradient spike function"""
|
| return SurrogateSpike.apply(membrane_potential, threshold)
|
|
|
|
|
|
|
|
|
|
|
| class SpikeEncoder(nn.Module):
|
| """
|
| Converts discrete token IDs into spatio-temporal spike tensors.
|
|
|
| Process: token_id -> Embedding -> LayerNorm -> LIF Spiking
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
|
| self.norm = nn.LayerNorm(config.embed_dim)
|
| self.threshold = config.spike_threshold
|
| self.leak = config.leak_factor
|
|
|
|
|
| self.pos_embedding = nn.Embedding(config.max_seq_len, config.embed_dim)
|
|
|
|
|
| self.register_buffer("v_mem", None)
|
|
|
| def _init_membrane(self, shape: torch.Size, device: torch.device):
|
| """Initialize/reset membrane potential"""
|
| self.v_mem = torch.zeros(shape, device=device)
|
|
|
| def forward(self, token_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Args:
|
| token_ids: [batch, seq_len]
|
| Returns:
|
| spikes: [batch, seq_len, embed_dim] Sparse 0/1 spikes
|
| continuous: [batch, seq_len, embed_dim] Continuous representation (used for residuals)
|
| """
|
| B, S = token_ids.shape
|
| positions = torch.arange(S, device=token_ids.device).unsqueeze(0)
|
|
|
|
|
| x = self.embedding(token_ids) + self.pos_embedding(positions)
|
| x = self.norm(x)
|
|
|
|
|
| self._init_membrane((B, x.size(-1)), x.device)
|
| spikes_list = []
|
|
|
| for t in range(S):
|
|
|
| self.v_mem = self.v_mem * self.leak + x[:, t, :]
|
|
|
| spike = surrogate_spike(self.v_mem, self.threshold)
|
|
|
| self.v_mem = self.v_mem * (1.0 - spike)
|
| spikes_list.append(spike)
|
|
|
| spikes = torch.stack(spikes_list, dim=1)
|
| return spikes, x
|
|
|
|
|
|
|
|
|
|
|
| class SparseTCAM(nn.Module):
|
| """
|
| Simulates in-memory compute of the S100 Graph-TCAM.
|
|
|
| Core difference from Self-Attention:
|
| - Attention: QxK^T -> All-to-all O(n^2) dense matrix multiplication
|
| - SparseTCAM: Spikes x Weights -> Only activate weight rows hit by spikes -> O(n*k), k << n
|
|
|
| In GPU software implementation, we achieve \"sparse read\" via spike masks.
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.n_heads = config.n_heads
|
| self.head_dim = config.head_dim
|
| self.embed_dim = config.embed_dim
|
| self.threshold = config.spike_threshold
|
| self.leak = config.leak_factor
|
|
|
|
|
| self.tcam_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
|
|
|
| self.out_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
|
| self.norm = nn.LayerNorm(config.embed_dim)
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| def forward(self, spikes: torch.Tensor, continuous: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Args:
|
| spikes: [B, S, D] Input spikes (sparse 0/1)
|
| continuous: [B, S, D] Continuous representation (for residual connections)
|
| Returns:
|
| out_spikes: [B, S, D] Output spikes
|
| out_continuous: [B, S, D] Updated continuous representation
|
| """
|
| B, S, D = spikes.shape
|
|
|
|
|
|
|
| spike_energy = spikes.sum(dim=-1)
|
| active_mask = (spike_energy > 0).unsqueeze(-1).float()
|
|
|
|
|
|
|
| tcam_out = self.tcam_proj(spikes * active_mask)
|
|
|
|
|
| tcam_out = tcam_out.view(B, S, self.n_heads, self.head_dim)
|
|
|
|
|
|
|
| causal_cumsum = torch.cumsum(tcam_out, dim=1)
|
| counts = torch.arange(1, S + 1, device=spikes.device).float().view(1, S, 1, 1)
|
| context = causal_cumsum / counts
|
|
|
|
|
| context = context.view(B, S, D)
|
| output = self.out_proj(self.dropout(context))
|
|
|
|
|
| out_continuous = self.norm(continuous + output)
|
|
|
|
|
| out_spikes = surrogate_spike(out_continuous, self.threshold)
|
|
|
| return out_spikes, out_continuous
|
|
|
|
|
|
|
|
|
|
|
| class EntropyGate(nn.Module):
|
| """
|
| Innovation from S22 Entropy Engine: Calculate information entropy of current spike stream.
|
| Low entropy = Model is highly certain -> Can early exit, no need to run all layers.
|
| High entropy = Model is still confused -> Continue to deeper layers.
|
|
|
| Transformers lack this capability: regardless of query simplicity, all layers must execute.
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.threshold = config.entropy_exit_threshold
|
|
|
| self.gate = nn.Linear(config.embed_dim, 1)
|
|
|
| def forward(self, spikes: torch.Tensor) -> Tuple[torch.Tensor, bool]:
|
| """
|
| Returns:
|
| entropy: Information entropy of the current spike stream [B]
|
| should_exit: Whether early exit is recommended
|
| """
|
|
|
| firing_rate = spikes.mean(dim=-1)
|
|
|
|
|
| p = firing_rate.clamp(1e-7, 1 - 1e-7)
|
| entropy = -(p * p.log() + (1 - p) * (1 - p).log()).mean()
|
|
|
|
|
| should_exit = (entropy.item() < self.threshold) if not self.training else False
|
| return entropy, should_exit
|
|
|
|
|
|
|
|
|
|
|
| class SpikingFeedForward(nn.Module):
|
| """
|
| Replaces the 2-layer MLP of Transformers.
|
| Key difference: Intermediate layer uses LIF spike activation instead of GELU/ReLU.
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.up = nn.Linear(config.embed_dim, config.intermediate_dim, bias=False)
|
| self.down = nn.Linear(config.intermediate_dim, config.embed_dim, bias=False)
|
| self.norm = nn.LayerNorm(config.embed_dim)
|
| self.threshold = config.spike_threshold
|
| self.dropout = nn.Dropout(config.dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| residual = x
|
| h = self.up(x)
|
|
|
| h = surrogate_spike(h, self.threshold)
|
| h = self.down(self.dropout(h))
|
| return self.norm(residual + h)
|
|
|
|
|
|
|
|
|
|
|
| class SymbolicLightBlock(nn.Module):
|
| """
|
| A complete SymbolicLight layer, including:
|
| - SparseTCAM (In-memory compute routing)
|
| - SpikingFeedForward (Spiking feed-forward)
|
| - EntropyGate (Entropy gating)
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.tcam = SparseTCAM(config)
|
| self.ffn = SpikingFeedForward(config)
|
| self.entropy_gate = EntropyGate(config)
|
|
|
| def forward(self, spikes: torch.Tensor, continuous: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, bool]:
|
|
|
| spikes, continuous = self.tcam(spikes, continuous)
|
|
|
| continuous = self.ffn(continuous)
|
|
|
| spikes = surrogate_spike(continuous, self.tcam.threshold)
|
|
|
| _, should_exit = self.entropy_gate(spikes)
|
| return spikes, continuous, should_exit
|
|
|
|
|
|
|
|
|
|
|
| class BayesianHead(nn.Module):
|
| """
|
| Innovation from S100 LALU array: Use Bayesian posterior replacing Softmax.
|
|
|
| Softmax: P(word) = exp(logit) / Σexp(logits) <- Brutal normalization
|
| Bayesian: P(word|context) ∝ P(context|word) x P(word) <- Exact inference
|
|
|
| In V1, we approximate Bayesian updates using addition in the log domain.
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.output_proj = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
|
|
|
| self.log_prior = nn.Parameter(torch.zeros(config.vocab_size))
|
|
|
| def forward(self, continuous: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| continuous: [B, S, D]
|
| Returns:
|
| logits: [B, S, vocab_size] (Log probabilities, can be directly trained with CrossEntropy)
|
| """
|
|
|
| log_likelihood = self.output_proj(continuous)
|
|
|
|
|
| logits = log_likelihood + self.log_prior.unsqueeze(0).unsqueeze(0)
|
| return logits
|
|
|
|
|
|
|
|
|
|
|
| class STDPUpdater:
|
| """
|
| Innovation from S100 ILE Inductive Learning Engine.
|
|
|
| Activated ONLY during inference (model.eval()).
|
| No loss.backward() required, purely local learning rules based on spike timing.
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| self.lr = config.stdp_lr
|
| self.enabled = config.enable_stdp
|
|
|
| @torch.no_grad()
|
| def update(self, model: nn.Module, pre_spikes: torch.Tensor, post_spikes: torch.Tensor):
|
| """
|
| STDP Rules:
|
| - Pre-synaptic fires first -> Strengthen connection (LTP)
|
| - Post-synaptic fires first -> Weaken connection (LTD)
|
| """
|
| if not self.enabled:
|
| return
|
|
|
|
|
|
|
| causal = (pre_spikes.sum(dim=1, keepdim=True) > 0) & (post_spikes.sum(dim=1, keepdim=True) > 0)
|
|
|
| if causal.any():
|
|
|
| for block in model.blocks:
|
| w = block.tcam.tcam_proj.weight
|
|
|
| delta = self.lr * (pre_spikes.mean(dim=(0, 1)) @ post_spikes.mean(dim=(0, 1)).unsqueeze(-1))
|
| w.data += delta.squeeze() * 0.001
|
| w.data.clamp_(-5, 5)
|
|
|
|
|
|
|
|
|
|
|
| class SymbolicLightModel(nn.Module):
|
| """
|
| SymbolicLight: Next-generation Neuro-Symbolic Spiking Large Model
|
|
|
| Usage:
|
| config = SymbolicLightConfig()
|
| model = SymbolicLightModel(config)
|
|
|
| # Training
|
| logits = model(token_ids)
|
| loss = F.cross_entropy(logits.view(-1, config.vocab_size), targets.view(-1))
|
|
|
| # Inference (Autoregressive generation)
|
| output_ids = model.generate(prompt_ids, max_new_tokens=100)
|
| """
|
| def __init__(self, config: SymbolicLightConfig):
|
| super().__init__()
|
| self.config = config
|
| self.spike_encoder = SpikeEncoder(config)
|
| self.blocks = nn.ModuleList([
|
| SymbolicLightBlock(config) for _ in range(config.n_layers)
|
| ])
|
| self.output_head = BayesianHead(config)
|
| self.stdp = STDPUpdater(config)
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
| n_params = sum(p.numel() for p in self.parameters())
|
| print(f"[SymbolicLight] Model initialization complete | Parameters: {n_params/1e6:.1f}M ({n_params/1e9:.3f}B)")
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
| def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| """
|
| Forward propagation (Universal for training + inference)
|
| Args:
|
| token_ids: [batch, seq_len]
|
| Returns:
|
| logits: [batch, seq_len, vocab_size]
|
| """
|
|
|
| spikes, continuous = self.spike_encoder(token_ids)
|
| initial_spikes = spikes
|
|
|
|
|
| for block in self.blocks:
|
| spikes, continuous, should_exit = block(spikes, continuous)
|
|
|
|
|
|
|
|
|
| if should_exit and not self.training and self.config.enable_entropy_exit:
|
| break
|
|
|
|
|
| logits = self.output_head(continuous)
|
|
|
|
|
| if not self.training and self.config.enable_stdp:
|
| self.stdp.update(self, initial_spikes, spikes)
|
|
|
| return logits
|
|
|
| @torch.no_grad()
|
| def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 100,
|
| temperature: float = 0.8, top_k: int = 50) -> torch.Tensor:
|
| """
|
| Autoregressive text generation
|
|
|
| Args:
|
| prompt_ids: [1, prompt_len] Prompt token IDs
|
| max_new_tokens: Maximum number of new tokens to generate
|
| temperature: Sampling temperature (higher = more random)
|
| top_k: Sample only from the top k highest probability tokens
|
| Returns:
|
| Generated complete token sequence
|
| """
|
| self.eval()
|
| generated = prompt_ids.clone()
|
|
|
| for _ in range(max_new_tokens):
|
|
|
| input_ids = generated[:, -self.config.max_seq_len:]
|
|
|
|
|
| logits = self.forward(input_ids)
|
|
|
|
|
| next_logits = logits[:, -1, :] / temperature
|
|
|
|
|
| if top_k > 0:
|
| top_k_vals, _ = torch.topk(next_logits, top_k)
|
| min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| next_logits[next_logits < min_top_k] = float('-inf')
|
|
|
|
|
| probs = F.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| generated = torch.cat([generated, next_token], dim=1)
|
|
|
|
|
| if next_token.item() == 2:
|
| break
|
|
|
| return generated
|
|
|
| def get_sparsity_stats(self) -> dict:
|
| """Returns model sparsity statistics (for papers and debugging)"""
|
| stats = {}
|
| with torch.no_grad():
|
| dummy = torch.randint(0, 100, (1, 32))
|
| spikes, _ = self.spike_encoder(dummy)
|
| stats['encoder_sparsity'] = 1.0 - spikes.mean().item()
|
| for i, block in enumerate(self.blocks):
|
| spikes, _, _ = block(spikes, spikes)
|
| stats[f'block_{i}_sparsity'] = 1.0 - spikes.mean().item()
|
| return stats
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print(" SymbolicLight Model Architecture Validation")
|
| print("=" * 60)
|
|
|
| config = SymbolicLightConfig(
|
| vocab_size=32000,
|
| embed_dim=768,
|
| n_layers=12,
|
| n_heads=12,
|
| head_dim=64,
|
| intermediate_dim=2048,
|
| )
|
|
|
| model = SymbolicLightModel(config)
|
|
|
|
|
| dummy_input = torch.randint(0, 32000, (2, 128))
|
| print(f"\nInput: batch=2, seq_len=128")
|
|
|
|
|
| logits = model(dummy_input)
|
| print(f"Output logits: {logits.shape}")
|
|
|
|
|
| stats = model.get_sparsity_stats()
|
| print(f"\nSparsity analysis:")
|
| for k, v in stats.items():
|
| print(f" {k}: {v*100:.1f}% Silent")
|
|
|
|
|
| prompt = torch.randint(0, 32000, (1, 10))
|
| print(f"\nAutoregressive generation test (prompt length=10, generating 20 tokens)...")
|
| output = model.generate(prompt, max_new_tokens=20)
|
| print(f"Generated sequence length: {output.shape[1]}")
|
|
|
| print("\n[PASS] SymbolicLight model architecture verified!")
|
|
|