SymbolicLight-V1 / src /model.py
symboliclight-ai's picture
Upload SymbolicLight V1 open weights
5762a7c verified
#!/usr/bin/env python3
"""SymbolicLight V1 model implementation."""
import math
from dataclasses import dataclass, field
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class SymbolicLightConfig:
"""Default configuration for SymbolicLight V1."""
vocab_size: int = 57344
embed_dim: int = 1536
n_layers: int = 22
n_heads: int = 24
head_dim: int = 64
intermediate_dim: int = 6144
max_seq_len: int = 512
spike_chunk_size: int = 64
dropout: float = 0.1
spike_threshold: float = 1.0
leak_factor: float = 0.95
stdp_lr: float = 0.01
enable_stdp: bool = False
rope_theta: float = 10000.0
frontend_mode: str = "text"
sparse_attn_window: int = 512
n_global_anchors: int = 4
enable_sparse_attn: bool = True
enable_dynamic_prior: bool = True
use_topk_mask: bool = False
topk_sparsity: float = 0.89
class ATanSurrogate(torch.autograd.Function):
"""ATan surrogate-gradient spike function."""
@staticmethod
def forward(ctx, membrane_potential, threshold):
ctx.save_for_backward(membrane_potential, torch.tensor(threshold,
device=membrane_potential.device,
dtype=membrane_potential.dtype))
return (membrane_potential >= threshold).float()
@staticmethod
def backward(ctx, grad_output):
membrane_potential, threshold = ctx.saved_tensors
alpha = 2.0
grad_v = 1.0 / (1.0 + (alpha * (membrane_potential - threshold)) ** 2)
return grad_output * grad_v, None
def surrogate_spike(membrane_potential: torch.Tensor, threshold: float = 1.0) -> torch.Tensor:
"""Apply the surrogate spike function."""
return ATanSurrogate.apply(membrane_potential, threshold)
class RotaryPositionEncoding(nn.Module):
"""Rotary position embedding."""
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
"""
Apply RoPE rotation to the input tensor.
Args:
x: [B, S, D] input continuous representation
offset: position offset for incremental decoding
Returns:
rotated: [B, S, D] rotated representation
"""
B, S, D = x.shape
t = torch.arange(offset, offset + S, device=x.device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(x.device))
emb = torch.cat([freqs, freqs], dim=-1)
cos_emb = emb.cos().unsqueeze(0)
sin_emb = emb.sin().unsqueeze(0)
x_rotated = torch.cat([
-x[..., D // 2:],
x[..., :D // 2],
], dim=-1)
return x * cos_emb + x_rotated * sin_emb
class FrontendRouter(nn.Module):
"""Text embedding frontend."""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.config = config
self.text_embedding = nn.Embedding(config.vocab_size, config.embed_dim)
def forward(self, token_ids: torch.Tensor, modality: str = "text") -> torch.Tensor:
if modality == "text":
return self.text_embedding(token_ids)
elif modality == "vision":
raise NotImplementedError("Vision frontend is not included in this release.")
elif modality == "audio":
raise NotImplementedError("Audio frontend is not included in this release.")
else:
raise ValueError(f"Unknown modality: {modality}")
def _lif_scan_forward(x: torch.Tensor, v_mem: torch.Tensor,
leak: float, threshold: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
JIT forward pass for temporal LIF neuron scanning.
Returns: (spikes [B,S,D], final_v_mem [B,D], all_v_mem [B,S,D] for backward)
"""
B, S, D = x.shape
spikes = torch.empty_like(x)
all_v = torch.empty_like(x)
for t in range(S):
v_mem = v_mem * leak + x[:, t, :]
v_mem = torch.clamp(v_mem, -3.0, 3.0)
all_v[:, t, :] = v_mem
spike = (v_mem >= threshold).float()
v_mem = v_mem * (1.0 - spike)
spikes[:, t, :] = spike
return spikes, v_mem, all_v
class LIFScan(torch.autograd.Function):
"""LIF scan with ATan surrogate gradient for backward."""
@staticmethod
def forward(ctx, x, v_mem, leak, threshold):
spikes, final_v, all_v = _lif_scan_forward(x, v_mem, leak, threshold)
ctx.save_for_backward(all_v)
ctx.threshold = threshold
return spikes, final_v
@staticmethod
def backward(ctx, grad_spikes, grad_v_mem):
all_v, = ctx.saved_tensors
alpha = 2.0
surrogate_grad = 1.0 / (1.0 + (alpha * (all_v - ctx.threshold)) ** 2)
grad_x = grad_spikes * surrogate_grad
return grad_x, None, None, None
class SpikeEncoder(nn.Module):
"""
Convert discrete token IDs into spatiotemporal spike tensors.
Main design updates:
- remove learned positional embeddings and use RoPE in SparseTCAM
- use chunk-parallel LIF spike conversion to reduce Python loops
- route token embeddings through FrontendRouter instead of a hard-coded embedding
Flow: token_id -> FrontendRouter -> LayerNorm -> parallel LIF spike conversion
"""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.config = config
self.frontend = FrontendRouter(config)
self.norm = nn.LayerNorm(config.embed_dim)
self.threshold = config.spike_threshold
self.leak = config.leak_factor
self.v_mem = None
def _init_membrane(self, shape: torch.Size, device: torch.device):
"""Initialize or reset the membrane potential."""
self.v_mem = torch.zeros(shape, device=device)
def forward(self, token_ids: torch.Tensor, use_cache: bool = False,
cache: dict = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
token_ids: [batch, seq_len]
use_cache: if True, run O(1) incremental decoding
cache: cache dictionary
Returns:
spikes: [batch, seq_len, embed_dim] sparse 0/1 spikes
continuous: [batch, seq_len, embed_dim] continuous residual stream
"""
B, S = token_ids.shape
if use_cache and cache is not None:
if 'v_mem' not in cache:
cache['v_mem'] = torch.zeros(B, self.config.embed_dim, device=token_ids.device)
if 'seq_len' not in cache:
cache['seq_len'] = 0
self.v_mem = cache['v_mem']
cache['seq_len'] += S
else:
self._init_membrane((B, self.config.embed_dim), token_ids.device)
x = self.frontend(token_ids)
x = self.norm(x)
if getattr(self.config, 'use_topk_mask', False):
k = max(1, int((1.0 - self.config.topk_sparsity) * self.config.embed_dim))
_, topk_indices = torch.topk(x.abs(), k, dim=-1)
spikes = torch.zeros_like(x)
spikes.scatter_(-1, topk_indices, 1.0)
if self.training:
spikes = spikes + (surrogate_spike(x, self.threshold) - spikes).detach()
else:
chunk_size = self.config.spike_chunk_size
spikes_list = []
for chunk_start in range(0, S, chunk_size):
chunk_end = min(chunk_start + chunk_size, S)
x_chunk = x[:, chunk_start:chunk_end, :]
chunk_spikes, self.v_mem = LIFScan.apply(
x_chunk, self.v_mem, self.leak, self.threshold
)
spikes_list.append(chunk_spikes)
if self.training:
self.v_mem = self.v_mem.detach()
spikes = torch.cat(spikes_list, dim=1)
if use_cache and cache is not None:
cache['v_mem'] = self.v_mem.detach()
return spikes, x
class SparseLocalAttention(nn.Module):
"""
Compute attention only among active spike positions with a local window and global anchors.
Key idea:
- dense attention attends across all S positions -> O(S^2)
- this path only attends over active positions inside a local window -> O(S * k * w)
where k is the active fraction and w is the window size
- global anchors let the first few tokens interact broadly and stabilize global context
Relation to the decay path:
- the decay path compresses history into a fixed-size hidden state for coarse long-range memory
- the attention path focuses precisely on recent informative positions for local reasoning
- a learned gate blends both paths
"""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.window_size = max(1, int(config.sparse_attn_window))
self.n_global_anchors = config.n_global_anchors
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.scale = config.head_dim ** -0.5
self._use_sdpa = hasattr(F, "scaled_dot_product_attention")
self.q_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
self.k_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
self.v_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
self.rope = RotaryPositionEncoding(config.head_dim, theta=config.rope_theta)
S = config.max_seq_len
q_pos = torch.arange(S).unsqueeze(1)
k_pos = torch.arange(S).unsqueeze(0)
distance = q_pos - k_pos
causal = distance >= 0
window = (q_pos - k_pos) <= self.window_size
anchors = k_pos < self.n_global_anchors
self.register_buffer('_cached_mask', causal & (window | anchors))
def forward(self, x: torch.Tensor, spike_mask: torch.Tensor,
offset: int = 0, use_cache: bool = False, cache: dict = None) -> torch.Tensor:
"""
Args:
x: [B, S_q, D] continuous representation; RoPE is applied internally to Q/K
spike_mask: [B, S_q] boolean mask, True means the position fired a spike
offset: RoPE position offset for incremental decoding
use_cache: whether to use the KV cache for incremental decoding
cache: KV cache dictionary
Returns:
attn_out: [B, S_q, D] sparse attention output with zeros on inactive positions
"""
B, S_q, D = x.shape
Q = self.q_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(x).view(B, S_q, self.n_heads, self.head_dim).transpose(1, 2)
Q = self.rope(Q.contiguous().view(B * self.n_heads, S_q, self.head_dim), offset=offset)
Q = Q.view(B, self.n_heads, S_q, self.head_dim).to(V.dtype)
K = self.rope(K.contiguous().view(B * self.n_heads, S_q, self.head_dim), offset=offset)
K = K.view(B, self.n_heads, S_q, self.head_dim).to(V.dtype)
if use_cache and cache is not None:
if 'K' in cache:
K = torch.cat([cache['K'], K], dim=2)
V = torch.cat([cache['V'], V], dim=2)
spike_mask_kv = torch.cat([cache['spike_mask'], spike_mask], dim=1)
else:
spike_mask_kv = spike_mask
cache['K'] = K.detach()
cache['V'] = V.detach()
cache['spike_mask'] = spike_mask_kv.detach()
else:
spike_mask_kv = spike_mask
S_kv = K.size(2)
if offset == 0 and S_q == S_kv and S_q == self._cached_mask.size(0):
attn_mask = self._cached_mask
else:
q_pos = torch.arange(offset, offset + S_q, device=x.device).unsqueeze(1)
k_pos = torch.arange(0, S_kv, device=x.device).unsqueeze(0)
distance = q_pos - k_pos
causal = distance >= 0
window = distance <= self.window_size
anchors = k_pos < self.n_global_anchors
attn_mask = causal & (window | anchors)
spike_key_mask = spike_mask_kv.unsqueeze(1).unsqueeze(2)
full_mask = attn_mask.unsqueeze(0).unsqueeze(0) & spike_key_mask
query_has_any_key = full_mask.any(dim=-1, keepdim=True)
if self._use_sdpa:
safe_mask = full_mask | ~query_has_any_key
attn_out = F.scaled_dot_product_attention(Q, K, V, attn_mask=safe_mask, dropout_p=0.0)
attn_out = attn_out.masked_fill(~query_has_any_key, 0.0)
else:
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
scores = scores.masked_fill(~full_mask, float('-inf'))
scores = scores.masked_fill(~query_has_any_key, 0.0)
attn_weights = F.softmax(scores, dim=-1).to(V.dtype)
attn_weights = attn_weights.masked_fill(~query_has_any_key, 0.0)
attn_out = torch.matmul(attn_weights, V)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S_q, D)
attn_out = attn_out * spike_mask.unsqueeze(-1).to(dtype=attn_out.dtype)
return attn_out
class SparseTCAM(nn.Module):
"""Dual-path spike-gated sequence mixer."""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.head_dim
self.embed_dim = config.embed_dim
self.threshold = config.spike_threshold
self.leak = config.leak_factor
self.enable_sparse_attn = config.enable_sparse_attn
self.tcam_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
self.out_proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.dropout)
self.decay_raw = nn.Parameter(torch.full((config.n_heads,), 3.0))
if self.enable_sparse_attn:
self.sparse_attn = SparseLocalAttention(config)
self.attn_gate = nn.Parameter(torch.zeros(1))
def forward(self, spikes: torch.Tensor, continuous: torch.Tensor,
use_cache: bool = False, cache: dict = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Main updates:
1. apply RoPE inside SparseLocalAttention after Q/K projection
2. allow training-time cache reads and writes for hidden state h across chunks
"""
B, S, D = spikes.shape
compute_dtype = continuous.dtype
if spikes.dtype != compute_dtype:
spikes = spikes.to(compute_dtype)
offset = 0
if use_cache and cache is not None:
offset = cache.get('rope_offset', 0)
cache['rope_offset'] = offset + S
spike_energy = spikes.sum(dim=-1)
active_mask = (spike_energy > 0).unsqueeze(-1).to(dtype=compute_dtype)
tcam_out = self.tcam_proj(spikes * active_mask)
tcam_out = tcam_out.view(B, S, self.n_heads, self.head_dim)
decay = torch.sigmoid(self.decay_raw)
if cache is not None:
if 'h' not in cache:
cache['h'] = torch.zeros(B, self.n_heads, self.head_dim, device=spikes.device, dtype=compute_dtype)
h = cache['h']
else:
h = torch.zeros(B, self.n_heads, self.head_dim, device=spikes.device, dtype=compute_dtype)
if use_cache and cache is not None and S == 1:
h = decay.view(1, self.n_heads, 1) * h + (1 - decay.view(1, self.n_heads, 1)) * tcam_out[:, 0]
cache['h'] = h.detach()
context = h.unsqueeze(1)
else:
powers = torch.arange(S - 1, -1, -1, dtype=compute_dtype, device=spikes.device)
kernel = ((decay.view(-1, 1) ** powers.view(1, -1)) * (1 - decay).view(-1, 1)).unsqueeze(1)
tcam_out_trans = tcam_out.permute(0, 3, 2, 1).reshape(-1, self.n_heads, S)
tcam_out_pad = F.pad(tcam_out_trans, (S - 1, 0))
out = F.conv1d(tcam_out_pad, kernel, groups=self.n_heads)
context = out.view(-1, self.head_dim, self.n_heads, S).permute(0, 3, 2, 1)
powers_fwd = torch.arange(1, S + 1, dtype=compute_dtype, device=spikes.device).view(1, S, 1, 1)
decay_t = decay.view(1, 1, self.n_heads, 1) ** powers_fwd
context = context + h.unsqueeze(1) * decay_t
if cache is not None:
cache['h'] = context[:, -1, :, :].detach()
decay_output = context.reshape(B, S, D)
if self.enable_sparse_attn:
spike_mask = (spikes.sum(dim=-1) > 0)
attn_cache = cache.setdefault('attn', {}) if cache is not None else None
attn_output = self.sparse_attn(
continuous, spike_mask, offset=offset,
use_cache=use_cache, cache=attn_cache
)
gate = torch.sigmoid(self.attn_gate)
output = gate * attn_output + (1 - gate) * decay_output
else:
output = decay_output
output = self.out_proj(self.dropout(output))
out_continuous = self.norm(continuous + output)
out_spikes = surrogate_spike(out_continuous, self.threshold).to(out_continuous.dtype)
return out_spikes, out_continuous
class SpikingFeedForward(nn.Module):
"""
Two-layer feed-forward block used in place of the standard Transformer MLP.
The main difference is the LIF-style spike activation in the hidden layer.
"""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.up = nn.Linear(config.embed_dim, config.intermediate_dim, bias=False)
self.down = nn.Linear(config.intermediate_dim, config.embed_dim, bias=False)
self.norm = nn.LayerNorm(config.embed_dim)
self.threshold = config.spike_threshold
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
h = self.up(x)
h = surrogate_spike(h, self.threshold).to(x.dtype)
h = self.down(self.dropout(h))
return self.norm(residual + h)
class SymbolicLightBlock(nn.Module):
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.tcam = SparseTCAM(config)
self.ffn = SpikingFeedForward(config)
def forward(self, spikes, continuous, use_cache=False, cache=None):
spikes, continuous = self.tcam(spikes, continuous, use_cache=use_cache, cache=cache)
continuous = self.ffn(continuous)
spikes = surrogate_spike(continuous, self.tcam.threshold).to(continuous.dtype)
return spikes, continuous
class BayesianHead(nn.Module):
"""
Dynamic context-conditioned prior head.
Earlier versions used a static learned log_prior vector.
This version predicts log_prior from the current context with a lightweight network.
Bayesian form:
log P(word|context) = log P(context|word) + log P(word|context_summary)
likelihood term dynamic prior term
Intuition:
- when the context is about cooking, the prior can upweight tokens such as salt or pan
- when the context is about programming, the prior can upweight tokens such as function or loop
- this is more targeted than a static frequency bias
"""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.output_proj = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
self.use_dynamic_prior = getattr(config, 'enable_dynamic_prior', True)
self.prior_weight = nn.Parameter(torch.tensor(0.1))
if self.use_dynamic_prior:
bottleneck_dim = config.embed_dim // 4
self.prior_net = nn.Sequential(
nn.Linear(config.embed_dim, bottleneck_dim, bias=False),
nn.GELU(),
nn.Linear(bottleneck_dim, config.vocab_size, bias=False),
)
else:
self.log_prior = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, continuous: torch.Tensor) -> torch.Tensor:
"""
Args:
continuous: [B, S, D]
Returns:
logits: [B, S, vocab_size]
"""
log_likelihood = self.output_proj(continuous)
if self.use_dynamic_prior:
dynamic_prior = self.prior_net(continuous)
logits = log_likelihood + self.prior_weight * dynamic_prior
else:
logits = log_likelihood + self.prior_weight * self.log_prior
return logits
class STDPUpdater:
"""Optional local spike-timing update rule."""
def __init__(self, config: SymbolicLightConfig):
self.lr = config.stdp_lr
self.enabled = config.enable_stdp
@torch.no_grad()
def update(self, model: nn.Module, pre_spikes: torch.Tensor, post_spikes: torch.Tensor):
if not self.enabled:
return
causal = (pre_spikes.sum(dim=1, keepdim=True) > 0) & (post_spikes.sum(dim=1, keepdim=True) > 0)
if causal.any():
for block in model.blocks:
w = block.tcam.tcam_proj.weight
pre_active = (pre_spikes > 0).float()
post_active = (post_spikes > 0).float()
co_firing = torch.einsum('bsd,bse->de', post_active, pre_active)
delta = self.lr * co_firing / (pre_spikes.size(0) * pre_spikes.size(1))
mask = (co_firing > 0).float()
w.data += delta * mask * 0.05
w.data.clamp_(-5, 5)
class SymbolicLightModel(nn.Module):
"""SymbolicLight language model."""
def __init__(self, config: SymbolicLightConfig):
super().__init__()
self.config = config
self.spike_encoder = SpikeEncoder(config)
self.blocks = nn.ModuleList([
SymbolicLightBlock(config) for _ in range(config.n_layers)
])
self.output_head = BayesianHead(config)
self.stdp = STDPUpdater(config)
self.gradient_checkpointing = False
self.apply(self._init_weights)
n_params = sum(p.numel() for p in self.parameters())
print(f"[SymbolicLight V1] model initialized | parameters: {n_params/1e6:.1f}M ({n_params/1e9:.3f}B)")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def compile_for_inference(self):
"""Apply torch.compile to selected inference-critical submodules."""
self.spike_encoder = torch.compile(self.spike_encoder, mode='reduce-overhead')
for block in self.blocks:
block.tcam = torch.compile(block.tcam, mode='reduce-overhead')
block.ffn = torch.compile(block.ffn, mode='reduce-overhead')
print("[SymbolicLight V1] torch.compile applied for inference acceleration")
def gradient_checkpointing_enable(self):
self.gradient_checkpointing = True
def gradient_checkpointing_disable(self):
self.gradient_checkpointing = False
def forward(self, token_ids: torch.Tensor, use_cache: bool = False,
past_key_values: list = None,
streaming_state: list = None):
"""
Shared forward pass for training and inference.
Args:
token_ids: [B, S] input token IDs
use_cache: whether to use the KV cache for inference
past_key_values: list of inference caches
streaming_state: list of streaming caches used to carry hidden state across chunks
during training; format matches past_key_values:
[encoder_cache, block0_cache, block1_cache, ...]
Returns:
logits: [B, S, vocab_size]
"""
if use_cache and past_key_values is None:
past_key_values = [{} for _ in range(len(self.blocks) + 1)]
if not use_cache and streaming_state is not None:
caches = streaming_state
elif use_cache:
caches = past_key_values
else:
caches = [None] * (len(self.blocks) + 1)
encoder_cache = caches[0] if caches[0] is not None else (
past_key_values[0] if use_cache else None
)
spikes, continuous = self.spike_encoder(token_ids, use_cache=use_cache, cache=encoder_cache)
model_dtype = self.output_head.output_proj.weight.dtype
if continuous.dtype != model_dtype:
continuous = continuous.to(model_dtype)
if spikes.dtype != model_dtype:
spikes = spikes.to(model_dtype)
initial_spikes = spikes
for i, block in enumerate(self.blocks):
block_cache = caches[i + 1] if caches[i + 1] is not None else (
past_key_values[i + 1] if use_cache else None
)
if self.training and self.gradient_checkpointing and not use_cache and block_cache is None:
def _checkpointed_block(spk, cont, current_block=block):
out_spikes, out_continuous = current_block(
spk, cont, use_cache=False, cache=None,
)
return out_spikes, out_continuous
spikes, continuous = torch.utils.checkpoint.checkpoint(
_checkpointed_block, spikes, continuous,
use_reentrant=False,
)
else:
spikes, continuous = block(
spikes, continuous,
use_cache=use_cache, cache=block_cache,
)
logits = self.output_head(continuous)
if not self.training and self.config.enable_stdp and initial_spikes.size(1) > 1:
self.stdp.update(self, initial_spikes, spikes)
return logits
@torch.no_grad()
def generate(self, prompt_ids: torch.Tensor, max_new_tokens: int = 100,
temperature: float = 0.8, top_k: int = 50,
adaptive_temperature: bool = True) -> torch.Tensor:
"""
Autoregressive text generation with O(1) cached incremental decoding.
Adaptive temperature:
- lower entropy -> lower temperature for more deterministic outputs
- higher entropy -> higher temperature for more exploratory outputs
- effective range is approximately [0.3, 1.5]
"""
self.eval()
generated = prompt_ids.clone()
past_key_values = [{} for _ in range(len(self.blocks) + 1)]
logits = self.forward(prompt_ids, use_cache=True, past_key_values=past_key_values)
def _adaptive_temp(raw_logits, base_temp):
"""Adjust temperature dynamically from the logits entropy."""
if not adaptive_temperature:
return base_temp
probs = F.softmax(raw_logits, dim=-1)
p = probs.clamp(1e-7, 1.0)
entropy = -(p * p.log()).sum(dim=-1).mean()
max_entropy = math.log(self.config.vocab_size)
norm_entropy = (entropy / max_entropy).clamp(0, 1)
temp = max(0.1, base_temp - norm_entropy.item() * (base_temp - 0.1))
return temp
raw_logits = logits[:, -1, :]
temp = _adaptive_temp(raw_logits, temperature)
next_logits = raw_logits / temp
if top_k > 0:
top_k_vals, _ = torch.topk(next_logits, top_k)
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
next_logits[next_logits < min_top_k] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
for _ in range(1, max_new_tokens):
if next_token.item() == 2:
break
logits = self.forward(next_token, use_cache=True, past_key_values=past_key_values)
raw_logits = logits[:, -1, :]
temp = _adaptive_temp(raw_logits, temperature)
next_logits = raw_logits / temp
if top_k > 0:
top_k_vals, _ = torch.topk(next_logits, top_k)
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
next_logits[next_logits < min_top_k] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated
def get_sparsity_stats(self) -> dict:
"""Return sparsity statistics for debugging and reporting."""
stats = {}
with torch.no_grad():
dummy = torch.randint(0, 100, (1, 32))
spikes, _ = self.spike_encoder(dummy)
stats['encoder_sparsity'] = 1.0 - spikes.mean().item()
for i, block in enumerate(self.blocks):
spikes, _ = block(spikes, spikes)
stats[f'block_{i}_sparsity'] = 1.0 - spikes.mean().item()
return stats
if __name__ == "__main__":
print("=" * 60)
print(" SymbolicLight V1 model smoke test")
print("=" * 60)
config = SymbolicLightConfig(
vocab_size=57344,
embed_dim=768,
n_layers=12,
n_heads=12,
head_dim=64,
)
model = SymbolicLightModel(config)
dummy_input = torch.randint(0, 57344, (2, 128))
print(f"\nInput: batch=2, seq_len=128")
logits = model(dummy_input)
print(f"Output logits: {logits.shape}")
print(f"\nStreaming context test (2 chunks x 128 tokens)...")
chunk1 = torch.randint(0, 57344, (2, 128))
chunk2 = torch.randint(0, 57344, (2, 128))
streaming_state = [{} for _ in range(len(model.blocks) + 1)]
logits1 = model(chunk1, streaming_state=streaming_state)
print(f" Chunk 1 logits: {logits1.shape}, streaming state saved [OK]")
logits2 = model(chunk2, streaming_state=streaming_state)
print(f" Chunk 2 logits: {logits2.shape}, cross-chunk memory passed [OK]")
stats = model.get_sparsity_stats()
print(f"\nSparsity stats:")
for k, v in stats.items():
print(f" {k}: {v*100:.1f}% silent")
prompt = torch.randint(0, 57344, (1, 10))
print(f"\nAutoregressive generation test (prompt=10, gen=20)...")
output = model.generate(prompt, max_new_tokens=20)
print(f"Generated sequence length: {output.shape[1]}")
print("\n[PASS] SymbolicLight V1 smoke checks completed.")
print(" [1] RoPE rotary position encoding [OK]")
print(" [2] Cross-chunk state passing [OK]")
print(" [3] BayesianHead dynamic prior [OK]")
print(" [4] SpikeEncoder parallel scan [OK]")
print(" [5] FrontendRouter multimodal stub [OK]")