""" backbone.py — Eyla V2 Custom Hybrid Backbone =============================================== Llama-3.2-1B compatible architecture with custom zero-cost extensions. Architecture: - 24 transformer layers (Llama-compatible for weight transplant) - Grouped Query Attention (32 heads, 8 KV heads) - RoPE (Rotary Position Embedding) - RMSNorm + SiLU-gated MLP - SSM side-cars at layers 4, 8, 12, 16, 20 (HiPPO init) - Heuristic surprise gates (no learned params) - Heuristic early exit (confidence-based) - Heuristic complexity estimator (entropy-based) Zero-cost design: - Donor weights transplanted into all 24 layers → works on day 1 - SSM side-cars start as no-ops (gate=0) → no interference - Heuristic gates need no training - Online learning gradually activates SSM contribution Naming convention matches LlamaForCausalLM for weight transplant: - token_embedding ← model.embed_tokens - layers.{i}.* ← model.layers.{i}.* - final_norm ← model.norm - lm_head ← lm_head """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any, List, Tuple import math import logging from .ssm_block import SSMBlock from .heuristic_gates import HeuristicGates logger = logging.getLogger(__name__) # ── Default config matching Llama 3.2 1B ──────────────────────────────────── EYLA_V2_CONFIG = { "hidden_size": 2048, "intermediate_size": 8192, "num_attention_heads": 32, "num_key_value_heads": 8, "num_layers": 24, "vocab_size": 128256, "rms_norm_eps": 1e-5, "rope_theta": 500000.0, "rope_scaling": { "factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, "max_position_embeddings": 131072, "tie_word_embeddings": True, # Eyla custom — SSM side-cars every 4 layers (BUILD_PLAN spec) "ssm_layers": [4, 8, 12, 16, 20], "ssm_state_dim": 64, "ssm_dt": 0.01, "side_car_init_std": 1e-5, "early_exit_confidence": 0.9, "early_exit_min_layers": 8, "surprise_threshold": 4.0, } # ── Building blocks ───────────────────────────────────────────────────────── class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (matches LlamaRMSNorm).""" def __init__(self, hidden_size: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE) — matches Llama 3 implementation with rope_scaling.""" def __init__(self, dim: int, theta: float = 500000.0, rope_scaling: Optional[Dict] = None): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # Apply Llama 3 rope scaling if configured if rope_scaling is not None and rope_scaling.get("rope_type") == "llama3": inv_freq = self._apply_llama3_scaling(inv_freq, rope_scaling) self.register_buffer("inv_freq", inv_freq, persistent=False) self._max_cached = 0 self._cos_cached = None self._sin_cached = None @staticmethod def _apply_llama3_scaling(inv_freq: torch.Tensor, scaling: Dict) -> torch.Tensor: """Apply Llama 3 frequency scaling (matches HF transformers).""" factor = scaling["factor"] low_freq_factor = scaling.get("low_freq_factor", 1.0) high_freq_factor = scaling.get("high_freq_factor", 4.0) old_context_len = scaling.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = [] for freq in inv_freq: wavelen = 2 * math.pi / freq.item() if wavelen < high_freq_wavelen: new_freqs.append(freq.item()) elif wavelen > low_freq_wavelen: new_freqs.append(freq.item() / factor) else: smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) new_freqs.append((1 - smooth) * freq.item() / factor + smooth * freq.item()) return torch.tensor(new_freqs, dtype=inv_freq.dtype) def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): if seq_len <= self._max_cached and self._cos_cached is not None: return self._max_cached = max(seq_len, 2048) t = torch.arange(self._max_cached, device=device, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq.to(device)) emb = torch.cat([freqs, freqs], dim=-1) # (seq, dim) self._cos_cached = emb.cos().to(dtype) self._sin_cached = emb.sin().to(dtype) def forward(self, x: torch.Tensor, position_ids: torch.Tensor): """ Args: x: (B, n_heads, S, head_dim) position_ids: (B, S) or (1, S) Returns: cos, sin: (1, 1, S, head_dim) for broadcasting """ seq_len = position_ids.max().item() + 1 self._build_cache(seq_len, x.device, x.dtype) # Gather by position cos = self._cos_cached[position_ids].unsqueeze(1) # (B, 1, S, dim) sin = self._sin_cached[position_ids].unsqueeze(1) return cos, sin def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input for RoPE.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): """Apply rotary position embeddings to query and key tensors.""" q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # ── Attention ──────────────────────────────────────────────────────────────── class Attention(nn.Module): """ Grouped Query Attention (GQA) — matches LlamaAttention. 32 query heads, 8 KV heads (4:1 ratio). """ def __init__(self, config: Dict[str, Any]): super().__init__() self.hidden_size = config["hidden_size"] self.num_heads = config["num_attention_heads"] self.num_kv_heads = config["num_key_value_heads"] self.head_dim = self.hidden_size // self.num_heads self.num_kv_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding( self.head_dim, theta=config.get("rope_theta", 500000.0), rope_scaling=config.get("rope_scaling"), ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, S, _ = hidden_states.shape # Project Q, K, V q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE cos, sin = self.rotary_emb(q, position_ids) q, k = apply_rotary_pos_emb(q, k, cos, sin) # KV cache: concatenate with past keys/values if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) new_kv = (k, v) if use_cache else None # Repeat KV heads for GQA k_expanded = k.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else k v_expanded = v.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else v # Scaled dot-product attention KV_LEN = k_expanded.shape[2] scale = 1.0 / math.sqrt(self.head_dim) attn_weights = torch.matmul(q, k_expanded.transpose(-2, -1)) * scale # Causal mask (Q_len x KV_len) causal_mask = torch.triu( torch.full((S, KV_LEN), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype), diagonal=KV_LEN - S + 1, ) attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0) # Padding mask if attention_mask is not None: pad_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2).float()) * float("-inf") attn_weights = attn_weights + pad_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v_expanded) # Merge heads attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, self.hidden_size) return self.o_proj(attn_output), new_kv # ── MLP ────────────────────────────────────────────────────────────────────── class MLP(nn.Module): """SiLU-gated MLP — matches LlamaMLP.""" def __init__(self, config: Dict[str, Any]): super().__init__() self.gate_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False) self.up_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False) self.down_proj = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # ── Transformer Layer ──────────────────────────────────────────────────────── class TransformerLayer(nn.Module): """ Single transformer layer — matches LlamaDecoderLayer naming. Sub-module names must match for weight transplant: self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.o_proj mlp.gate_proj, mlp.up_proj, mlp.down_proj input_layernorm, post_attention_layernorm Layers 16-23 (duplicated from donor 8-15) have a learnable layer_gate that starts at 0.0 so they act as pass-through on day 1. This prevents the duplicated layers from breaking the hidden state distribution. Online learning gradually opens the gate. """ def __init__(self, config: Dict[str, Any], layer_idx: int): super().__init__() self.layer_idx = layer_idx num_layers = config.get("num_layers", 24) donor_layers = config.get("donor_layers", 16) # Standard Llama components self.self_attn = Attention(config) self.mlp = MLP(config) self.input_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5)) self.post_attention_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5)) # Deep init scaling (GPT-2 style) — prevents NaN with random weights # These weights will be overwritten by donor transplant anyway init_scale = 1.0 / math.sqrt(2 * num_layers) nn.init.normal_(self.self_attn.o_proj.weight, std=0.02 * init_scale) nn.init.normal_(self.mlp.down_proj.weight, std=0.02 * init_scale) # Duplicate layer gate: layers >= donor_layers start as pass-through (gate=0). # On day 1: output = input + gate * layer_output = input (since gate=0) # Through online learning: gate opens, layer contributes. self.is_duplicate = layer_idx >= donor_layers if self.is_duplicate: self.layer_gate = nn.Parameter(torch.tensor(0.0)) # ── Brain Region Labels ───────────────────────────────────────── # PFC subdivision labels (layers 16-23 map to prefrontal cortex regions) _pfc_regions = { 16: "dlPFC (Working Memory)", 17: "dlPFC (Working Memory)", 18: "vmPFC (Value/Emotion)", 19: "vmPFC (Value/Emotion)", 20: "OFC (Outcome Prediction)", 21: "vlPFC (Response Inhibition)", 22: "vlPFC (Response Inhibition)", 23: "Anterior PFC (Metacognition)", } self.pfc_region = _pfc_regions.get(layer_idx, None) # SSM brain region labels (5 side-cars = 5 brain regions) _ssm_brain_regions = { 4: "Secondary Sensory Cortex", 8: "Superior Temporal Sulcus", 12: "Temporal-Parietal Junction", 16: "Dorsolateral PFC", 20: "Anterior PFC / Frontal Pole", } # SSM side-car (only at specific layers) self.has_ssm = layer_idx in config.get("ssm_layers", []) if self.has_ssm: self.ssm = SSMBlock( d_model=config["hidden_size"], state_dim=config.get("ssm_state_dim", 64), dt=config.get("ssm_dt", 0.01), init_std=config.get("side_car_init_std", 1e-5), ) self.ssm.brain_region = _ssm_brain_regions.get(layer_idx, f"SSM@L{layer_idx}") def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Save layer input for duplicate gating layer_input = hidden_states # Pre-norm attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, new_kv = self.self_attn( hidden_states, position_ids, attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) hidden_states = residual + hidden_states # Pre-norm MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states # Duplicate layer gate: on day 1, gate=0 → layer is pass-through. # output = input + gate * (layer_output - input) # At gate=0: output = input (skip layer entirely) # As gate opens: layer gradually contributes if self.is_duplicate: gate = self.layer_gate * torch.sigmoid(self.layer_gate) hidden_states = layer_input + gate * (hidden_states - layer_input) # SSM side-car (additive — no interference on day 1) if self.has_ssm: hidden_states = hidden_states + self.ssm(hidden_states) return hidden_states, new_kv # ── Full Model ─────────────────────────────────────────────────────────────── class EylaBackbone(nn.Module): """ Eyla V2 Custom Hybrid Backbone. Llama-3.2-1B compatible for weight transplant, with custom extensions: - SSM side-cars (HiPPO init, zero-gated on day 1) - Heuristic surprise gates - Heuristic early exit - Heuristic complexity estimator The model works on day 1 after weight transplant with zero training. """ def __init__(self, config: Optional[Dict[str, Any]] = None): super().__init__() self.config = config or EYLA_V2_CONFIG.copy() hidden_size = self.config["hidden_size"] num_layers = self.config["num_layers"] vocab_size = self.config["vocab_size"] # Embeddings (matches Llama naming for transplant) self.token_embedding = nn.Embedding(vocab_size, hidden_size) # Transformer layers self.layers = nn.ModuleList([ TransformerLayer(self.config, layer_idx=i) for i in range(num_layers) ]) # Final norm self.final_norm = RMSNorm(hidden_size, self.config.get("rms_norm_eps", 1e-5)) # Output head if self.config.get("tie_word_embeddings", True): self.lm_head = None # Use token_embedding.weight else: self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) # Memory compressor: use last hidden state (no extra module needed) # But keep a simple linear for compatibility with MemoryRetriever (256-d) self.memory_compressor = nn.Linear(hidden_size, 256, bias=False) nn.init.normal_(self.memory_compressor.weight, std=self.config.get("side_car_init_std", 1e-5)) # Memory agents at layers 7 and 15: predict expected hidden state # Comparison of predicted vs actual = surprise signal for online learning # Lazy-initialized via enable_memory_agents() to avoid OOM during model construction self.memory_agent_layers = [7, 15] self.memory_agents = None self._memory_agent_predictions = {} # Heuristic gates (NOT nn.Module — no parameters) self.gates = HeuristicGates( surprise_threshold=self.config.get("surprise_threshold", 4.0), exit_confidence=self.config.get("early_exit_confidence", 0.9), exit_min_layers=self.config.get("early_exit_min_layers", 4), ) # Brain orchestrator (disabled by default — call enable_brain() to activate) self.brain = None def get_lm_head_weight(self) -> torch.Tensor: """Get the output projection weight (handles tied embeddings).""" if self.lm_head is not None: return self.lm_head.weight return self.token_embedding.weight def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: """Raw token embeddings (before any transformer layers).""" return self.token_embedding(input_ids) def enable_brain(self, config: Optional[Dict[str, Any]] = None): """ Activate the brain orchestrator (86 brain systems). All gates start at 0 → day-1 identity preserved. Brain params are trainable; donor params should be frozen separately. """ from .brain_orchestrator import BrainOrchestrator self.brain = BrainOrchestrator( d_model=self.config["hidden_size"], state_dim=self.config.get("ssm_state_dim", 64), config=config, ) brain_summary = self.brain.param_summary() logger.info( f"Brain enabled: {brain_summary['total_brain_params']:,} params " f"(gates: {brain_summary['gate_params']}, " f"nn_modules: {brain_summary['nn_module_params']:,})" ) def enable_memory_agents(self): """Initialize memory agents at layers 7 and 15 (call after model load to avoid OOM).""" hidden_size = self.config["hidden_size"] bottleneck = 128 init_std = self.config.get("side_car_init_std", 1e-5) self.memory_agents = nn.ModuleDict({ str(l): nn.Sequential( nn.Linear(hidden_size, bottleneck, bias=False), nn.SiLU(), nn.Linear(bottleneck, bottleneck, bias=False), nn.SiLU(), nn.Linear(bottleneck, hidden_size, bias=False), ) for l in self.memory_agent_layers }) for key in self.memory_agents: nn.init.normal_(self.memory_agents[key][-1].weight, std=init_std) total = sum(p.numel() for p in self.memory_agents.parameters()) logger.info(f"Memory agents enabled at layers {self.memory_agent_layers}: {total:,} params") def decode_from_hidden( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, start_layer: int = 0, ) -> torch.Tensor: """ Run transformer layers from start_layer onward, then output logits. Used by MemConsistencyLoss for teacher pass (memory-augmented decode). Args: hidden_states: (B, S, d_model) attention_mask: (B, S) — 1=attend, 0=pad start_layer: skip layers before this index Returns: logits: (B, S, vocab_size) """ B, S, _ = hidden_states.shape position_ids = torch.arange(S, device=hidden_states.device).unsqueeze(0).expand(B, S) for i, layer in enumerate(self.layers): if i < start_layer: continue hidden_states, _ = layer(hidden_states, position_ids, attention_mask) hidden_states = self.final_norm(hidden_states) # nan_to_num: safety net for random-weight initialization; # never triggers with real donor weights hidden_states = torch.nan_to_num(hidden_states) logits = hidden_states @ self.get_lm_head_weight().T return logits def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, early_exit: bool = False, return_hidden_states: bool = False, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, ) -> Dict[str, Any]: """ Full forward pass. Args: input_ids: (B, S) input token IDs attention_mask: (B, S) 1=attend, 0=pad early_exit: enable heuristic early exit return_hidden_states: return per-layer hidden states past_key_values: list of (K, V) tuples per layer for KV cache use_cache: if True, return new key_values for caching Returns: dict with: logits: (B, S, vocab_size) hidden_states: list of (B, S, d_model) per layer (if requested) exit_layer: int — which layer we exited at complexity: float — estimated input complexity past_key_values: list of (K, V) tuples (if use_cache) """ B, S = input_ids.shape device = input_ids.device # Embeddings hidden_states = self.token_embedding(input_ids) # Position IDs — offset by past sequence length for KV cache past_len = past_key_values[0][0].shape[2] if past_key_values is not None else 0 position_ids = torch.arange(past_len, past_len + S, device=device).unsqueeze(0).expand(B, -1) # Estimate complexity from initial embeddings complexity = self.gates.complexity.estimate(hidden_states) # ── Brain hook 1: pre_layers ───────────────────────────────────── if self.brain is not None: orig_dtype = hidden_states.dtype hidden_states = self.brain.pre_layers(hidden_states.float()).to(orig_dtype) # Process through layers all_hidden_states = [] if return_hidden_states else None new_key_values = [] if use_cache else None exit_layer = len(self.layers) - 1 lm_head_weight = self.get_lm_head_weight() self._memory_agent_predictions = {} for i, layer in enumerate(self.layers): # Memory agent: predict expected hidden state BEFORE this layer if self.memory_agents is not None and i in self.memory_agent_layers: pred = self.memory_agents[str(i)](hidden_states.float()).to(hidden_states.dtype) self._memory_agent_predictions[i] = pred past_kv = past_key_values[i] if past_key_values is not None else None hidden_states, layer_kv = layer( hidden_states, position_ids, attention_mask, past_key_value=past_kv, use_cache=use_cache, ) # Memory agent: store actual hidden state AFTER this layer for surprise if self.memory_agents is not None and i in self.memory_agent_layers: self._memory_agent_predictions[f"{i}_actual"] = hidden_states.detach() # ── Brain hook 2: after_layer ──────────────────────────────── if self.brain is not None: ssm_hidden = None if layer.has_ssm and hasattr(layer.ssm, 'last_hidden'): ssm_hidden = layer.ssm.last_hidden orig_dtype = hidden_states.dtype ssm_f = ssm_hidden.float() if ssm_hidden is not None else None hidden_states = self.brain.after_layer(i, hidden_states.float(), ssm_f).to(orig_dtype) if use_cache: new_key_values.append(layer_kv) if return_hidden_states: all_hidden_states.append(hidden_states.detach()) # Early exit check (heuristic — no learned params) if early_exit and i < len(self.layers) - 1: should_exit, confidence = self.gates.early_exit.should_exit( hidden_states, lm_head_weight, i ) if should_exit: exit_layer = i break # Final norm + output projection hidden_states = self.final_norm(hidden_states) # nan_to_num: safety for random-weight init; never triggers with donor weights hidden_states = torch.nan_to_num(hidden_states) logits = hidden_states @ lm_head_weight.T # ── Brain hook 3: post_forward ─────────────────────────────────── brain_state = None if self.brain is not None: brain_state = self.brain.post_forward(logits.float(), hidden_states.float()) result = { "logits": logits, "exit_layer": exit_layer, "complexity": complexity, "last_hidden_state": hidden_states, } if brain_state is not None: result["brain_state"] = brain_state if return_hidden_states: result["hidden_states"] = all_hidden_states if use_cache: result["past_key_values"] = new_key_values return result def get_memory_agent_surprise(self) -> Dict[int, float]: """Get surprise values from last forward pass (predicted vs actual MSE per layer).""" surprises = {} for layer_idx in self.memory_agent_layers: pred = self._memory_agent_predictions.get(layer_idx) actual = self._memory_agent_predictions.get(f"{layer_idx}_actual") if pred is not None and actual is not None: surprises[layer_idx] = torch.nn.functional.mse_loss( pred.float(), actual.float() ).item() return surprises def compress_memory(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Compress hidden states for memory storage. Args: hidden_states: (B, S, d_model) or (B, d_model) Returns: (B, 256) compressed memory vector """ if hidden_states.dim() == 3: # Use last token's hidden state hidden_states = hidden_states[:, -1, :] return self.memory_compressor(hidden_states) @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.8, top_p: float = 0.9, repetition_penalty: float = 1.3, ) -> torch.Tensor: """ Autoregressive generation with KV cache for fast inference. Args: input_ids: (B, S) starting tokens max_new_tokens: how many tokens to generate temperature: sampling temperature top_p: nucleus sampling threshold repetition_penalty: penalize repeated tokens (1.0 = off, >1.0 = penalize) Returns: (B, S + max_new_tokens) generated tokens """ generated = input_ids.clone() # Prefill: process entire prompt, cache KV states outputs = self.forward(generated, use_cache=True) past_key_values = outputs["past_key_values"] next_logits = outputs["logits"][:, -1, :] for _ in range(max_new_tokens): # Apply repetition penalty before temperature if repetition_penalty != 1.0: for token_id in set(generated[0].tolist()): if next_logits[0, token_id] > 0: next_logits[0, token_id] /= repetition_penalty else: next_logits[0, token_id] *= repetition_penalty # Apply temperature next_logits = next_logits / temperature # Top-p (nucleus) sampling sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_mask = cumulative_probs - sorted_probs > top_p sorted_logits[sorted_mask] = float("-inf") probs = F.softmax(sorted_logits, dim=-1) next_token_sorted = torch.multinomial(probs, num_samples=1) next_token = sorted_indices.gather(-1, next_token_sorted) generated = torch.cat([generated, next_token], dim=-1) # Stop on EOS (token ID 128001 for Llama 3.2) if (next_token == 128001).all(): break # Decode step: only process the new token, reuse cached KV outputs = self.forward(next_token, past_key_values=past_key_values, use_cache=True) past_key_values = outputs["past_key_values"] next_logits = outputs["logits"][:, -1, :] return generated def get_side_car_params(self) -> List[nn.Parameter]: """Get all side-car parameters (for online learning), including brain params.""" params = [] for layer in self.layers: if hasattr(layer, "ssm") and layer.has_ssm: params.extend(layer.ssm.parameters()) # Layer gates for duplicate layers (16-23) must be trainable if layer.is_duplicate and hasattr(layer, "layer_gate"): params.append(layer.layer_gate) params.extend(self.memory_compressor.parameters()) # Memory agent params (layers 7, 15) — when enabled if self.memory_agents is not None: params.extend(self.memory_agents.parameters()) # Brain orchestrator params (when enabled) if self.brain is not None: params.extend(self.brain.get_brain_params()) return params def get_donor_params(self) -> List[nn.Parameter]: """Get all donor (transplanted) parameters.""" side_car_ids = {id(p) for p in self.get_side_car_params()} return [p for p in self.parameters() if id(p) not in side_car_ids] def freeze_donor(self): """Freeze all donor parameters (requires_grad=False).""" for p in self.get_donor_params(): p.requires_grad = False logger.info("Frozen all donor parameters") def unfreeze_side_cars(self): """Ensure side-car parameters are trainable.""" for p in self.get_side_car_params(): p.requires_grad = True logger.info("Side-car parameters set to trainable") def param_summary(self) -> Dict[str, int]: """Count parameters by category.""" total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) side_car = sum(p.numel() for p in self.get_side_car_params()) donor = total - side_car return { "total": total, "trainable": trainable, "frozen": total - trainable, "donor": donor, "side_car": side_car, } def create_eyla_v2(config: Optional[Dict[str, Any]] = None) -> EylaBackbone: """Factory function to create an Eyla V2 model.""" model = EylaBackbone(config) summary = model.param_summary() logger.info( f"Created Eyla V2: {summary['total']:,} params " f"(donor: {summary['donor']:,}, side-car: {summary['side_car']:,})" ) return model