| """ |
| backbone.py — Eyla V2 Custom Hybrid Backbone |
| =============================================== |
| Llama-3.2-1B compatible architecture with custom zero-cost extensions. |
| |
| Architecture: |
| - 24 transformer layers (Llama-compatible for weight transplant) |
| - Grouped Query Attention (32 heads, 8 KV heads) |
| - RoPE (Rotary Position Embedding) |
| - RMSNorm + SiLU-gated MLP |
| - SSM side-cars at layers 4, 8, 12, 16, 20 (HiPPO init) |
| - Heuristic surprise gates (no learned params) |
| - Heuristic early exit (confidence-based) |
| - Heuristic complexity estimator (entropy-based) |
| |
| Zero-cost design: |
| - Donor weights transplanted into all 24 layers → works on day 1 |
| - SSM side-cars start as no-ops (gate=0) → no interference |
| - Heuristic gates need no training |
| - Online learning gradually activates SSM contribution |
| |
| Naming convention matches LlamaForCausalLM for weight transplant: |
| - token_embedding ← model.embed_tokens |
| - layers.{i}.* ← model.layers.{i}.* |
| - final_norm ← model.norm |
| - lm_head ← lm_head |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, Any, List, Tuple |
| import math |
| import logging |
|
|
| from .ssm_block import SSMBlock |
| from .heuristic_gates import HeuristicGates |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| EYLA_V2_CONFIG = { |
| "hidden_size": 2048, |
| "intermediate_size": 8192, |
| "num_attention_heads": 32, |
| "num_key_value_heads": 8, |
| "num_layers": 24, |
| "vocab_size": 128256, |
| "rms_norm_eps": 1e-5, |
| "rope_theta": 500000.0, |
| "rope_scaling": { |
| "factor": 32.0, |
| "high_freq_factor": 4.0, |
| "low_freq_factor": 1.0, |
| "original_max_position_embeddings": 8192, |
| "rope_type": "llama3", |
| }, |
| "max_position_embeddings": 131072, |
| "tie_word_embeddings": True, |
| |
| "ssm_layers": [4, 8, 12, 16, 20], |
| "ssm_state_dim": 64, |
| "ssm_dt": 0.01, |
| "side_car_init_std": 1e-5, |
| "early_exit_confidence": 0.9, |
| "early_exit_min_layers": 8, |
| "surprise_threshold": 4.0, |
| } |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization (matches LlamaRMSNorm).""" |
|
|
| def __init__(self, hidden_size: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return x * norm * self.weight |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary Position Embedding (RoPE) — matches Llama 3 implementation with rope_scaling.""" |
|
|
| def __init__(self, dim: int, theta: float = 500000.0, rope_scaling: Optional[Dict] = None): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
|
|
| |
| if rope_scaling is not None and rope_scaling.get("rope_type") == "llama3": |
| inv_freq = self._apply_llama3_scaling(inv_freq, rope_scaling) |
|
|
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._max_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
|
|
| @staticmethod |
| def _apply_llama3_scaling(inv_freq: torch.Tensor, scaling: Dict) -> torch.Tensor: |
| """Apply Llama 3 frequency scaling (matches HF transformers).""" |
| factor = scaling["factor"] |
| low_freq_factor = scaling.get("low_freq_factor", 1.0) |
| high_freq_factor = scaling.get("high_freq_factor", 4.0) |
| old_context_len = scaling.get("original_max_position_embeddings", 8192) |
|
|
| low_freq_wavelen = old_context_len / low_freq_factor |
| high_freq_wavelen = old_context_len / high_freq_factor |
|
|
| new_freqs = [] |
| for freq in inv_freq: |
| wavelen = 2 * math.pi / freq.item() |
| if wavelen < high_freq_wavelen: |
| new_freqs.append(freq.item()) |
| elif wavelen > low_freq_wavelen: |
| new_freqs.append(freq.item() / factor) |
| else: |
| smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) |
| new_freqs.append((1 - smooth) * freq.item() / factor + smooth * freq.item()) |
|
|
| return torch.tensor(new_freqs, dtype=inv_freq.dtype) |
|
|
| def _build_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): |
| if seq_len <= self._max_cached and self._cos_cached is not None: |
| return |
| self._max_cached = max(seq_len, 2048) |
| t = torch.arange(self._max_cached, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, self.inv_freq.to(device)) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self._cos_cached = emb.cos().to(dtype) |
| self._sin_cached = emb.sin().to(dtype) |
|
|
| def forward(self, x: torch.Tensor, position_ids: torch.Tensor): |
| """ |
| Args: |
| x: (B, n_heads, S, head_dim) |
| position_ids: (B, S) or (1, S) |
| Returns: |
| cos, sin: (1, 1, S, head_dim) for broadcasting |
| """ |
| seq_len = position_ids.max().item() + 1 |
| self._build_cache(seq_len, x.device, x.dtype) |
| |
| cos = self._cos_cached[position_ids].unsqueeze(1) |
| sin = self._sin_cached[position_ids].unsqueeze(1) |
| return cos, sin |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate half the hidden dims of the input for RoPE.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin): |
| """Apply rotary position embeddings to query and key tensors.""" |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| |
|
|
| class Attention(nn.Module): |
| """ |
| Grouped Query Attention (GQA) — matches LlamaAttention. |
| |
| 32 query heads, 8 KV heads (4:1 ratio). |
| """ |
|
|
| def __init__(self, config: Dict[str, Any]): |
| super().__init__() |
| self.hidden_size = config["hidden_size"] |
| self.num_heads = config["num_attention_heads"] |
| self.num_kv_heads = config["num_key_value_heads"] |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_kv_groups = self.num_heads // self.num_kv_heads |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| self.rotary_emb = RotaryEmbedding( |
| self.head_dim, |
| theta=config.get("rope_theta", 500000.0), |
| rope_scaling=config.get("rope_scaling"), |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| B, S, _ = hidden_states.shape |
|
|
| |
| q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| cos, sin = self.rotary_emb(q, position_ids) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| |
| if past_key_value is not None: |
| k = torch.cat([past_key_value[0], k], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
|
|
| new_kv = (k, v) if use_cache else None |
|
|
| |
| k_expanded = k.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else k |
| v_expanded = v.repeat_interleave(self.num_kv_groups, dim=1) if self.num_kv_groups > 1 else v |
|
|
| |
| KV_LEN = k_expanded.shape[2] |
| scale = 1.0 / math.sqrt(self.head_dim) |
| attn_weights = torch.matmul(q, k_expanded.transpose(-2, -1)) * scale |
|
|
| |
| causal_mask = torch.triu( |
| torch.full((S, KV_LEN), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype), |
| diagonal=KV_LEN - S + 1, |
| ) |
| attn_weights = attn_weights + causal_mask.unsqueeze(0).unsqueeze(0) |
|
|
| |
| if attention_mask is not None: |
| pad_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2).float()) * float("-inf") |
| attn_weights = attn_weights + pad_mask |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
| attn_output = torch.matmul(attn_weights, v_expanded) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, self.hidden_size) |
| return self.o_proj(attn_output), new_kv |
|
|
|
|
| |
|
|
| class MLP(nn.Module): |
| """SiLU-gated MLP — matches LlamaMLP.""" |
|
|
| def __init__(self, config: Dict[str, Any]): |
| super().__init__() |
| self.gate_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False) |
| self.up_proj = nn.Linear(config["hidden_size"], config["intermediate_size"], bias=False) |
| self.down_proj = nn.Linear(config["intermediate_size"], config["hidden_size"], bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
|
|
| class TransformerLayer(nn.Module): |
| """ |
| Single transformer layer — matches LlamaDecoderLayer naming. |
| |
| Sub-module names must match for weight transplant: |
| self_attn.q_proj, self_attn.k_proj, self_attn.v_proj, self_attn.o_proj |
| mlp.gate_proj, mlp.up_proj, mlp.down_proj |
| input_layernorm, post_attention_layernorm |
| |
| Layers 16-23 (duplicated from donor 8-15) have a learnable layer_gate |
| that starts at 0.0 so they act as pass-through on day 1. This prevents |
| the duplicated layers from breaking the hidden state distribution. |
| Online learning gradually opens the gate. |
| """ |
|
|
| def __init__(self, config: Dict[str, Any], layer_idx: int): |
| super().__init__() |
| self.layer_idx = layer_idx |
| num_layers = config.get("num_layers", 24) |
| donor_layers = config.get("donor_layers", 16) |
|
|
| |
| self.self_attn = Attention(config) |
| self.mlp = MLP(config) |
| self.input_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5)) |
| self.post_attention_layernorm = RMSNorm(config["hidden_size"], config.get("rms_norm_eps", 1e-5)) |
|
|
| |
| |
| init_scale = 1.0 / math.sqrt(2 * num_layers) |
| nn.init.normal_(self.self_attn.o_proj.weight, std=0.02 * init_scale) |
| nn.init.normal_(self.mlp.down_proj.weight, std=0.02 * init_scale) |
|
|
| |
| |
| |
| self.is_duplicate = layer_idx >= donor_layers |
| if self.is_duplicate: |
| self.layer_gate = nn.Parameter(torch.tensor(0.0)) |
|
|
| |
| |
| _pfc_regions = { |
| 16: "dlPFC (Working Memory)", |
| 17: "dlPFC (Working Memory)", |
| 18: "vmPFC (Value/Emotion)", |
| 19: "vmPFC (Value/Emotion)", |
| 20: "OFC (Outcome Prediction)", |
| 21: "vlPFC (Response Inhibition)", |
| 22: "vlPFC (Response Inhibition)", |
| 23: "Anterior PFC (Metacognition)", |
| } |
| self.pfc_region = _pfc_regions.get(layer_idx, None) |
|
|
| |
| _ssm_brain_regions = { |
| 4: "Secondary Sensory Cortex", |
| 8: "Superior Temporal Sulcus", |
| 12: "Temporal-Parietal Junction", |
| 16: "Dorsolateral PFC", |
| 20: "Anterior PFC / Frontal Pole", |
| } |
|
|
| |
| self.has_ssm = layer_idx in config.get("ssm_layers", []) |
| if self.has_ssm: |
| self.ssm = SSMBlock( |
| d_model=config["hidden_size"], |
| state_dim=config.get("ssm_state_dim", 64), |
| dt=config.get("ssm_dt", 0.01), |
| init_std=config.get("side_car_init_std", 1e-5), |
| ) |
| self.ssm.brain_region = _ssm_brain_regions.get(layer_idx, f"SSM@L{layer_idx}") |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| |
| layer_input = hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, new_kv = self.self_attn( |
| hidden_states, position_ids, attention_mask, |
| past_key_value=past_key_value, use_cache=use_cache, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| |
| |
| |
| |
| if self.is_duplicate: |
| gate = self.layer_gate * torch.sigmoid(self.layer_gate) |
| hidden_states = layer_input + gate * (hidden_states - layer_input) |
|
|
| |
| if self.has_ssm: |
| hidden_states = hidden_states + self.ssm(hidden_states) |
|
|
| return hidden_states, new_kv |
|
|
|
|
| |
|
|
| class EylaBackbone(nn.Module): |
| """ |
| Eyla V2 Custom Hybrid Backbone. |
| |
| Llama-3.2-1B compatible for weight transplant, with custom extensions: |
| - SSM side-cars (HiPPO init, zero-gated on day 1) |
| - Heuristic surprise gates |
| - Heuristic early exit |
| - Heuristic complexity estimator |
| |
| The model works on day 1 after weight transplant with zero training. |
| """ |
|
|
| def __init__(self, config: Optional[Dict[str, Any]] = None): |
| super().__init__() |
| self.config = config or EYLA_V2_CONFIG.copy() |
|
|
| hidden_size = self.config["hidden_size"] |
| num_layers = self.config["num_layers"] |
| vocab_size = self.config["vocab_size"] |
|
|
| |
| self.token_embedding = nn.Embedding(vocab_size, hidden_size) |
|
|
| |
| self.layers = nn.ModuleList([ |
| TransformerLayer(self.config, layer_idx=i) |
| for i in range(num_layers) |
| ]) |
|
|
| |
| self.final_norm = RMSNorm(hidden_size, self.config.get("rms_norm_eps", 1e-5)) |
|
|
| |
| if self.config.get("tie_word_embeddings", True): |
| self.lm_head = None |
| else: |
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
| |
| |
| self.memory_compressor = nn.Linear(hidden_size, 256, bias=False) |
| nn.init.normal_(self.memory_compressor.weight, std=self.config.get("side_car_init_std", 1e-5)) |
|
|
| |
| |
| |
| self.memory_agent_layers = [7, 15] |
| self.memory_agents = None |
| self._memory_agent_predictions = {} |
|
|
| |
| self.gates = HeuristicGates( |
| surprise_threshold=self.config.get("surprise_threshold", 4.0), |
| exit_confidence=self.config.get("early_exit_confidence", 0.9), |
| exit_min_layers=self.config.get("early_exit_min_layers", 4), |
| ) |
|
|
| |
| self.brain = None |
|
|
| def get_lm_head_weight(self) -> torch.Tensor: |
| """Get the output projection weight (handles tied embeddings).""" |
| if self.lm_head is not None: |
| return self.lm_head.weight |
| return self.token_embedding.weight |
|
|
| def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """Raw token embeddings (before any transformer layers).""" |
| return self.token_embedding(input_ids) |
|
|
| def enable_brain(self, config: Optional[Dict[str, Any]] = None): |
| """ |
| Activate the brain orchestrator (86 brain systems). |
| |
| All gates start at 0 → day-1 identity preserved. |
| Brain params are trainable; donor params should be frozen separately. |
| """ |
| from .brain_orchestrator import BrainOrchestrator |
| self.brain = BrainOrchestrator( |
| d_model=self.config["hidden_size"], |
| state_dim=self.config.get("ssm_state_dim", 64), |
| config=config, |
| ) |
| brain_summary = self.brain.param_summary() |
| logger.info( |
| f"Brain enabled: {brain_summary['total_brain_params']:,} params " |
| f"(gates: {brain_summary['gate_params']}, " |
| f"nn_modules: {brain_summary['nn_module_params']:,})" |
| ) |
|
|
| def enable_memory_agents(self): |
| """Initialize memory agents at layers 7 and 15 (call after model load to avoid OOM).""" |
| hidden_size = self.config["hidden_size"] |
| bottleneck = 128 |
| init_std = self.config.get("side_car_init_std", 1e-5) |
| self.memory_agents = nn.ModuleDict({ |
| str(l): nn.Sequential( |
| nn.Linear(hidden_size, bottleneck, bias=False), |
| nn.SiLU(), |
| nn.Linear(bottleneck, bottleneck, bias=False), |
| nn.SiLU(), |
| nn.Linear(bottleneck, hidden_size, bias=False), |
| ) for l in self.memory_agent_layers |
| }) |
| for key in self.memory_agents: |
| nn.init.normal_(self.memory_agents[key][-1].weight, std=init_std) |
| total = sum(p.numel() for p in self.memory_agents.parameters()) |
| logger.info(f"Memory agents enabled at layers {self.memory_agent_layers}: {total:,} params") |
|
|
| def decode_from_hidden( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| start_layer: int = 0, |
| ) -> torch.Tensor: |
| """ |
| Run transformer layers from start_layer onward, then output logits. |
| |
| Used by MemConsistencyLoss for teacher pass (memory-augmented decode). |
| |
| Args: |
| hidden_states: (B, S, d_model) |
| attention_mask: (B, S) — 1=attend, 0=pad |
| start_layer: skip layers before this index |
| |
| Returns: |
| logits: (B, S, vocab_size) |
| """ |
| B, S, _ = hidden_states.shape |
| position_ids = torch.arange(S, device=hidden_states.device).unsqueeze(0).expand(B, S) |
|
|
| for i, layer in enumerate(self.layers): |
| if i < start_layer: |
| continue |
| hidden_states, _ = layer(hidden_states, position_ids, attention_mask) |
|
|
| hidden_states = self.final_norm(hidden_states) |
| |
| |
| hidden_states = torch.nan_to_num(hidden_states) |
| logits = hidden_states @ self.get_lm_head_weight().T |
| return logits |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| early_exit: bool = False, |
| return_hidden_states: bool = False, |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| ) -> Dict[str, Any]: |
| """ |
| Full forward pass. |
| |
| Args: |
| input_ids: (B, S) input token IDs |
| attention_mask: (B, S) 1=attend, 0=pad |
| early_exit: enable heuristic early exit |
| return_hidden_states: return per-layer hidden states |
| past_key_values: list of (K, V) tuples per layer for KV cache |
| use_cache: if True, return new key_values for caching |
| |
| Returns: |
| dict with: |
| logits: (B, S, vocab_size) |
| hidden_states: list of (B, S, d_model) per layer (if requested) |
| exit_layer: int — which layer we exited at |
| complexity: float — estimated input complexity |
| past_key_values: list of (K, V) tuples (if use_cache) |
| """ |
| B, S = input_ids.shape |
| device = input_ids.device |
|
|
| |
| hidden_states = self.token_embedding(input_ids) |
|
|
| |
| past_len = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
| position_ids = torch.arange(past_len, past_len + S, device=device).unsqueeze(0).expand(B, -1) |
|
|
| |
| complexity = self.gates.complexity.estimate(hidden_states) |
|
|
| |
| if self.brain is not None: |
| orig_dtype = hidden_states.dtype |
| hidden_states = self.brain.pre_layers(hidden_states.float()).to(orig_dtype) |
|
|
| |
| all_hidden_states = [] if return_hidden_states else None |
| new_key_values = [] if use_cache else None |
| exit_layer = len(self.layers) - 1 |
| lm_head_weight = self.get_lm_head_weight() |
|
|
| self._memory_agent_predictions = {} |
|
|
| for i, layer in enumerate(self.layers): |
| |
| if self.memory_agents is not None and i in self.memory_agent_layers: |
| pred = self.memory_agents[str(i)](hidden_states.float()).to(hidden_states.dtype) |
| self._memory_agent_predictions[i] = pred |
|
|
| past_kv = past_key_values[i] if past_key_values is not None else None |
| hidden_states, layer_kv = layer( |
| hidden_states, position_ids, attention_mask, |
| past_key_value=past_kv, use_cache=use_cache, |
| ) |
|
|
| |
| if self.memory_agents is not None and i in self.memory_agent_layers: |
| self._memory_agent_predictions[f"{i}_actual"] = hidden_states.detach() |
|
|
| |
| if self.brain is not None: |
| ssm_hidden = None |
| if layer.has_ssm and hasattr(layer.ssm, 'last_hidden'): |
| ssm_hidden = layer.ssm.last_hidden |
| orig_dtype = hidden_states.dtype |
| ssm_f = ssm_hidden.float() if ssm_hidden is not None else None |
| hidden_states = self.brain.after_layer(i, hidden_states.float(), ssm_f).to(orig_dtype) |
|
|
| if use_cache: |
| new_key_values.append(layer_kv) |
|
|
| if return_hidden_states: |
| all_hidden_states.append(hidden_states.detach()) |
|
|
| |
| if early_exit and i < len(self.layers) - 1: |
| should_exit, confidence = self.gates.early_exit.should_exit( |
| hidden_states, lm_head_weight, i |
| ) |
| if should_exit: |
| exit_layer = i |
| break |
|
|
| |
| hidden_states = self.final_norm(hidden_states) |
| |
| hidden_states = torch.nan_to_num(hidden_states) |
| logits = hidden_states @ lm_head_weight.T |
|
|
| |
| brain_state = None |
| if self.brain is not None: |
| brain_state = self.brain.post_forward(logits.float(), hidden_states.float()) |
|
|
| result = { |
| "logits": logits, |
| "exit_layer": exit_layer, |
| "complexity": complexity, |
| "last_hidden_state": hidden_states, |
| } |
|
|
| if brain_state is not None: |
| result["brain_state"] = brain_state |
|
|
| if return_hidden_states: |
| result["hidden_states"] = all_hidden_states |
|
|
| if use_cache: |
| result["past_key_values"] = new_key_values |
|
|
| return result |
|
|
| def get_memory_agent_surprise(self) -> Dict[int, float]: |
| """Get surprise values from last forward pass (predicted vs actual MSE per layer).""" |
| surprises = {} |
| for layer_idx in self.memory_agent_layers: |
| pred = self._memory_agent_predictions.get(layer_idx) |
| actual = self._memory_agent_predictions.get(f"{layer_idx}_actual") |
| if pred is not None and actual is not None: |
| surprises[layer_idx] = torch.nn.functional.mse_loss( |
| pred.float(), actual.float() |
| ).item() |
| return surprises |
|
|
| def compress_memory(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """ |
| Compress hidden states for memory storage. |
| |
| Args: |
| hidden_states: (B, S, d_model) or (B, d_model) |
| |
| Returns: |
| (B, 256) compressed memory vector |
| """ |
| if hidden_states.dim() == 3: |
| |
| hidden_states = hidden_states[:, -1, :] |
| return self.memory_compressor(hidden_states) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 50, |
| temperature: float = 0.8, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.3, |
| ) -> torch.Tensor: |
| """ |
| Autoregressive generation with KV cache for fast inference. |
| |
| Args: |
| input_ids: (B, S) starting tokens |
| max_new_tokens: how many tokens to generate |
| temperature: sampling temperature |
| top_p: nucleus sampling threshold |
| repetition_penalty: penalize repeated tokens (1.0 = off, >1.0 = penalize) |
| |
| Returns: |
| (B, S + max_new_tokens) generated tokens |
| """ |
| generated = input_ids.clone() |
|
|
| |
| outputs = self.forward(generated, use_cache=True) |
| past_key_values = outputs["past_key_values"] |
| next_logits = outputs["logits"][:, -1, :] |
|
|
| for _ in range(max_new_tokens): |
| |
| if repetition_penalty != 1.0: |
| for token_id in set(generated[0].tolist()): |
| if next_logits[0, token_id] > 0: |
| next_logits[0, token_id] /= repetition_penalty |
| else: |
| next_logits[0, token_id] *= repetition_penalty |
|
|
| |
| next_logits = next_logits / temperature |
|
|
| |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
| sorted_probs = F.softmax(sorted_logits, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
| sorted_mask = cumulative_probs - sorted_probs > top_p |
| sorted_logits[sorted_mask] = float("-inf") |
|
|
| probs = F.softmax(sorted_logits, dim=-1) |
| next_token_sorted = torch.multinomial(probs, num_samples=1) |
| next_token = sorted_indices.gather(-1, next_token_sorted) |
|
|
| generated = torch.cat([generated, next_token], dim=-1) |
|
|
| |
| if (next_token == 128001).all(): |
| break |
|
|
| |
| outputs = self.forward(next_token, past_key_values=past_key_values, use_cache=True) |
| past_key_values = outputs["past_key_values"] |
| next_logits = outputs["logits"][:, -1, :] |
|
|
| return generated |
|
|
| def get_side_car_params(self) -> List[nn.Parameter]: |
| """Get all side-car parameters (for online learning), including brain params.""" |
| params = [] |
| for layer in self.layers: |
| if hasattr(layer, "ssm") and layer.has_ssm: |
| params.extend(layer.ssm.parameters()) |
| |
| if layer.is_duplicate and hasattr(layer, "layer_gate"): |
| params.append(layer.layer_gate) |
| params.extend(self.memory_compressor.parameters()) |
| |
| if self.memory_agents is not None: |
| params.extend(self.memory_agents.parameters()) |
| |
| if self.brain is not None: |
| params.extend(self.brain.get_brain_params()) |
| return params |
|
|
| def get_donor_params(self) -> List[nn.Parameter]: |
| """Get all donor (transplanted) parameters.""" |
| side_car_ids = {id(p) for p in self.get_side_car_params()} |
| return [p for p in self.parameters() if id(p) not in side_car_ids] |
|
|
| def freeze_donor(self): |
| """Freeze all donor parameters (requires_grad=False).""" |
| for p in self.get_donor_params(): |
| p.requires_grad = False |
| logger.info("Frozen all donor parameters") |
|
|
| def unfreeze_side_cars(self): |
| """Ensure side-car parameters are trainable.""" |
| for p in self.get_side_car_params(): |
| p.requires_grad = True |
| logger.info("Side-car parameters set to trainable") |
|
|
| def param_summary(self) -> Dict[str, int]: |
| """Count parameters by category.""" |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| side_car = sum(p.numel() for p in self.get_side_car_params()) |
| donor = total - side_car |
| return { |
| "total": total, |
| "trainable": trainable, |
| "frozen": total - trainable, |
| "donor": donor, |
| "side_car": side_car, |
| } |
|
|
|
|
| def create_eyla_v2(config: Optional[Dict[str, Any]] = None) -> EylaBackbone: |
| """Factory function to create an Eyla V2 model.""" |
| model = EylaBackbone(config) |
| summary = model.param_summary() |
| logger.info( |
| f"Created Eyla V2: {summary['total']:,} params " |
| f"(donor: {summary['donor']:,}, side-car: {summary['side_car']:,})" |
| ) |
| return model |
|
|