eyla-v1-audit / model /backbone.py
Adiuk's picture
Upload folder using huggingface_hub
0e849e2 verified
"""
backbone.py — Eyla V2 Custom Hybrid Backbone
===============================================
Llama-3.2-1B compatible architecture with custom zero-cost extensions.
Architecture:
- 24 transformer layers (Llama-compatible for weight transplant)
- Grouped Query Attention (32 heads, 8 KV heads)
- RoPE (Rotary Position Embedding)
- RMSNorm + SiLU-gated MLP
- SSM side-cars at layers 4, 8, 12, 16, 20 (HiPPO init)
- Heuristic surprise gates (no learned params)
- Heuristic early exit (confidence-based)
- Heuristic complexity estimator (entropy-based)
Zero-cost design:
- Donor weights transplanted into all 24 layers → works on day 1
- SSM side-cars start as no-ops (gate=0) → no interference
- Heuristic gates need no training
- Online learning gradually activates SSM contribution
Naming convention matches LlamaForCausalLM for weight transplant:
- token_embedding ← model.embed_tokens
- layers.{i}.* ← model.layers.{i}.*
- final_norm ← model.norm
- lm_head ← lm_head
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, List, Tuple
import math
import logging
from .ssm_block import SSMBlock
from .heuristic_gates import HeuristicGates
logger = logging.getLogger(__name__)
# ── Default config matching Llama 3.2 1B ────────────────────────────────────
EYLA_V2_CONFIG = {
"hidden_size": 2048,
"intermediate_size": 8192,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_layers": 24,
"vocab_size": 128256,
"rms_norm_eps": 1e-5,
"rope_theta": 500000.0,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"max_position_embeddings": 131072,
"tie_word_embeddings": True,
# Eyla custom — SSM side-cars every 4 layers (BUILD_PLAN spec)
"ssm_layers": [4, 8, 12, 16, 20],
"ssm_state_dim": 64,
"ssm_dt": 0.01,
"side_car_init_std": 1e-5,
"early_exit_confidence": 0.9,
"early_exit_min_layers": 8,
"surprise_threshold": 4.0,
}
# ── Building blocks ─────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (matches LlamaRMSNorm)."""
def __init__(self, hidden_size: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) — matches Llama 3 implementation with rope_scaling."""
def __init__(self, dim: int, theta: float = 500000.0, rope_scaling: Optional[Dict] = None):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
# Apply Llama 3 rope scaling if configured
if rope_scaling is not None and rope_scaling.get("rope_type") == "llama3":
inv_freq = self._apply_llama3_scaling(inv_freq, rope_scaling)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._max_cached = 0
self._cos_cached = None
self._sin_cached = None
@staticmethod
def _apply_llama3_scaling(inv_freq: torch.Tensor, scaling: Dict) -> torch.Tensor:
"""Apply Llama 3 frequency scaling (matches HF transformers)."""
factor = scaling["factor"]
low_freq_factor = scaling.get("low_freq_factor", 1.0)
high_freq_factor = scaling.get("high_freq_factor", 4.0)
old_context_len = scaling.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq.item()
if wavelen < high_freq_wavelen:
new_freqs.append(freq.item())
elif wavelen > low_freq_wavelen:
new_freqs.append(freq.item() / factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq.item() / factor + smooth * freq.item())
return torch.tensor(new_freqs, dtype=inv_freq.dtype)
def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
if seq_len <= self._max_cached and self._cos_cached is not None:
return
self._max_cached = max(seq_len, 2048)
t = torch.arange(self._max_cached, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device))
emb = torch.cat([freqs, freqs], dim=-1) # (seq, dim)
self._cos_cached = emb.cos().to(dtype)
self._sin_cached = emb.sin().to(dtype)
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""
Args:
x: (B, n_heads, S, head_dim)
position_ids: (B, S) or (1, S)
Returns:
cos, sin: (1, 1, S, head_dim) for broadcasting
"""
seq_len = position_ids.max().item() + 1
self._build_cache(seq_len, x.device, x.dtype)
# Gather by position
cos = self._cos_cached[position_ids].unsqueeze(1) # (B, 1, S, dim)
sin = self._sin_cached[position_ids].unsqueeze(1)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input for RoPE."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary position embeddings to query and key tensors."""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# ── Attention ────────────────────────────────────────────────────────────────
class Attention(nn.Module):
"""
Grouped Query Attention (GQA) — matches LlamaAttention.
32 query heads, 8 KV heads (4:1 ratio).
"""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.hidden_size = config["hidden_size"]
self.num_heads = config["num_attention_heads"]
self.num_kv_heads = config["num_key_value_heads"]
self.head_dim = self.hidden_size // self.num_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(
self.head_dim,
theta=config.get("rope_theta", 500000.0),
rope_scaling=config.get("rope_scaling"),
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, S, _ = hidden_states.shape
# Project Q, K, V
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
cos, sin = self.rotary_emb(q, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# KV cache: concatenate with past keys/values
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
new_kv = (k, v) if use_cache else None
# Repeat KV heads for GQA
k_expanded = k.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else k
v_expanded = v.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else v
# Scaled dot-product attention
KV_LEN = k_expanded.shape[2]
scale = 1.0 / math.sqrt(self.head_dim)
attn_weights = torch.matmul(q, k_expanded.transpose(-2, -1)) * scale
# Causal mask (Q_len x KV_len)
causal_mask = torch.triu(
torch.full((S, KV_LEN), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype),
diagonal=KV_LEN - S + 1,
)
attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0)
# Padding mask
if attention_mask is not None:
pad_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2).float()) * float("-inf")
attn_weights = attn_weights + pad_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v_expanded)
# Merge heads
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, self.hidden_size)
return self.o_proj(attn_output), new_kv
# ── MLP ──────────────────────────────────────────────────────────────────────
class MLP(nn.Module):
"""SiLU-gated MLP — matches LlamaMLP."""
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.gate_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
self.up_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False)
self.down_proj = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# ── Transformer Layer ────────────────────────────────────────────────────────
class TransformerLayer(nn.Module):
"""
Single transformer layer — matches LlamaDecoderLayer naming.
Sub-module names must match for weight transplant:
self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.o_proj
mlp.gate_proj, mlp.up_proj, mlp.down_proj
input_layernorm, post_attention_layernorm
Layers 16-23 (duplicated from donor 8-15) have a learnable layer_gate
that starts at 0.0 so they act as pass-through on day 1. This prevents
the duplicated layers from breaking the hidden state distribution.
Online learning gradually opens the gate.
"""
def __init__(self, config: Dict[str, Any], layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
num_layers = config.get("num_layers", 24)
donor_layers = config.get("donor_layers", 16)
# Standard Llama components
self.self_attn = Attention(config)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5))
self.post_attention_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5))
# Deep init scaling (GPT-2 style) — prevents NaN with random weights
# These weights will be overwritten by donor transplant anyway
init_scale = 1.0 / math.sqrt(2 * num_layers)
nn.init.normal_(self.self_attn.o_proj.weight, std=0.02 * init_scale)
nn.init.normal_(self.mlp.down_proj.weight, std=0.02 * init_scale)
# Duplicate layer gate: layers >= donor_layers start as pass-through (gate=0).
# On day 1: output = input + gate * layer_output = input (since gate=0)
# Through online learning: gate opens, layer contributes.
self.is_duplicate = layer_idx >= donor_layers
if self.is_duplicate:
self.layer_gate = nn.Parameter(torch.tensor(0.0))
# ── Brain Region Labels ─────────────────────────────────────────
# PFC subdivision labels (layers 16-23 map to prefrontal cortex regions)
_pfc_regions = {
16: "dlPFC (Working Memory)",
17: "dlPFC (Working Memory)",
18: "vmPFC (Value/Emotion)",
19: "vmPFC (Value/Emotion)",
20: "OFC (Outcome Prediction)",
21: "vlPFC (Response Inhibition)",
22: "vlPFC (Response Inhibition)",
23: "Anterior PFC (Metacognition)",
}
self.pfc_region = _pfc_regions.get(layer_idx, None)
# SSM brain region labels (5 side-cars = 5 brain regions)
_ssm_brain_regions = {
4: "Secondary Sensory Cortex",
8: "Superior Temporal Sulcus",
12: "Temporal-Parietal Junction",
16: "Dorsolateral PFC",
20: "Anterior PFC / Frontal Pole",
}
# SSM side-car (only at specific layers)
self.has_ssm = layer_idx in config.get("ssm_layers", [])
if self.has_ssm:
self.ssm = SSMBlock(
d_model=config["hidden_size"],
state_dim=config.get("ssm_state_dim", 64),
dt=config.get("ssm_dt", 0.01),
init_std=config.get("side_car_init_std", 1e-5),
)
self.ssm.brain_region = _ssm_brain_regions.get(layer_idx, f"SSM@L{layer_idx}")
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Save layer input for duplicate gating
layer_input = hidden_states
# Pre-norm attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, new_kv = self.self_attn(
hidden_states, position_ids, attention_mask,
past_key_value=past_key_value, use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Pre-norm MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# Duplicate layer gate: on day 1, gate=0 → layer is pass-through.
# output = input + gate * (layer_output - input)
# At gate=0: output = input (skip layer entirely)
# As gate opens: layer gradually contributes
if self.is_duplicate:
gate = self.layer_gate * torch.sigmoid(self.layer_gate)
hidden_states = layer_input + gate * (hidden_states - layer_input)
# SSM side-car (additive — no interference on day 1)
if self.has_ssm:
hidden_states = hidden_states + self.ssm(hidden_states)
return hidden_states, new_kv
# ── Full Model ───────────────────────────────────────────────────────────────
class EylaBackbone(nn.Module):
"""
Eyla V2 Custom Hybrid Backbone.
Llama-3.2-1B compatible for weight transplant, with custom extensions:
- SSM side-cars (HiPPO init, zero-gated on day 1)
- Heuristic surprise gates
- Heuristic early exit
- Heuristic complexity estimator
The model works on day 1 after weight transplant with zero training.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
self.config = config or EYLA_V2_CONFIG.copy()
hidden_size = self.config["hidden_size"]
num_layers = self.config["num_layers"]
vocab_size = self.config["vocab_size"]
# Embeddings (matches Llama naming for transplant)
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
TransformerLayer(self.config, layer_idx=i)
for i in range(num_layers)
])
# Final norm
self.final_norm = RMSNorm(hidden_size, self.config.get("rms_norm_eps", 1e-5))
# Output head
if self.config.get("tie_word_embeddings", True):
self.lm_head = None # Use token_embedding.weight
else:
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
# Memory compressor: use last hidden state (no extra module needed)
# But keep a simple linear for compatibility with MemoryRetriever (256-d)
self.memory_compressor = nn.Linear(hidden_size, 256, bias=False)
nn.init.normal_(self.memory_compressor.weight, std=self.config.get("side_car_init_std", 1e-5))
# Memory agents at layers 7 and 15: predict expected hidden state
# Comparison of predicted vs actual = surprise signal for online learning
# Lazy-initialized via enable_memory_agents() to avoid OOM during model construction
self.memory_agent_layers = [7, 15]
self.memory_agents = None
self._memory_agent_predictions = {}
# Heuristic gates (NOT nn.Module — no parameters)
self.gates = HeuristicGates(
surprise_threshold=self.config.get("surprise_threshold", 4.0),
exit_confidence=self.config.get("early_exit_confidence", 0.9),
exit_min_layers=self.config.get("early_exit_min_layers", 4),
)
# Brain orchestrator (disabled by default — call enable_brain() to activate)
self.brain = None
def get_lm_head_weight(self) -> torch.Tensor:
"""Get the output projection weight (handles tied embeddings)."""
if self.lm_head is not None:
return self.lm_head.weight
return self.token_embedding.weight
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Raw token embeddings (before any transformer layers)."""
return self.token_embedding(input_ids)
def enable_brain(self, config: Optional[Dict[str, Any]] = None):
"""
Activate the brain orchestrator (86 brain systems).
All gates start at 0 → day-1 identity preserved.
Brain params are trainable; donor params should be frozen separately.
"""
from .brain_orchestrator import BrainOrchestrator
self.brain = BrainOrchestrator(
d_model=self.config["hidden_size"],
state_dim=self.config.get("ssm_state_dim", 64),
config=config,
)
brain_summary = self.brain.param_summary()
logger.info(
f"Brain enabled: {brain_summary['total_brain_params']:,} params "
f"(gates: {brain_summary['gate_params']}, "
f"nn_modules: {brain_summary['nn_module_params']:,})"
)
def enable_memory_agents(self):
"""Initialize memory agents at layers 7 and 15 (call after model load to avoid OOM)."""
hidden_size = self.config["hidden_size"]
bottleneck = 128
init_std = self.config.get("side_car_init_std", 1e-5)
self.memory_agents = nn.ModuleDict({
str(l): nn.Sequential(
nn.Linear(hidden_size, bottleneck, bias=False),
nn.SiLU(),
nn.Linear(bottleneck, bottleneck, bias=False),
nn.SiLU(),
nn.Linear(bottleneck, hidden_size, bias=False),
) for l in self.memory_agent_layers
})
for key in self.memory_agents:
nn.init.normal_(self.memory_agents[key][-1].weight, std=init_std)
total = sum(p.numel() for p in self.memory_agents.parameters())
logger.info(f"Memory agents enabled at layers {self.memory_agent_layers}: {total:,} params")
def decode_from_hidden(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
start_layer: int = 0,
) -> torch.Tensor:
"""
Run transformer layers from start_layer onward, then output logits.
Used by MemConsistencyLoss for teacher pass (memory-augmented decode).
Args:
hidden_states: (B, S, d_model)
attention_mask: (B, S) — 1=attend, 0=pad
start_layer: skip layers before this index
Returns:
logits: (B, S, vocab_size)
"""
B, S, _ = hidden_states.shape
position_ids = torch.arange(S, device=hidden_states.device).unsqueeze(0).expand(B, S)
for i, layer in enumerate(self.layers):
if i < start_layer:
continue
hidden_states, _ = layer(hidden_states, position_ids, attention_mask)
hidden_states = self.final_norm(hidden_states)
# nan_to_num: safety net for random-weight initialization;
# never triggers with real donor weights
hidden_states = torch.nan_to_num(hidden_states)
logits = hidden_states @ self.get_lm_head_weight().T
return logits
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
early_exit: bool = False,
return_hidden_states: bool = False,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Dict[str, Any]:
"""
Full forward pass.
Args:
input_ids: (B, S) input token IDs
attention_mask: (B, S) 1=attend, 0=pad
early_exit: enable heuristic early exit
return_hidden_states: return per-layer hidden states
past_key_values: list of (K, V) tuples per layer for KV cache
use_cache: if True, return new key_values for caching
Returns:
dict with:
logits: (B, S, vocab_size)
hidden_states: list of (B, S, d_model) per layer (if requested)
exit_layer: int — which layer we exited at
complexity: float — estimated input complexity
past_key_values: list of (K, V) tuples (if use_cache)
"""
B, S = input_ids.shape
device = input_ids.device
# Embeddings
hidden_states = self.token_embedding(input_ids)
# Position IDs — offset by past sequence length for KV cache
past_len = past_key_values[0][0].shape[2] if past_key_values is not None else 0
position_ids = torch.arange(past_len, past_len + S, device=device).unsqueeze(0).expand(B, -1)
# Estimate complexity from initial embeddings
complexity = self.gates.complexity.estimate(hidden_states)
# ── Brain hook 1: pre_layers ─────────────────────────────────────
if self.brain is not None:
orig_dtype = hidden_states.dtype
hidden_states = self.brain.pre_layers(hidden_states.float()).to(orig_dtype)
# Process through layers
all_hidden_states = [] if return_hidden_states else None
new_key_values = [] if use_cache else None
exit_layer = len(self.layers) - 1
lm_head_weight = self.get_lm_head_weight()
self._memory_agent_predictions = {}
for i, layer in enumerate(self.layers):
# Memory agent: predict expected hidden state BEFORE this layer
if self.memory_agents is not None and i in self.memory_agent_layers:
pred = self.memory_agents[str(i)](hidden_states.float()).to(hidden_states.dtype)
self._memory_agent_predictions[i] = pred
past_kv = past_key_values[i] if past_key_values is not None else None
hidden_states, layer_kv = layer(
hidden_states, position_ids, attention_mask,
past_key_value=past_kv, use_cache=use_cache,
)
# Memory agent: store actual hidden state AFTER this layer for surprise
if self.memory_agents is not None and i in self.memory_agent_layers:
self._memory_agent_predictions[f"{i}_actual"] = hidden_states.detach()
# ── Brain hook 2: after_layer ────────────────────────────────
if self.brain is not None:
ssm_hidden = None
if layer.has_ssm and hasattr(layer.ssm, 'last_hidden'):
ssm_hidden = layer.ssm.last_hidden
orig_dtype = hidden_states.dtype
ssm_f = ssm_hidden.float() if ssm_hidden is not None else None
hidden_states = self.brain.after_layer(i, hidden_states.float(), ssm_f).to(orig_dtype)
if use_cache:
new_key_values.append(layer_kv)
if return_hidden_states:
all_hidden_states.append(hidden_states.detach())
# Early exit check (heuristic — no learned params)
if early_exit and i < len(self.layers) - 1:
should_exit, confidence = self.gates.early_exit.should_exit(
hidden_states, lm_head_weight, i
)
if should_exit:
exit_layer = i
break
# Final norm + output projection
hidden_states = self.final_norm(hidden_states)
# nan_to_num: safety for random-weight init; never triggers with donor weights
hidden_states = torch.nan_to_num(hidden_states)
logits = hidden_states @ lm_head_weight.T
# ── Brain hook 3: post_forward ───────────────────────────────────
brain_state = None
if self.brain is not None:
brain_state = self.brain.post_forward(logits.float(), hidden_states.float())
result = {
"logits": logits,
"exit_layer": exit_layer,
"complexity": complexity,
"last_hidden_state": hidden_states,
}
if brain_state is not None:
result["brain_state"] = brain_state
if return_hidden_states:
result["hidden_states"] = all_hidden_states
if use_cache:
result["past_key_values"] = new_key_values
return result
def get_memory_agent_surprise(self) -> Dict[int, float]:
"""Get surprise values from last forward pass (predicted vs actual MSE per layer)."""
surprises = {}
for layer_idx in self.memory_agent_layers:
pred = self._memory_agent_predictions.get(layer_idx)
actual = self._memory_agent_predictions.get(f"{layer_idx}_actual")
if pred is not None and actual is not None:
surprises[layer_idx] = torch.nn.functional.mse_loss(
pred.float(), actual.float()
).item()
return surprises
def compress_memory(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Compress hidden states for memory storage.
Args:
hidden_states: (B, S, d_model) or (B, d_model)
Returns:
(B, 256) compressed memory vector
"""
if hidden_states.dim() == 3:
# Use last token's hidden state
hidden_states = hidden_states[:, -1, :]
return self.memory_compressor(hidden_states)
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 0.8,
top_p: float = 0.9,
repetition_penalty: float = 1.3,
) -> torch.Tensor:
"""
Autoregressive generation with KV cache for fast inference.
Args:
input_ids: (B, S) starting tokens
max_new_tokens: how many tokens to generate
temperature: sampling temperature
top_p: nucleus sampling threshold
repetition_penalty: penalize repeated tokens (1.0 = off, >1.0 = penalize)
Returns:
(B, S + max_new_tokens) generated tokens
"""
generated = input_ids.clone()
# Prefill: process entire prompt, cache KV states
outputs = self.forward(generated, use_cache=True)
past_key_values = outputs["past_key_values"]
next_logits = outputs["logits"][:, -1, :]
for _ in range(max_new_tokens):
# Apply repetition penalty before temperature
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
if next_logits[0, token_id] > 0:
next_logits[0, token_id] /= repetition_penalty
else:
next_logits[0, token_id] *= repetition_penalty
# Apply temperature
next_logits = next_logits / temperature
# Top-p (nucleus) sampling
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_mask = cumulative_probs - sorted_probs > top_p
sorted_logits[sorted_mask] = float("-inf")
probs = F.softmax(sorted_logits, dim=-1)
next_token_sorted = torch.multinomial(probs, num_samples=1)
next_token = sorted_indices.gather(-1, next_token_sorted)
generated = torch.cat([generated, next_token], dim=-1)
# Stop on EOS (token ID 128001 for Llama 3.2)
if (next_token == 128001).all():
break
# Decode step: only process the new token, reuse cached KV
outputs = self.forward(next_token, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs["past_key_values"]
next_logits = outputs["logits"][:, -1, :]
return generated
def get_side_car_params(self) -> List[nn.Parameter]:
"""Get all side-car parameters (for online learning), including brain params."""
params = []
for layer in self.layers:
if hasattr(layer, "ssm") and layer.has_ssm:
params.extend(layer.ssm.parameters())
# Layer gates for duplicate layers (16-23) must be trainable
if layer.is_duplicate and hasattr(layer, "layer_gate"):
params.append(layer.layer_gate)
params.extend(self.memory_compressor.parameters())
# Memory agent params (layers 7, 15) — when enabled
if self.memory_agents is not None:
params.extend(self.memory_agents.parameters())
# Brain orchestrator params (when enabled)
if self.brain is not None:
params.extend(self.brain.get_brain_params())
return params
def get_donor_params(self) -> List[nn.Parameter]:
"""Get all donor (transplanted) parameters."""
side_car_ids = {id(p) for p in self.get_side_car_params()}
return [p for p in self.parameters() if id(p) not in side_car_ids]
def freeze_donor(self):
"""Freeze all donor parameters (requires_grad=False)."""
for p in self.get_donor_params():
p.requires_grad = False
logger.info("Frozen all donor parameters")
def unfreeze_side_cars(self):
"""Ensure side-car parameters are trainable."""
for p in self.get_side_car_params():
p.requires_grad = True
logger.info("Side-car parameters set to trainable")
def param_summary(self) -> Dict[str, int]:
"""Count parameters by category."""
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
side_car = sum(p.numel() for p in self.get_side_car_params())
donor = total - side_car
return {
"total": total,
"trainable": trainable,
"frozen": total - trainable,
"donor": donor,
"side_car": side_car,
}
def create_eyla_v2(config: Optional[Dict[str, Any]] = None) -> EylaBackbone:
"""Factory function to create an Eyla V2 model."""
model = EylaBackbone(config)
summary = model.param_summary()
logger.info(
f"Created Eyla V2: {summary['total']:,} params "
f"(donor: {summary['donor']:,}, side-car: {summary['side_car']:,})"
)
return model