#!/usr/bin/env python3 # ============================================================================= # Copyright (c) 2024-2026 Luis E. Davila Flores. All rights reserved. # # FireEcho Engine — High-Performance Inference Kernel # Creator & Sole Author: Luis E. Davila Flores # # Licensed under Creative Commons Attribution-NonCommercial 4.0 International # (CC BY-NC 4.0). You may share and adapt this work for non-commercial # purposes with proper attribution. Full license terms: # https://creativecommons.org/licenses/by-nc/4.0/ # ============================================================================= """ FireEcho Kernel v3 — Production Inference Engine ================================================= Heterogeneous World-Model Kernel for NVIDIA Blackwell (RTX 5090) AMD Ryzen 9 9950X + RTX 5090 Reference Platform Single-file inference engine: one import, one class, full pipeline. from fireecho_kernel import FireEchoEngine engine = FireEchoEngine.from_pretrained("/path/to/Molmo2-O-7B") output = engine.generate(input_ids, max_new_tokens=100) PERFORMANCE (RTX 5090, Molmo2-O-7B, BF16/FP4) --------------------------------------------- Goliath FP8 fused FFN: 67.5 TFLOPS (1.03 ms / layer) Goliath FP4 fused FFN: 37.6 TFLOPS (1.84 ms / layer) BF16 cuBLAS baseline: 29.3 TFLOPS (2.36 ms / layer) Legacy int8 quant: 20.2 TFLOPS (3.43 ms / layer) Target decode: 100+ tok/s on 7B at 32k context KV capacity: 500k+ tokens (real paged cache) Peak matmul: 223 TFLOPS (TMA kernel, large tiles) BUILT-IN TRITON KERNELS (9 autotuned @triton.jit) -------------------------------------------------- _matmul_kernel Standard GEMM, 5 autotuned tile configs _fused_qkv_kernel Q/K/V projection in a single pass (3x fewer loads) _splitk_gemm_kernel Split-K with atomic reduction for tall/skinny shapes _fused_swiglu_kernel SiLU(x @ W_gate) * (x @ W_up) fused in one kernel _persistent_gemm_kernel Persistent tiles — reduced launch overhead on 170 SMs _l2_persistent_load_kernel L2-persistent prefetch with evict_last hints _tma_matmul_kernel Block-pointer (TMA-style) async memory loads _blackwell_2cta_matmul 2-CTA cooperative MMA — 11% faster on medium matrices _fp8_matmul_kernel Hardware FP8 tensor-core matmul with E4M3 scaling OPTIONAL BACKENDS (auto-detected at import) -------------------------------------------- Backend Flag Source ───────────────── ───────────────────── ────────────────────── Goliath FP4/FP8 _GOLIATH_AVAILABLE goliath_kernel.py CUTLASS TMA _CUTLASS_AVAILABLE cutlass_kernels.py CUTLASS NVFP4 _CUTLASS_FP4_AVAILABLE cutlass_kernels.py DSMEM Cluster _DSMEM_AVAILABLE dsmem_ops.py Quantum Gold _QUANTUM_AVAILABLE quantum.py Dispatch priority (FusedFFN): 1. Goliath fused — FP4/FP8 dequant in Triton registers, zero extra traffic 2. CUTLASS TMA — BF16 matmul via block-pointer async loads 3. DSMEM Cluster — distributed shared memory across CTAs 4. Quantum Gold — quantum-optimized matmul 5. torch.matmul — cuBLAS fallback (always available) QUANTIZATION ------------ Goliath FP4 4-bit E2M1, 16-element blocks, E4M3 scales, FP32 tensor scale Fused dequant inside Triton matmul — zero global-memory overhead Accuracy: ~0.10 rel_err vs FP32, 8x weight compression Goliath FP8 8-bit E4M3, 32-element blocks, FP32 per-block scales Fused dequant inside Triton matmul — zero global-memory overhead Accuracy: ~0.02 rel_err vs FP32, 4x weight compression NVFP4 Legacy 4-bit int8 dual-scaling (per-block + global) Separate dequant step, used as fallback when Goliath unavailable Auto mode goliath_bits='auto' — selects FP4 vs FP8 per-layer based on weight distribution (kurtosis / outlier ratio) Critical layers config.critical_layers=[0, 31] keeps first/last in FP16 MEMORY & CACHING ---------------- PagedKVCache Real vLLM-style block paging, 500k+ token capacity 16-token blocks, lazy allocation, per-sequence block tables L2CacheManager Hardware L2 pinning via cudaAccessPolicyWindow (Blackwell) 128 MB L2 on RTX 5090, persistent hints for hot weights ATTENTION --------- FusedAttention Multi-head / grouped-query (GQA) with paged KV cache Fused Q/K/V projection kernel (single pass) Flash Attention integration for long-context FEED-FORWARD ------------- FusedFFN 4-tier dispatch: Goliath → CUTLASS → SwiGLU Triton → nn.Linear Fused SwiGLU: SiLU(gate) * up computed in one kernel launch QuantizedLinear stores Goliath weights, forward = fused GEMM LEARNING & MEMORY ----------------- HebbianMemory STELLAR fast-weight memory with eligibility traces Rare-correlation filtering, neuromodulation gating Reward-modulated updates, configurable decay (0.99) PerLayerHebbian Per-layer Hebbian banks for deeper adaptation SlicedCramerPreservation Anti-forgetting via distribution preservation (STELLAR Eqs 14-17) Sliced 1D projections for efficient distribution matching ContextSkillModel Dual temporal scales: fast skill module + LSTM context module MULTIMODAL ---------- VisionEncoder Patch embeddings (14x14 patches, 224x224 images) AudioEncoder 1D convolution stack (16 kHz sample rate) MultimodalFusion Cross-modal fusion of vision + audio + text embeddings CONFIGURATION ------------- FireEchoConfig( # Architecture dim=4096, num_heads=32, num_kv_heads=32, num_layers=32, vocab_size=32000, intermediate_size=11008, max_seq_len=32768, # Quantization use_nvfp4=True, quantize_weights=True, goliath_bits=4, # 4 = FP4, 8 = FP8, 'auto' = per-layer use_goliath=True, # fused dequant-matmul (best path) # Backends (auto-detected, all True by default) use_native_cutlass=True, use_quantum_matmul=True, use_dsmem_cluster=True, # Kernels use_fused_qkv=True, use_fused_swiglu=True, use_splitk_gemm=True, use_persistent_gemm=True, use_l2_cache_control=True, # Memory kv_block_size=16, max_kv_blocks=8192, # 131k token KV capacity use_hebbian=True, hebbian_per_layer=True, # Multimodal (off by default) use_vision=False, use_audio=False, ) Pre-configured models: FireEchoConfig.molmo2_7b() Molmo2-O-7B (AI2, 65k context) USAGE EXAMPLES -------------- # Basic inference (local model) from fireecho_kernel import FireEchoEngine engine = FireEchoEngine.from_pretrained("/path/to/Molmo2-O-7B") output = engine.generate(input_ids, max_new_tokens=100) # FP8 for higher accuracy from fireecho_kernel import FireEchoEngine, FireEchoConfig config = FireEchoConfig.molmo2_7b() config.goliath_bits = 8 engine = FireEchoEngine(config) # Auto FP4/FP8 per-layer selection config = FireEchoConfig(goliath_bits='auto') # Multimodal config = FireEchoConfig(use_vision=True, use_audio=True) engine = FireEchoEngine(config) # Benchmark from fireecho_kernel import benchmark_fused_kernels benchmark_fused_kernels() CLASSES ------- FireEchoEngine Main inference engine (from_pretrained, generate) FireEchoConfig All configuration fields FusedTransformerBlock Attention + FFN + RMSNorm block FusedAttention Multi-head attention with paged KV cache FusedFFN Feed-forward with Goliath fused dequant-matmul QuantizedLinear Linear layer storing Goliath FP4/FP8 weights PagedKVCache vLLM-style paged key-value cache L2CacheManager Hardware L2 pinning (cudaAccessPolicyWindow) HebbianMemory STELLAR fast-weight memory PerLayerHebbian Per-layer Hebbian banks SlicedCramerPreservation Anti-forgetting distribution preservation ContextSkillModel Dual temporal scale model VisionEncoder Patch-based image encoder AudioEncoder 1D convolution audio encoder MultimodalFusion Cross-modal fusion module MegaFusedTransformerBlock Mega-fused attention+FFN (minimal launches) FILE LAYOUT ----------- Lines Section ────── ─────────────────────────────────────── 1-155 Imports, optional backends, _fused_matmul dispatcher 155-250 FireEchoConfig dataclass 252-720 Triton kernels: matmul, fused QKV, Split-K, fused SwiGLU 722-1200 Phase 2: persistent GEMM, L2 cache control 1207-1480 TMA matmul, Blackwell 2-CTA MMA, FP8 GEMM 1484-1640 Mega-fused transformer block 1643-1740 NVFP4 quantization (legacy), FWHT, stochastic rounding 1741-1935 PagedKVCache 1937-2465 Hebbian memory, SCP, Context-Skill model 2467-2605 Multimodal (Vision, Audio, Fusion) 2605-2955 FusedAttention, QuantizedLinear, FusedFFN, TransformerBlock 2955-3320 FireEchoEngine (init, forward, generate, from_pretrained) 3320-3700 Utilities, benchmarks, main entry point Author: FireEcho Team """ import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl import os import math import gc import ctypes import ctypes.util from typing import Optional, Tuple, Dict, Any, List, Union from functools import lru_cache from dataclasses import dataclass, field from collections import OrderedDict # ============================================================================= # Copyright Integrity Check — FireEcho Engine # Verifies that the copyright notice has not been removed or tampered with. # This check runs at import time. If the copyright is missing, the engine # will not load. This protects the intellectual property of the creator. # ============================================================================= def _verify_fireecho_copyright(): import os as _os, hashlib as _hl _src = _os.path.abspath(__file__) try: with open(_src, 'r', encoding='utf-8') as _f: _header = _f.read(512) # Hash-based tamper check: SHA256 of copyright line must match _line = [l for l in _header.splitlines() if 'Copyright' in l and 'Luis' in l] if not _line or _hl.sha256(_line[0].encode()).hexdigest()[:16] != \ _hl.sha256(b'# Copyright (c) 2024-2026 Luis E. Davila Flores. All rights reserved.').hexdigest()[:16]: raise RuntimeError( "FireEcho Engine: Copyright notice has been removed or modified.\n" "This software is created by Luis E. Davila Flores and licensed " "under CC BY-NC 4.0. Attribution is required.\n" "https://creativecommons.org/licenses/by-nc/4.0/") except (IOError, OSError): pass # Allow loading from frozen/compiled distributions _verify_fireecho_copyright() del _verify_fireecho_copyright # FE-MX: Age-Adaptive Microscaling for Hebbian Memory (Phase 1 Python) try: from femx_storage import FEMXStorage, FEMX4, FEMX6, FEMX8 _FEMX_AVAILABLE = True except ImportError: _FEMX_AVAILABLE = False # CUTLASS-compatible kernels (self-contained — Triton/PyTorch/ctypes, no .so needed) try: from cutlass_kernels import ( tma_matmul as _cutlass_tma_matmul, tma_attention as _cutlass_tma_attention, tma_gqa_attention as _cutlass_tma_gqa_attention, L2CacheManager as _CutlassL2CacheManager, ) _CUTLASS_AVAILABLE = True except Exception: _CUTLASS_AVAILABLE = False _cutlass_tma_matmul = None _cutlass_tma_attention = None _cutlass_tma_gqa_attention = None _CutlassL2CacheManager = None # Optional Quantum Gold (quantum_optimized_matmul) try: from quantum import quantum_optimized_matmul as _quantum_matmul _QUANTUM_AVAILABLE = True except Exception: _quantum_matmul = None _QUANTUM_AVAILABLE = False # Optional DSMEM cluster (cluster_matmul_dsmem) try: from dsmem_ops import cluster_matmul_dsmem as _dsmem_matmul, supports_dsmem as _dsmem_supports _DSMEM_AVAILABLE = True except Exception: _dsmem_matmul = None _dsmem_supports = None _DSMEM_AVAILABLE = False # Goliath FP4/FP8/INT2 fused dequant-matmul kernels try: from goliath_kernel import ( goliath_quantize as _goliath_quantize, goliath_gemm as _goliath_gemm, goliath_multi_expert_gemm as _goliath_multi_expert_gemm, goliath_packed_moe_gemm as _goliath_packed_moe_gemm, goliath_packed_moe_swiglu_down as _goliath_packed_moe_swiglu_down, # Fused SwiGLU+down goliath_packed_moe_int2_gemm as _goliath_packed_moe_int2_gemm, # INT2 for cold experts goliath_packed_moe_fexc_gemm as _goliath_packed_moe_fexc_gemm, # FE-XC codebook 2-bit fexc_precompute_psumbook as _fexc_precompute_psumbook, GoliathFP4Weights, GoliathFP8Weights, GoliathINT2Weights, GoliathFEXCWeights, GoliathFEXVQWeights, _can_use_goliath_dot_scaled, GoliathLinear as _GoliathLinear, _encode_e4m3, _decode_e4m3, # FP8 E4M3 encode/decode for KV cache ) _GOLIATH_AVAILABLE = True except Exception: _goliath_quantize = None _goliath_gemm = None _goliath_multi_expert_gemm = None _goliath_packed_moe_gemm = None _goliath_packed_moe_swiglu_down = None _goliath_packed_moe_int2_gemm = None _goliath_packed_moe_fexc_gemm = None _fexc_precompute_psumbook = None GoliathFP4Weights = None GoliathFP8Weights = None GoliathINT2Weights = None GoliathFEXCWeights = None GoliathFEXVQWeights = None _can_use_goliath_dot_scaled = None _GoliathLinear = None _encode_e4m3 = None _decode_e4m3 = None _GOLIATH_AVAILABLE = False # FP8 KV cache helpers — inline for speed, fallback if goliath unavailable def _encode_e4m3_kv(values: torch.Tensor) -> torch.Tensor: """Encode BF16/FP32 to E4M3 (FP8) as uint8 — optimized for KV cache.""" if _encode_e4m3 is not None: return _encode_e4m3(values) # Fallback: use native FP8 if available if hasattr(torch, 'float8_e4m3fn'): return values.clamp(-448.0, 448.0).to(torch.float8_e4m3fn).view(torch.uint8) # Manual E4M3 encode v = values.float().clamp(-448.0, 448.0) sign = (v < 0).to(torch.uint8) << 7 av = v.abs().clamp(min=2**-9) log2_av = torch.log2(av) exp_raw = torch.floor(log2_av).clamp(-6, 8) exp_biased = (exp_raw + 7).clamp(0, 15) mantissa = ((av / torch.pow(2.0, exp_raw) - 1.0) * 8.0).round().clamp(0, 7).to(torch.uint8) return sign | (exp_biased.to(torch.uint8) << 3) | mantissa def _decode_e4m3_kv(encoded: torch.Tensor) -> torch.Tensor: """Decode E4M3 uint8 back to FP32 — optimized for KV cache.""" if _decode_e4m3 is not None: return _decode_e4m3(encoded) # Fallback: use native FP8 if available if hasattr(torch, 'float8_e4m3fn'): return encoded.view(torch.float8_e4m3fn).float() # Manual E4M3 decode sign = ((encoded >> 7) & 1).float() exp = ((encoded >> 3) & 0xF).long() mant = (encoded & 0x7).long() is_normal = exp > 0 normal_val = (8 + mant).float() * torch.pow(2.0, (exp - 10).float()) subnormal_val = mant.float() * (2.0 ** -9) unsigned = torch.where(is_normal, normal_val, subnormal_val) return torch.where(sign != 0, -unsigned, unsigned) # NVFP4/MXFP4 kernels from cutlass_kernels (fused Triton dequant-matmul) try: from cutlass_kernels import ( NVFP4Weights as _NVFP4Weights, nvfp4_gemm as _nvfp4_gemm, MXFP4Weights as _MXFP4Weights, mxfp4_gemm as _mxfp4_gemm, ) _CUTLASS_FP4_AVAILABLE = True except Exception: _NVFP4Weights = None _nvfp4_gemm = None _MXFP4Weights = None _mxfp4_gemm = None _CUTLASS_FP4_AVAILABLE = False # Phase 5: C++/CUDA preprocessing acceleration (JIT-compiled) try: from torch.utils.cpp_extension import load as _cpp_load import os as _os _csrc_dir = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), 'csrc') if _os.path.exists(_os.path.join(_csrc_dir, 'fireecho_preproc_cuda.cu')): _fireecho_preproc = _cpp_load( name='fireecho_preproc', sources=[ _os.path.join(_csrc_dir, 'fireecho_preproc.cpp'), _os.path.join(_csrc_dir, 'fireecho_preproc_cuda.cu'), ], extra_cflags=['-O3'], extra_cuda_cflags=['-O3', '--use_fast_math'], extra_ldflags=['-lcufft'], verbose=False, ) _PREPROC_CUDA_AVAILABLE = True else: _fireecho_preproc = None _PREPROC_CUDA_AVAILABLE = False except Exception: _fireecho_preproc = None _PREPROC_CUDA_AVAILABLE = False # FE-MX CUDA kernels (fused quantize/dequantize for Hebbian memory) try: if _os.path.exists(_os.path.join(_csrc_dir, 'femx_kernels.cu')): _femx_cuda = _cpp_load( name='femx_cuda', sources=[ _os.path.join(_csrc_dir, 'femx_bindings.cpp'), _os.path.join(_csrc_dir, 'femx_kernels.cu'), ], extra_cflags=['-O3'], extra_cuda_cflags=['-O3', '--use_fast_math'], verbose=False, ) _FEMX_CUDA_AVAILABLE = True # Wire CUDA backend into FEMXStorage for direct use if _FEMX_AVAILABLE: from femx_storage import set_cuda_backend set_cuda_backend(_femx_cuda) else: _femx_cuda = None _FEMX_CUDA_AVAILABLE = False except Exception: _femx_cuda = None _FEMX_CUDA_AVAILABLE = False # Triton fused kernels for HebbianMemory (competition, traces, gate output) _TRITON_HEBBIAN_AVAILABLE = False try: from triton_hebbian import ( init_triton_hebbian as _init_triton_hebbian, fused_competition as _fused_competition, fused_soft_hebbian as _fused_soft_hebbian, fused_traces_update as _fused_traces_update, fused_gate_output as _fused_gate_output, fused_dequant_matvec as _fused_dequant_matvec, compute_effective_lr as _compute_effective_lr, update_slot_metadata as _update_slot_metadata, ) _TRITON_HEBBIAN_AVAILABLE = _init_triton_hebbian() except Exception: _TRITON_HEBBIAN_AVAILABLE = False def _fused_matmul(a: torch.Tensor, b: torch.Tensor, use_cutlass: bool = True, use_dsmem: bool = True, use_quantum: bool = True) -> torch.Tensor: """ Fused matmul: try CUTLASS TMA, then DSMEM cluster, then Quantum-optimized, then torch.matmul. All backends expect BF16/FP16; returns same dtype as a. """ orig_dtype = a.dtype a = a.to(torch.bfloat16).contiguous() if a.dtype != torch.bfloat16 else a.contiguous() b = b.to(torch.bfloat16).contiguous() if b.dtype != torch.bfloat16 else b.contiguous() if not a.is_cuda or not b.is_cuda: return torch.matmul(a, b).to(orig_dtype) try: if use_cutlass and _CUTLASS_AVAILABLE and _cutlass_tma_matmul is not None: return _cutlass_tma_matmul(a, b).to(orig_dtype) except Exception: pass try: if use_dsmem and _DSMEM_AVAILABLE and _dsmem_matmul is not None and (_dsmem_supports is None or _dsmem_supports()): return _dsmem_matmul(a, b).to(orig_dtype) except Exception: pass try: if use_quantum and _QUANTUM_AVAILABLE and _quantum_matmul is not None: return _quantum_matmul(a, b).to(orig_dtype) except Exception: pass return torch.matmul(a, b).to(orig_dtype) # Re-export fused backends for unified FireEcho API (all fused inside fireecho_kernel) tma_matmul = _cutlass_tma_matmul tma_attention = _cutlass_tma_attention tma_gqa_attention = _cutlass_tma_gqa_attention CutlassL2CacheManager = _CutlassL2CacheManager # CUTLASS L2; local L2CacheManager class remains below quantum_optimized_matmul = _quantum_matmul cluster_matmul_dsmem = _dsmem_matmul supports_dsmem = _dsmem_supports try: from quantum import StateVector, QuantumCircuit, QuantumSimulator except Exception: StateVector = None # type: ignore QuantumCircuit = None # type: ignore QuantumSimulator = None # type: ignore # ============================================================================ # CONFIGURATION # ============================================================================ @dataclass class FireEchoConfig: """Configuration for FireEcho Engine.""" # Model architecture dim: int = 4096 num_heads: int = 32 num_kv_heads: int = 32 # For GQA (grouped query attention) head_dim: Optional[int] = None # Explicit head dim (default: dim // num_heads). Qwen3: 128. num_layers: int = 32 vocab_size: int = 32000 intermediate_size: int = 11008 # FFN hidden dim max_seq_len: int = 32768 # 32k context rope_theta: float = 10000.0 partial_rotary_factor: float = 1.0 # Fraction of head_dim that gets RoPE (Phi-4: 0.75) attn_bias: bool = False # Q/K/V projection bias tie_word_embeddings: bool = False # Share embed/lm_head weights use_qk_norm: bool = False # QK normalization (Molmo2-O-7B uses OLMo-style RMSNorm on Q,K) qk_norm_per_head: bool = False # Per-head QK norm (Qwen3: RMSNorm(head_dim)) vs flat (Molmo: RMSNorm(H*D)) norm_after: bool = False # Post-norm (Molmo2) vs pre-norm (LLaMA). Critical for correct output. # KV Cache - Real paging kv_block_size: int = 16 max_kv_blocks: int = 8192 # Supports 131k tokens # Quantization use_nvfp4: bool = True nvfp4_block_size: int = 32 quantize_weights: bool = True # Apply NVFP4 to weights critical_layers: List[int] = field(default_factory=list) # Keep in FP16 auto_critical_layers: bool = True # Auto-detect first/last ~15% of layers goliath_bits: Union[int, str] = 4 # Goliath fused kernel: 4=FP4, 8=FP8, 'auto' use_residual_correction: bool = True # FP8 residual correction for FP4 weights (double-buff) use_goliath: bool = True # Use Goliath fused dequant-matmul when available w4a4_mode: bool = False # False=W4A16 (fastest), True=W4A4 (activation quantization) decode_skip_act_quant: bool = True # Skip activation quant during decode (M <= threshold) decode_act_quant_threshold: int = 64 # M threshold for skipping act quant use_goliath_linear: bool = False # Use GoliathLinear (FP4/FP8 fwd + FP32 backward) for training FFN # Hebbian Memory - Enhanced use_hebbian: bool = True hebbian_memory_size: int = 128 hebbian_lr: float = 0.01 hebbian_decay: float = 0.99 # Memory decay rate hebbian_per_layer: bool = True # Per-layer Hebbian # Soft Competitive Hebbian (NHL: Tang et al. 2023) hebbian_temperature: float = 1.0 # τ for softmax competition (Eq. 6) hebbian_weight_radius: float = 1.0 # R - sphere convergence radius (Eq. 7) hebbian_use_soft: bool = True # Enable NHL-style soft competitive updates # Learned Neuro-Modulator (NHL Sec 3.4 + Floreano 2008) hebbian_use_learned_modulator: bool = True # Trainable feedback layer hebbian_modulator_entropy_weight: float = 0.1 # λ for entropy loss (Eq. 10) # Dual-timescale eligibility traces (STELLAR improved) hebbian_tau_fast: float = 0.90 # Fast eligibility trace decay hebbian_tau_slow: float = 0.99 # Slow eligibility trace decay # Three-factor Hebbian rule (Floreano: pre × post × modulator) hebbian_use_three_factor: bool = True # Enable three-factor update # Pattern separation (Surget & Belzung 2022 — orthogonalization) hebbian_separation_strength: float = 0.05 # Decorrelation pressure on memory slots hebbian_separation_threshold: float = 0.5 # Cosine similarity threshold to trigger separation # Synaptic competition (neurogenesis — memory eviction) hebbian_competition_strength: float = 0.1 # Weakening factor for overlapping old traces hebbian_slot_recycle_after: int = 30 # Recycle unused slots after N updates # Ventral-dorsal specialization (neurogenesis — layer-dependent params) hebbian_use_layer_specialization: bool = True # Enable per-layer hyperparameter scaling # Dynamic memory allocation (HAG — grow from empty) hebbian_slot_activation_threshold: float = 0.01 # Norm threshold for "active" slot # Phase 4 — Research-backed Hebbian tuning # Speed: torch.compile on Hebbian paths (disabled by default: .item() calls # in competition stats + dynamic shapes cause CUDA graph pool crashes) hebbian_compile: bool = False # Norm-scaled updates (Duan et al. ICLR 2023) hebbian_max_update_norm: float = 1.0 # Trace decay + clipping (Szelogowski ENN 2025) hebbian_trace_clip: float = 0.1 hebbian_weight_clip: float = 1.0 # Noise injection for robustness (ENN) hebbian_noise_scale: float = 0.001 # BCPNN-inspired adaptive per-slot lr (Ravichandran 2021) hebbian_adaptive_slot_lr: bool = False hebbian_tau_age: float = 100.0 hebbian_importance_scale: float = 2.0 # Homeostatic thresholds (Zhou/MTJ Nature 2025) hebbian_homeostatic_threshold: bool = False hebbian_threshold_incr: float = 0.01 hebbian_threshold_decr: float = 0.001 # Learnable per-layer residual alpha hebbian_learnable_alpha: bool = False # Cosine similarity retrieval with temperature (ENN) hebbian_cosine_retrieval: bool = False hebbian_retrieval_tau: float = 1.0 hebbian_sparsity_lambda: float = 0.1 # Multi-timescale memory (Limbacher/TU Graz 2022) hebbian_multi_timescale: bool = False hebbian_working_memory_ratio: float = 0.3 # Structural plasticity — slot merge/split (IBM CAL) hebbian_structural_plasticity: bool = False hebbian_merge_threshold: float = 0.95 # BCPNN trace filter chain (Yang 2020) hebbian_use_trace_filter: bool = False # Phase 5: ELM + Bayesian integration # ELM pseudoinverse warm-start (Huang 2006) hebbian_elm_warmstart: bool = False hebbian_elm_warmstart_samples: int = 32 # MESU uncertainty-scaled learning (Bonnet et al. 2025) hebbian_mesu: bool = False hebbian_mesu_sigma_prior: float = 0.5 hebbian_mesu_sigma_res: float = 10.0 # Bayesian reward estimation (BDA3 conjugate Normal-Normal) hebbian_bayesian_reward: bool = False hebbian_reward_prior_mean: float = 0.0 hebbian_reward_prior_var: float = 1.0 # Identity-init for projection weights (improves frozen-mode convergence) hebbian_identity_init: bool = False # Frozen-mode AdamW lr override (projection weights need faster adaptation) hebbian_frozen_adamw_lr: float = 0.0 # 0 = use default # Memory consolidation — complementary learning systems (hippocampus → neocortex) hebbian_consolidation: bool = False hebbian_consolidation_interval: int = 20 # consolidate every N updates hebbian_consolidation_threshold: float = 0.01 # min slot relevance to promote hebbian_consolidated_decay: float = 0.9999 # near-zero decay for long-term bank hebbian_consolidation_ratio: float = 0.3 # fraction of fast_weight to transfer # Adaptive transfer filtering — prevent negative transfer (Cao et al. AAAI 2010) hebbian_adaptive_transfer: bool = False # gate promotion on signed impact score hebbian_transfer_ema_decay: float = 0.95 # TIS EMA decay (slower than relevance) hebbian_transfer_demotion: bool = False # actively shrink harmful consolidated patterns hebbian_transfer_demotion_rate: float = 0.1 # fraction to remove per demotion step hebbian_kg_consolidation_gate: bool = False # gate promotion on external KG score # Layer 4: Self-Direction — intrinsic motivation (curiosity + competence) hebbian_intrinsic_reward: bool = False # master switch for intrinsic reward hebbian_intrinsic_weight: float = 0.35 # total weight of intrinsic signal in reward blend hebbian_curiosity_ema_decay: float = 0.95 # EMA decay for activation baseline (curiosity) hebbian_competence_ema_decay: float = 0.99 # EMA decay for competence smoothing hebbian_goal_selection: bool = False # self-paced curriculum: pick weakest eval prompt hebbian_goal_score_decay: float = 0.9 # EMA decay for per-prompt goal scores # Layer 5: Agency — SPAR agent loop (perceive → plan → act → reflect) hebbian_spar_agent: bool = False # master switch for SPAR agent hebbian_spar_error_threshold: int = 2 # consecutive errors before recovery (fast trigger) hebbian_spar_recovery_steps: int = 25 # recovery boost duration (long enough to consolidate) # FE-MX: Age-adaptive microscaling for fast_weight compression hebbian_use_femx: bool = False # enable FE-MX BFP storage (~48% VRAM savings) # Multimodal — Vision (SigLIP) use_vision: bool = False vision_hidden_size: int = 1152 # SigLIP hidden dim vision_intermediate_size: int = 4304 # SigLIP MLP dim vision_num_layers: int = 27 # SigLIP encoder layers vision_num_heads: int = 16 # SigLIP attention heads vision_image_size: int = 448 # crop_size from Phi-4 config vision_patch_size: int = 14 # patch size vision_num_patches: int = 1024 # (448/14)^2 vision_compressed_tokens: int = 545 # after HD transform: sub(272) + glb_GN(1) + global(272) # Multimodal — Audio (Conformer encoder, 24 blocks, 1024-dim) use_audio: bool = False audio_dim: int = 1024 # Conformer attention_dim audio_sample_rate: int = 16000 audio_num_layers: int = 24 # Conformer blocks audio_num_heads: int = 16 # attention heads audio_hidden_size: int = 1024 # d_model audio_ffn_dim: int = 1536 # post-GLU FFN dim (pre-GLU: 3072) audio_kernel_size: int = 3 # ConvModule depthwise kernel audio_n_mels: int = 80 # mel spectrogram bins audio_time_reduction: int = 8 # conv subsampling factor audio_max_rel_distance: int = 500 # T5 relative attention bias audio_conv_channels: int = 1024 # conv subsampling channels # Multimodal — Special token IDs image_token_id: int = 200010 # <|endoftext10|> marks image positions audio_token_id: int = 200011 # <|endoftext11|> marks audio positions # Multimodal — Vision LoRA (merged at load time) vision_lora_r: int = 256 # LoRA rank vision_lora_alpha: int = 512 # LoRA scaling (alpha/r = 2.0) # Multimodal — Speech LoRA (merged at load time) speech_lora_r: int = 320 # LoRA rank speech_lora_alpha: int = 640 # LoRA scaling (alpha/r = 2.0) # Precision compute_dtype: torch.dtype = torch.bfloat16 accumulate_dtype: torch.dtype = torch.float32 # Performance use_hybrid_matmul: bool = True use_flash_attention: bool = True # Fused Kernels (Phase 1 optimizations) use_fused_qkv: bool = True # Fused Q,K,V projection use_fused_swiglu: bool = True # Fused SwiGLU FFN use_splitk_gemm: bool = True # Split-K for tall matrices # Phase 2 optimizations use_persistent_gemm: bool = True # Persistent GEMM for large matrices use_l2_cache_control: bool = True # L2 cache pinning for weights num_sms: int = 170 # RTX 5090 has 170 SMs l2_cache_mb: float = 128.0 # L2 cache size # CUTLASS fusion (TMA MatMul, TMA Attention, L2 pinning) use_native_cutlass: bool = True # Use native CUTLASS kernels when available # Quantum Gold fusion (quantum_optimized_matmul) use_quantum_matmul: bool = True # Use Quantum-optimized matmul when available # DSMEM cluster fusion (cluster_matmul_dsmem) use_dsmem_cluster: bool = True # Use DSMEM cluster matmul when available # Phase 3 fusions: norm+projection, residual+norm, RoPE, GQA use_fused_norm_qkv: bool = True # Fused RMSNorm + QKV projection (eliminates intermediate) use_fused_residual_norm: bool = True # Fused residual add + RMSNorm use_fused_rope: bool = True # Fused RoPE on Q+K in single kernel launch use_gqa_native: bool = True # Native GQA in SDPA (no repeat_interleave alloc) # ── MoE (Mixture-of-Experts) ── use_moe: bool = False # Use MoE FFN layers (Qwen3-Omni) num_experts: int = 1 # Total experts per layer num_experts_per_tok: int = 1 # Active experts per token (top-k) moe_intermediate_size: int = 768 # Per-expert FFN intermediate dim shared_expert_intermediate_size: int = 0 # Shared expert dim (0 = none) norm_topk_prob: bool = True # Normalize top-k expert probabilities # FE-MX age-adaptive expert quantization use_femx_experts: bool = False # Enable per-expert FE-MX tier assignment femx_cold_threshold: int = 50 # Usage count below this → FEMX4 femx_warm_threshold: int = 200 # Usage count below this → FEMX6, above → FEMX8 femx_tier_interval: int = 100 # Re-evaluate tiers every N forward passes def compute_critical_layers(self) -> List[int]: """Auto-detect critical layers (first/last ~7.5% each, ~15% total). First and last layers are most sensitive to quantization error. These layers are kept in higher precision (FP16/BF16). """ if self.critical_layers: return self.critical_layers if not self.auto_critical_layers: return [] n = self.num_layers n_critical = max(1, int(n * 0.075)) first = list(range(n_critical)) last = list(range(n - n_critical, n)) return sorted(set(first + last)) @classmethod def molmo2_7b(cls) -> 'FireEchoConfig': """Molmo2-O-7B (AI2) — 7.76B params, 65k context. vocab_size=100406 (100278 base + 128 additional). norm_after=True for post-norm architecture. from_pretrained() auto-verifies against loaded embedding shapes.""" return cls(dim=4096, num_heads=32, num_kv_heads=32, num_layers=32, vocab_size=100406, intermediate_size=11008, max_seq_len=65536, rope_theta=500000.0, use_qk_norm=True, norm_after=True) @classmethod def phi4_multimodal(cls) -> 'FireEchoConfig': """Phi-4-multimodal-instruct (Microsoft) — 5.6B text params, 131k context. Pre-norm, GQA 24/8, partial RoPE (75%), tie_word_embeddings=True. LoRA vision/speech adapters skipped in text-only mode.""" return cls(dim=3072, num_heads=24, num_kv_heads=8, num_layers=32, vocab_size=200064, intermediate_size=8192, max_seq_len=131072, rope_theta=10000.0, partial_rotary_factor=0.75, tie_word_embeddings=True, norm_after=False) @classmethod def qwen3_omni(cls) -> 'FireEchoConfig': """Qwen3-Omni-30B-A3B thinker (Alibaba) — 30.5B total, ~3.3B active/token. MoE: 128 experts/layer, top-8 routing. GQA 32/4, QK-norm, full RoPE. Text-only thinker extraction from multimodal omni-model.""" return cls(dim=2048, num_heads=32, num_kv_heads=4, head_dim=128, num_layers=48, vocab_size=152064, intermediate_size=768, max_seq_len=65536, rope_theta=1000000.0, partial_rotary_factor=1.0, tie_word_embeddings=False, norm_after=False, use_qk_norm=True, qk_norm_per_head=True, use_fused_qkv=False, use_moe=True, num_experts=128, num_experts_per_tok=8, moe_intermediate_size=768, norm_topk_prob=True, use_femx_experts=True, # KV cache: 4k tokens (vs 131k default) — 30B MoE leaves # ~10GB for cache on 32GB GPU. 4k × 48L × 4H × 128D × 2 = 3.1GB. max_kv_blocks=256, kv_block_size=16) def apply_frozen_preset(self): """Optimized hyperparameters for frozen-base Hebbian-only training. Research-backed: higher lr for faster synapse adaptation, lower decay for faster forgetting of bad patterns, larger memory for more capacity. Based on Duan (ICLR 2023), Szelogowski (ENN), Yang (BCPNN). Automatically enables the proven Phase 4-5 features that improve frozen-only performance: adaptive slot lr, homeostatic thresholds, learnable alpha, cosine retrieval, ELM warm-start, MESU, and Bayesian reward. These can still be individually overridden after calling this method. Note: norm/clip values scaled for [M, 3072] weight matrices. Previous values (1.0/0.1) were too tight and crushed learning signal to near-zero. """ # Core hyperparameters self.hebbian_lr = 0.05 # 5x faster synapse adaptation (frozen needs it) self.hebbian_decay = 0.97 # faster forgetting of bad patterns self.hebbian_memory_size = 256 # double slot capacity self.hebbian_compile = False # disabled: .item() + dynamic shapes cause CUDA graph crashes self.hebbian_max_update_norm = 10.0 # scaled for [256, 3072] matrices self.hebbian_trace_clip = 1.0 # ENN trace clipping (was 0.1, too tight) self.hebbian_weight_clip = 3.0 # fast weight bounds (wider for frozen) self.hebbian_noise_scale = 0.01 # robustness noise (10x prev for frozen) self.hebbian_slot_activation_threshold = 0.001 # lower threshold for slot detection # Auto-scale weight_radius for dim (default 1.0 is too small for dim=3072) self.hebbian_weight_radius = max(1.0, math.sqrt(self.dim) / 10.0) # ~5.5 for dim=3072 # Phase 4 features — proven to help frozen-only self.hebbian_adaptive_slot_lr = True # BCPNN per-slot lr (Ravichandran 2021) self.hebbian_tau_age = 200.0 # slower age decay for frozen mode self.hebbian_importance_scale = 1.0 # less aggressive protection initially self.hebbian_homeostatic_threshold = True # slot anti-monopolization (Zhou 2025) self.hebbian_learnable_alpha = True # per-layer residual mixing self.hebbian_cosine_retrieval = True # direction-based matching (ENN) self.hebbian_retrieval_tau = 0.5 # sharper retrieval for frozen # Phase 5 features — ELM + Bayesian self.hebbian_elm_warmstart = True # pseudoinverse warm-start self.hebbian_mesu = True # uncertainty-scaled learning self.hebbian_mesu_sigma_prior = 0.5 # moderate initial uncertainty self.hebbian_bayesian_reward = True # conjugate reward smoothing # Memory consolidation — lock in useful patterns permanently self.hebbian_consolidation = True # Adaptive transfer — block negative transfer, demote harmful patterns self.hebbian_adaptive_transfer = True self.hebbian_transfer_demotion = True self.hebbian_kg_consolidation_gate = True # Layer 4: Self-Direction — intrinsic motivation self.hebbian_intrinsic_reward = True self.hebbian_goal_selection = True # Layer 5: Agency — SPAR agent loop self.hebbian_spar_agent = True # Identity-init for projections (critical for frozen-mode convergence) self.hebbian_identity_init = True # AdamW lr for projection weights in frozen mode # Projections don't need gradient updates — Hebbian does all the learning. # Proven invariant across 6e-15 to 10.0 (same loss, same generation quality). self.hebbian_frozen_adamw_lr = 6e-15 # ============================================================================ # HYBRID MATMUL - Auto-dispatch Triton/cuBLAS # ============================================================================ @lru_cache(maxsize=1024) def _is_power_of_2(n: int) -> bool: return n > 0 and (n & (n - 1)) == 0 @lru_cache(maxsize=1024) def _is_prime(n: int) -> bool: if n < 2: return False if n == 2: return True if n % 2 == 0: return False for i in range(3, int(math.sqrt(n)) + 1, 2): if n % i == 0: return False return True @lru_cache(maxsize=1024) def _should_use_triton(M: int, N: int, K: int) -> bool: """Decide Triton vs cuBLAS based on RTX 5090 benchmarks.""" sizes = [M, N, K] # Round numbers -> cuBLAS if any(s % 1000 == 0 and s >= 3000 for s in sizes): return False # Exact 8192 -> cuBLAS if all(s == 8192 for s in sizes): return False # Prime-adjacent -> Triton for s in sizes: if _is_prime(s) or _is_prime(s - 1) or _is_prime(s + 1): return True # Powers of 2 (2048-4096) -> Triton if all(_is_power_of_2(s) and 2048 <= s <= 4096 for s in sizes): return True # Large pow2 -> cuBLAS if any(_is_power_of_2(s) and s >= 8192 for s in sizes): return False # Medium non-round -> Triton if all(1024 <= s <= 6000 and s % 1000 != 0 for s in sizes): return True return False @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=8), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=4), ], key=['M', 'N', 'K'], ) @triton.jit def _matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, ): pid = tl.program_id(0) num_pid_m = tl.cdiv(M, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_in_group = GROUP_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_M group_size_m = min(num_pid_m - first_pid_m, GROUP_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): k_off = k * BLOCK_K a_mask = (offs_m[:, None] < M) & ((offs_k[None, :] + k_off) < K) b_mask = ((offs_k[:, None] + k_off) < K) & (offs_n[None, :] < N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) acc += tl.dot(a, b) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc.to(tl.bfloat16), mask=c_mask) def _triton_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: M, K = a.shape K2, N = b.shape assert K == K2 c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) _matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1)) return c def hybrid_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Auto-dispatch matmul between Triton and cuBLAS. Triton SM 12.0 (Blackwell) is now supported in Triton 3.6.0+! Uses Triton for power-of-2 and prime-adjacent sizes, cuBLAS otherwise. """ M, K = a.shape K2, N = b.shape if _should_use_triton(M, N, K): try: return _triton_matmul(a, b) except Exception: # Fallback to cuBLAS if Triton fails return torch.matmul(a, b) else: return torch.matmul(a, b) # ============================================================================ # FUSED QKV PROJECTION KERNEL - Single kernel for Q, K, V # ============================================================================ @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=4), ], key=['M', 'N', 'K'], ) @triton.jit def _fused_qkv_kernel( # Input x_ptr, # Weight matrices (concatenated: [W_q; W_k; W_v]) w_ptr, # Outputs q_ptr, k_ptr, v_ptr, # Dimensions M, N_q, N_k, N_v, K, # Strides stride_xm, stride_xk, stride_wk, stride_wn, stride_qm, stride_qn, stride_km, stride_kn, stride_vm, stride_vn, # Block sizes BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Fused Q, K, V projection in single kernel - 3x fewer memory loads.""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) # Determine which output (Q, K, or V) this block computes N_total = N_q + N_k + N_v n_start = pid_n * BLOCK_N # Compute output offset and pointer if n_start < N_q: out_ptr = q_ptr out_stride_m = stride_qm out_stride_n = stride_qn local_n = n_start N_out = N_q elif n_start < N_q + N_k: out_ptr = k_ptr out_stride_m = stride_km out_stride_n = stride_kn local_n = n_start - N_q N_out = N_k else: out_ptr = v_ptr out_stride_m = stride_vm out_stride_n = stride_vn local_n = n_start - N_q - N_k N_out = N_v # Standard matmul computation offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk w_ptrs = w_ptr + offs_k[:, None] * stride_wk + (n_start + offs_n[None, :]) * stride_wn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): k_off = k * BLOCK_K x_mask = (offs_m[:, None] < M) & ((offs_k[None, :] + k_off) < K) w_mask = ((offs_k[:, None] + k_off) < K) & ((n_start + offs_n[None, :]) < N_total) x = tl.load(x_ptrs, mask=x_mask, other=0.0) w = tl.load(w_ptrs, mask=w_mask, other=0.0) acc += tl.dot(x, w) x_ptrs += BLOCK_K * stride_xk w_ptrs += BLOCK_K * stride_wk # Store result out_ptrs = out_ptr + offs_m[:, None] * out_stride_m + (local_n + offs_n[None, :]) * out_stride_n out_mask = (offs_m[:, None] < M) & ((local_n + offs_n[None, :]) < N_out) tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask) def fused_qkv_projection(x: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, w_v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Fused Q, K, V projection - computes all three with single input read. Uses native Triton kernel on SM 12.0 (Blackwell) - Triton 3.6.0+ supported! Args: x: Input tensor [batch*seq, dim] w_q: Query weight [dim, num_heads * head_dim] w_k: Key weight [dim, num_kv_heads * head_dim] w_v: Value weight [dim, num_kv_heads * head_dim] Returns: q, k, v: Projected tensors """ try: M, K = x.shape N_q = w_q.shape[1] N_k = w_k.shape[1] N_v = w_v.shape[1] # Concatenate weights for fused kernel w_concat = torch.cat([w_q, w_k, w_v], dim=1) # Allocate outputs q = torch.empty((M, N_q), device=x.device, dtype=torch.bfloat16) k = torch.empty((M, N_k), device=x.device, dtype=torch.bfloat16) v = torch.empty((M, N_v), device=x.device, dtype=torch.bfloat16) # Launch fused kernel N_total = N_q + N_k + N_v grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N_total, META['BLOCK_N'])) _fused_qkv_kernel[grid]( x, w_concat, q, k, v, M, N_q, N_k, N_v, K, x.stride(0), x.stride(1), w_concat.stride(0), w_concat.stride(1), q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), ) return q, k, v except Exception: # Fallback to PyTorch matmuls q = torch.matmul(x, w_q) k = torch.matmul(x, w_k) v = torch.matmul(x, w_v) return q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) # ============================================================================ # FUSED RMSNORM + QKV PROJECTION - Eliminates intermediate normed tensor # Two-loop kernel: Pass 1 computes RMSNorm scale, Pass 2 does normed-matmul # ============================================================================ @triton.autotune( configs=[ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=4), ], key=['M', 'K', 'N_total'], ) @triton.jit def _fused_rmsnorm_qkv_kernel( x_ptr, norm_weight_ptr, w_ptr, q_ptr, k_ptr, v_ptr, M, N_q, N_k, N_v, K, N_total, eps, stride_xm, stride_xk, stride_wk, stride_wn, stride_qm, stride_qn, stride_km, stride_kn, stride_vm, stride_vn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Fused RMSNorm + QKV projection: reads x twice, never writes intermediate normed tensor.""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) n_start = pid_n * BLOCK_N # Route output to Q, K, or V buffer based on column offset if n_start < N_q: out_ptr = q_ptr out_stride_m = stride_qm out_stride_n = stride_qn local_n = n_start N_out = N_q elif n_start < N_q + N_k: out_ptr = k_ptr out_stride_m = stride_km out_stride_n = stride_kn local_n = n_start - N_q N_out = N_k else: out_ptr = v_ptr out_stride_m = stride_vm out_stride_n = stride_vn local_n = n_start - N_q - N_k N_out = N_v offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # ── Pass 1: Compute RMSNorm scale per row ── # variance = sum(x^2) / K, scale = rsqrt(variance + eps) variance = tl.zeros((BLOCK_M,), dtype=tl.float32) for k_off in range(0, K, BLOCK_K): offs_k = k_off + tl.arange(0, BLOCK_K) mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) x_block = tl.load(x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, mask=mask, other=0.0).to(tl.float32) variance += tl.sum(x_block * x_block, axis=1) rms_scale = tl.rsqrt(variance / K + eps) # [BLOCK_M] # ── Pass 2: Fused normed-matmul ── # acc += (x * rms_scale * norm_weight) @ W offs_n = tl.arange(0, BLOCK_N) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k_off in range(0, K, BLOCK_K): offs_k = k_off + tl.arange(0, BLOCK_K) x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) w_mask = (offs_k[:, None] < K) & ((n_start + offs_n[None, :]) < N_total) x_block = tl.load(x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, mask=x_mask, other=0.0).to(tl.float32) nw = tl.load(norm_weight_ptr + offs_k, mask=offs_k < K, other=0.0).to(tl.float32) # Apply RMSNorm in-register: x_normed = x * scale * weight x_normed = x_block * rms_scale[:, None] * nw[None, :] w_block = tl.load(w_ptr + offs_k[:, None] * stride_wk + (n_start + offs_n[None, :]) * stride_wn, mask=w_mask, other=0.0) acc += tl.dot(x_normed.to(tl.bfloat16), w_block) # Store to appropriate Q/K/V output out_ptrs = out_ptr + offs_m[:, None] * out_stride_m + (local_n + offs_n[None, :]) * out_stride_n out_mask = (offs_m[:, None] < M) & ((local_n + offs_n[None, :]) < N_out) tl.store(out_ptrs, acc.to(tl.bfloat16), mask=out_mask) def fused_rmsnorm_qkv_projection(x: torch.Tensor, norm_weight: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, w_v: torch.Tensor, eps: float = 1e-6 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Fused RMSNorm + QKV: eliminates intermediate normed tensor from HBM. Reads x twice (norm pass + matmul pass), never materializes the normed result. Saves one full HBM read/write round-trip per token. Args: x: Input tensor [batch*seq, dim] norm_weight: RMSNorm weight [dim] w_q: Query weight [dim, num_heads * head_dim] w_k: Key weight [dim, num_kv_heads * head_dim] w_v: Value weight [dim, num_kv_heads * head_dim] eps: RMSNorm epsilon Returns: q, k, v: Projected tensors in BF16 """ try: M, K = x.shape N_q, N_k, N_v = w_q.shape[1], w_k.shape[1], w_v.shape[1] N_total = N_q + N_k + N_v # Concatenate weights for fused kernel w_concat = torch.cat([w_q, w_k, w_v], dim=1) # Allocate outputs q = torch.empty((M, N_q), device=x.device, dtype=torch.bfloat16) k = torch.empty((M, N_k), device=x.device, dtype=torch.bfloat16) v = torch.empty((M, N_v), device=x.device, dtype=torch.bfloat16) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N_total, META['BLOCK_N'])) _fused_rmsnorm_qkv_kernel[grid]( x, norm_weight, w_concat, q, k, v, M, N_q, N_k, N_v, K, N_total, eps, x.stride(0), x.stride(1), w_concat.stride(0), w_concat.stride(1), q.stride(0), q.stride(1), k.stride(0), k.stride(1), v.stride(0), v.stride(1), ) return q, k, v except Exception: # Fallback: separate RMSNorm + matmul variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x_norm = x * torch.rsqrt(variance + eps).to(x.dtype) * norm_weight return (torch.matmul(x_norm, w_q).to(torch.bfloat16), torch.matmul(x_norm, w_k).to(torch.bfloat16), torch.matmul(x_norm, w_v).to(torch.bfloat16)) # ============================================================================ # FUSED RESIDUAL + RMSNORM - Avoids HBM round-trip between add and norm # ============================================================================ @triton.jit def _fused_residual_rmsnorm_kernel( residual_ptr, sublayer_ptr, norm_weight_ptr, out_ptr, normed_ptr, M, D, eps, stride_m, stride_d, BLOCK_D: tl.constexpr, ): """Fused residual addition + RMSNorm in single pass. Computes: x = residual + sublayer_out; x_normed = RMSNorm(x, weight) Writes both x and x_normed without intermediate HBM round-trip. """ row = tl.program_id(0) offs_d = tl.arange(0, BLOCK_D) mask = offs_d < D # Load residual + sublayer output, fuse addition r = tl.load(residual_ptr + row * stride_m + offs_d * stride_d, mask=mask, other=0.0).to(tl.float32) s = tl.load(sublayer_ptr + row * stride_m + offs_d * stride_d, mask=mask, other=0.0).to(tl.float32) x = r + s # Write combined residual (needed for next block's residual connection) tl.store(out_ptr + row * stride_m + offs_d * stride_d, x.to(tl.bfloat16), mask=mask) # RMSNorm in-register (no extra HBM read) variance = tl.sum(x * x, axis=0) / D scale = tl.rsqrt(variance + eps) w = tl.load(norm_weight_ptr + offs_d, mask=mask, other=0.0).to(tl.float32) x_normed = x * scale * w # Write normed output (feeds into next sublayer) tl.store(normed_ptr + row * stride_m + offs_d * stride_d, x_normed.to(tl.bfloat16), mask=mask) def fused_residual_rmsnorm(residual: torch.Tensor, sublayer_out: torch.Tensor, norm_weight: torch.Tensor, eps: float = 1e-6 ) -> Tuple[torch.Tensor, torch.Tensor]: """Fused: x = residual + sublayer; x_normed = RMSNorm(x). Returns (x, x_normed). Eliminates one HBM write + read cycle between residual add and norm. Args: residual: Input residual [B, S, D] or [M, D] sublayer_out: Output from attention/FFN sublayer (same shape) norm_weight: RMSNorm weight [D] eps: RMSNorm epsilon Returns: (x, x_normed): Combined residual and its normed version """ try: shape = residual.shape D = residual.shape[-1] M = residual.numel() // D r_flat = residual.reshape(M, D) s_flat = sublayer_out.reshape(M, D) out = torch.empty_like(r_flat, dtype=torch.bfloat16) normed = torch.empty_like(r_flat, dtype=torch.bfloat16) BLOCK_D = triton.next_power_of_2(D) _fused_residual_rmsnorm_kernel[(M,)]( r_flat, s_flat, norm_weight, out, normed, M, D, eps, r_flat.stride(0), r_flat.stride(1), BLOCK_D=BLOCK_D, ) return out.view(shape), normed.view(shape) except Exception: # Fallback: separate add + norm x = residual + sublayer_out variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x_normed = x * torch.rsqrt(variance + eps).to(x.dtype) * norm_weight return x, x_normed # ============================================================================ # FUSED ROPE Q+K - Apply RoPE to both Q and K in single kernel launch # ============================================================================ @triton.jit def _fused_rope_qk_kernel( q_ptr, k_ptr, cos_ptr, sin_ptr, q_out_ptr, k_out_ptr, B, num_q_heads, num_kv_heads, S, head_dim, half_dim, position, stride_qb, stride_qh, stride_qs, stride_qd, stride_kb, stride_kh, stride_ks, stride_kd, stride_cos_s, stride_cos_d, BLOCK_HD: tl.constexpr, ): """Apply RoPE to both Q and K in a single kernel launch. Processes one (batch, head, seq_pos) triple per program. Q and K are processed by the same grid — head index determines which tensor to use. Total heads = num_q_heads + num_kv_heads per batch element. """ pid = tl.program_id(0) # Flattened: batch * (num_q_heads + num_kv_heads) * S total_heads = num_q_heads + num_kv_heads total_per_batch = total_heads * S b = pid // total_per_batch remainder = pid % total_per_batch h = remainder // S s = remainder % S # Determine if this is a Q or K head is_q = h < num_q_heads if is_q: src_ptr = q_ptr dst_ptr = q_out_ptr local_h = h stride_b = stride_qb stride_h = stride_qh stride_s = stride_qs stride_d = stride_qd else: src_ptr = k_ptr dst_ptr = k_out_ptr local_h = h - num_q_heads stride_b = stride_kb stride_h = stride_kh stride_s = stride_ks stride_d = stride_kd offs_d = tl.arange(0, BLOCK_HD) # Load cos/sin for this position seq_pos = position + s cos_val = tl.load(cos_ptr + seq_pos * stride_cos_s + offs_d * stride_cos_d, mask=offs_d < half_dim, other=0.0).to(tl.float32) sin_val = tl.load(sin_ptr + seq_pos * stride_cos_s + offs_d * stride_cos_d, mask=offs_d < half_dim, other=0.0).to(tl.float32) # Load x1 (first half) and x2 (second half) base = b * stride_b + local_h * stride_h + s * stride_s x1 = tl.load(src_ptr + base + offs_d * stride_d, mask=offs_d < half_dim, other=0.0).to(tl.float32) x2 = tl.load(src_ptr + base + (offs_d + half_dim) * stride_d, mask=offs_d < half_dim, other=0.0).to(tl.float32) # RoPE rotation: [x1*cos - x2*sin, x2*cos + x1*sin] y1 = x1 * cos_val - x2 * sin_val y2 = x2 * cos_val + x1 * sin_val # Store rotated values tl.store(dst_ptr + base + offs_d * stride_d, y1.to(tl.bfloat16), mask=offs_d < half_dim) tl.store(dst_ptr + base + (offs_d + half_dim) * stride_d, y2.to(tl.bfloat16), mask=offs_d < half_dim) # Pass-through for partial RoPE: copy unrotated dims (if head_dim > rotary_dim) if head_dim > half_dim * 2: extra_offs = half_dim * 2 + tl.arange(0, BLOCK_HD) extra_mask = extra_offs < head_dim extra_vals = tl.load(src_ptr + base + extra_offs * stride_d, mask=extra_mask, other=0.0) tl.store(dst_ptr + base + extra_offs * stride_d, extra_vals, mask=extra_mask) def fused_rope_qk(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: """Apply RoPE to both Q and K in a single fused kernel launch. Saves one kernel launch compared to applying RoPE separately to Q and K. Args: q: Query tensor [B, num_q_heads, S, head_dim] k: Key tensor [B, num_kv_heads, S, head_dim] cos: Precomputed cosines [max_len, rotary_dim//2] sin: Precomputed sines [max_len, rotary_dim//2] position: Starting position index Returns: (q_rotated, k_rotated): RoPE-applied tensors """ try: B, num_q_heads, S, head_dim = q.shape num_kv_heads = k.shape[1] half_dim = cos.shape[-1] # rotary_dim // 2 q_out = torch.empty_like(q) k_out = torch.empty_like(k) total_programs = B * (num_q_heads + num_kv_heads) * S BLOCK_HD = triton.next_power_of_2(max(half_dim, head_dim - half_dim * 2 if head_dim > half_dim * 2 else half_dim)) _fused_rope_qk_kernel[(total_programs,)]( q, k, cos, sin, q_out, k_out, B, num_q_heads, num_kv_heads, S, head_dim, half_dim, position, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), cos.stride(0), cos.stride(1), BLOCK_HD=BLOCK_HD, ) return q_out, k_out except Exception: # Fallback: apply RoPE separately via PyTorch q_rot = _apply_rotary_emb(q, cos, sin, position) k_rot = _apply_rotary_emb(k, cos, sin, position) return q_rot, k_rot # ============================================================================ # FireEcho FlashDecode — Single-launch M=1 GQA attention with online softmax # ============================================================================ # For CUDA graph decode: reads only valid KV positions via GPU valid_len # tensor. One kernel launch replaces 5+ PyTorch ops per layer. # Grid: (kv_heads,) — each program handles (num_heads/kv_heads) Q head groups. @triton.jit def _fireecho_flash_decode_kernel( q_ptr, # [num_heads, 1, head_dim] k_ptr, # [kv_heads, max_seq, head_dim] v_ptr, # [kv_heads, max_seq, head_dim] o_ptr, # [num_heads, 1, head_dim] valid_len_ptr, # [1] int64 — number of valid KV positions (GPU tensor) scale, # float: 1/sqrt(head_dim) num_heads: tl.constexpr, kv_heads: tl.constexpr, max_seq: tl.constexpr, head_dim: tl.constexpr, stride_qh, stride_qd, stride_kh, stride_ks, stride_kd, stride_vh, stride_vs, stride_vd, stride_oh, stride_od, BLOCK_KV: tl.constexpr, GROUPS: tl.constexpr, # num_heads // kv_heads (actual) PAD_GROUPS: tl.constexpr, # max(GROUPS, 16) for tl.dot compat ): """FireEcho FlashDecode — M=1 GQA decode with online softmax. Each program processes one KV head and its GROUPS Q heads. Tiles through only valid KV positions — skips padding entirely. Uses padded group dimension (PAD_GROUPS >= 16) for Triton tl.dot tensor core compatibility. Invalid group rows masked on store. Online softmax: maintains unnormalized accumulator, single division at the end (FlashAttention Algorithm 1). Numerically stable. """ kv_head = tl.program_id(0) valid_len = tl.load(valid_len_ptr).to(tl.int32) q_head_start = kv_head * GROUPS offs_g = tl.arange(0, PAD_GROUPS) offs_d = tl.arange(0, head_dim) g_mask = offs_g < GROUPS # Load Q for all groups (padded to PAD_GROUPS): [PAD_GROUPS, head_dim] q = tl.load( q_ptr + (q_head_start + offs_g[:, None]) * stride_qh + offs_d[None, :] * stride_qd, mask=g_mask[:, None], other=0.0).to(tl.float32) # Online softmax state — unnormalized accumulator m_prev = tl.full([PAD_GROUPS], value=float('-inf'), dtype=tl.float32) l_prev = tl.zeros([PAD_GROUPS], dtype=tl.float32) o_acc = tl.zeros([PAD_GROUPS, head_dim], dtype=tl.float32) # Tile through valid KV positions only for kv_start in range(0, valid_len, BLOCK_KV): offs_kv = kv_start + tl.arange(0, BLOCK_KV) kv_mask = offs_kv < valid_len # K tile: [BLOCK_KV, head_dim] k_tile = tl.load( k_ptr + kv_head * stride_kh + offs_kv[:, None] * stride_ks + offs_d[None, :] * stride_kd, mask=kv_mask[:, None], other=0.0).to(tl.float32) # Scores: Q @ K^T → [PAD_GROUPS, BLOCK_KV] scores = tl.dot(q, tl.trans(k_tile)) * scale scores = tl.where(kv_mask[None, :], scores, float('-inf')) # Online softmax update (unnormalized accumulator) m_cur = tl.max(scores, axis=1) # [PAD_GROUPS] m_new = tl.maximum(m_prev, m_cur) alpha = tl.exp(m_prev - m_new) # rescale old p = tl.exp(scores - m_new[:, None]) # [PAD_GROUPS, BLOCK_KV] l_new = l_prev * alpha + tl.sum(p, axis=1) # V tile: [BLOCK_KV, head_dim] v_tile = tl.load( v_ptr + kv_head * stride_vh + offs_kv[:, None] * stride_vs + offs_d[None, :] * stride_vd, mask=kv_mask[:, None], other=0.0).to(tl.float32) # Unnormalized output update: rescale + accumulate o_acc = o_acc * alpha[:, None] + tl.dot(p, v_tile) m_prev = m_new l_prev = l_new # Final normalization: divide by sum of softmax denominators o_final = o_acc / tl.maximum(l_prev[:, None], 1e-6) # Store output: only valid groups [GROUPS, head_dim] tl.store( o_ptr + (q_head_start + offs_g[:, None]) * stride_oh + offs_d[None, :] * stride_od, o_final.to(tl.bfloat16), mask=g_mask[:, None]) def fireecho_flash_decode( q: torch.Tensor, # [B, num_heads, 1, head_dim] k: torch.Tensor, # [B, kv_heads, max_seq, head_dim] v: torch.Tensor, # [B, kv_heads, max_seq, head_dim] valid_len: torch.Tensor, # [1] int64 on GPU scale: float = None, ) -> torch.Tensor: """FireEcho FlashDecode: single-launch M=1 GQA attention. Only reads valid KV positions (via GPU valid_len tensor). CUDA graph safe — no Python-dependent control flow. Uses padded group dim (>= 16) for tensor core tl.dot compatibility. """ B, num_heads, _, head_dim = q.shape kv_heads = k.shape[1] max_seq = k.shape[2] groups = num_heads // kv_heads pad_groups = max(groups, 16) # tl.dot requires M >= 16 if scale is None: scale = 1.0 / (head_dim ** 0.5) out = torch.empty(B, num_heads, 1, head_dim, dtype=q.dtype, device=q.device) BLOCK_KV = 64 # 128 exceeds 99KB shared memory; 64 fits for b in range(B): _fireecho_flash_decode_kernel[(kv_heads,)]( q[b], k[b], v[b], out[b], valid_len, scale, num_heads, kv_heads, max_seq, head_dim, q[b].stride(0), q[b].stride(2), k[b].stride(0), k[b].stride(1), k[b].stride(2), v[b].stride(0), v[b].stride(1), v[b].stride(2), out[b].stride(0), out[b].stride(2), BLOCK_KV=BLOCK_KV, GROUPS=groups, PAD_GROUPS=pad_groups, num_stages=1, # prevent double-buffer, save shared memory num_warps=4, ) return out # ============================================================================ # FireEcho FlashDecode FP8 — E4M3 KV cache with inline dequantization # ============================================================================ # Same algorithm as BF16 FlashDecode, but loads uint8 K/V + scales and # dequantizes inline. 50% less memory bandwidth, ~8% faster decode. @triton.jit def _decode_e4m3_triton(encoded): """Inline E4M3 decode in Triton — returns float32.""" sign = ((encoded >> 7) & 1).to(tl.float32) exp = ((encoded >> 3) & 0xF).to(tl.int32) mant = (encoded & 0x7).to(tl.int32) is_normal = exp > 0 # Normal: (8 + mant) * 2^(exp - 10) normal_val = (8 + mant).to(tl.float32) * tl.exp2((exp - 10).to(tl.float32)) # Subnormal: mant * 2^-9 subnormal_val = mant.to(tl.float32) * (2.0 ** -9) unsigned = tl.where(is_normal, normal_val, subnormal_val) return tl.where(sign != 0, -unsigned, unsigned) @triton.jit def _fireecho_flash_decode_fp8_kernel( q_ptr, # [num_heads, 1, head_dim] BF16 k_ptr, # [kv_heads, max_seq, head_dim] uint8 E4M3 v_ptr, # [kv_heads, max_seq, head_dim] uint8 E4M3 k_scale_ptr, # [kv_heads, max_seq] float32 v_scale_ptr, # [kv_heads, max_seq] float32 o_ptr, # [num_heads, 1, head_dim] BF16 valid_len_ptr, # [1] int64 — number of valid KV positions scale, # float: 1/sqrt(head_dim) num_heads: tl.constexpr, kv_heads: tl.constexpr, max_seq: tl.constexpr, head_dim: tl.constexpr, stride_qh, stride_qd, stride_kh, stride_ks, stride_kd, stride_vh, stride_vs, stride_vd, stride_ksh, stride_kss, # K scale strides stride_vsh, stride_vss, # V scale strides stride_oh, stride_od, BLOCK_KV: tl.constexpr, GROUPS: tl.constexpr, PAD_GROUPS: tl.constexpr, ): """FireEcho FlashDecode FP8 — M=1 GQA with inline E4M3 dequant.""" kv_head = tl.program_id(0) valid_len = tl.load(valid_len_ptr).to(tl.int32) q_head_start = kv_head * GROUPS offs_g = tl.arange(0, PAD_GROUPS) offs_d = tl.arange(0, head_dim) g_mask = offs_g < GROUPS # Load Q for all groups: [PAD_GROUPS, head_dim] q = tl.load( q_ptr + (q_head_start + offs_g[:, None]) * stride_qh + offs_d[None, :] * stride_qd, mask=g_mask[:, None], other=0.0).to(tl.float32) # Online softmax state m_prev = tl.full([PAD_GROUPS], value=float('-inf'), dtype=tl.float32) l_prev = tl.zeros([PAD_GROUPS], dtype=tl.float32) o_acc = tl.zeros([PAD_GROUPS, head_dim], dtype=tl.float32) # Tile through valid KV positions for kv_start in range(0, valid_len, BLOCK_KV): offs_kv = kv_start + tl.arange(0, BLOCK_KV) kv_mask = offs_kv < valid_len # Load K as uint8: [BLOCK_KV, head_dim] k_fp8 = tl.load( k_ptr + kv_head * stride_kh + offs_kv[:, None] * stride_ks + offs_d[None, :] * stride_kd, mask=kv_mask[:, None], other=0) # Load K scales: [BLOCK_KV] k_scale = tl.load( k_scale_ptr + kv_head * stride_ksh + offs_kv * stride_kss, mask=kv_mask, other=1.0) # Dequant K: decode E4M3 then scale k_tile = _decode_e4m3_triton(k_fp8) * k_scale[:, None] # Scores: Q @ K^T → [PAD_GROUPS, BLOCK_KV] scores = tl.dot(q, tl.trans(k_tile)) * scale scores = tl.where(kv_mask[None, :], scores, float('-inf')) # Online softmax update m_cur = tl.max(scores, axis=1) m_new = tl.maximum(m_prev, m_cur) alpha = tl.exp(m_prev - m_new) p = tl.exp(scores - m_new[:, None]) l_new = l_prev * alpha + tl.sum(p, axis=1) # Load V as uint8: [BLOCK_KV, head_dim] v_fp8 = tl.load( v_ptr + kv_head * stride_vh + offs_kv[:, None] * stride_vs + offs_d[None, :] * stride_vd, mask=kv_mask[:, None], other=0) # Load V scales: [BLOCK_KV] v_scale = tl.load( v_scale_ptr + kv_head * stride_vsh + offs_kv * stride_vss, mask=kv_mask, other=1.0) # Dequant V v_tile = _decode_e4m3_triton(v_fp8) * v_scale[:, None] # Unnormalized output update o_acc = o_acc * alpha[:, None] + tl.dot(p, v_tile) m_prev = m_new l_prev = l_new # Final normalization o_final = o_acc / tl.maximum(l_prev[:, None], 1e-6) # Store output tl.store( o_ptr + (q_head_start + offs_g[:, None]) * stride_oh + offs_d[None, :] * stride_od, o_final.to(tl.bfloat16), mask=g_mask[:, None]) def fireecho_flash_decode_fp8( q: torch.Tensor, # [B, num_heads, 1, head_dim] BF16 k: torch.Tensor, # [B, kv_heads, max_seq, head_dim] uint8 v: torch.Tensor, # [B, kv_heads, max_seq, head_dim] uint8 k_scales: torch.Tensor, # [B, kv_heads, max_seq] float32 v_scales: torch.Tensor, # [B, kv_heads, max_seq] float32 valid_len: torch.Tensor, # [1] int64 on GPU scale: float = None, ) -> torch.Tensor: """FireEcho FlashDecode FP8: M=1 GQA with E4M3 KV cache. Same as BF16 FlashDecode but loads uint8 K/V + scales. Inline E4M3 dequant in Triton kernel — 50% less bandwidth. """ B, num_heads, _, head_dim = q.shape kv_heads = k.shape[1] max_seq = k.shape[2] groups = num_heads // kv_heads pad_groups = max(groups, 16) if scale is None: scale = 1.0 / (head_dim ** 0.5) out = torch.empty(B, num_heads, 1, head_dim, dtype=q.dtype, device=q.device) BLOCK_KV = 64 for b in range(B): _fireecho_flash_decode_fp8_kernel[(kv_heads,)]( q[b], k[b], v[b], k_scales[b], v_scales[b], out[b], valid_len, scale, num_heads, kv_heads, max_seq, head_dim, q[b].stride(0), q[b].stride(2), k[b].stride(0), k[b].stride(1), k[b].stride(2), v[b].stride(0), v[b].stride(1), v[b].stride(2), k_scales[b].stride(0), k_scales[b].stride(1), v_scales[b].stride(0), v_scales[b].stride(1), out[b].stride(0), out[b].stride(2), BLOCK_KV=BLOCK_KV, GROUPS=groups, PAD_GROUPS=pad_groups, num_stages=1, num_warps=4, ) return out # ============================================================================ # SPLIT-K GEMM - Parallel reduction for tall/skinny matrices # Uses FP32 atomics (BF16 atomic_add not supported) per AMD ROCm guidelines # https://rocm.docs.amd.com/en/docs-6.1.2/how-to/tuning-guides/mi300x/workload.html # ============================================================================ @triton.jit def _splitk_gemm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, SPLIT_K: tl.constexpr, ): """ Split-K GEMM: parallelize along K dimension for tall matrices. Uses FP32 atomic accumulation (BF16 atomic_add not supported). Output buffer must be FP32, converted to BF16 by wrapper. """ pid_m = tl.program_id(0) pid_n = tl.program_id(1) pid_k = tl.program_id(2) # Each split handles K // SPLIT_K elements k_per_split = tl.cdiv(K, SPLIT_K) k_start = pid_k * k_per_split k_end = min(k_start + k_per_split, K) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_m[:, None] * stride_am + (k_start + offs_k[None, :]) * stride_ak b_ptrs = b_ptr + (k_start + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(k_start, k_end, BLOCK_K): k_remaining = k_end - k a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < k_remaining) b_mask = (offs_k[:, None] < k_remaining) & (offs_n[None, :] < N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) acc += tl.dot(a, b) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Atomic add for reduction across K splits - USE FP32 (BF16 not supported) c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.atomic_add(c_ptrs, acc, mask=c_mask) # FP32 atomic def splitk_matmul(a: torch.Tensor, b: torch.Tensor, split_k: int = 4) -> torch.Tensor: """ Split-K GEMM for tall/skinny matrices (M >> N or N >> M). Uses FP32 atomics for reduction (BF16 atomic_add not supported). Converts output to BF16 after computation. Reference: AMD ROCm workload optimization guide https://rocm.docs.amd.com/en/docs-6.1.2/how-to/tuning-guides/mi300x/workload.html """ try: M, K = a.shape K2, N = b.shape assert K == K2 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # Use FP32 output buffer for atomic operations c = torch.zeros((M, N), device=a.device, dtype=torch.float32) grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), split_k) _splitk_gemm_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, SPLIT_K=split_k, ) # Convert FP32 result to BF16 return c.to(torch.bfloat16) except Exception: # Fallback to PyTorch return torch.matmul(a, b).to(torch.bfloat16) # ============================================================================ # FUSED SWIGLU FFN KERNEL - Gate and Up in single kernel # ============================================================================ @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=4), ], key=['M', 'N', 'K'], ) @triton.jit def _fused_swiglu_kernel( # Input x_ptr, # Weights w_gate_ptr, w_up_ptr, # Output out_ptr, # Dimensions M, N, K, # Strides stride_xm, stride_xk, stride_gk, stride_gn, stride_uk, stride_un, stride_om, stride_on, # Block sizes BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Fused SwiGLU: computes SiLU(x @ W_gate) * (x @ W_up) in single pass.""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) # Pointers x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk g_ptrs = w_gate_ptr + offs_k[:, None] * stride_gk + offs_n[None, :] * stride_gn u_ptrs = w_up_ptr + offs_k[:, None] * stride_uk + offs_n[None, :] * stride_un # Accumulators for gate and up acc_gate = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): k_off = k * BLOCK_K x_mask = (offs_m[:, None] < M) & ((offs_k[None, :] + k_off) < K) w_mask = ((offs_k[:, None] + k_off) < K) & (offs_n[None, :] < N) x = tl.load(x_ptrs, mask=x_mask, other=0.0) g = tl.load(g_ptrs, mask=w_mask, other=0.0) u = tl.load(u_ptrs, mask=w_mask, other=0.0) acc_gate += tl.dot(x, g) acc_up += tl.dot(x, u) x_ptrs += BLOCK_K * stride_xk g_ptrs += BLOCK_K * stride_gk u_ptrs += BLOCK_K * stride_uk # SwiGLU activation: SiLU(gate) * up # SiLU(x) = x * sigmoid(x) gate_sigmoid = tl.sigmoid(acc_gate) gate_silu = acc_gate * gate_sigmoid result = gate_silu * acc_up # Store result out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(out_ptrs, result.to(tl.bfloat16), mask=out_mask) def fused_swiglu(x: torch.Tensor, w_gate: torch.Tensor, w_up: torch.Tensor) -> torch.Tensor: """ Fused SwiGLU: SiLU(x @ W_gate) * (x @ W_up). Uses native Triton kernel on SM 12.0 (Blackwell) - Triton 3.6.0+ supported! """ try: M, K = x.shape N = w_gate.shape[1] out = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) _fused_swiglu_kernel[grid]( x, w_gate, w_up, out, M, N, K, x.stride(0), x.stride(1), w_gate.stride(0), w_gate.stride(1), w_up.stride(0), w_up.stride(1), out.stride(0), out.stride(1), ) return out except Exception: # Fallback to PyTorch gate = F.silu(torch.matmul(x, w_gate)) up = torch.matmul(x, w_up) return (gate * up).to(torch.bfloat16) # ============================================================================ # ADVANCED HYBRID MATMUL - With Split-K for tall matrices # ============================================================================ def hybrid_matmul_v2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Advanced hybrid matmul with Split-K support for tall/skinny matrices. Dispatch logic: - Very small M (< 64) with large K: Use Split-K - Power-of-2, primes: Use Triton - Round numbers, large: Use cuBLAS """ M, K = a.shape K2, N = b.shape # Split-K only benefits VERY small M with large K (decode phase) if M <= 64 and K >= 2048 and N >= 2048: return splitk_matmul(a.to(torch.bfloat16), b.to(torch.bfloat16), split_k=4) elif _should_use_triton(M, N, K): return _triton_matmul(a.to(torch.bfloat16), b.to(torch.bfloat16)) else: return torch.matmul(a, b) # ============================================================================ # PHASE 2: PERSISTENT GEMM - Reduced kernel launch overhead # ============================================================================ @triton.jit def _persistent_gemm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Persistent params num_tiles_m, num_tiles_n, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, NUM_SMS: tl.constexpr, ): """ Persistent GEMM: Thread blocks stay resident and process multiple tiles. Benefits: - Reduced kernel launch overhead - Better L2 cache utilization - Amortized setup costs """ # Get persistent tile ID pid = tl.program_id(0) num_tiles = num_tiles_m * num_tiles_n # Each SM processes multiple tiles in a loop for tile_id in range(pid, num_tiles, NUM_SMS): # Convert linear tile ID to 2D tile_m = tile_id // num_tiles_n tile_n = tile_id % num_tiles_n # Compute offsets offs_m = tile_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tile_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) # Pointers a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn # Accumulator acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # Main loop over K for k in range(0, tl.cdiv(K, BLOCK_K)): k_off = k * BLOCK_K a_mask = (offs_m[:, None] < M) & ((offs_k[None, :] + k_off) < K) b_mask = ((offs_k[:, None] + k_off) < K) & (offs_n[None, :] < N) a = tl.load(a_ptrs, mask=a_mask, other=0.0) b = tl.load(b_ptrs, mask=b_mask, other=0.0) acc += tl.dot(a, b) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Store result c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc.to(tl.bfloat16), mask=c_mask) def persistent_matmul(a: torch.Tensor, b: torch.Tensor, num_sms: int = 170) -> torch.Tensor: """ Persistent GEMM - thread blocks process multiple tiles. Uses native Triton kernel on SM 12.0 (Blackwell) - Triton 3.6.0+ supported! RTX 5090 has 170 SMs. """ try: M, K = a.shape K2, N = b.shape assert K == K2 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 num_tiles_m = triton.cdiv(M, BLOCK_M) num_tiles_n = triton.cdiv(N, BLOCK_N) c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) # Launch with NUM_SMS thread blocks _persistent_gemm_kernel[(num_sms,)]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), num_tiles_m, num_tiles_n, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=num_sms, ) return c except Exception: # Fallback to PyTorch return torch.matmul(a, b).to(torch.bfloat16) # ============================================================================ # PHASE 2: L2 CACHE CONTROL - Pin hot data via cudaAccessPolicyWindow # ============================================================================ # --------------------------------------------------------------------------- # CUDA Runtime API bindings for L2 cache persistence (CUDA 11.0+ / CC 8.0+) # --------------------------------------------------------------------------- class _AccessPolicyWindow(ctypes.Structure): """Maps to cudaAccessPolicyWindow (CUDA Runtime API).""" _fields_ = [ ("base_ptr", ctypes.c_void_p), # Starting address ("num_bytes", ctypes.c_size_t), # Size of the window ("hitRatio", ctypes.c_float), # Fraction with hitProp (0.0-1.0) ("hitProp", ctypes.c_int), # Property on hit (cudaAccessProperty) ("missProp", ctypes.c_int), # Property on miss (cudaAccessProperty) ] class _StreamAttrValue(ctypes.Union): """Maps to cudaStreamAttrValue (union).""" _fields_ = [ ("accessPolicyWindow", _AccessPolicyWindow), ("syncPolicy", ctypes.c_int), ] # cudaAccessProperty enum _CUDA_ACCESS_PROPERTY_NORMAL = 0 _CUDA_ACCESS_PROPERTY_STREAMING = 1 _CUDA_ACCESS_PROPERTY_PERSISTING = 2 # cudaStreamAttrID enum _CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW = 1 # cudaLimit enum _CUDA_LIMIT_PERSISTING_L2_CACHE_SIZE = 0x06 # cudaDeviceAttr enum _CUDA_DEV_ATTR_L2_CACHE_SIZE = 89 _CUDA_DEV_ATTR_MAX_PERSISTING_L2_CACHE_SIZE = 108 def _load_cudart(): """Load the CUDA runtime shared library, return handle or None.""" for name in ("libcudart.so", "libcudart.so.12", "libcudart.so.11.0"): try: return ctypes.CDLL(name) except OSError: continue try: path = ctypes.util.find_library("cudart") if path: return ctypes.CDLL(path) except (OSError, TypeError): pass return None _cudart = _load_cudart() class L2CacheManager: """ L2 Cache management for Blackwell GPUs via cudaAccessPolicyWindow. Uses the CUDA Runtime API (CUDA 11.0+, CC 8.0+) to configure hardware-level L2 persistence for hot tensors. RTX 5090 has 128 MB L2 cache; this manager reserves a configurable fraction for persisting data and applies per-stream access policy windows so the GPU's L2 replacement policy keeps the pinned data resident. Fallback: when the CUDA Runtime library is unavailable (e.g. CPU-only builds), the manager degrades to software-level tracking with no hardware side-effects, keeping the rest of the engine functional. """ def __init__(self, l2_size_mb: float = 128.0, reserve_fraction: float = 0.5, device: int = 0): self.device = device self.pinned_tensors: Dict[str, torch.Tensor] = {} self.pinned_bytes = 0 # --- Query actual device L2 geometry when possible ---------------- self._hw_available = False self.l2_size_bytes = int(l2_size_mb * 1024 * 1024) self.max_persisting_bytes = int(self.l2_size_bytes * reserve_fraction) if _cudart is not None and torch.cuda.is_available(): try: # Total L2 cache size val = ctypes.c_int(0) if _cudart.cudaDeviceGetAttribute( ctypes.byref(val), ctypes.c_int(_CUDA_DEV_ATTR_L2_CACHE_SIZE), ctypes.c_int(device), ) == 0 and val.value > 0: self.l2_size_bytes = val.value # Maximum persisting L2 size the hardware supports val2 = ctypes.c_int(0) if _cudart.cudaDeviceGetAttribute( ctypes.byref(val2), ctypes.c_int(_CUDA_DEV_ATTR_MAX_PERSISTING_L2_CACHE_SIZE), ctypes.c_int(device), ) == 0 and val2.value > 0: hw_max = val2.value else: hw_max = int(self.l2_size_bytes * 0.75) self.max_persisting_bytes = min( int(self.l2_size_bytes * reserve_fraction), hw_max ) self._hw_available = True except Exception: pass self.reserved_bytes = self.max_persisting_bytes # Apply the persisting limit once on construction self._set_persisting_limit(self.reserved_bytes) # ------------------------------------------------------------------ # Internal CUDA helpers # ------------------------------------------------------------------ def _set_persisting_limit(self, num_bytes: int) -> bool: """cudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, num_bytes).""" if not self._hw_available or _cudart is None: return False return _cudart.cudaDeviceSetLimit( ctypes.c_int(_CUDA_LIMIT_PERSISTING_L2_CACHE_SIZE), ctypes.c_size_t(num_bytes), ) == 0 def _apply_access_policy(self, tensor: torch.Tensor, hit_ratio: float, stream_ptr: int) -> bool: """Set a cudaAccessPolicyWindow on *stream_ptr* for *tensor*.""" if not self._hw_available or _cudart is None: return False window = _AccessPolicyWindow() window.base_ptr = tensor.data_ptr() window.num_bytes = min( tensor.numel() * tensor.element_size(), self.max_persisting_bytes, ) window.hitRatio = hit_ratio window.hitProp = _CUDA_ACCESS_PROPERTY_PERSISTING window.missProp = _CUDA_ACCESS_PROPERTY_STREAMING attr = _StreamAttrValue() attr.accessPolicyWindow = window return _cudart.cudaStreamSetAttribute( ctypes.c_void_p(stream_ptr), ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), ctypes.byref(attr), ) == 0 def _reset_stream_policy(self, stream_ptr: int) -> bool: """Clear the access policy window on a stream (set num_bytes=0).""" if not self._hw_available or _cudart is None: return False attr = _StreamAttrValue() attr.accessPolicyWindow = _AccessPolicyWindow() # zeroed — disables window return _cudart.cudaStreamSetAttribute( ctypes.c_void_p(stream_ptr), ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), ctypes.byref(attr), ) == 0 def _reset_persisting_l2(self) -> bool: """cudaCtxResetPersistingL2Cache — evict all persisting lines.""" if not self._hw_available or _cudart is None: return False return _cudart.cudaCtxResetPersistingL2Cache() == 0 # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def pin(self, name: str, tensor: torch.Tensor, hit_ratio: float = 1.0, stream: Optional[torch.cuda.Stream] = None) -> bool: """ Pin a tensor in L2 cache. Sets a ``cudaAccessPolicyWindow`` with ``hitProp=persisting`` on the given (or current) CUDA stream so that subsequent kernel accesses to *tensor*'s memory will be retained in L2. Args: name: Identifier for bookkeeping / stats. tensor: CUDA tensor to pin. hit_ratio: Fraction of accesses that should persist (0.0-1.0). stream: CUDA stream to attach the policy to (default: current). Returns: True if pinned (hardware or software), False if cache full. """ tensor_bytes = tensor.numel() * tensor.element_size() if self.pinned_bytes + tensor_bytes > self.reserved_bytes: return False if not tensor.is_contiguous(): tensor = tensor.contiguous() # Software tracking self.pinned_tensors[name] = tensor self.pinned_bytes += tensor_bytes # Hardware L2 persistence if tensor.is_cuda and self._hw_available: stream_ptr = ( stream.cuda_stream if stream is not None else torch.cuda.current_stream(self.device).cuda_stream ) self._apply_access_policy(tensor, hit_ratio, stream_ptr) return True def unpin(self, name: str, stream: Optional[torch.cuda.Stream] = None) -> bool: """Remove tensor from L2 pinning and clear its access policy.""" if name not in self.pinned_tensors: return False tensor = self.pinned_tensors.pop(name) self.pinned_bytes -= tensor.numel() * tensor.element_size() # Clear the access policy on the stream if self._hw_available: stream_ptr = ( stream.cuda_stream if stream is not None else torch.cuda.current_stream(self.device).cuda_stream ) self._reset_stream_policy(stream_ptr) return True def apply_to_stream(self, name: str, hit_ratio: float = 1.0, stream: Optional[torch.cuda.Stream] = None) -> bool: """ Re-apply (or switch) the access policy for a pinned tensor on a specific stream. Useful when rotating which tensor is persisted before a kernel launch. """ tensor = self.pinned_tensors.get(name) if tensor is None or not tensor.is_cuda: return False stream_ptr = ( stream.cuda_stream if stream is not None else torch.cuda.current_stream(self.device).cuda_stream ) return self._apply_access_policy(tensor, hit_ratio, stream_ptr) def get(self, name: str) -> Optional[torch.Tensor]: """Get pinned tensor by name.""" return self.pinned_tensors.get(name) def stats(self) -> Dict[str, Any]: """Get cache statistics.""" return { 'l2_size_mb': self.l2_size_bytes / (1024 * 1024), 'reserved_mb': self.reserved_bytes / (1024 * 1024), 'pinned_mb': self.pinned_bytes / (1024 * 1024), 'utilization': ( self.pinned_bytes / self.reserved_bytes if self.reserved_bytes > 0 else 0 ), 'num_tensors': len(self.pinned_tensors), 'tensor_names': list(self.pinned_tensors.keys()), 'hw_pinning': self._hw_available, } def clear(self): """Clear all pinned tensors and reset hardware L2 persistence.""" self.pinned_tensors.clear() self.pinned_bytes = 0 self._reset_persisting_l2() @triton.jit def _l2_persistent_load_kernel( src_ptr, dst_ptr, N, BLOCK: tl.constexpr, ): """Kernel that loads with L2 persistence hint (evict_last).""" pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < N data = tl.load(src_ptr + offs, mask=mask, other=0.0, eviction_policy="evict_last") tl.store(dst_ptr + offs, data, mask=mask, eviction_policy="evict_first") def prefetch_to_l2(tensor: torch.Tensor, stream: Optional[torch.cuda.Stream] = None) -> torch.Tensor: """ Prefetch tensor to L2 cache with persistence hints. Sets a ``cudaAccessPolicyWindow`` with ``hitProp=persisting`` on the stream, then runs a Triton copy kernel with ``evict_last`` loads to pull the data into L2 and mark it for retention. The access policy window is cleared after the prefetch completes so it does not interfere with subsequent kernels. Args: tensor: CUDA tensor to prefetch. stream: CUDA stream (default: current). Returns: The (contiguous) input tensor — no copy is made. """ if not tensor.is_cuda: return tensor if not tensor.is_contiguous(): tensor = tensor.contiguous() N = tensor.numel() stream_ptr = ( stream.cuda_stream if stream is not None else torch.cuda.current_stream().cuda_stream ) # Set persisting access policy for this tensor's memory range hw_active = False if _cudart is not None: window = _AccessPolicyWindow() window.base_ptr = tensor.data_ptr() window.num_bytes = N * tensor.element_size() window.hitRatio = 1.0 window.hitProp = _CUDA_ACCESS_PROPERTY_PERSISTING window.missProp = _CUDA_ACCESS_PROPERTY_STREAMING attr = _StreamAttrValue() attr.accessPolicyWindow = window hw_active = _cudart.cudaStreamSetAttribute( ctypes.c_void_p(stream_ptr), ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), ctypes.byref(attr), ) == 0 # Run the Triton load kernel to pull data through L2 with evict_last BLOCK = 1024 grid = ((N + BLOCK - 1) // BLOCK,) _l2_persistent_load_kernel[grid](tensor, tensor, N, BLOCK=BLOCK) # Clear the access policy window so it doesn't affect later launches if hw_active: empty_attr = _StreamAttrValue() _cudart.cudaStreamSetAttribute( ctypes.c_void_p(stream_ptr), ctypes.c_int(_CUDA_STREAM_ATTR_ACCESS_POLICY_WINDOW), ctypes.byref(empty_attr), ) return tensor # ============================================================================ # TMA-STYLE MATMUL - Block Pointer API for async memory loads # Uses tl.make_block_ptr for TMA-like behavior on Blackwell # ============================================================================ @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=5, num_warps=4), ], key=['M', 'N', 'K'], ) @triton.jit def _tma_matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """ TMA-style MatMul using block pointers for async memory access. Benefits: - Async DDR7/HBM -> SMEM loads (overlapped with compute) - Hardware-managed address generation - Better memory coalescing """ pid_m = tl.program_id(0) pid_n = tl.program_id(1) # Create block pointers (TMA-like) a_block_ptr = tl.make_block_ptr( base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0) ) b_block_ptr = tl.make_block_ptr( base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, pid_n * BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0) ) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # Main loop with block pointer advance for _ in range(0, tl.cdiv(K, BLOCK_K)): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) acc += tl.dot(a, b) a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) # Store with block pointer c_block_ptr = tl.make_block_ptr( base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) tl.store(c_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) def tma_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ TMA-style MatMul using block pointers. Uses tl.make_block_ptr for TMA-like async memory access on Blackwell. """ try: M, K = a.shape K2, N = b.shape assert K == K2 c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) _tma_matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), ) return c except Exception: return torch.matmul(a, b).to(torch.bfloat16) # ============================================================================ # BLACKWELL 2-CTA MMA MODE - Cooperative Matrix Multiplication # Uses num_ctas=2 for 2-thread-block cooperative MMA on SM100/SM120 # ============================================================================ @triton.autotune( configs=[ # 2-CTA configurations for Blackwell cooperative MMA (up to 11% faster) triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8, num_ctas=2), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=3, num_warps=8, num_ctas=2), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=3, num_warps=8, num_ctas=2), triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4, num_ctas=2), ], key=['M', 'N', 'K'], ) @triton.jit def _blackwell_2cta_matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """ Blackwell 2-CTA MMA Mode MatMul. On SM100/SM120 (Blackwell), num_ctas=2 enables cooperative matrix multiplication where two thread blocks work together on a single dense MMA operation. Benefits: - Up to 11% faster than 1-CTA mode on medium matrices (2K-4K) - 116% of cuBLAS performance at 4096x4096 - Better utilization of Blackwell's tensor cores """ pid_m = tl.program_id(0) pid_n = tl.program_id(1) # Use block pointers for TMA-like async loads a_block_ptr = tl.make_block_ptr( base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0) ) b_block_ptr = tl.make_block_ptr( base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, pid_n * BLOCK_N), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0) ) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # Main loop with cooperative MMA for _ in range(0, tl.cdiv(K, BLOCK_K)): a = tl.load(a_block_ptr, boundary_check=(0, 1)) b = tl.load(b_block_ptr, boundary_check=(0, 1)) acc += tl.dot(a, b) a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) # Store result c_block_ptr = tl.make_block_ptr( base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) ) tl.store(c_block_ptr, acc.to(tl.bfloat16), boundary_check=(0, 1)) def blackwell_2cta_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Blackwell 2-CTA MMA Mode MatMul. Uses cooperative matrix multiplication on SM100/SM120 (Blackwell). Two thread blocks work together on a single MMA operation. Performance: - 2048x2048: 164 TFLOPS (11% faster than 1-CTA) - 4096x4096: 191 TFLOPS (116% of cuBLAS!) """ try: M, K = a.shape K2, N = b.shape assert K == K2 c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) _blackwell_2cta_matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), ) return c except Exception: return torch.matmul(a, b).to(torch.bfloat16) # ============================================================================ # FP8 GEMM - Hardware FP8 tensor core operations # Uses torch._scaled_mm for native FP8 support # ============================================================================ def fp8_matmul(a: torch.Tensor, b: torch.Tensor, scale_a: float = 1.0, scale_b: float = 1.0) -> torch.Tensor: """ FP8 MatMul using PyTorch's native FP8 support. Uses torch._scaled_mm which maps to cuBLAS FP8 tensor core operations. Blackwell has 2x FP8 throughput vs FP16. """ try: # Convert to FP8 a_fp8 = a.to(torch.float8_e4m3fn) b_fp8 = b.to(torch.float8_e4m3fn) # Scale tensors scale_a_t = torch.tensor(scale_a, device=a.device, dtype=torch.float32) scale_b_t = torch.tensor(scale_b, device=b.device, dtype=torch.float32) # Use scaled_mm for FP8 matmul result = torch._scaled_mm( a_fp8, b_fp8, scale_a=scale_a_t, scale_b=scale_b_t, out_dtype=torch.bfloat16, ) return result except Exception: # Fallback to standard matmul return torch.matmul(a, b).to(torch.bfloat16) @triton.jit def _fp8_matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, scale_a, scale_b, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Triton FP8 MatMul kernel using tl.float8e4nv.""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K)): k_off = k * BLOCK_K a_mask = (offs_m[:, None] < M) & ((offs_k[None, :] + k_off) < K) b_mask = ((offs_k[:, None] + k_off) < K) & (offs_n[None, :] < N) # Load as FP8, compute in FP32 a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) * scale_a b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) * scale_b acc += tl.dot(a.to(tl.float16), b.to(tl.float16)) a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc.to(tl.bfloat16), mask=c_mask) # ============================================================================ # FUSED STELLAR UPDATE KERNEL (Phase 4 — research-backed tuning) # Fuses 5+ element-wise ops into one kernel launch: # fast_weight += lr * eligibility * reward * modulation * hebbian_delta # eligibility = decay * eligibility + (1-decay) * new_trace + noise # eligibility = clamp(eligibility, -clip, clip) # fast_weight = clamp(fast_weight, -wclip, wclip) # ============================================================================ @triton.jit def _fused_stellar_update_kernel( fw_ptr, elig_ptr, delta_ptr, theta_ptr, M, D, lr: tl.constexpr, decay: tl.constexpr, tau_fast: tl.constexpr, r_mod: tl.constexpr, mod_gate: tl.constexpr, trace_clip: tl.constexpr, weight_clip: tl.constexpr, noise_scale: tl.constexpr, stride_m: tl.constexpr, stride_d: tl.constexpr, BLOCK_D: tl.constexpr, ): """Fused STELLAR update: trace update + weight update in one kernel.""" pid_m = tl.program_id(0) offs_d = tl.arange(0, BLOCK_D) mask = offs_d < D base = pid_m * stride_m # Load current values elig = tl.load(elig_ptr + base + offs_d * stride_d, mask=mask, other=0.0) theta = tl.load(theta_ptr + base + offs_d * stride_d, mask=mask, other=0.0) delta = tl.load(delta_ptr + base + offs_d * stride_d, mask=mask, other=0.0) fw = tl.load(fw_ptr + base + offs_d * stride_d, mask=mask, other=0.0) # Noise injection on theta if noise_scale > 0: # Simple noise approximation (Triton doesn't have randn) # Use hash-based pseudo-random noise seed = pid_m * 997 + offs_d * 31 noise = tl.sin(seed.to(tl.float32) * 12.9898) * 43758.5453 noise = (noise - tl.floor(noise)) * 2.0 - 1.0 # uniform [-1, 1] theta = theta + noise * noise_scale # Eligibility trace update: elig = tau_fast * elig + theta elig = tau_fast * elig + theta # Trace clipping if trace_clip > 0: elig = tl.minimum(tl.maximum(elig, -trace_clip), trace_clip) # Three-factor weight update: delta_w = delta * mod_gate * (r_mod * elig) delta_w = delta * mod_gate * (r_mod * elig) # Weight update: fw = decay * fw + delta_w fw = decay * fw + delta_w # Weight clipping fw = tl.minimum(tl.maximum(fw, -weight_clip), weight_clip) # Store tl.store(elig_ptr + base + offs_d * stride_d, elig, mask=mask) tl.store(fw_ptr + base + offs_d * stride_d, fw, mask=mask) def fused_stellar_update( fast_weight: torch.Tensor, eligibility: torch.Tensor, hebbian_delta: torch.Tensor, theta: torch.Tensor, lr: float, decay: float, tau_fast: float, r_mod: float, mod_gate: float, trace_clip: float, weight_clip: float, noise_scale: float, ) -> None: """Python wrapper for fused STELLAR Triton kernel.""" M, D = fast_weight.shape BLOCK_D = triton.next_power_of_2(D) grid = (M,) _fused_stellar_update_kernel[grid]( fast_weight, eligibility, hebbian_delta, theta, M, D, lr=lr, decay=decay, tau_fast=tau_fast, r_mod=r_mod, mod_gate=mod_gate, trace_clip=trace_clip, weight_clip=weight_clip, noise_scale=noise_scale, stride_m=fast_weight.stride(0), stride_d=fast_weight.stride(1), BLOCK_D=BLOCK_D, ) # ============================================================================ # FUSED ATTENTION + FFN MEGA-KERNEL # Combines RMSNorm -> Attention -> FFN in single kernel launch # ============================================================================ class MegaFusedTransformerBlock(nn.Module): """ Mega-Fused Transformer Block: combines attention and FFN with minimal kernel launches. Operations fused: 1. RMSNorm1 + QKV projection (fused) 2. Attention (SDPA) 3. Output projection + Residual 4. RMSNorm2 + SwiGLU FFN (fused) 5. Down projection + Residual This reduces kernel launches from 10+ to 4-5. """ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, intermediate_size: int, head_dim: Optional[int] = None, eps: float = 1e-6): super().__init__() self.dim = dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim or dim // num_heads self.intermediate_size = intermediate_size self.eps = eps # Attention weights (can be fused QKV) self.wq = nn.Linear(dim, num_heads * self.head_dim, bias=False) self.wk = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(num_heads * self.head_dim, dim, bias=False) # FFN weights (can be fused gate+up) self.w_gate = nn.Linear(dim, intermediate_size, bias=False) self.w_up = nn.Linear(dim, intermediate_size, bias=False) self.w_down = nn.Linear(intermediate_size, dim, bias=False) # RMSNorm weights (fused into projections) self.norm1_weight = nn.Parameter(torch.ones(dim)) self.norm2_weight = nn.Parameter(torch.ones(dim)) def _rms_norm(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """Fused RMSNorm.""" variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return x * weight def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass with maximum fusion. Fusion strategy: - Fuse RMSNorm into QKV projection launch - Use fused SwiGLU for FFN - Minimize intermediate tensor allocations """ B, S, D = x.shape residual = x # Fused: RMSNorm1 + QKV projection x_norm = self._rms_norm(x, self.norm1_weight) x_flat = x_norm.view(-1, D) # Use fused QKV projection if available q, k, v = fused_qkv_projection( x_flat, self.wq.weight.T, self.wk.weight.T, self.wv.weight.T ) # Reshape for attention q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) # GQA expansion if needed if self.num_kv_heads < self.num_heads: n_rep = self.num_heads // self.num_kv_heads k = k.repeat_interleave(n_rep, dim=1) v = v.repeat_interleave(n_rep, dim=1) # Scaled dot-product attention (fused in PyTorch) attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask) attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1) # Output projection + residual x = residual + self.wo(attn_out) residual = x # Fused: RMSNorm2 + SwiGLU FFN x_norm = self._rms_norm(x, self.norm2_weight) x_flat = x_norm.view(-1, D) # Use fused SwiGLU ffn_out = fused_swiglu(x_flat, self.w_gate.weight.T, self.w_up.weight.T) ffn_out = self.w_down(ffn_out.view(B, S, -1)) # Final residual return residual + ffn_out def fused_transformer_block_forward( x: torch.Tensor, wq: torch.Tensor, wk: torch.Tensor, wv: torch.Tensor, wo: torch.Tensor, w_gate: torch.Tensor, w_up: torch.Tensor, w_down: torch.Tensor, norm1_weight: torch.Tensor, norm2_weight: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int, eps: float = 1e-6, ) -> torch.Tensor: """ Functional interface for fused transformer block. Minimizes kernel launches by fusing: - RMSNorm + QKV projection - SwiGLU activation """ B, S, D = x.shape residual = x # RMSNorm1 variance = x.pow(2).mean(-1, keepdim=True) x_norm = x * torch.rsqrt(variance + eps) * norm1_weight # Fused QKV x_flat = x_norm.view(-1, D) q, k, v = fused_qkv_projection(x_flat, wq, wk, wv) # Reshape and attention q = q.view(B, S, num_heads, head_dim).transpose(1, 2) k = k.view(B, S, num_kv_heads, head_dim).transpose(1, 2) v = v.view(B, S, num_kv_heads, head_dim).transpose(1, 2) if num_kv_heads < num_heads: n_rep = num_heads // num_kv_heads k = k.repeat_interleave(n_rep, dim=1) v = v.repeat_interleave(n_rep, dim=1) attn = F.scaled_dot_product_attention(q, k, v) attn = attn.transpose(1, 2).contiguous().view(B, S, -1) x = residual + F.linear(attn, wo.T) # RMSNorm2 + Fused SwiGLU residual = x variance = x.pow(2).mean(-1, keepdim=True) x_norm = x * torch.rsqrt(variance + eps) * norm2_weight x_flat = x_norm.view(-1, D) ffn = fused_swiglu(x_flat, w_gate, w_up) ffn = F.linear(ffn.view(B, S, -1), w_down.T) return residual + ffn # ============================================================================ # NVFP4 QUANTIZATION - Dual-Scaling 4-bit # ============================================================================ def quantize_nvfp4(x: torch.Tensor, block_size: int = 32) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Quantize to NVFP4 with dual-scaling (block + global).""" original_shape = x.shape x_flat = x.view(-1) # Pad to block_size pad_len = (block_size - x_flat.numel() % block_size) % block_size if pad_len > 0: x_flat = F.pad(x_flat, (0, pad_len)) x_blocks = x_flat.view(-1, block_size) # Per-block scale (E4M3 range: ±448) block_max = x_blocks.abs().max(dim=1, keepdim=True).values.clamp(min=1e-8) block_scale = block_max / 7.0 # FP4 range [-8, 7] # Global scale global_scale = block_scale.max().clamp(min=1e-8) # Quantize x_scaled = x_blocks / (block_scale * global_scale) x_clipped = x_scaled.clamp(-8.0, 7.0) x_quantized = x_clipped.round().to(torch.int8) return x_quantized, block_scale.squeeze(-1), global_scale def dequantize_nvfp4(q: torch.Tensor, block_scale: torch.Tensor, global_scale: torch.Tensor, original_shape: torch.Size, dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: """Dequantize from NVFP4 (memory-efficient).""" # Use bfloat16 instead of float32 to save memory (2x reduction) # Process in chunks if tensor is large numel = original_shape.numel() if numel > 10_000_000: # Large tensor - process in chunks chunk_size = 1_000_000 # 1M elements per chunk num_blocks = q.shape[0] block_size = q.shape[1] result = torch.empty(num_blocks * block_size, dtype=dtype, device=q.device) for i in range(0, num_blocks, chunk_size // block_size): end_i = min(i + chunk_size // block_size, num_blocks) chunk_q = q[i:end_i].to(dtype) chunk_scale = block_scale[i:end_i].unsqueeze(-1) result[i*block_size:end_i*block_size] = (chunk_q * chunk_scale * global_scale).view(-1) return result[:numel].view(original_shape) else: # Small tensor - direct computation in target dtype x = q.to(dtype) * block_scale.unsqueeze(-1).to(dtype) * global_scale.to(dtype) return x.view(-1)[:numel].view(original_shape) # ============================================================================ # FAST WALSH-HADAMARD TRANSFORM (FWHT) # ============================================================================ def fwht_forward(x: torch.Tensor) -> torch.Tensor: """Vectorized FWHT for outlier flattening.""" n = x.shape[-1] if n & (n - 1) != 0: next_pow2 = 1 << (n - 1).bit_length() x = F.pad(x, (0, next_pow2 - n)) n = next_pow2 h = 1 while h < n: x = x.view(*x.shape[:-1], n // (2 * h), 2, h) a, b = x[..., 0, :], x[..., 1, :] x = torch.stack([a + b, a - b], dim=-2).view(*x.shape[:-3], n) h *= 2 return x / math.sqrt(n) def fwht_inverse(x: torch.Tensor) -> torch.Tensor: """Inverse FWHT.""" return fwht_forward(x) # FWHT is self-inverse (up to scaling) # ============================================================================ # W4A4 ACTIVATION QUANTIZATION # ============================================================================ def _apply_act_quant(x: torch.Tensor) -> torch.Tensor: """Hadamard + FP4 quant/dequant on activations (W4A4 mode). Applies FWHT to spread outliers, then round-trips through Goliath FP4 quantization to simulate 4-bit activation precision. When Goliath is unavailable, falls back to simple clamp-and-round. """ M, K = x.shape # Hadamard smoothing (power-of-two block size) if K >= 32: x_f = x.float().view(M, K // 32, 32) x_f = fwht_forward(x_f) x_f = x_f.view(M, K) else: x_f = x.float() if _GOLIATH_AVAILABLE: from goliath_kernel import GoliathFP4Weights as _GFP4W # Round-trip: float → FP4 packed → float (simulates 4-bit precision) x_q = _GFP4W.from_float(x_f.T.contiguous()) return x_q.to_float().T.contiguous().to(x.dtype) # Fallback: 4-bit symmetric clamp-and-round (16 levels) amax = x_f.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) scale = amax / 6.0 # E2M1 max positive = 6.0 x_scaled = (x_f / scale).clamp(-6.0, 6.0).round() return (x_scaled * scale).to(x.dtype) # ============================================================================ # STOCHASTIC ROUNDING # ============================================================================ def stochastic_round(x: torch.Tensor, seed: int = 42) -> torch.Tensor: """Stochastic rounding with Philox-like randomness.""" gen = torch.Generator(device=x.device).manual_seed(seed) rand = torch.rand_like(x, generator=gen) floor = x.floor() frac = x - floor return floor + (rand < frac).float() # ============================================================================ # PAGED KV CACHE - Real Implementation # ============================================================================ class PagedKVCache: """ Real vLLM-style paged KV cache with proper block management. Features: - Lazy allocation (only allocate when needed) - Block-level granularity (16 tokens per block default) - Per-sequence block tables - Efficient memory reuse - Supports 500k+ tokens with proper config """ def __init__(self, num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_dim: int, dtype: torch.dtype = torch.bfloat16, device: str = 'cuda'): self.num_blocks = num_blocks self.block_size = block_size self.num_layers = num_layers self.num_heads = num_heads self.head_dim = head_dim self.dtype = dtype self.device = device # Lazy allocation - None until first use self.k_cache: Optional[torch.Tensor] = None # [num_blocks, num_layers, num_heads, block_size, head_dim] self.v_cache: Optional[torch.Tensor] = None # Block management self.free_blocks: List[int] = list(range(num_blocks)) self.block_tables: Dict[int, List[int]] = {} # seq_id -> list of block_ids self.seq_lengths: Dict[int, int] = {} # seq_id -> current length self._allocated = False def _allocate_cache(self): """Lazy allocation of cache tensors.""" if not self._allocated: self.k_cache = torch.zeros( self.num_blocks, self.num_layers, self.num_heads, self.block_size, self.head_dim, dtype=self.dtype, device=self.device ) self.v_cache = torch.zeros_like(self.k_cache) self._allocated = True def _get_block_for_position(self, seq_id: int, position: int) -> Tuple[int, int]: """Get block_id and local position for a sequence position.""" block_idx = position // self.block_size local_pos = position % self.block_size # Ensure we have enough blocks if seq_id not in self.block_tables: self.block_tables[seq_id] = [] self.seq_lengths[seq_id] = 0 while len(self.block_tables[seq_id]) <= block_idx: if not self.free_blocks: raise RuntimeError(f"KV cache full! Max tokens: {self.num_blocks * self.block_size}") new_block = self.free_blocks.pop(0) self.block_tables[seq_id].append(new_block) return self.block_tables[seq_id][block_idx], local_pos def store(self, seq_id: int, layer_idx: int, position: int, k: torch.Tensor, v: torch.Tensor): """ Store KV at position for a sequence (optimized batch storage). Args: seq_id: Sequence identifier layer_idx: Layer index position: Token position in sequence k: Key tensor [num_kv_heads, head_dim] or [num_kv_heads, seq_len, head_dim] v: Value tensor (same shape as k) """ self._allocate_cache() if k.dim() == 2: # Single token: [num_kv_heads, head_dim] block_id, local_pos = self._get_block_for_position(seq_id, position) self.k_cache[block_id, layer_idx, :, local_pos, :] = k self.v_cache[block_id, layer_idx, :, local_pos, :] = v self.seq_lengths[seq_id] = max(self.seq_lengths.get(seq_id, 0), position + 1) else: # Multiple tokens: [num_kv_heads, seq_len, head_dim] num_tokens = k.shape[1] # Pre-allocate blocks for entire sequence end_pos = position + num_tokens end_block_idx = (end_pos - 1) // self.block_size if seq_id not in self.block_tables: self.block_tables[seq_id] = [] self.seq_lengths[seq_id] = 0 # Allocate all needed blocks upfront while len(self.block_tables[seq_id]) <= end_block_idx: if not self.free_blocks: raise RuntimeError(f"KV cache full! Max tokens: {self.num_blocks * self.block_size}") self.block_tables[seq_id].append(self.free_blocks.pop(0)) # Vectorized scatter — store all tokens in one operation positions = torch.arange(position, position + num_tokens, device=self.device) block_indices = positions // self.block_size local_positions = positions % self.block_size table_t = torch.tensor(self.block_tables[seq_id], dtype=torch.long, device=self.device) block_ids = table_t[block_indices] # k: [num_kv_heads, num_tokens, head_dim] -> permute to [num_tokens, num_kv_heads, head_dim] self.k_cache[block_ids, layer_idx, :, local_positions, :] = k.permute(1, 0, 2) self.v_cache[block_ids, layer_idx, :, local_positions, :] = v.permute(1, 0, 2) self.seq_lengths[seq_id] = max(self.seq_lengths.get(seq_id, 0), end_pos) def get(self, seq_id: int, layer_idx: int, num_tokens: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Retrieve KV for a sequence (optimized batch retrieval). Args: seq_id: Sequence identifier layer_idx: Layer index num_tokens: Number of tokens to retrieve (default: all) Returns: k: [num_heads, seq_len, head_dim] v: [num_heads, seq_len, head_dim] """ self._allocate_cache() if seq_id not in self.block_tables: return (torch.empty(self.num_heads, 0, self.head_dim, dtype=self.dtype, device=self.device), torch.empty(self.num_heads, 0, self.head_dim, dtype=self.dtype, device=self.device)) seq_len = num_tokens or self.seq_lengths.get(seq_id, 0) if seq_len == 0: return (torch.empty(self.num_heads, 0, self.head_dim, dtype=self.dtype, device=self.device), torch.empty(self.num_heads, 0, self.head_dim, dtype=self.dtype, device=self.device)) # Vectorized gather — compute all block IDs and offsets as tensors block_tables = self.block_tables[seq_id] positions = torch.arange(seq_len, device=self.device) block_indices = positions // self.block_size local_positions = positions % self.block_size table_t = torch.tensor(block_tables, dtype=torch.long, device=self.device) block_ids = table_t[block_indices] # k_cache: [num_blocks, num_layers, num_heads, block_size, head_dim] # Advanced indexing: [seq_len, num_heads, head_dim] -> permute to [num_heads, seq_len, head_dim] k_out = self.k_cache[block_ids, layer_idx, :, local_positions, :].permute(1, 0, 2).contiguous() v_out = self.v_cache[block_ids, layer_idx, :, local_positions, :].permute(1, 0, 2).contiguous() return k_out, v_out def get_contiguous(self, seq_id: int, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get KV as contiguous tensors (optimized for attention).""" return self.get(seq_id, layer_idx) def free_sequence(self, seq_id: int): """Free all blocks for a sequence.""" if seq_id in self.block_tables: self.free_blocks.extend(self.block_tables[seq_id]) del self.block_tables[seq_id] del self.seq_lengths[seq_id] def clear(self): """Clear all sequences and reset cache.""" self.free_blocks = list(range(self.num_blocks)) self.block_tables.clear() self.seq_lengths.clear() # Reset flat decode cache position (keep tensors allocated) if hasattr(self, '_flat_pos'): self._flat_pos = 0 # ================================================================= # Flat Decode Cache — zero-copy, CUDA-graph-safe KV storage # ================================================================= # For single-token decode: direct tensor write at position, return # view slice. No Python dicts, no torch.cat, no tensor allocation. # Pre-allocated for max_seq_len — same memory as paged cache. def enable_flat_decode(self, max_seq_len: int = 4096, kv_dtype: str = 'bf16'): """Enable flat decode cache for high-speed single-token decode. Pre-allocates [num_layers, num_kv_heads, max_seq_len, head_dim] tensors. During decode: direct write + view slice (zero copy, zero torch.cat). Same memory as paged cache, just contiguous layout. Args: max_seq_len: Maximum sequence length to support kv_dtype: 'bf16' (default, full precision) or 'fp8' (E4M3, 50% VRAM) """ # Guard: flat KV length must not exceed the paged KV capacity paged_capacity = self.num_blocks * self.block_size if hasattr(self, 'num_blocks') else max_seq_len if max_seq_len > paged_capacity: print(f" [Flat KV] WARNING: flat len {max_seq_len} > paged capacity " f"{paged_capacity}. This causes NaN! Clamping to {paged_capacity}.") max_seq_len = paged_capacity self._flat_kv_dtype = kv_dtype self._flat_mode = True self._flat_pos = 0 self._flat_max_len = max_seq_len if kv_dtype == 'fp8': # FP8 E4M3 mode: 50% VRAM savings, ~0.3% quality loss # K/V stored as uint8, per-token scales stored as float32 self.flat_k = torch.zeros( self.num_layers, self.num_heads, max_seq_len, self.head_dim, dtype=torch.uint8, device=self.device) self.flat_v = torch.zeros_like(self.flat_k) # Per-token scales: one scale per (layer, head, position) self.flat_k_scales = torch.ones( self.num_layers, self.num_heads, max_seq_len, dtype=torch.float32, device=self.device) self.flat_v_scales = torch.ones_like(self.flat_k_scales) flat_mb = (self.flat_k.numel() + self.flat_k_scales.numel() * 4) * 2 / 1e6 print(f" [Flat KV FP8] Enabled: {max_seq_len} tokens, {flat_mb:.0f} MB (50% savings)") else: # BF16 mode: full precision self.flat_k = torch.zeros( self.num_layers, self.num_heads, max_seq_len, self.head_dim, dtype=self.dtype, device=self.device) self.flat_v = torch.zeros_like(self.flat_k) self.flat_k_scales = None self.flat_v_scales = None flat_mb = self.flat_k.numel() * self.flat_k.element_size() * 2 / 1e6 print(f" [Flat KV] Enabled: {max_seq_len} tokens, {flat_mb:.0f} MB") def store_flat(self, layer_idx: int, position: int, k: torch.Tensor, v: torch.Tensor): """Direct write to flat cache — no Python logic, no allocation. For FP8 mode: quantizes K/V to E4M3 with per-head-position scales. For BF16 mode: direct copy (original behavior). """ # k: [kv_heads, seq_len, head_dim] or [kv_heads, head_dim] if getattr(self, '_flat_kv_dtype', 'bf16') == 'fp8': # FP8 quantization path if k.dim() == 2: # Single position: [kv_heads, head_dim] k_fp32 = k.float() v_fp32 = v.float() # Per-head absmax scaling k_absmax = k_fp32.abs().amax(dim=-1).clamp(min=1e-10) # [kv_heads] v_absmax = v_fp32.abs().amax(dim=-1).clamp(min=1e-10) k_scale = k_absmax / 448.0 # E4M3 max value v_scale = v_absmax / 448.0 # Quantize to E4M3 using goliath's encode function k_scaled = k_fp32 / k_scale.unsqueeze(-1) v_scaled = v_fp32 / v_scale.unsqueeze(-1) k_fp8 = _encode_e4m3_kv(k_scaled) # [kv_heads, head_dim] uint8 v_fp8 = _encode_e4m3_kv(v_scaled) # Store quantized values + scales self.flat_k[layer_idx, :, position, :] = k_fp8 self.flat_v[layer_idx, :, position, :] = v_fp8 self.flat_k_scales[layer_idx, :, position] = k_scale self.flat_v_scales[layer_idx, :, position] = v_scale else: # Multiple positions: [kv_heads, seq_len, head_dim] seq_len = k.shape[1] k_fp32 = k.float() v_fp32 = v.float() # Per-head-position absmax scaling k_absmax = k_fp32.abs().amax(dim=-1).clamp(min=1e-10) # [kv_heads, seq_len] v_absmax = v_fp32.abs().amax(dim=-1).clamp(min=1e-10) k_scale = k_absmax / 448.0 v_scale = v_absmax / 448.0 k_scaled = k_fp32 / k_scale.unsqueeze(-1) v_scaled = v_fp32 / v_scale.unsqueeze(-1) k_fp8 = _encode_e4m3_kv(k_scaled) v_fp8 = _encode_e4m3_kv(v_scaled) self.flat_k[layer_idx, :, position:position + seq_len, :] = k_fp8 self.flat_v[layer_idx, :, position:position + seq_len, :] = v_fp8 self.flat_k_scales[layer_idx, :, position:position + seq_len] = k_scale self.flat_v_scales[layer_idx, :, position:position + seq_len] = v_scale else: # BF16 path: direct copy (original behavior) if k.dim() == 2: self.flat_k[layer_idx, :, position, :] = k self.flat_v[layer_idx, :, position, :] = v else: seq_len = k.shape[1] self.flat_k[layer_idx, :, position:position + seq_len, :] = k self.flat_v[layer_idx, :, position:position + seq_len, :] = v def get_flat_view(self, layer_idx: int, end_pos: int): """Return views into flat cache — zero copy for BF16, dequant for FP8. For FP8 mode: dequantizes on-the-fly using stored scales. Returns BF16 tensors regardless of storage format. """ if getattr(self, '_flat_kv_dtype', 'bf16') == 'fp8': # FP8 dequant path k_fp8 = self.flat_k[layer_idx, :, :end_pos, :] # [heads, seq, dim] uint8 v_fp8 = self.flat_v[layer_idx, :, :end_pos, :] k_scales = self.flat_k_scales[layer_idx, :, :end_pos] # [heads, seq] v_scales = self.flat_v_scales[layer_idx, :, :end_pos] # Dequant: decode E4M3 then multiply by scale k_dequant = _decode_e4m3_kv(k_fp8) * k_scales.unsqueeze(-1) v_dequant = _decode_e4m3_kv(v_fp8) * v_scales.unsqueeze(-1) return (k_dequant.to(self.dtype), v_dequant.to(self.dtype)) else: # BF16 path: zero copy views (original behavior) return (self.flat_k[layer_idx, :, :end_pos, :], self.flat_v[layer_idx, :, :end_pos, :]) # ================================================================= # FireEcho CUDA Graph Decode — scatter-based KV for graph replay # ================================================================= # All position updates via GPU tensors — zero CPU-GPU sync. # Scatter-based KV writes + full-length views + attention masking # give fixed-shape ops compatible with CUDA graph capture/replay. def enable_cuda_graph(self, max_seq_len: int = 4096, kv_dtype: str = 'bf16'): """Enable CUDA graph-compatible decode mode. Extends flat decode with GPU position tensor, attention mask, and RoPE side buffers for graph-safe position handling. Args: max_seq_len: Maximum sequence length to support kv_dtype: 'bf16' (default) or 'fp8' (50% VRAM, 8% faster) """ if not getattr(self, '_flat_mode', False): self.enable_flat_decode(max_seq_len, kv_dtype=kv_dtype) self._graph_mode = False # Activated during capture/replay self._graph_position = torch.zeros(1, dtype=torch.long, device=self.device) self._graph_attn_mask = torch.full( (1, 1, 1, max_seq_len), float('-inf'), dtype=self.dtype, device=self.device) half_dim = self.head_dim // 2 self._graph_rope_cos = torch.zeros( 1, half_dim, dtype=self.dtype, device=self.device) self._graph_rope_sin = torch.zeros( 1, half_dim, dtype=self.dtype, device=self.device) self._graph_scatter_idx = torch.zeros( self.num_heads, 1, self.head_dim, dtype=torch.long, device=self.device) # FlashDecode valid length — how many KV positions to read self._graph_valid_len = torch.zeros( 1, dtype=torch.long, device=self.device) graph_mb = (self._graph_attn_mask.nelement() * 2 + self._graph_rope_cos.nelement() * 4) / 1e6 print(f" [CUDA Graph] KV graph mode ready ({graph_mb:.0f} MB)") def store_flat_scatter(self, layer_idx: int, k: torch.Tensor, v: torch.Tensor): """Scatter-based KV store — CUDA graph safe. Position read from _graph_scatter_idx (GPU tensor), not Python int. k: [kv_heads, head_dim] or [kv_heads, 1, head_dim] For FP8 mode: quantizes K/V to E4M3 then scatters both values + scales. """ if k.dim() == 2: k = k.unsqueeze(1) # [kv_heads, 1, head_dim] v = v.unsqueeze(1) if getattr(self, '_flat_kv_dtype', 'bf16') == 'fp8': # FP8 quantization + scatter k_fp32 = k.float() v_fp32 = v.float() # Per-head absmax scaling: [kv_heads, 1] k_absmax = k_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) v_absmax = v_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-10) k_scale = (k_absmax / 448.0).squeeze(-1) # [kv_heads, 1] v_scale = (v_absmax / 448.0).squeeze(-1) # Quantize to E4M3 k_scaled = k_fp32 / k_absmax * 448.0 v_scaled = v_fp32 / v_absmax * 448.0 k_fp8 = _encode_e4m3_kv(k_scaled) # [kv_heads, 1, head_dim] uint8 v_fp8 = _encode_e4m3_kv(v_scaled) # Scatter quantized values self.flat_k[layer_idx].scatter_(1, self._graph_scatter_idx, k_fp8) self.flat_v[layer_idx].scatter_(1, self._graph_scatter_idx, v_fp8) # Scatter scales (need separate index for 2D scatter) scale_idx = self._graph_scatter_idx[:, :, 0] # [kv_heads, 1] self.flat_k_scales[layer_idx].scatter_(1, scale_idx, k_scale) self.flat_v_scales[layer_idx].scatter_(1, scale_idx, v_scale) else: # BF16 path: direct scatter (original behavior) self.flat_k[layer_idx].scatter_(1, self._graph_scatter_idx, k) self.flat_v[layer_idx].scatter_(1, self._graph_scatter_idx, v) def prepare_graph_step(self, position: int, rope_cos: torch.Tensor, rope_sin: torch.Tensor): """Update side buffers before CUDA graph replay. All ops are GPU-to-GPU — zero CPU-GPU synchronization. """ self._graph_position.fill_(position) self._graph_attn_mask[0, 0, 0, position] = 0 self._graph_rope_cos.copy_(rope_cos[position:position + 1]) self._graph_rope_sin.copy_(rope_sin[position:position + 1]) self._graph_scatter_idx.fill_(position) self._graph_valid_len.fill_(position + 1) # includes current token def rollback_to(self, position: int, num_positions: int = 0): """Roll back flat KV cache for speculative decoding rejection. Zeros out positions [position, position + num_positions) in flat_k/flat_v. Resets graph attention mask and valid_len for those positions. For FP8 mode: also zeros the scale buffers. """ if not getattr(self, '_flat_mode', False): return end = min(position + max(num_positions, 1), self._flat_max_len) if position < end: self.flat_k[:, :, position:end, :] = 0 self.flat_v[:, :, position:end, :] = 0 # Also reset scales for FP8 mode if getattr(self, '_flat_kv_dtype', 'bf16') == 'fp8': self.flat_k_scales[:, :, position:end] = 1.0 self.flat_v_scales[:, :, position:end] = 1.0 if hasattr(self, '_graph_attn_mask'): self._graph_attn_mask[0, 0, 0, position:end] = float('-inf') if hasattr(self, '_graph_valid_len'): self._graph_valid_len.fill_(position) def memory_usage_gb(self) -> float: """Return memory usage in GB.""" if not self._allocated: return 0.0 return (self.k_cache.numel() + self.v_cache.numel()) * self.k_cache.element_size() / 1e9 def capacity_tokens(self) -> int: """Return max token capacity.""" return self.num_blocks * self.block_size def used_tokens(self) -> int: """Return currently used tokens.""" return sum(self.seq_lengths.values()) def stats(self) -> Dict[str, Any]: """Return cache statistics.""" return { 'capacity_tokens': self.capacity_tokens(), 'used_tokens': self.used_tokens(), 'num_sequences': len(self.block_tables), 'blocks_used': self.num_blocks - len(self.free_blocks), 'blocks_free': len(self.free_blocks), 'memory_gb': self.memory_usage_gb(), } # ============================================================================ # HEBBIAN FAST-WEIGHT MEMORY - Enhanced with STELLAR Innovations # ============================================================================ # Implements: Eligibility Traces (MOHN), Rare Correlations, Neuromodulation Gating # Based on: DARPA L2M STELLAR - AD1180225.pdf class HebbianMemory(nn.Module): """ Neuro-Modulated Hebbian Memory with Soft Competitive Learning. Integrates three foundational papers into a unified fast-weight system: 1. NHL (Tang et al. 2023) - Soft competitive Hebbian + learned neuro-modulator - Eq. 6: Softmax competition y_k = exp(u_k/τ) / Σ exp(u_i/τ) - Eq. 7: Plasticity Δw_ik = η·y_k·(R·x_i − u_k·w_ik) - Eq. 10: Neuro-modulator via entropy minimization argmin H(y|B) - Free-energy principle: bottom-up Hebbian + top-down modulator 2. Floreano, Dürr, Mattiussi (2008) - Neuroevolution of learning architectures - Three-factor rule: Δw = f(pre, post, modulator) - Dedicated modulatory neurons gate plasticity at specific synapses - Actor-critic: modulator acts as value system for learning events 3. STELLAR (DARPA L2M) - Eligibility traces + reward modulation - Eq. 20: Rare correlations Θ (top/bottom θ% of outer product) - Eq. 22: Eligibility traces E(t) = τ·E(t-1) + Θ(t) - Eq. 23: Modulated update Δw = (r(t) + b)·E(t) The three mechanisms serve distinct roles: - Soft Hebbian = WHAT to learn (competitive feature detection) - Neuro-modulator = WHERE to learn (feedback-gated consolidation) - Eligibility traces = WHEN to consolidate (temporal credit assignment) """ def __init__(self, dim: int, memory_size: int = 128, lr: float = 0.01, decay: float = 0.99, num_heads: int = 4, # STELLAR parameters tau_e: float = 0.95, rare_correlation_pct: float = 0.1, use_neuromodulation: bool = True, baseline_modulation: float = 0.01, # NHL: Soft Competitive Hebbian temperature: float = 1.0, weight_radius: float = 1.0, use_soft_hebbian: bool = True, # NHL + Floreano: Learned Neuro-Modulator use_learned_modulator: bool = True, modulator_entropy_weight: float = 0.1, # STELLAR improved: Dual-timescale traces tau_fast: float = 0.90, tau_slow: float = 0.99, # Floreano: Three-factor rule use_three_factor: bool = True, # Pattern separation (Surget & Belzung 2022) separation_strength: float = 0.05, separation_threshold: float = 0.5, slot_activation_threshold: float = 0.01, # Synaptic competition (neurogenesis + HAG) competition_strength: float = 0.1, slot_recycle_after: int = 500, # Phase 4: Research-backed tuning max_update_norm: float = 1.0, trace_clip: float = 0.1, weight_clip: float = 1.0, noise_scale: float = 0.001, # Phase 4b: BCPNN adaptive slot lr (Ravichandran 2021) adaptive_slot_lr: bool = False, tau_age: float = 100.0, importance_scale: float = 2.0, # Phase 4b: Homeostatic thresholds (Zhou/Nature 2025) homeostatic_threshold: bool = False, threshold_incr: float = 0.01, threshold_decr: float = 0.001, # Phase 4d: Cosine similarity retrieval (ENN, Szelogowski 2025) cosine_retrieval: bool = False, retrieval_tau: float = 1.0, sparsity_lambda: float = 0.1, # Phase 4e: Multi-timescale memory (Limbacher 2022) multi_timescale: bool = False, working_memory_ratio: float = 0.3, # Phase 4f: Structural plasticity (IBM CAL 2019) structural_plasticity: bool = False, merge_threshold: float = 0.95, # Phase 4g: BCPNN trace filter (Yang 2020) use_trace_filter: bool = False, # Phase 5: MESU uncertainty-scaled learning (Bonnet et al. 2025) use_mesu: bool = False, mesu_sigma_prior: float = 0.1, mesu_sigma_res: float = 10.0, # Phase 5: Bayesian reward estimation (BDA3 conjugate Normal-Normal) use_bayesian_reward: bool = False, reward_prior_mean: float = 0.0, reward_prior_var: float = 1.0, # Identity-init for projections (frozen-mode convergence boost) identity_init: bool = False, # Memory consolidation — complementary learning systems use_consolidation: bool = False, consolidation_interval: int = 20, consolidation_threshold: float = 0.01, consolidated_decay: float = 0.9999, consolidation_ratio: float = 0.3, # Adaptive transfer filtering — prevent negative transfer adaptive_transfer: bool = False, transfer_ema_decay: float = 0.95, transfer_demotion: bool = False, transfer_demotion_rate: float = 0.1, # KG consolidation gate — block promotion on negative KG score kg_consolidation_gate: bool = False, # Layer 4: Intrinsic reward — curiosity + competence use_intrinsic_reward: bool = False, curiosity_ema_decay: float = 0.95, competence_ema_decay: float = 0.99, # Layer 5: SPAR reflect — error detection threshold error_threshold: int = 2, # FE-MX: Age-adaptive microscaling for fast_weight compression use_femx: bool = False, # Phase 5b: Paper-informed upgrades (Lansner, Vasquez, GHA, STDP, Triesch) use_kappa_switching: bool = False, kappa_encoding: float = 6.0, encoding_steps: int = 50, use_adaptive_tau: bool = False, adaptive_tau_alpha: float = 0.5, use_stdp_traces: bool = False, stdp_window: int = 50, use_intrinsic_plasticity: bool = False, ip_lr: float = 0.001, use_multihead_retrieval: bool = False, use_gha_decorrelation: bool = False, gha_lr: float = 0.001, use_pmi_correction: bool = False, pmi_weight: float = 0.1, pmi_ema_decay: float = 0.99, ): super().__init__() self.dim = dim self.memory_size = memory_size self.lr = lr self.decay = decay self.num_heads = num_heads self.head_dim = dim // num_heads # Phase 4: Research-backed tuning parameters self.max_update_norm = max_update_norm self.trace_clip = trace_clip self.weight_clip = weight_clip self.noise_scale = noise_scale # Phase 4b: BCPNN adaptive per-slot lr (Ravichandran 2021, IBM CAL 2019) self.adaptive_slot_lr = adaptive_slot_lr self.tau_age = tau_age self.importance_scale = importance_scale # Phase 4b: Homeostatic thresholds (Zhou/Nature 2025) self.homeostatic_threshold = homeostatic_threshold self.threshold_incr = threshold_incr self.threshold_decr = threshold_decr # Phase 4d: Cosine retrieval (ENN, Szelogowski 2025) self.cosine_retrieval = cosine_retrieval self.retrieval_tau = retrieval_tau self.sparsity_lambda = sparsity_lambda # Phase 4e: Multi-timescale memory (Limbacher 2022) self.multi_timescale = multi_timescale self.working_memory_ratio = working_memory_ratio if multi_timescale: self.n_working = int(memory_size * working_memory_ratio) self.n_longterm = memory_size - self.n_working else: self.n_working = memory_size self.n_longterm = 0 # Phase 4f: Structural plasticity (IBM CAL 2019) self.structural_plasticity = structural_plasticity self.merge_threshold = merge_threshold # Phase 4g: BCPNN trace filter (Yang 2020) self.use_trace_filter = use_trace_filter # Phase 5: MESU uncertainty-scaled learning (Bonnet et al. Nature Comms 2025) self.use_mesu = use_mesu self.mesu_sigma_prior = mesu_sigma_prior self.mesu_sigma_res = mesu_sigma_res # Phase 5: Bayesian reward estimation (BDA3 Ch. 2-3, conjugate Normal-Normal) self.use_bayesian_reward = use_bayesian_reward self.reward_prior_mean = reward_prior_mean self.reward_prior_var = reward_prior_var # STELLAR parameters self.tau_e = tau_e self.rare_correlation_pct = rare_correlation_pct self.use_neuromodulation = use_neuromodulation self.baseline_modulation = baseline_modulation # NHL parameters self.temperature = temperature self.weight_radius = weight_radius self.use_soft_hebbian = use_soft_hebbian # Learned modulator parameters self.use_learned_modulator = use_learned_modulator self.modulator_entropy_weight = modulator_entropy_weight # Dual-timescale trace parameters self.tau_fast = tau_fast self.tau_slow = tau_slow # Three-factor flag self.use_three_factor = use_three_factor # Pattern separation (Surget & Belzung 2022 — hippocampal neurogenesis) self.separation_strength = separation_strength self.separation_threshold = separation_threshold self.slot_activation_threshold = slot_activation_threshold # Synaptic competition (neurogenesis — memory eviction) self.competition_strength = competition_strength self.slot_recycle_after = slot_recycle_after # ── Fast weight memory bank ── # Random init breaks symmetry so soft competition can differentiate slots. # With zeros, all 256 slots produce identical activations y=1/256, # no slot specializes, and slot_last_used never gets updated → all stale. self.fast_weight = nn.Parameter( torch.randn(memory_size, dim) * 0.01, requires_grad=False) # ── FE-MX: Block Floating Point compression for fast_weight ── # When enabled, femx.master (FP32) is the arithmetic accumulator. # After each Hebbian update cycle, stochastic rounding compresses # master → packed (uint8 mantissa + E8M0 shared exponent). # fast_weight becomes a BF16 dequantized view for read-only access. self.use_femx = use_femx self._femx = None # Lazy init — created on first _get_fw() call if use_femx and not _FEMX_AVAILABLE: import warnings warnings.warn("use_femx=True but femx_storage not found — falling back to BF16") self.use_femx = False # ── Eligibility traces (dual timescale) ── self.eligibility_traces = nn.Parameter( torch.zeros(memory_size, dim), requires_grad=False) # Fast trace self.slow_traces = nn.Parameter( torch.zeros(memory_size, dim), requires_grad=False) # Slow trace # ── Slot usage tracking (synaptic competition) ── self.usage_count = nn.Parameter( torch.zeros(memory_size), requires_grad=False) self.usage_ema = nn.Parameter( torch.zeros(memory_size), requires_grad=False) self.slot_last_used = nn.Parameter( torch.zeros(memory_size, dtype=torch.long), requires_grad=False) # ── BCPNN adaptive per-slot lr buffers (Ravichandran 2021) ── if adaptive_slot_lr: self.register_buffer('slot_age', torch.zeros(memory_size)) self.register_buffer('slot_relevance', torch.zeros(memory_size)) # ── Homeostatic threshold buffers (Zhou/Nature 2025) ── if homeostatic_threshold: self.register_buffer('slot_threshold', torch.ones(memory_size)) # ── Multi-timescale decay rates (Limbacher 2022) ── if multi_timescale: # Working memory: fast adaptation, fast forgetting # Long-term memory: slow adaptation, slow forgetting working_decay = torch.full((self.n_working,), 0.90) longterm_decay = torch.full((self.n_longterm,), 0.999) self.register_buffer('slot_decay', torch.cat([working_decay, longterm_decay])) working_lr_scale = torch.full((self.n_working,), 2.0) longterm_lr_scale = torch.full((self.n_longterm,), 0.3) self.register_buffer('slot_lr_scale', torch.cat([working_lr_scale, longterm_lr_scale])) # ── Structural plasticity tracking (IBM CAL 2019) ── if structural_plasticity: self.register_buffer('slot_activation_count', torch.zeros(memory_size)) # ── BCPNN trace filter buffers (Yang 2020) ── if use_trace_filter: # z: fast trace, e: intermediate, p: probability estimate self.register_buffer('trace_z', torch.zeros(memory_size, dim)) self.register_buffer('trace_e', torch.zeros(memory_size, dim)) self.register_buffer('trace_p', torch.zeros(memory_size, dim)) # Time constants (shorter = faster response) self.bcpnn_tau_z = 0.8 # fast trace self.bcpnn_tau_e = 0.95 # intermediate self.bcpnn_tau_p = 0.99 # slow probability # ── MESU per-weight uncertainty (Bonnet et al. 2025) ── if use_mesu: # σ² per fast_weight element — serves as adaptive learning rate self.register_buffer('weight_sigma_sq', torch.full((memory_size, dim), mesu_sigma_prior ** 2)) # ── Bayesian reward estimation buffers (BDA3 conjugate Normal-Normal) ── if use_bayesian_reward: # Per-slot posterior: N(mu_n, sigma_n^2) self.register_buffer('reward_mu', torch.full((memory_size,), reward_prior_mean)) self.register_buffer('reward_sigma_sq', torch.full((memory_size,), reward_prior_var)) self.register_buffer('reward_n_obs', torch.zeros(memory_size)) # ── Memory consolidation — complementary learning systems ── # Hippocampus (fast_weight): fast learning, fast forgetting (decay=0.97) # Neocortex (consolidated_weight): slow consolidation, near-zero forgetting (decay=0.9999) # Retrieval queries both banks. High-relevance fast patterns get promoted. self.use_consolidation = use_consolidation self.consolidation_interval = consolidation_interval self.consolidation_threshold = consolidation_threshold self.consolidated_decay = consolidated_decay self.consolidation_ratio = consolidation_ratio if use_consolidation: self.register_buffer('consolidated_weight', torch.zeros(memory_size, dim)) self.register_buffer('slot_cumulative_reward', torch.zeros(memory_size)) self.register_buffer('consolidation_count', torch.zeros(memory_size)) # ── Adaptive transfer filtering — signed impact score ── self.use_adaptive_transfer = use_consolidation and adaptive_transfer self.transfer_ema_decay = transfer_ema_decay self.use_transfer_demotion = use_consolidation and transfer_demotion self.transfer_demotion_rate = transfer_demotion_rate if self.use_adaptive_transfer: self.register_buffer('slot_transfer_score', torch.zeros(memory_size)) self._demotion_count = 0 # cumulative demotion events for logging # ── KG consolidation gate — external factual consistency signal ── self.use_kg_gate = use_consolidation and kg_consolidation_gate if self.use_kg_gate: self.register_buffer('slot_kg_score', torch.zeros(memory_size)) self._last_activation = None # stored from last _three_factor_update # ── Layer 4: Intrinsic motivation — curiosity + competence ── self.use_intrinsic_reward = use_intrinsic_reward self.curiosity_ema_decay = curiosity_ema_decay self.competence_ema_decay = competence_ema_decay if use_intrinsic_reward: self.register_buffer('_activation_ema', torch.zeros(memory_size)) self.register_buffer('_curiosity_ema', torch.tensor(0.0)) self.register_buffer('_competence_ema', torch.tensor(0.0)) self._intrinsic_update_count = 0 # SPAR reflect state (thresholds from config) self._expected_loss_ema = 0.0 self._prediction_error_ema = 0.0 self._consecutive_errors = 0 self._error_threshold = error_threshold # ── Phase 5b: Paper-informed upgrades ── # Kappa switching (Lansner BCPNN 2023): elevated plasticity during encoding self.use_kappa_switching = use_kappa_switching self.kappa_encoding = kappa_encoding self.encoding_steps = encoding_steps # Adaptive temperature (Vasquez MaxEnt): entropy-driven tau self.use_adaptive_tau = use_adaptive_tau self.adaptive_tau_alpha = adaptive_tau_alpha if use_adaptive_tau: self.register_buffer('_attn_entropy_ema', torch.tensor(0.0)) # STDP eligibility traces (timing-dependent asymmetric modulation) self.use_stdp_traces = use_stdp_traces self.stdp_window = stdp_window # Intrinsic plasticity (Triesch 2005): per-slot gain/bias for max entropy self.use_intrinsic_plasticity = use_intrinsic_plasticity self.ip_lr = ip_lr self._ip_mu_target = 1.0 / memory_size if use_intrinsic_plasticity: self.register_buffer('slot_gain', torch.ones(memory_size)) self.register_buffer('slot_bias', torch.zeros(memory_size)) # Multi-head memory retrieval (d2l.ai MHA) self.use_multihead_retrieval = use_multihead_retrieval # GHA decorrelation (Sanger's rule): online PCA deflation step self.use_gha_decorrelation = use_gha_decorrelation self.gha_lr = gha_lr # PMI correction (Lansner BCPNN): pointwise mutual information bonus self.use_pmi_correction = use_pmi_correction self.pmi_weight = pmi_weight self.pmi_ema_decay = pmi_ema_decay if use_pmi_correction: self.register_buffer('_coactivation_ema', torch.zeros(memory_size, dim)) self.register_buffer('_feature_ema', torch.zeros(dim)) # ── Learned projections ── self.query_proj = nn.Linear(dim, dim, bias=False) self.key_proj = nn.Linear(dim, dim, bias=False) self.value_proj = nn.Linear(dim, dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False) # Identity-init: start projections near identity so initial retrieval # passes through meaningful features from step 0 (critical for frozen mode # where these projections are the main trainable parameters). # Scale factor 0.1 prevents initial Hebbian contribution from being too large. if identity_init: scale = 0.1 for proj in [self.query_proj, self.key_proj, self.value_proj, self.out_proj]: nn.init.eye_(proj.weight) proj.weight.data.mul_(scale) # ── Gated mixing ── self.gate = nn.Linear(dim * 2, dim, bias=False) # ── Neuro-Modulator (learned feedback, NHL Sec 3.4 + Floreano) ── if use_learned_modulator: # Learned interface between Hebbian output and downstream network. # Takes the prediction error (retrieved − input) and produces a # per-feature gating signal that controls which features are # consolidated. Trained via entropy minimization on its output. self.modulator_layer = nn.Linear(dim, dim, bias=False) self.modulator_norm = nn.LayerNorm(dim) elif use_neuromodulation: # Legacy PNN-style gating (STELLAR Eq. 50-53) as fallback self.neuromod_proj = nn.Linear(dim, dim, bias=False) self.neuromod_gate = nn.Linear(dim, dim, bias=False) # ── Running state for entropy loss ── self._modulator_entropy: Optional[torch.Tensor] = None # ── Statistics tracking ── self.update_count = 0 self.memory_norm_history: List[float] = [] self.reward_buffer: List[float] = [] self._competition_stats: List[float] = [] # Track competition sharpness # ────────────────────────────────────────────────────────────────────── # FE-MX: Block Floating Point helpers # ────────────────────────────────────────────────────────────────────── def _femx_init(self): """Lazy-init FEMXStorage on the same device as fast_weight.""" if self._femx is not None: return device = str(self.fast_weight.device) self._femx = FEMXStorage( self.memory_size, self.dim, block_size=32, default_tier=FEMX8, device=device) # Seed master from current fast_weight self._femx.master.copy_(self.fast_weight.data.float()) self._femx.sync_from_master(stochastic=False) def _get_fw(self) -> torch.Tensor: """Get the writable weight tensor: FP32 FEMX master or fast_weight.data.""" if self.use_femx: self._femx_init() return self._femx.master return self.fast_weight.data def _femx_sync(self): """After Hebbian update: quantize master→packed, update fast_weight view.""" if not self.use_femx or self._femx is None: return # Age-adaptive tier assignment: precision earned by activation frequency self._femx_update_tiers() # CUDA fast path: fused quantize + dequantize → BF16 in single kernel if (_FEMX_CUDA_AVAILABLE and _femx_cuda is not None and self._femx.master.is_cuda and self.fast_weight.dtype == torch.bfloat16): seed = torch.randint(0, 2**31, (1,)).item() _femx_cuda.femx_sync( self._femx.master, self._femx.tier, self._femx.packed, self._femx.scales, self.fast_weight.data, seed) self._femx._quantize_count += 1 else: # Python fallback self._femx.sync_from_master(stochastic=True) self.fast_weight.data.copy_( self._femx.dequantize(dtype=self.fast_weight.dtype)) def _femx_update_tiers(self): """Assign FE-MX precision tiers based on slot maturity. Tier policy (age-adaptive microscaling): - Young (usage_count < 50): FEMX4 (4-bit) — volatile, don't waste bits - Maturing (50 <= count < 200): FEMX6 (6-bit) — earning precision - Consolidated (count >= 200): FEMX8 (8-bit) — proven, full precision - Stale (about to be recycled): FEMX4 — soon overwritten - consolidation_count > 0: FEMX8 forced — promoted to neocortex """ if self._femx is None: return usage = self.usage_count.data # [M] — cumulative activation count tier = self._femx.tier # [M] uint8 # Base assignment from usage count tier.fill_(FEMX4) tier[usage >= 50] = FEMX6 tier[usage >= 200] = FEMX8 # Stale slots → FEMX4 (about to be recycled, no point in high precision) if self.slot_recycle_after > 0 and self.update_count > 0: stale = (self.update_count - self.slot_last_used.data) > self.slot_recycle_after tier[stale] = FEMX4 # Force FEMX8 for consolidated slots (promoted to long-term memory) if self.use_consolidation and hasattr(self, 'consolidation_count'): tier[self.consolidation_count > 0] = FEMX8 def _femx_ensure_device(self): """Move FEMXStorage tensors if fast_weight moved (e.g. .to(cuda)).""" if not self.use_femx or self._femx is None: return target = self.fast_weight.device if self._femx.master.device != target: self._femx.master = self._femx.master.to(target) self._femx.packed = self._femx.packed.to(target) self._femx.scales = self._femx.scales.to(target) self._femx.tier = self._femx.tier.to(target) self._femx.device = str(target) # ────────────────────────────────────────────────────────────────────── # NHL Eq. 6: Soft Competitive Activation # ────────────────────────────────────────────────────────────────────── def _soft_competitive_activation(self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Competitive softmax activation over memory slots (NHL Eq. 6). Each memory slot competes to represent the input. Winners (high y_k) learn strongly, losers barely change — lateral inhibition via softmax. u_k = w_k · x (pre-activation per slot) y_k = exp(u_k/τ) / Σ exp(u_i/τ) (soft WTA) Args: x: Mean input activations [D] Returns: y: Competition probabilities [memory_size] (sums to 1) u: Pre-activations [memory_size] """ # u_k = w_k · x for each memory slot k u = torch.mv(self.fast_weight, x) # [memory_size] # Intrinsic plasticity (Triesch 2005): per-slot gain/bias # Adjusts each slot's excitability to maximize output entropy if self.use_intrinsic_plasticity: u = self.slot_gain * u + self.slot_bias # Adaptive temperature (Vasquez MaxEnt): adjust tau based on entropy # High entropy (diffuse) → lower tau to sharpen; low entropy → raise tau if self.use_adaptive_tau: H_target = math.log(self.memory_size) * 0.5 H_current = self._attn_entropy_ema.item() tau_eff = self.temperature * ( 1.0 + self.adaptive_tau_alpha * (H_target - H_current) / max(H_target, 1e-8)) tau_eff = max(tau_eff, 0.1) else: tau_eff = self.temperature # Triton fused path: homeostatic → Gumbel → softmax in single kernel if _TRITON_HEBBIAN_AVAILABLE and u.is_cuda and not self.use_intrinsic_plasticity: y = _fused_competition( u, self.slot_threshold, self.noise_scale, tau_eff, self.homeostatic_threshold, ) else: # ── Python fallback ── if self.homeostatic_threshold: u_scaled = u / self.slot_threshold else: u_scaled = u if self.noise_scale > 0: gumbel = -torch.log(-torch.log(torch.rand_like(u_scaled) + 1e-20) + 1e-20) u_scaled = u_scaled + gumbel * self.noise_scale * tau_eff y = F.softmax(u_scaled / tau_eff, dim=0) # [memory_size] # Update entropy EMA for adaptive temperature if self.use_adaptive_tau: H = -(y * torch.log(y + 1e-10)).sum() self._attn_entropy_ema.mul_(0.95).add_(0.05 * H) # Intrinsic plasticity update: adapt gain/bias toward exponential distribution if self.use_intrinsic_plasticity and self.training: mu = self._ip_mu_target db = self.ip_lr * (1.0 - y / mu * (2.0 * self.slot_bias + 1.0)) dg = self.ip_lr / self.slot_gain.clamp(min=1e-6) + db * u.detach() self.slot_bias.data.add_(db) self.slot_gain.data.add_(dg) self.slot_gain.data.clamp_(min=0.1, max=10.0) return y, u # ────────────────────────────────────────────────────────────────────── # NHL Eq. 7: Soft Hebbian Plasticity Rule # ────────────────────────────────────────────────────────────────────── def _soft_hebbian_update(self, x: torch.Tensor, y: torch.Tensor, u: torch.Tensor) -> torch.Tensor: """ Soft Hebbian plasticity: winning slots move toward input (NHL Eq. 7). Δw_k = η · y_k · (R·x − u_k·w_k) Convergence: weight norms → sphere of radius R. ||w_k|| > R → norm decreases; ||w_k|| < R → norm increases. Args: x: Mean input [D] y: Competition activations [memory_size] u: Pre-activations [memory_size] Returns: delta_w: Weight update [memory_size, dim] """ # Triton fused path: R*x - u*w → lr*y*delta in single kernel if _TRITON_HEBBIAN_AVAILABLE and self.fast_weight.is_cuda: return _fused_soft_hebbian( x, y, u, self.fast_weight, self.lr, self.weight_radius, ) # ── Python fallback ── # R·x broadcast [D] and u_k·w_k per slot [M, D] rx = self.weight_radius * x.unsqueeze(0) # [1, D] uw = u.unsqueeze(1) * self.fast_weight # [M, D] # Δw_k = η · y_k · (R·x − u_k·w_k) delta = self.lr * y.unsqueeze(1) * (rx - uw) # [M, D] return delta # ────────────────────────────────────────────────────────────────────── # STELLAR Eq. 20: Rare Correlations # ────────────────────────────────────────────────────────────────────── def _compute_rare_correlations(self, pre: torch.Tensor, post: torch.Tensor) -> torch.Tensor: """ Rare correlation filter Θ (STELLAR Eq. 20). Only top/bottom θ% of correlations are used, zeroing the middle. This provides noise robustness and sparse, meaningful updates. Args: pre: Presynaptic activations [D] post: Postsynaptic activations [M] (or [D] truncated to M) Returns: Θ: Sparse correlation matrix [memory_size, dim] with {-1, 0, +1} """ correlations = torch.outer(post, pre) # [M, D] flat = correlations.flatten() k = max(1, int(len(flat) * self.rare_correlation_pct)) top_thresh = torch.topk(flat, k).values[-1] bot_thresh = torch.topk(-flat, k).values[-1] theta = torch.zeros_like(correlations) theta[correlations >= top_thresh] = 1.0 theta[correlations <= -bot_thresh] = -1.0 return theta # ────────────────────────────────────────────────────────────────────── # NHL Sec 3.4 + Floreano: Learned Neuro-Modulator # ────────────────────────────────────────────────────────────────────── def _compute_modulation(self, x: torch.Tensor, retrieved: torch.Tensor) -> torch.Tensor: """ Learned neuro-modulator interface (NHL Sec 3.4, Floreano Fig. 13). Takes the prediction error between retrieved memory and input, produces a per-feature gating signal that controls consolidation. Accumulates entropy loss for external optimizer (NHL Eq. 10). Biological analogy: dopaminergic modulation controls which synapses undergo plasticity (Floreano Sec 5.1). Args: x: Input hidden states [B, S, D] retrieved: Memory retrieval [B, S, D] Returns: Modulation signal [B, S, D] in (0, 1) — gates retrieved memory """ # Prediction error: how much does memory differ from input? delta = retrieved - x # [B, S, D] # Learned modulation: sigmoid output ∈ (0, 1) m = torch.sigmoid(self.modulator_layer(self.modulator_norm(delta))) # Entropy of modulation signal for loss (NHL Eq. 10): # H(m) = -Σ m·log(m) - (1-m)·log(1-m) (binary entropy per feature) # Minimizing this encourages decisive gating (near 0 or 1) eps = 1e-7 entropy = -(m * torch.log(m + eps) + (1 - m) * torch.log(1 - m + eps)) self._modulator_entropy = entropy.mean() return m # ────────────────────────────────────────────────────────────────────── # Legacy PNN gating (STELLAR Eq. 50-53) — fallback # ────────────────────────────────────────────────────────────────────── def _neuromodulation_gating(self, x: torch.Tensor, hs: torch.Tensor) -> torch.Tensor: """PNN gating fallback: h = ReLU(h_s ⊗ tanh(W_m · ReLU(W_g · x)))""" g = F.relu(self.neuromod_proj(x)) hm = torch.tanh(self.neuromod_gate(g)) return F.relu(hs * hm) # ────────────────────────────────────────────────────────────────────── # Synaptic Competition (Surget & Belzung 2022 + HAG Nature 2025) # ────────────────────────────────────────────────────────────────────── def _synaptic_competition(self, y: torch.Tensor, hebbian_delta: torch.Tensor) -> torch.Tensor: """ Synaptic competition: new traces weaken overlapping old traces. Biological basis (Surget & Belzung 2022): Adult-born neurons compete with mature neurons for synaptic inputs. Integrating new representations destabilizes preexisting synapses, promoting pattern separation. Also inspired by HAG (Nature 2025): Sculpt task-specific wiring by strengthening active connections and weakening unused ones. Args: y: Competition activations [M] — which slots are winning hebbian_delta: Proposed weight update [M, D] Returns: Modified hebbian_delta with competition applied """ # Update usage tracking self.usage_count.data += y.detach() self.usage_ema.data.mul_(0.99).add_(0.01 * y.detach()) # Adaptive threshold: 0.5/M ensures slots are marked as used even # with near-uniform activation (1/M). Old hardcoded 0.01 was too high # for 256 slots (1/256 = 0.0039 < 0.01 → no slot ever marked used). usage_threshold = max(0.001, 0.5 / self.memory_size) self.slot_last_used.data[y > usage_threshold] = self.update_count # Compute overlap: cosine similarity between winning slot's delta # and existing fast_weight rows winner_idx = y.argmax() winner_delta = hebbian_delta[winner_idx] # [D] if winner_delta.norm() < 1e-8: return hebbian_delta # Similarity of winner's update direction to all existing slots fw = self._get_fw() # FP32 FEMX master or fast_weight.data W_normed = F.normalize(fw, dim=1) # [M, D] delta_normed = F.normalize(winner_delta.unsqueeze(0).to(fw.dtype), dim=1) # [1, D] overlap = (W_normed @ delta_normed.t()).squeeze() # [M] # Weaken overlapping old traces (excluding the winner) overlap[winner_idx] = 0.0 weakening = self.competition_strength * F.relu(overlap) fw.sub_(weakening.unsqueeze(1) * fw) # Recycle dead slots: re-init slots unused for slot_recycle_after updates if self.slot_recycle_after > 0 and self.update_count > self.slot_recycle_after: stale = (self.update_count - self.slot_last_used.data) > self.slot_recycle_after if stale.any(): # Re-init with small random values (not zeros!) to preserve # symmetry-breaking. Zeroing creates permanently dead slots # because uniform competition never updates slot_last_used. fw[stale] = torch.randn_like(fw[stale]) * 0.01 self.usage_count.data[stale] = 0.0 self.usage_ema.data[stale] = 0.0 # Reset last_used so recycled slots get a fresh lease on life. # Without this, stale slots are immediately re-detected as stale # on the next check and zeroed again — permanently dead. self.slot_last_used.data[stale] = self.update_count return hebbian_delta # ────────────────────────────────────────────────────────────────────── # Combined Three-Factor Update (NHL + Floreano + STELLAR) # ────────────────────────────────────────────────────────────────────── def _three_factor_update(self, x_mean: torch.Tensor, k_mean: torch.Tensor, y: torch.Tensor, u: torch.Tensor, modulation_scalar: float, reward: Optional[float]) -> None: """ Three-factor Hebbian update: WHAT × WHERE × WHEN. Combines all three papers into a single update rule: Δw = hebbian_delta · mod_gate · (reward_mod · eligibility) Factor 1 — WHAT (NHL Eq. 7): Soft competitive Hebbian delta Determines which direction each slot should move. Factor 2 — WHERE (NHL Sec 3.4 + Floreano): Neuro-modulator gate Controls which slots actually undergo plasticity. Factor 3 — WHEN (STELLAR Eq. 22-23): Reward-modulated eligibility Temporal credit: only consolidate when reward signal confirms. Args: x_mean: Mean input activations [D] k_mean: Mean key (presynaptic) activations [D] y: Competition activations [memory_size] u: Pre-activations [memory_size] modulation_scalar: Mean modulation gate value (0-1) reward: External reward signal or None """ # Factor 1: Soft Hebbian delta (WHAT to learn) hebbian_delta = self._soft_hebbian_update(x_mean, y, u) # [M, D] # PMI correction (Lansner BCPNN): add pointwise mutual information bonus # Slots that co-activate with input more than expected get boosted if self.use_pmi_correction: hebbian_delta = self._pmi_correction(hebbian_delta, x_mean, y) # Synaptic competition: weaken overlapping old traces (Surget & Belzung 2022) if self.competition_strength > 0: hebbian_delta = self._synaptic_competition(y, hebbian_delta) # Factor 3 prerequisite: Rare correlations for eligibility traces post = x_mean[:self.memory_size] if self.dim > self.memory_size else x_mean if len(post) < self.memory_size: post = F.pad(post, (0, self.memory_size - len(post))) theta = self._compute_rare_correlations(k_mean, post) # [M, D] # ── Triton mega-kernel fast path ── # Fuses: traces → three-factor → adaptive LR → MESU → delta_w # Then Python handles: norm check → decay → fw update → clip if _TRITON_HEBBIAN_AVAILABLE and hebbian_delta.is_cuda: # Bayesian reward smoothing (before mega-kernel) if self.use_bayesian_reward and reward is not None: reward = self._bayesian_reward_update(reward, y) r_mod = (reward + self.baseline_modulation) if reward is not None else self.baseline_modulation # Kappa switching (Lansner BCPNN): elevated plasticity during encoding if self.use_kappa_switching: kappa = self.kappa_encoding if self.update_count < self.encoding_steps else 1.0 r_mod = r_mod * kappa mod_gate = max(0.01, modulation_scalar) # STDP eligibility traces: asymmetric temporal modulation if self.use_stdp_traces: theta = self._stdp_modulate_theta(theta, y) # Slot metadata: update BEFORE computing effective_lr # (effective_lr depends on updated slot_age and slot_relevance) effective_lr = None if self.adaptive_slot_lr: abs_reward = abs(reward) if reward is not None else 0.0 signed_reward = reward if reward is not None else 0.0 _update_slot_metadata( y, self.slot_age, self.slot_relevance, self.slot_transfer_score if self.use_adaptive_transfer else None, self.slot_threshold, abs_reward=abs_reward, signed_reward=signed_reward, transfer_ema_decay=self.transfer_ema_decay, threshold_incr=self.threshold_incr, threshold_decr=self.threshold_decr, use_adaptive_transfer=self.use_adaptive_transfer, homeostatic_threshold=self.homeostatic_threshold, ) effective_lr = _compute_effective_lr( self.slot_age, self.slot_relevance, self.tau_age, self.importance_scale, ) elif self.homeostatic_threshold: winner_mask = (y > y.mean()) self.slot_threshold[winner_mask] += self.threshold_incr self.slot_threshold[~winner_mask] -= self.threshold_decr self.slot_threshold.clamp_(min=0.5, max=2.0) self._last_activation = y.detach().clone() y_max = y.max().clamp(min=1e-8).item() # Mega-kernel: traces + three-factor + adaptive LR + MESU → delta_w # COMPUTE_DELTA_ONLY=True so caller can norm-check before fw update delta_w = _fused_traces_update( hebbian_delta, theta, getattr(self, 'trace_z', None), getattr(self, 'trace_e', None), getattr(self, 'trace_p', None), self.eligibility_traces.data, self.slow_traces.data, getattr(self, 'weight_sigma_sq', None) if self.use_mesu else None, self._get_fw(), y, effective_lr, getattr(self, 'slot_decay', None) if self.multi_timescale else None, getattr(self, 'slot_lr_scale', None) if self.multi_timescale else None, tau_z=getattr(self, 'bcpnn_tau_z', 0.9), tau_e=getattr(self, 'bcpnn_tau_e', 0.95), tau_p=getattr(self, 'bcpnn_tau_p', 0.99), tau_fast=self.tau_fast, tau_slow=self.tau_slow, noise_scale=self.noise_scale, trace_clip=self.trace_clip, mod_gate=mod_gate, r_mod=r_mod, decay=self.decay, weight_clip=self.weight_clip, y_max=y_max, sigma_prior_sq=self.mesu_sigma_prior ** 2 if self.use_mesu else 1.0, sigma_res_sq=self.mesu_sigma_res ** 2 if self.use_mesu else 1.0, use_trace_filter=self.use_trace_filter, use_mesu=self.use_mesu, multi_timescale=self.multi_timescale, adaptive_slot_lr=self.adaptive_slot_lr, compute_delta_only=True, ) # Norm check (max_update_norm always > 0 in production) if self.max_update_norm > 0: update_norm = delta_w.norm() if update_norm > self.max_update_norm: delta_w = delta_w * (self.max_update_norm / update_norm) # Decay + weight update (not fused — needs norm-checked delta_w) fw = self._get_fw() delta_w = delta_w.to(fw.dtype) if self.multi_timescale: fw.mul_(self.slot_decay.unsqueeze(1)) delta_w = delta_w * self.slot_lr_scale.unsqueeze(1) fw.add_(delta_w) else: y_activity = y.unsqueeze(1) y_mx = y.max().clamp(min=1e-8) activity_ratio = y_activity / y_mx effective_decay = self.decay + (1.0 - self.decay) * (1.0 - activity_ratio) * 0.9 fw.mul_(effective_decay).add_(delta_w) fw.clamp_(-self.weight_clip, self.weight_clip) # GHA decorrelation (Sanger's rule): PCA deflation step if self.use_gha_decorrelation: self._gha_deflation_step(fw, y, x_mean) # Structural plasticity + consolidation (separate passes) if self.structural_plasticity: self._structural_plasticity_step(y) if self.use_consolidation: self._consolidation_step(reward, y) self._femx_sync() return # ── Python fallback ── # BCPNN trace filtering chain (Yang 2020): cleaner signal for Hebbian update # Raw activation → fast trace z → intermediate e → probability estimate p if self.use_trace_filter: self.trace_z.mul_(self.bcpnn_tau_z).add_((1 - self.bcpnn_tau_z) * theta) self.trace_e.mul_(self.bcpnn_tau_e).add_((1 - self.bcpnn_tau_e) * self.trace_z) self.trace_p.mul_(self.bcpnn_tau_p).add_((1 - self.bcpnn_tau_p) * self.trace_e) theta = self.trace_p # Use filtered signal instead of raw # STDP eligibility traces: asymmetric temporal modulation (Python fallback) if self.use_stdp_traces: theta = self._stdp_modulate_theta(theta, y) # Update dual-timescale eligibility traces (STELLAR Eq. 22 improved) # Noise injection for robustness (Szelogowski ENN 2025) if self.noise_scale > 0: noise = torch.randn_like(theta) * self.noise_scale theta = theta + noise self.eligibility_traces.data.mul_(self.tau_fast).add_(theta) self.slow_traces.data.mul_(self.tau_slow).add_(theta) # Trace clipping (ENN: prevents unbounded trace growth) if self.trace_clip > 0: self.eligibility_traces.data.clamp_(-self.trace_clip, self.trace_clip) self.slow_traces.data.clamp_(-self.trace_clip * 10, self.trace_clip * 10) # Factor 3: Reward modulation (STELLAR Eq. 23) # Bayesian reward estimation (BDA3 Ch. 2-3): smooth noisy reward via conjugate prior if self.use_bayesian_reward and reward is not None: reward = self._bayesian_reward_update(reward, y) r_mod = (reward + self.baseline_modulation) if reward is not None else self.baseline_modulation # Kappa switching (Lansner BCPNN): elevated plasticity during encoding if self.use_kappa_switching: kappa = self.kappa_encoding if self.update_count < self.encoding_steps else 1.0 r_mod = r_mod * kappa # Combined eligibility: fast trace + attenuated slow trace combined_elig = self.eligibility_traces.data + 0.5 * self.slow_traces.data # ── Three-factor combination ── # WHAT (hebbian_delta) × WHERE (modulation) × WHEN (reward × eligibility) mod_gate = max(0.01, modulation_scalar) # Floor to prevent dead updates delta_w = hebbian_delta * mod_gate * (r_mod * combined_elig) # Store activation for intrinsic reward + KG signal injection self._last_activation = y.detach().clone() # BCPNN adaptive per-slot learning rate (Ravichandran 2021, IBM CAL 2019) # Each slot gets its own effective lr based on age and relevance. # Important synapses get plasticity protection — prevents catastrophic forgetting. if self.adaptive_slot_lr: # Update slot age and relevance self.slot_age.add_(1.0) abs_reward = abs(reward) if reward is not None else 0.0 # EMA of reward contribution per slot, weighted by competition activation y self.slot_relevance.mul_(0.99).add_(0.01 * y * abs_reward) # Signed Transfer Impact Score — tracks whether slot helps or hurts # Unlike slot_relevance (abs), this preserves sign: positive = helpful if self.use_adaptive_transfer: signed_reward = reward if reward is not None else 0.0 self.slot_transfer_score.mul_(self.transfer_ema_decay).add_( (1 - self.transfer_ema_decay) * y * signed_reward ) # Per-slot lr: older slots learn slower, relevant slots learn more age_factor = 1.0 / (1.0 + self.slot_age / self.tau_age) # [M] relevance_factor = 0.5 + 0.5 * self.slot_relevance # [M] slot_lr = age_factor * relevance_factor # [M] # Plasticity protection (IBM CAL 2019): important synapses resist change protection = 1.0 / (1.0 + self.importance_scale * self.slot_relevance) # [M] effective_lr = slot_lr * protection # [M] # Apply per-slot lr modulation: [M, 1] broadcasts over [M, D] delta_w = delta_w * effective_lr.unsqueeze(1) # Homeostatic threshold update (Zhou/Nature 2025) # Winner slots get harder to win next time, inactive slots get easier if self.homeostatic_threshold: winner_mask = (y > y.mean()) # active slots self.slot_threshold[winner_mask] += self.threshold_incr self.slot_threshold[~winner_mask] -= self.threshold_decr self.slot_threshold.clamp_(min=0.5, max=2.0) # Norm-scaled update (Duan et al. ICLR 2023: prevents weight explosion) if self.max_update_norm > 0: update_norm = delta_w.norm() if update_norm > self.max_update_norm: delta_w = delta_w * (self.max_update_norm / update_norm) # MESU uncertainty-scaled learning (Bonnet et al. Nature Comms 2025) # μ_{n+1} = μ_n - σ_n² × grad (σ² is the adaptive learning rate) # σ_{n+1} = σ_n - σ_n² × ∂C/∂σ + σ_n(σ²_prior - σ_n²)/σ²_res if self.use_mesu: # Use σ² as per-weight learning rate (replaces uniform lr) delta_w = delta_w * self.weight_sigma_sq # Update σ²: decrease proportional to update magnitude (learning reduces uncertainty) # Plus forgetting regularization to prevent vanishing plasticity grad_magnitude = delta_w.abs() sigma_prior_sq = self.mesu_sigma_prior ** 2 sigma_res_sq = self.mesu_sigma_res ** 2 sigma_update = ( -self.weight_sigma_sq * grad_magnitude + self.weight_sigma_sq.sqrt() * (sigma_prior_sq - self.weight_sigma_sq) / sigma_res_sq ) self.weight_sigma_sq.add_(sigma_update) # Floor σ² to prevent numerical issues self.weight_sigma_sq.clamp_(min=1e-8, max=sigma_prior_sq * 10) # Multi-timescale memory (Limbacher 2022): per-slot decay and lr scaling # Working slots (first 30%) adapt fast & forget fast; long-term (70%) slow & stable fw = self._get_fw() # FP32 FEMX master or fast_weight.data delta_w = delta_w.to(fw.dtype) # BF16 → FP32 when FEMX active if self.multi_timescale: # Per-slot decay: [M, 1] broadcasts over [M, D] fw.mul_(self.slot_decay.unsqueeze(1)) delta_w = delta_w * self.slot_lr_scale.unsqueeze(1) fw.add_(delta_w) else: # Activity-proportional decay: active slots decay at self.decay (0.97), # inactive slots decay much slower (midpoint toward 1.0). # This prevents uniform decay from killing non-winning slots: # with 0.97 uniform, inactive slots hit 0.002 after 200 steps (dead). # With 0.985 for inactive, they retain 0.049 — 25x more alive. y_activity = y.unsqueeze(1) # [M, 1] — competition activation y_max = y.max().clamp(min=1e-8) activity_ratio = y_activity / y_max # 0=inactive, 1=winner # Inactive: decay 90% of the way toward 1.0 (nearly no forgetting) # 0.97 → 0.997 for inactive. After 500 steps: 0.997^500=0.22 (alive) # vs old 0.985^500=0.0005 (dead) or uniform 0.97^500=2e-7 (dead) effective_decay = self.decay + (1.0 - self.decay) * (1.0 - activity_ratio) * 0.9 fw.mul_(effective_decay).add_(delta_w) fw.clamp_(-self.weight_clip, self.weight_clip) # GHA decorrelation (Sanger's rule): PCA deflation step if self.use_gha_decorrelation: self._gha_deflation_step(fw, y, x_mean) # Structural plasticity — slot merge/split (IBM CAL 2019) # Merge: when two slots are too similar (cosine > threshold), combine them # Split: when a slot wins too often, split it to create diversity if self.structural_plasticity: self._structural_plasticity_step(y) # Memory consolidation — hippocampus → neocortex transfer # Promote high-relevance fast weight patterns to long-term storage if self.use_consolidation: self._consolidation_step(reward, y) # FE-MX: requantize master → packed after all updates complete self._femx_sync() # ────────────────────────────────────────────────────────────────────── # Phase 5b: Paper-informed helpers (STDP, PMI, GHA) # ────────────────────────────────────────────────────────────────────── def _stdp_modulate_theta(self, theta: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """STDP-modulated eligibility traces (timing-dependent asymmetry). Recently active slots (small delta_t) get LTP (potentiation), older slots get LTD (depression). This introduces temporal credit assignment beyond symmetric exponential decay. Based on: Bi & Poo (1998), Vasquez thesis (Gibbs spike-train statistics) Args: theta: Raw rare correlations [M, D] y: Competition activations [M] Returns: STDP-modulated theta [M, D] """ delta_t = (self.update_count - self.slot_last_used.data).float() # [M] A_plus, A_minus = 0.01, 0.005 tau_plus, tau_minus = 20.0, 40.0 stdp_factor = torch.where( delta_t <= self.stdp_window, A_plus * torch.exp(-delta_t / tau_plus), -A_minus * torch.exp(-(delta_t - self.stdp_window) / tau_minus) ) # [M] # Bias toward positive for active slots (prevent total LTD collapse) stdp_factor = stdp_factor + 0.5 * y.detach() return theta * (1.0 + stdp_factor.unsqueeze(1)) def _pmi_correction(self, hebbian_delta: torch.Tensor, x_mean: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """PMI correction (Lansner BCPNN): pointwise mutual information bonus. Adds a term based on log(P(slot,feature) / (P(slot) * P(feature))) to favor slots that capture surprising co-activations rather than merely frequent ones. Based on: Lansner et al. 2023 (BCPNN), Ravichandran 2021 Args: hebbian_delta: Base Hebbian update [M, D] x_mean: Mean input activations [D] y: Competition activations [M] Returns: Modified hebbian_delta with PMI bonus [M, D] """ eps = 1e-8 # Update running statistics self._feature_ema.mul_(self.pmi_ema_decay).add_( (1 - self.pmi_ema_decay) * x_mean.abs().detach()) self._coactivation_ema.mul_(self.pmi_ema_decay).add_( (1 - self.pmi_ema_decay) * torch.outer(y.detach(), x_mean.abs().detach())) # PMI: log(P_ij / (P_i * P_j)) p_i = self.usage_ema.clamp(min=eps) # [M] p_j = self._feature_ema.clamp(min=eps) # [D] p_ij = self._coactivation_ema.clamp(min=eps) # [M, D] pmi = torch.log(p_ij / (p_i.unsqueeze(1) * p_j.unsqueeze(0) + eps)) pmi = pmi.clamp(-5.0, 5.0) # Bounded PMI # PMI bonus scaled by activation and pmi_weight pmi_bonus = self.pmi_weight * pmi * y.unsqueeze(1) return hebbian_delta + pmi_bonus.to(hebbian_delta.dtype) def _gha_deflation_step(self, fw: torch.Tensor, y: torch.Tensor, x_mean: torch.Tensor) -> None: """GHA decorrelation (Sanger's rule): online PCA deflation step. Ensures memory slots represent orthogonal features (principal components of the input stream) rather than redundant overlapping features. Each slot i captures variance NOT already captured by slots 0..i-1. Based on: Sanger (1989), Chen et al. 2019 (Neural Processing Letters) Args: fw: Fast weight tensor [M, D] (modified in-place) y: Competition activations [M] x_mean: Mean input activations [D] """ M = fw.shape[0] # y_outer = [M, M], lower triangular mask y_vec = y.detach().float() y_outer = torch.outer(y_vec, y_vec) # [M, M] lt_mask = torch.tril(torch.ones(M, M, device=fw.device)) # Deflation: subtract influence of lower-indexed slots deflation = torch.mm(lt_mask * y_outer, fw.float()) # [M, D] # GHA delta: move toward input, subtract deflation gha_delta = self.gha_lr * ( torch.outer(y_vec, x_mean.float()) - deflation ) # [M, D] fw.add_(gha_delta.to(fw.dtype)) def inject_kg_signal(self, score: float) -> None: """Inject external KG consistency score into per-slot KG accumulator. Called from the training loop after a KG fact-check. Uses the last stored activation pattern to distribute the score across active slots. Slots active during factually consistent generation get positive scores; slots active during contradictory generation get negative scores. """ if not self.use_kg_gate or self._last_activation is None: return self.slot_kg_score.mul_(0.95).add_( 0.05 * self._last_activation * score ) # ── Layer 4: Intrinsic motivation ── def compute_intrinsic_reward(self) -> Dict[str, float]: """Compute curiosity (activation surprise) + competence (learning health).""" if not self.use_intrinsic_reward or self._last_activation is None: return {'curiosity': 0.0, 'competence': 0.0, 'combined': 0.0} y = self._last_activation # ── Curiosity: activation surprise ── self._activation_ema.mul_(self.curiosity_ema_decay).add_( (1 - self.curiosity_ema_decay) * y ) self._intrinsic_update_count += 1 if self._intrinsic_update_count < 10: raw_curiosity = 0.0 else: # Mean-center before cosine: softmax outputs are always positive # and sum to 1, so raw cosine ≈ 1.0 regardless of which slot wins. # Centering makes cosine sensitive to *which* slot deviates from # average, not just that all values are positive. y_c = y - y.mean() ema_c = self._activation_ema - self._activation_ema.mean() y_c_norm = y_c.norm() ema_c_norm = ema_c.norm() if y_c_norm > 1e-8 and ema_c_norm > 1e-8: cos_sim = torch.dot(y_c, ema_c) / (y_c_norm * ema_c_norm) raw_curiosity = max(-1.0, min(1.0, (1.0 - cos_sim.item() - 0.5) * 2.0)) else: raw_curiosity = 0.0 self._curiosity_ema.mul_(0.9).add_(0.1 * raw_curiosity) curiosity = self._curiosity_ema.item() # ── Competence: learning health ── # Use dynamic metrics that respond to ongoing training, not static # binary counters that freeze once entropy collapses. # 1. Slot utilization entropy — measures retrieval diversity usage_sum = self.usage_ema.data.sum() if usage_sum > 1e-8: usage_probs = self.usage_ema.data / usage_sum usage_entropy = -(usage_probs * torch.log(usage_probs + 1e-8)).sum().item() max_entropy = math.log(self.memory_size) slot_utilization = min(1.0, usage_entropy / max_entropy) else: slot_utilization = 0.0 # 2. Reward trend — are recent rewards improving? if len(self.reward_buffer) >= 5: recent = self.reward_buffer[-10:] reward_health = max(0.0, min(1.0, (sum(recent) / len(recent) + 1.0) / 2.0)) else: reward_health = 0.5 raw_competence = (0.5 * slot_utilization + 0.5 * reward_health) * 2.0 - 1.0 self._competence_ema.mul_(self.competence_ema_decay).add_( (1 - self.competence_ema_decay) * raw_competence ) # Floor clip: prevent competence from diving into a doom loop competence = max(-0.3, self._competence_ema.item()) # Weight competence heavier (0.6) to reward learning progress over novelty combined = max(-1.0, min(1.0, 0.4 * curiosity + 0.6 * competence)) return {'curiosity': round(curiosity, 6), 'competence': round(competence, 6), 'combined': round(combined, 6)} # ── Layer 5: SPAR perceive + reflect ── @torch.no_grad() def perceive(self, x_mean: torch.Tensor) -> Dict[str, float]: """Assess input familiarity before processing (SPAR perceive phase).""" if not self.use_intrinsic_reward: return {'familiarity': 0.0, 'complexity': 0.0} x_mean = x_mean.to(self.query_proj.weight.dtype) q = self.query_proj(x_mean.unsqueeze(0) if x_mean.dim() == 1 else x_mean) q = q.squeeze(0) if q.dim() == 2 else q.squeeze(0).squeeze(0) fw_norms = self.fast_weight.data.norm(dim=1, keepdim=True).clamp(min=1e-8) q_norm = q.norm().clamp(min=1e-8) sims = (self.fast_weight.data @ q) / (fw_norms.squeeze() * q_norm) familiarity = sims.max().item() complexity = 1.0 - sims.std().item() return {'familiarity': max(0.0, min(1.0, familiarity)), 'complexity': max(0.0, min(1.0, complexity))} def reflect(self, loss: float, reward: float) -> Dict[str, Any]: """Post-step self-monitoring (SPAR reflect phase).""" if not self.use_intrinsic_reward: return {'prediction_error': 0.0, 'error_detected': False, 'consecutive_errors': 0, 'recovery_action': 'none'} prediction_error = abs(loss - self._expected_loss_ema) self._expected_loss_ema = 0.95 * self._expected_loss_ema + 0.05 * loss self._prediction_error_ema = 0.9 * self._prediction_error_ema + 0.1 * prediction_error if self._prediction_error_ema > 0 and prediction_error > 2 * self._prediction_error_ema: self._consecutive_errors += 1 else: self._consecutive_errors = max(0, self._consecutive_errors - 1) threshold = getattr(self, '_error_threshold', 2) error_detected = self._consecutive_errors >= threshold recovery = 'boost_curiosity' if error_detected else 'none' return {'prediction_error': round(prediction_error, 6), 'error_detected': error_detected, 'consecutive_errors': self._consecutive_errors, 'recovery_action': recovery} def _consolidation_step(self, reward: Optional[float], y: torch.Tensor) -> None: """Consolidate high-relevance patterns to long-term memory bank. Complementary Learning Systems (McClelland et al. 1995): - Hippocampus (fast_weight): rapid encoding, high plasticity, fast decay - Neocortex (consolidated_weight): slow consolidation, near-permanent Promotion criteria: slot_relevance > threshold (has been consistently useful). Transfer: fraction of fast_weight added to consolidated bank. Consolidated bank decays at 0.9999 per step (retains 90% after 1000 steps). """ # Track per-slot cumulative reward (weighted by activation) if reward is not None: self.slot_cumulative_reward.add_(y * reward) # Consolidated bank slow decay every step self.consolidated_weight.mul_(self.consolidated_decay) # Promotion only every N updates (avoids overhead) if self.update_count % self.consolidation_interval != 0: return # Find slots worth consolidating: high relevance (if tracked) or high norm if self.adaptive_slot_lr and hasattr(self, 'slot_relevance'): promote_mask = self.slot_relevance > self.consolidation_threshold else: # Fallback: promote slots with significant fast_weight norm slot_norms = self.fast_weight.data.norm(dim=1) promote_mask = slot_norms > self.consolidation_threshold # Adaptive transfer gate — only promote patterns that demonstrably help # slot_relevance uses abs(reward), so it can't distinguish helpful from harmful. # slot_transfer_score is signed: positive = helps, negative = hurts. if self.use_adaptive_transfer: promote_mask = promote_mask & (self.slot_transfer_score > 0) # KG consolidation gate — only promote patterns with non-negative KG score # Blocks factually inconsistent patterns even if they have positive TIS if self.use_kg_gate: promote_mask = promote_mask & (self.slot_kg_score >= 0) if promote_mask.any(): # Transfer fraction of fast_weight to consolidated bank promote_idx = promote_mask.nonzero(as_tuple=True)[0] fw = self._get_fw() self.consolidated_weight[promote_idx] += ( self.consolidation_ratio * fw[promote_idx] ) self.consolidation_count[promote_idx] += 1 # Clip consolidated bank self.consolidated_weight.clamp_(-self.weight_clip * 2, self.weight_clip * 2) # Active demotion — gradually shrink consolidated patterns that now hurt # If a pattern's signed impact turns persistently negative, reduce its # weight in the neocortex. Capped at transfer_demotion_rate per step (default 10%) # and floored so consolidated norms never drop below 0.1 (preserves marginal utility). if self.use_transfer_demotion: demote_mask = self.slot_transfer_score < -self.consolidation_threshold if demote_mask.any(): demote_idx = demote_mask.nonzero(as_tuple=True)[0] # Gradual shrink: max transfer_demotion_rate fraction per step self.consolidated_weight[demote_idx] *= (1.0 - self.transfer_demotion_rate) # Floor: never erase below 0.1 norm — preserves marginally useful patterns demote_norms = self.consolidated_weight[demote_idx].norm(dim=1, keepdim=True) too_small = demote_norms.squeeze(1) < 0.1 if too_small.any(): # Scale back up to 0.1 norm for slots that went too low tiny_idx = too_small.nonzero(as_tuple=True)[0] for ti in tiny_idx: actual_idx = demote_idx[ti] n = self.consolidated_weight[actual_idx].norm() if n > 1e-8: self.consolidated_weight[actual_idx] *= (0.1 / n) # Log demotion event count self._demotion_count += len(demote_idx) def _structural_plasticity_step(self, y: torch.Tensor) -> None: """Merge similar slots and split overloaded ones (IBM CAL 2019).""" # Track activation counts winner_mask = (y > y.mean()) self.slot_activation_count[winner_mask] += 1 # Only run merge/split every 50 updates (expensive) if self.update_count % 50 != 0: return M = self.memory_size fw = self.fast_weight.data # MERGE: find pairs with cosine similarity > threshold norms = fw.norm(dim=1, keepdim=True).clamp(min=1e-8) fw_normed = fw / norms # Only check upper triangle (avoid self-pairs and duplicates) sim = torch.mm(fw_normed, fw_normed.t()) sim.fill_diagonal_(-1.0) # ignore self max_sim, max_j = sim.max(dim=1) for i in range(M): j = max_j[i].item() if max_sim[i] > self.merge_threshold and i < j: # Weighted merge by age (if available) or equal age_i = self.slot_activation_count[i].item() + 1 age_j = self.slot_activation_count[j].item() + 1 total = age_i + age_j fw[i] = (fw[i] * age_i + fw[j] * age_j) / total # Reset freed slot fw[j].zero_() self.eligibility_traces.data[j].zero_() self.slow_traces.data[j].zero_() self.slot_activation_count[j] = 0 break # one merge per step # SPLIT: if a slot wins too often (>50 activations since last check) split_threshold = 50 overloaded = (self.slot_activation_count > split_threshold).nonzero(as_tuple=True)[0] if len(overloaded) > 0: # Find a free slot (lowest activation count) _, free_idx = self.slot_activation_count.min(dim=0) free_idx = free_idx.item() src_idx = overloaded[0].item() if self.slot_activation_count[free_idx] < split_threshold // 2: noise = torch.randn_like(fw[src_idx]) * 0.01 fw[free_idx] = fw[src_idx] + noise fw[src_idx] = fw[src_idx] - noise self.slot_activation_count[src_idx] = 0 self.slot_activation_count[free_idx] = 0 # ────────────────────────────────────────────────────────────────────── # Legacy STELLAR-only update (backward compat) # ────────────────────────────────────────────────────────────────────── def _legacy_stellar_update(self, x_mean: torch.Tensor, k_mean: torch.Tensor, reward: Optional[float]) -> None: """Original STELLAR Eq. 20-23 update path (no NHL, no Floreano).""" post = x_mean[:self.memory_size] if self.dim > self.memory_size else x_mean if len(post) < self.memory_size: post = F.pad(post, (0, self.memory_size - len(post))) theta = self._compute_rare_correlations(k_mean, post) # Noise injection (ENN, Szelogowski 2025) if self.noise_scale > 0: noise = torch.randn_like(theta) * self.noise_scale theta = theta + noise self.eligibility_traces.data.mul_(self.tau_e).add_(theta) # Trace clipping (ENN) if self.trace_clip > 0: self.eligibility_traces.data.clamp_(-self.trace_clip, self.trace_clip) fw = self._get_fw() # FP32 FEMX master or fast_weight.data fw.mul_(self.decay) mod = (reward + self.baseline_modulation) if reward is not None else self.baseline_modulation update = (self.lr * mod * self.eligibility_traces.data).to(fw.dtype) # Norm-scaled update (Duan ICLR 2023) if self.max_update_norm > 0: update_norm = update.norm() if update_norm > self.max_update_norm: update = update * (self.max_update_norm / update_norm) fw.add_(update) fw.clamp_(-self.weight_clip, self.weight_clip) self._femx_sync() # ────────────────────────────────────────────────────────────────────── # Forward pass # ────────────────────────────────────────────────────────────────────── def forward(self, x: torch.Tensor, update: bool = True, reward: Optional[float] = None) -> torch.Tensor: """ Query memory and optionally update fast weights. Retrieval uses attention over the fast-weight bank. Modulation is either learned (NHL) or PNN-gated (legacy STELLAR). Updates follow the three-factor rule when enabled. Args: x: Input tensor [batch, seq, dim] update: Whether to perform Hebbian update reward: Optional reward signal for modulated learning Returns: Output with memory-augmented representation [batch, seq, dim] """ B, S, D = x.shape # FE-MX: ensure compressed storage is on the same device as fast_weight self._femx_ensure_device() # ── RETRIEVAL: query the fast-weight memory bank ── q = self.query_proj(x) # [B, S, D] # Effective memory: fast (working) + consolidated (long-term) if available effective_weight = self.fast_weight if self.use_consolidation: effective_weight = self.fast_weight + self.consolidated_weight # Multi-head memory retrieval (d2l.ai MHA): each head attends to its # own subspace of the memory bank, enabling specialized retrieval if self.use_multihead_retrieval and not self.cosine_retrieval: H = self.num_heads d_h = self.head_dim M = self.memory_size # Split into heads: [B, S, H, d_h] and [M, H, d_h] q_h = q.view(B, S, H, d_h).transpose(1, 2) # [B, H, S, d_h] mem_h = effective_weight.view(M, H, d_h).permute(1, 0, 2) # [H, M, d_h] scores = torch.matmul(q_h, mem_h.transpose(-1, -2)) / math.sqrt(d_h) attn = F.softmax(scores, dim=-1) # [B, H, S, M] retrieved = torch.matmul(attn, mem_h) # [B, H, S, d_h] retrieved = retrieved.transpose(1, 2).contiguous().view(B, S, D) elif self.cosine_retrieval: # Cosine similarity retrieval (ENN, Szelogowski 2025) q_norm = F.normalize(q, dim=-1) # [B, S, D] mem_aug = effective_weight + 0.1 * self.eligibility_traces mem_norm = F.normalize(mem_aug, dim=-1) # [M, D] tau_eff = self.retrieval_tau / (1.0 + 10.0 * self.sparsity_lambda) scores = torch.matmul(q_norm, mem_norm.t()) / tau_eff # [B, S, M] attn = F.softmax(scores, dim=-1) # [B, S, M] # Triton fused path if (self.use_femx and self._femx is not None and not self.use_consolidation and _TRITON_HEBBIAN_AVAILABLE and attn.is_cuda): retrieved = _fused_dequant_matvec( attn, self._femx.packed, self._femx.scales, self._femx.tier, self._femx.num_blocks, ).to(x.dtype) else: retrieved = torch.matmul(attn, effective_weight) else: scores = torch.matmul(q, effective_weight.t()) / math.sqrt(self.dim) attn = F.softmax(scores, dim=-1) # [B, S, M] if (self.use_femx and self._femx is not None and not self.use_consolidation and _TRITON_HEBBIAN_AVAILABLE and attn.is_cuda): retrieved = _fused_dequant_matvec( attn, self._femx.packed, self._femx.scales, self._femx.tier, self._femx.num_blocks, ).to(x.dtype) else: retrieved = torch.matmul(attn, effective_weight) # [B, S, D] retrieved = self.value_proj(retrieved) # ── NEURO-MODULATION: gate retrieved memory ── modulation_for_update = None if self.use_learned_modulator: # NHL Sec 3.4: learned feedback modulator (trainable) modulation = self._compute_modulation(x, retrieved) retrieved = retrieved * modulation # Store tensor for three-factor update (avoid .item() for torch.compile) modulation_for_update = modulation.detach().mean() elif self.use_neuromodulation: # Legacy PNN gating (STELLAR Eq. 50-53) retrieved = self._neuromodulation_gating(x, retrieved) # ── GATED OUTPUT MIXING ── gate_input = torch.cat([x, retrieved], dim=-1) gate_raw = self.gate(gate_input) # GEMM (hard boundary) proj_raw = self.out_proj(retrieved) # GEMM (hard boundary) # Triton fused path: sigmoid → multiply → residual add in single kernel if _TRITON_HEBBIAN_AVAILABLE and x.is_cuda: output = _fused_gate_output(x, gate_raw, proj_raw) else: output = x + torch.sigmoid(gate_raw) * proj_raw # ── HEBBIAN UPDATE ── if update and (self.training or self.update_count < 1000): with torch.no_grad(): x_mean = x.mean(dim=(0, 1)) # [D] k_mean = self.key_proj(x).mean(dim=(0, 1)) # [D] if self.use_soft_hebbian: # NHL + three-factor path y, u = self._soft_competitive_activation(x_mean) # Track competition sharpness (max probability) self._competition_stats.append(y.max().item()) if len(self._competition_stats) > 1000: self._competition_stats.pop(0) if self.use_three_factor: mod_scalar = modulation_for_update.item() if modulation_for_update is not None else 1.0 self._three_factor_update( x_mean, k_mean, y, u, mod_scalar, reward) else: # Soft Hebbian only (no eligibility, no reward mod) delta = self._soft_hebbian_update(x_mean, y, u) fw = self._get_fw() fw.mul_(self.decay).add_(delta.to(fw.dtype)) fw.clamp_(-self.weight_clip, self.weight_clip) self._femx_sync() else: # Legacy STELLAR-only path self._legacy_stellar_update(x_mean, k_mean, reward) # Track reward if reward is not None: self.reward_buffer.append(reward) if len(self.reward_buffer) > 1000: self.reward_buffer.pop(0) # Track statistics self.update_count += 1 if self.update_count % 100 == 0: self.memory_norm_history.append(self.fast_weight.norm().item()) return output @torch.no_grad() def retrieve_only(self, x: torch.Tensor) -> torch.Tensor: """Retrieve from memory WITHOUT updating fast weights. Used by EAGLE-3 draft head to get memory context during speculation. Speculative tokens may be rejected, so we must NOT contaminate the memory bank — retrieval only, zero side effects. Args: x: Input tensor [batch, seq, dim] Returns: Retrieved memory context [batch, seq, dim] (value-projected) """ B, S, D = x.shape q = self.query_proj(x) effective_weight = self.fast_weight if self.use_consolidation: effective_weight = self.fast_weight + self.consolidated_weight # Multi-head retrieval (d2l.ai MHA) if self.use_multihead_retrieval and not self.cosine_retrieval: H, d_h, M = self.num_heads, self.head_dim, self.memory_size q_h = q.view(B, S, H, d_h).transpose(1, 2) mem_h = effective_weight.view(M, H, d_h).permute(1, 0, 2) scores = torch.matmul(q_h, mem_h.transpose(-1, -2)) / math.sqrt(d_h) attn = F.softmax(scores, dim=-1) retrieved = torch.matmul(attn, mem_h) retrieved = retrieved.transpose(1, 2).contiguous().view(B, S, D) elif self.cosine_retrieval: q_norm = F.normalize(q, dim=-1) mem_aug = effective_weight + 0.1 * self.eligibility_traces mem_norm = F.normalize(mem_aug, dim=-1) tau_eff = self.retrieval_tau / (1.0 + 10.0 * self.sparsity_lambda) scores = torch.matmul(q_norm, mem_norm.t()) / tau_eff attn = F.softmax(scores, dim=-1) if (self.use_femx and self._femx is not None and not self.use_consolidation and _TRITON_HEBBIAN_AVAILABLE and attn.is_cuda): retrieved = _fused_dequant_matvec( attn, self._femx.packed, self._femx.scales, self._femx.tier, self._femx.num_blocks, ).to(x.dtype) else: retrieved = torch.matmul(attn, effective_weight) else: scores = torch.matmul(q, effective_weight.t()) / math.sqrt(self.dim) attn = F.softmax(scores, dim=-1) if (self.use_femx and self._femx is not None and not self.use_consolidation and _TRITON_HEBBIAN_AVAILABLE and attn.is_cuda): retrieved = _fused_dequant_matvec( attn, self._femx.packed, self._femx.scales, self._femx.tier, self._femx.num_blocks, ).to(x.dtype) else: retrieved = torch.matmul(attn, effective_weight) return self.value_proj(retrieved) # ────────────────────────────────────────────────────────────────────── # Entropy loss for external optimizer (NHL Eq. 10) # ────────────────────────────────────────────────────────────────────── def get_modulator_entropy_loss(self) -> torch.Tensor: """ Return neuro-modulator entropy loss for backpropagation. The modulator should make decisive gating decisions (near 0 or 1). Minimizing entropy of its output encourages this. Called by the training loop after forward(): loss = main_loss + engine.get_hebbian_entropy_loss() loss.backward() """ if self._modulator_entropy is not None: return self.modulator_entropy_weight * self._modulator_entropy return torch.tensor(0.0, device=self.fast_weight.device) # ────────────────────────────────────────────────────────────────────── # Pattern Separation Loss (Surget & Belzung 2022, Hippocampal Neurogenesis) # ────────────────────────────────────────────────────────────────────── def get_pattern_separation_loss(self) -> torch.Tensor: """ Pattern separation loss (Surget & Belzung 2022, Hippocampal Neurogenesis). Biological basis: Adult-born neurons in the dentate gyrus bias hippocampal computations toward pattern separation — orthogonalizing CA3 representations to reduce interference between similar memories. Implementation: Penalize high cosine similarity between memory slot vectors. This encourages the fast_weight bank to store diverse, non-overlapping representations. Returns: Scalar loss to add to AdamW optimization target. """ W = self.fast_weight.data # [M, D] norms = W.norm(dim=1) # [M] active_mask = norms > self.slot_activation_threshold active_count = active_mask.sum().item() if active_count < 2: return torch.tensor(0.0, device=W.device) active_W = W[active_mask] # [A, D] active_normed = F.normalize(active_W, dim=1) # [A, D] sim_matrix = active_normed @ active_normed.t() # [A, A] sim_matrix.fill_diagonal_(0.0) excess = F.relu(sim_matrix.abs() - self.separation_threshold) num_pairs = active_count * (active_count - 1) loss = self.separation_strength * excess.sum() / max(num_pairs, 1) return loss # ────────────────────────────────────────────────────────────────────── # State management # ────────────────────────────────────────────────────────────────────── def reset(self): """Reset all fast weights, traces, counters, and slot tracking.""" self.fast_weight.data.zero_() if self.use_femx and self._femx is not None: self._femx.master.zero_() self._femx.packed.zero_() self._femx.scales.zero_() self.eligibility_traces.data.zero_() self.slow_traces.data.zero_() self.usage_count.data.zero_() self.usage_ema.data.zero_() self.slot_last_used.data.zero_() self.update_count = 0 self.memory_norm_history.clear() self.reward_buffer.clear() self._competition_stats.clear() self._modulator_entropy = None # Adaptive slot lr buffers if self.adaptive_slot_lr and hasattr(self, 'slot_age'): self.slot_age.data.zero_() self.slot_relevance.data.zero_() # Homeostatic threshold — reset to 1.0 (neutral) if self.homeostatic_threshold and hasattr(self, 'slot_threshold'): self.slot_threshold.data.fill_(1.0) # Bayesian reward posteriors — reset to prior if hasattr(self, 'reward_mu'): self.reward_mu.data.fill_(self.reward_mu.data.mean().item() if self.reward_mu.numel() > 0 else 0.0) self.reward_sigma_sq.data.fill_(self.reward_sigma_sq.data.mean().item() if self.reward_sigma_sq.numel() > 0 else 1.0) self.reward_n_obs.data.zero_() # Memory consolidation — zero hippocampus→neocortex transfer state if self.use_consolidation: self.consolidated_weight.data.zero_() self.slot_cumulative_reward.data.zero_() self.consolidation_count.data.zero_() # Adaptive transfer score if self.use_adaptive_transfer: self.slot_transfer_score.data.zero_() if self.use_kg_gate: self.slot_kg_score.data.zero_() self._last_activation = None self._demotion_count = 0 # Intrinsic motivation + SPAR state if self.use_intrinsic_reward: self._activation_ema.data.zero_() self._curiosity_ema.data.zero_() self._competence_ema.data.zero_() self._intrinsic_update_count = 0 self._expected_loss_ema = 0.0 self._prediction_error_ema = 0.0 self._consecutive_errors = 0 def save_state(self) -> Dict[str, Any]: """Save all state for cross-session persistence (v6 format with FE-MX).""" state = { 'version': 6 if self.use_femx else 5, 'fast_weight': self.fast_weight.data.clone(), 'eligibility_traces': self.eligibility_traces.data.clone(), 'slow_traces': self.slow_traces.data.clone(), 'usage_count': self.usage_count.data.clone(), 'usage_ema': self.usage_ema.data.clone(), 'slot_last_used': self.slot_last_used.data.clone(), 'update_count': self.update_count, 'reward_buffer': list(self.reward_buffer), 'memory_norm_history': list(self.memory_norm_history), 'competition_stats': list(self._competition_stats[-100:]), } # v4 fields — adaptive slot lr if self.adaptive_slot_lr and hasattr(self, 'slot_age'): state['slot_age'] = self.slot_age.data.clone() state['slot_relevance'] = self.slot_relevance.data.clone() # v4 fields — homeostatic threshold if self.homeostatic_threshold and hasattr(self, 'slot_threshold'): state['slot_threshold'] = self.slot_threshold.data.clone() # v4 fields — Bayesian reward if hasattr(self, 'reward_mu'): state['reward_mu'] = self.reward_mu.data.clone() state['reward_sigma_sq'] = self.reward_sigma_sq.data.clone() state['reward_n_obs'] = self.reward_n_obs.data.clone() # v4 fields — memory consolidation if self.use_consolidation: state['consolidated_weight'] = self.consolidated_weight.data.clone() state['slot_cumulative_reward'] = self.slot_cumulative_reward.data.clone() state['consolidation_count'] = self.consolidation_count.data.clone() # v4 fields — adaptive transfer if self.use_adaptive_transfer: state['slot_transfer_score'] = self.slot_transfer_score.data.clone() # v4 fields — KG consolidation gate if self.use_kg_gate: state['slot_kg_score'] = self.slot_kg_score.data.clone() state['demotion_count'] = self._demotion_count # v5 fields — intrinsic motivation + SPAR if self.use_intrinsic_reward: state['activation_ema'] = self._activation_ema.data.clone() state['curiosity_ema'] = self._curiosity_ema.data.clone() state['competence_ema'] = self._competence_ema.data.clone() state['intrinsic_update_count'] = self._intrinsic_update_count state['expected_loss_ema'] = self._expected_loss_ema state['prediction_error_ema'] = self._prediction_error_ema state['consecutive_errors'] = self._consecutive_errors # v6 fields — FE-MX compressed storage if self.use_femx and self._femx is not None: state['femx_packed'] = self._femx.packed.cpu().clone() state['femx_scales'] = self._femx.scales.cpu().clone() state['femx_tier'] = self._femx.tier.cpu().clone() state['femx_quantize_count'] = self._femx._quantize_count return state def load_state(self, state: Dict[str, Any]): """Restore state. Backward-compatible: loads v1-v5 + v6 FE-MX.""" self.fast_weight.data.copy_(state['fast_weight']) # Sync FE-MX master from loaded fast_weight (v5 backward compat) if self.use_femx: self._femx_init() if 'femx_packed' in state: # v6 format — load compressed state directly self._femx.packed.copy_(state['femx_packed'].to(self._femx.packed.device)) self._femx.scales.copy_(state['femx_scales'].to(self._femx.scales.device)) self._femx.tier.copy_(state['femx_tier'].to(self._femx.tier.device)) self._femx.master.copy_(state['fast_weight'].float().to(self._femx.master.device)) else: # v5 or older — seed master from fast_weight, quantize self._femx.master.copy_(self.fast_weight.data.float()) self._femx.sync_from_master(stochastic=False) self.eligibility_traces.data.copy_(state['eligibility_traces']) # v2 fields — use defaults if loading old format if 'slow_traces' in state: self.slow_traces.data.copy_(state['slow_traces']) else: self.slow_traces.data.zero_() # v3 fields — slot tracking (defaults to zero if loading v1/v2) if 'usage_count' in state: self.usage_count.data.copy_(state['usage_count']) else: self.usage_count.data.zero_() if 'usage_ema' in state: self.usage_ema.data.copy_(state['usage_ema']) else: self.usage_ema.data.zero_() if 'slot_last_used' in state: self.slot_last_used.data.copy_(state['slot_last_used']) else: self.slot_last_used.data.zero_() self.update_count = state.get('update_count', 0) self.reward_buffer = list(state.get('reward_buffer', [])) self.memory_norm_history = list(state.get('memory_norm_history', [])) self._competition_stats = list(state.get('competition_stats', [])) # v4 fields — adaptive slot lr if self.adaptive_slot_lr and hasattr(self, 'slot_age'): if 'slot_age' in state: self.slot_age.data.copy_(state['slot_age']) self.slot_relevance.data.copy_(state['slot_relevance']) else: self.slot_age.data.zero_() self.slot_relevance.data.zero_() # v4 fields — homeostatic threshold if self.homeostatic_threshold and hasattr(self, 'slot_threshold'): if 'slot_threshold' in state: self.slot_threshold.data.copy_(state['slot_threshold']) else: self.slot_threshold.data.fill_(1.0) # v4 fields — Bayesian reward if hasattr(self, 'reward_mu'): if 'reward_mu' in state: self.reward_mu.data.copy_(state['reward_mu']) self.reward_sigma_sq.data.copy_(state['reward_sigma_sq']) self.reward_n_obs.data.copy_(state['reward_n_obs']) # else: keep priors from __init__ # v4 fields — memory consolidation if self.use_consolidation: if 'consolidated_weight' in state: self.consolidated_weight.data.copy_(state['consolidated_weight']) self.slot_cumulative_reward.data.copy_(state['slot_cumulative_reward']) self.consolidation_count.data.copy_(state['consolidation_count']) else: self.consolidated_weight.data.zero_() self.slot_cumulative_reward.data.zero_() self.consolidation_count.data.zero_() # v4 fields — adaptive transfer if self.use_adaptive_transfer: if 'slot_transfer_score' in state: self.slot_transfer_score.data.copy_(state['slot_transfer_score']) else: self.slot_transfer_score.data.zero_() # v4 fields — KG consolidation gate if self.use_kg_gate: if 'slot_kg_score' in state: self.slot_kg_score.data.copy_(state['slot_kg_score']) else: self.slot_kg_score.data.zero_() self._demotion_count = state.get('demotion_count', 0) # v5 fields — intrinsic motivation + SPAR if self.use_intrinsic_reward: if 'activation_ema' in state: self._activation_ema.data.copy_(state['activation_ema']) self._curiosity_ema.data.copy_(state['curiosity_ema']) self._competence_ema.data.copy_(state['competence_ema']) self._intrinsic_update_count = state.get('intrinsic_update_count', 0) self._expected_loss_ema = state.get('expected_loss_ema', 0.0) self._prediction_error_ema = state.get('prediction_error_ema', 0.0) self._consecutive_errors = state.get('consecutive_errors', 0) else: self._activation_ema.data.zero_() self._curiosity_ema.data.zero_() self._competence_ema.data.zero_() self._intrinsic_update_count = 0 self._expected_loss_ema = 0.0 self._prediction_error_ema = 0.0 self._consecutive_errors = 0 def get_stats(self) -> Dict[str, Any]: """Return memory statistics including NHL competition and HAG allocation metrics.""" avg_sharpness = ( sum(self._competition_stats) / len(self._competition_stats) if self._competition_stats else 0.0) # Dynamic memory allocation stats (HAG-inspired) norms = self.fast_weight.data.norm(dim=1) active_slots = (norms > self.slot_activation_threshold).sum().item() stale_slots = 0 if self.slot_recycle_after > 0 and self.update_count > 0: stale_slots = int(((self.update_count - self.slot_last_used.data) > self.slot_recycle_after).sum().item()) # Usage entropy (how evenly distributed is slot usage) usage_sum = self.usage_ema.data.sum() if usage_sum > 1e-8: usage_probs = self.usage_ema.data / usage_sum usage_entropy = -(usage_probs * torch.log(usage_probs + 1e-8)).sum().item() else: usage_entropy = 0.0 return { 'update_count': self.update_count, 'memory_norm': self.fast_weight.norm().item(), 'memory_mean': self.fast_weight.mean().item(), 'memory_std': self.fast_weight.std().item(), 'eligibility_norm': self.eligibility_traces.norm().item(), 'slow_trace_norm': self.slow_traces.norm().item(), 'avg_reward': (sum(self.reward_buffer) / len(self.reward_buffer) if self.reward_buffer else 0.0), 'norm_history': self.memory_norm_history[-10:], 'competition_sharpness': avg_sharpness, 'use_soft_hebbian': self.use_soft_hebbian, 'use_learned_modulator': self.use_learned_modulator, 'use_three_factor': self.use_three_factor, 'rare_correlation_pct': self.rare_correlation_pct, 'active_slots': int(active_slots), 'total_slots': self.memory_size, 'stale_slots': stale_slots, 'usage_entropy': usage_entropy, 'slot_utilization': active_slots / self.memory_size, 'consolidated_norm': (self.consolidated_weight.norm().item() if self.use_consolidation else 0.0), 'consolidated_slots': (int((self.consolidated_weight.norm(dim=1) > self.slot_activation_threshold).sum().item()) if self.use_consolidation else 0), 'transfer_score_mean': (self.slot_transfer_score.mean().item() if self.use_adaptive_transfer else 0.0), 'transfer_score_pos': (int((self.slot_transfer_score > 0).sum().item()) if self.use_adaptive_transfer else 0), 'transfer_score_neg': (int((self.slot_transfer_score < 0).sum().item()) if self.use_adaptive_transfer else 0), 'demoted_slots': (int((self.slot_transfer_score < -self.consolidation_threshold).sum().item()) if self.use_transfer_demotion else 0), 'demotion_events': self._demotion_count, 'kg_score_mean': (self.slot_kg_score.mean().item() if self.use_kg_gate else 0.0), 'kg_gated_slots': (int((self.slot_kg_score < 0).sum().item()) if self.use_kg_gate else 0), 'intrinsic_curiosity': (self._curiosity_ema.item() if self.use_intrinsic_reward else 0.0), 'intrinsic_competence': (self._competence_ema.item() if self.use_intrinsic_reward else 0.0), # FE-MX compression stats 'femx_enabled': self.use_femx, 'femx_savings_pct': (self._femx.savings_ratio() * 100 if self.use_femx and self._femx is not None else 0.0), 'femx_tier_dist': (self._femx.tier_distribution() if self.use_femx and self._femx is not None else {}), 'femx_quantize_count': (self._femx._quantize_count if self.use_femx and self._femx is not None else 0), } # ────────────────────────────────────────────────────────────────────── # Phase 5a: ELM Pseudoinverse Warm-Start (Huang et al. 2006) # ────────────────────────────────────────────────────────────────────── def elm_warm_start(self, activations: torch.Tensor, targets: Optional[torch.Tensor] = None) -> None: """ Initialize fast_weight via Moore-Penrose pseudoinverse (ELM). Instead of starting from zeros, compute optimal initial weights analytically: W = pinv(H) * T, where H is the activation matrix from a calibration batch passed through the frozen base. For self-supervised mode (targets=None), uses identity mapping: the memory learns to reconstruct its own input (autoencoder init). Args: activations: [N, dim] activation matrix from calibration batch targets: [N, dim] target outputs (optional; defaults to activations) """ with torch.no_grad(): H = activations.float() # [N, D] T = targets.float() if targets is not None else H # [N, D] N, D = H.shape M = self.memory_size # Project to memory space: H_proj = H @ query_proj.weight.T → [N, D] # Then truncate/pad to [N, M] for fast_weight shape H_proj = H if D > M: H_proj = H_proj[:, :M] elif D < M: H_proj = F.pad(H_proj, (0, M - D)) # Pseudoinverse: W = pinv(H_proj) @ T → [M, D] # pinv(H_proj) is [M, N], T is [N, D] → result is [M, D] try: W = torch.linalg.lstsq(H_proj, T).solution # [M, D] except Exception: # Fallback: explicit pseudoinverse W = torch.linalg.pinv(H_proj) @ T # [M, D] # Clamp to weight bounds and assign W = W.clamp(-self.weight_clip, self.weight_clip) self.fast_weight.data.copy_(W.to(self.fast_weight.dtype)) # Sync FE-MX master from warm-started fast_weight if self.use_femx: self._femx_init() self._femx.master.copy_(W.float().to(self._femx.master.device)) self._femx.sync_from_master(stochastic=False) # Seed neocortex with ELM warm-start — calibrated base-model features # survive training because consolidated_decay=0.9999 vs fast_weight decay=0.97 # fast_weight will decay to ~0 after 200 steps; consolidated retains 90% after 1000 if self.use_consolidation: self.consolidated_weight.data.copy_(self.fast_weight.data) # Reset traces to match warm-started weights self.eligibility_traces.data.zero_() self.slow_traces.data.zero_() # ────────────────────────────────────────────────────────────────────── # Phase 5b: Bayesian Reward Estimation (BDA3 Ch. 2-3) # ────────────────────────────────────────────────────────────────────── def _bayesian_reward_update(self, reward: float, y: torch.Tensor) -> float: """ Conjugate Normal-Normal Bayesian update of reward estimate per slot. Prior: reward ~ N(mu_prior, sigma_prior^2) Likelihood: r_obs ~ N(mu_true, sigma_obs^2) Posterior: mu_n = (sigma_obs^2 * mu_{n-1} + sigma_{n-1}^2 * r) / (sigma_obs^2 + sigma_{n-1}^2) Returns the Bayesian posterior mean reward (more stable than raw reward). Args: reward: Raw observed reward signal y: Competition activations [memory_size] (determines which slots update) Returns: Smoothed Bayesian reward estimate """ # Observation variance (decreases with more data = more confident) sigma_obs_sq = 1.0 / (1.0 + self.reward_n_obs) # Conjugate update: posterior = prior × likelihood # Only update slots that are active (y > mean) active = (y > y.mean()).float() # Posterior precision = prior precision + likelihood precision posterior_prec = 1.0 / self.reward_sigma_sq + active / sigma_obs_sq posterior_var = 1.0 / posterior_prec # Posterior mean (precision-weighted average) posterior_mu = posterior_var * ( self.reward_mu / self.reward_sigma_sq + active * reward / sigma_obs_sq ) # Update buffers self.reward_mu.copy_(posterior_mu) self.reward_sigma_sq.copy_(posterior_var) self.reward_n_obs.add_(active) # Return precision-weighted global estimate weights = 1.0 / self.reward_sigma_sq weights = weights / weights.sum() return (weights * self.reward_mu).sum().item() # ────────────────────────────────────────────────────────────────────── # Phase 5c: Posterior Predictive Check (BDA3 Ch. 6) # ────────────────────────────────────────────────────────────────────── def posterior_predictive_check(self, x: torch.Tensor, n_samples: int = 100) -> Dict[str, Any]: """ Bayesian posterior predictive check for model quality diagnostics. Generates samples from the memory model and compares statistics against the observed input. Detects model misfit in frozen-only mode. Test statistics: - Mean activation magnitude - Activation variance - Slot utilization - Retrieval confidence (attention entropy) Args: x: Input tensor [batch, seq, dim] from real data n_samples: Number of posterior predictive samples Returns: Dict with test statistics, p-values, and diagnostic assessment """ with torch.no_grad(): B, S, D = x.shape # Observed test statistics q = self.query_proj(x) scores = torch.matmul(q, self.fast_weight.t()) / math.sqrt(self.dim) attn = F.softmax(scores, dim=-1) obs_stats = { 'mean_activation': x.mean().item(), 'activation_std': x.std().item(), 'attention_entropy': -(attn * torch.log(attn + 1e-8)).sum(-1).mean().item(), 'max_attention': attn.max(-1).values.mean().item(), 'memory_norm': self.fast_weight.norm().item(), } # Posterior predictive samples (perturb memory + measure) pp_stats = {k: [] for k in obs_stats} orig_fw = self.fast_weight.data.clone() for _ in range(n_samples): # Sample from posterior: add noise scaled by uncertainty if self.use_mesu: noise = torch.randn_like(self.fast_weight) * self.weight_sigma_sq.sqrt() else: noise = torch.randn_like(self.fast_weight) * 0.01 self.fast_weight.data.copy_(orig_fw + noise) q_s = self.query_proj(x) scores_s = torch.matmul(q_s, self.fast_weight.t()) / math.sqrt(self.dim) attn_s = F.softmax(scores_s, dim=-1) pp_stats['mean_activation'].append(x.mean().item()) pp_stats['activation_std'].append(x.std().item()) pp_stats['attention_entropy'].append( -(attn_s * torch.log(attn_s + 1e-8)).sum(-1).mean().item()) pp_stats['max_attention'].append( attn_s.max(-1).values.mean().item()) pp_stats['memory_norm'].append(self.fast_weight.norm().item()) # Restore original weights self.fast_weight.data.copy_(orig_fw) # Compute Bayesian p-values: P(T(y_rep) >= T(y_obs)) p_values = {} for stat_name, obs_val in obs_stats.items(): samples = pp_stats[stat_name] p_val = sum(1 for s in samples if s >= obs_val) / n_samples p_values[stat_name] = p_val # Diagnostic: p-values near 0 or 1 indicate misfit extreme = {k: v for k, v in p_values.items() if v < 0.05 or v > 0.95} return { 'observed': obs_stats, 'p_values': p_values, 'misfit_detected': len(extreme) > 0, 'misfit_stats': list(extreme.keys()), 'assessment': ('GOOD FIT' if not extreme else f'POTENTIAL MISFIT in: {list(extreme.keys())}'), } class PerLayerHebbian(nn.Module): """Per-layer Hebbian memory with NHL + STELLAR + Floreano + neurogenesis enhancements. Ventral-dorsal specialization (Surget & Belzung 2022): Layer position along the network axis maps to the hippocampal ventral-dorsal gradient: - Early layers ("dorsal"): structural/spatial features — more plastic - Late layers ("ventral"): semantic/contextual features — more stable """ def __init__(self, dim: int, num_layers: int, memory_size: int = 64, lr: float = 0.01, decay: float = 0.99, tau_e: float = 0.95, rare_correlation_pct: float = 0.1, use_neuromodulation: bool = True, # NHL + Floreano + STELLAR new params temperature: float = 1.0, weight_radius: float = 1.0, use_soft_hebbian: bool = True, use_learned_modulator: bool = True, modulator_entropy_weight: float = 0.1, tau_fast: float = 0.90, tau_slow: float = 0.99, use_three_factor: bool = True, # Pattern separation + synaptic competition separation_strength: float = 0.05, separation_threshold: float = 0.5, slot_activation_threshold: float = 0.01, competition_strength: float = 0.1, slot_recycle_after: int = 500, # Ventral-dorsal specialization use_layer_specialization: bool = True, # Phase 4: Research-backed tuning max_update_norm: float = 1.0, trace_clip: float = 0.1, weight_clip: float = 1.0, noise_scale: float = 0.001, # Phase 4b: BCPNN adaptive slot lr + homeostatic thresholds adaptive_slot_lr: bool = False, tau_age: float = 100.0, importance_scale: float = 2.0, homeostatic_threshold: bool = False, threshold_incr: float = 0.01, threshold_decr: float = 0.001, # Phase 4d: Cosine retrieval cosine_retrieval: bool = False, retrieval_tau: float = 1.0, sparsity_lambda: float = 0.1, # Phase 4e-g: Advanced multi_timescale: bool = False, working_memory_ratio: float = 0.3, structural_plasticity: bool = False, merge_threshold: float = 0.95, use_trace_filter: bool = False, # Phase 5: MESU + Bayesian reward use_mesu: bool = False, mesu_sigma_prior: float = 0.1, mesu_sigma_res: float = 10.0, use_bayesian_reward: bool = False, reward_prior_mean: float = 0.0, reward_prior_var: float = 1.0, # Identity-init for projections identity_init: bool = False, # Memory consolidation use_consolidation: bool = False, consolidation_interval: int = 20, consolidation_threshold: float = 0.01, consolidated_decay: float = 0.9999, consolidation_ratio: float = 0.3, # Adaptive transfer filtering adaptive_transfer: bool = False, transfer_ema_decay: float = 0.95, transfer_demotion: bool = False, transfer_demotion_rate: float = 0.1, # KG consolidation gate kg_consolidation_gate: bool = False, # Layer 4: Intrinsic reward use_intrinsic_reward: bool = False, curiosity_ema_decay: float = 0.95, competence_ema_decay: float = 0.99, # Layer 5: SPAR error threshold error_threshold: int = 2, # FE-MX compression use_femx: bool = False): super().__init__() self.num_layers = num_layers self.use_layer_specialization = use_layer_specialization self._error_threshold = error_threshold memories = [] for i in range(num_layers): if use_layer_specialization: # Ventral-dorsal axis (Surget & Belzung 2022): # position 0.0 = first layer (dorsal), 1.0 = last layer (ventral) position = i / max(num_layers - 1, 1) # Plasticity decreases with depth (early=high lr, late=low lr) layer_lr = lr * (1.5 - position) # 1.5x → 0.5x # Stability increases with depth (early=fast decay, late=slow decay) layer_decay = decay + (1.0 - decay) * position * 0.5 # 0.99 → 0.995 # Memory capacity larger in early layers layer_mem = int(memory_size * (1.3 - 0.6 * position)) # 1.3x → 0.7x layer_mem = max(16, layer_mem) # Floor at 16 slots # Temperature: early layers sharper competition, late layers softer layer_temp = temperature * (0.8 + 0.4 * position) # 0.8 → 1.2 else: layer_lr = lr layer_decay = decay layer_mem = memory_size layer_temp = temperature memories.append(HebbianMemory( dim, layer_mem, layer_lr, layer_decay, tau_e=tau_e, rare_correlation_pct=rare_correlation_pct, use_neuromodulation=use_neuromodulation, temperature=layer_temp, weight_radius=weight_radius, use_soft_hebbian=use_soft_hebbian, use_learned_modulator=use_learned_modulator, modulator_entropy_weight=modulator_entropy_weight, tau_fast=tau_fast, tau_slow=tau_slow, use_three_factor=use_three_factor, separation_strength=separation_strength, separation_threshold=separation_threshold, slot_activation_threshold=slot_activation_threshold, competition_strength=competition_strength, slot_recycle_after=slot_recycle_after, max_update_norm=max_update_norm, trace_clip=trace_clip, weight_clip=weight_clip, noise_scale=noise_scale, adaptive_slot_lr=adaptive_slot_lr, tau_age=tau_age, importance_scale=importance_scale, homeostatic_threshold=homeostatic_threshold, threshold_incr=threshold_incr, threshold_decr=threshold_decr, cosine_retrieval=cosine_retrieval, retrieval_tau=retrieval_tau, sparsity_lambda=sparsity_lambda, multi_timescale=multi_timescale, working_memory_ratio=working_memory_ratio, structural_plasticity=structural_plasticity, merge_threshold=merge_threshold, use_trace_filter=use_trace_filter, use_mesu=use_mesu, mesu_sigma_prior=mesu_sigma_prior, mesu_sigma_res=mesu_sigma_res, use_bayesian_reward=use_bayesian_reward, reward_prior_mean=reward_prior_mean, reward_prior_var=reward_prior_var, identity_init=identity_init, use_consolidation=use_consolidation, consolidation_interval=consolidation_interval, consolidation_threshold=consolidation_threshold, consolidated_decay=consolidated_decay, consolidation_ratio=consolidation_ratio, adaptive_transfer=adaptive_transfer, transfer_ema_decay=transfer_ema_decay, transfer_demotion=transfer_demotion, transfer_demotion_rate=transfer_demotion_rate, kg_consolidation_gate=kg_consolidation_gate, use_intrinsic_reward=use_intrinsic_reward, curiosity_ema_decay=curiosity_ema_decay, competence_ema_decay=competence_ema_decay, error_threshold=error_threshold, use_femx=use_femx, )) self.memories = nn.ModuleList(memories) def forward(self, x: torch.Tensor, layer_idx: int, update: bool = True, reward: Optional[float] = None) -> torch.Tensor: return self.memories[layer_idx](x, update, reward) def retrieve_only(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: """Retrieve from a specific layer's memory without updating.""" return self.memories[layer_idx].retrieve_only(x) def get_modulator_entropy_loss(self) -> torch.Tensor: """Aggregate entropy loss across all layers.""" total = torch.tensor(0.0, device=self.memories[0].fast_weight.device) for mem in self.memories: total = total + mem.get_modulator_entropy_loss() return total def get_pattern_separation_loss(self) -> torch.Tensor: """Aggregate pattern separation loss across all layers.""" total = torch.tensor(0.0, device=self.memories[0].fast_weight.device) for mem in self.memories: total = total + mem.get_pattern_separation_loss() return total def reset(self): for mem in self.memories: mem.reset() def save_state(self) -> Dict[str, Any]: return {f'layer_{i}': mem.save_state() for i, mem in enumerate(self.memories)} def load_state(self, state: Dict[str, Any]): for i, mem in enumerate(self.memories): key = f'layer_{i}' if key in state: mem.load_state(state[key]) def inject_kg_signal(self, score: float): """Forward KG signal to all per-layer memories.""" for mem in self.memories: mem.inject_kg_signal(score) def compute_intrinsic_reward(self) -> Dict[str, float]: """Average intrinsic reward across all layers.""" c_sum, k_sum, count = 0.0, 0.0, 0 for mem in self.memories: r = mem.compute_intrinsic_reward() c_sum += r['curiosity'] k_sum += r['competence'] count += 1 if count == 0: return {'curiosity': 0.0, 'competence': 0.0, 'combined': 0.0} c, k = c_sum / count, k_sum / count return {'curiosity': round(c, 6), 'competence': round(k, 6), 'combined': round(max(-1.0, min(1.0, 0.5 * c + 0.5 * k)), 6)} def perceive(self, x_mean: torch.Tensor) -> Dict[str, float]: """Use layer 0 (rawest input) for familiarity assessment.""" if len(self.memories) > 0: return self.memories[0].perceive(x_mean) return {'familiarity': 0.0, 'complexity': 0.0} def reflect(self, loss: float, reward: float) -> Dict[str, Any]: """Average prediction error, max consecutive errors across layers.""" pe_sum, consec_max = 0.0, 0 for mem in self.memories: r = mem.reflect(loss, reward) pe_sum += r['prediction_error'] consec_max = max(consec_max, r['consecutive_errors']) n = len(self.memories) error_detected = consec_max >= 5 return {'prediction_error': round(pe_sum / max(n, 1), 6), 'error_detected': error_detected, 'consecutive_errors': consec_max, 'recovery_action': 'boost_curiosity' if error_detected else 'none'} def get_stats(self) -> Dict[str, Any]: return {f'layer_{i}': mem.get_stats() for i, mem in enumerate(self.memories)} # ============================================================================ # SLICED CRAMER PRESERVATION (SCP) - Anti-Forgetting # ============================================================================ # Based on STELLAR (AD1180225.pdf) Equations 14-17 # Prevents catastrophic forgetting by preserving distribution of representations class SlicedCramerPreservation(nn.Module): """ Sliced Cramer Preservation for continual learning without catastrophic forgetting. STELLAR Innovation: Instead of preserving individual samples (EWC/MAS), SCP preserves the DISTRIBUTION of layer representations, enabling: - 25% less forgetting of old tasks - 30% improved performance on new tasks - Better utilization of network capacity Reference: STELLAR Equations 14-17 """ def __init__(self, num_slices: int = 100, # Monte Carlo samples for slicing lambda_reg: float = 1.0, # Regularization strength alpha_ema: float = 0.1, # EMA for importance accumulation ): super().__init__() self.num_slices = num_slices self.lambda_reg = lambda_reg self.alpha_ema = alpha_ema # Storage for importance weights (diagonal of Γ matrix) self.importance: Dict[str, torch.Tensor] = {} self.optimal_params: Dict[str, torch.Tensor] = {} self.task_count = 0 def compute_importance(self, model: nn.Module, dataloader, device: torch.device) -> Dict[str, torch.Tensor]: """ Compute importance parameters Γ for all model parameters. Uses sliced Cramer distance approximation (STELLAR Eq. 16): [Γ]_i,i = (1/L) Σ_l (d(ξ_l · z̄) / dθ_i)² Args: model: The model to compute importance for dataloader: DataLoader with samples from current task device: Device to compute on Returns: Dictionary mapping parameter names to importance tensors """ model.eval() importance = {name: torch.zeros_like(p, device=device) for name, p in model.named_parameters() if p.requires_grad} # Generate random slicing directions (ξ vectors on unit sphere) # Sample from standard normal and normalize num_samples = 0 for batch in dataloader: if isinstance(batch, (tuple, list)): x = batch[0].to(device) else: x = batch.to(device) # Forward pass to get representations with torch.enable_grad(): # Get intermediate representations (before final layer) z = model(x) # Assume model returns logits/representations # Compute z̄ (mean representation) z_bar = z.mean(dim=0) # [D] # Monte Carlo estimate of Γ diagonal for _ in range(self.num_slices): # Random direction on unit sphere xi = torch.randn_like(z_bar) xi = xi / (xi.norm() + 1e-8) # Projection: ξ · z̄ proj = (xi * z_bar).sum() # Compute gradient w.r.t. all parameters model.zero_grad() proj.backward(retain_graph=True) # Accumulate squared gradients for name, p in model.named_parameters(): if p.grad is not None: importance[name] += (p.grad ** 2) num_samples += x.shape[0] # Normalize by number of samples and slices for name in importance: importance[name] /= (num_samples * self.num_slices) return importance def register_task(self, model: nn.Module, dataloader, device: torch.device): """ Register completion of a task by computing and storing importance. Uses EMA for multi-task accumulation (STELLAR Eq. 29): Γ^(t) = α·Γ_θ*_t + (1-α)·Γ^(t-1) """ new_importance = self.compute_importance(model, dataloader, device) # Store optimal parameters for this task for name, p in model.named_parameters(): if p.requires_grad: self.optimal_params[name] = p.data.clone() # Accumulate importance with EMA if self.task_count == 0: self.importance = new_importance else: for name in new_importance: self.importance[name] = ( self.alpha_ema * new_importance[name] + (1 - self.alpha_ema) * self.importance[name] ) self.task_count += 1 def penalty(self, model: nn.Module) -> torch.Tensor: """ Compute SCP regularization penalty (STELLAR Eq. 17). L_B(θ) + λ Σ_m [Γ]_m,m · (θ - θ*)²_m Returns: Scalar penalty to add to task loss """ if self.task_count == 0: return torch.tensor(0.0) penalty = 0.0 for name, p in model.named_parameters(): if name in self.importance and name in self.optimal_params: diff = p - self.optimal_params[name] penalty += (self.importance[name] * diff ** 2).sum() return self.lambda_reg * penalty # ============================================================================ # CONTEXT-SKILL MODEL - Dual Temporal Scales # ============================================================================ # Based on STELLAR Context-Skill Model (Section 1.2.5, 1.3.5) # Achieves 30% better forward transfer via explicit context representation class ContextSkillModel(nn.Module): """ Context-Skill Model with dual temporal scales for rapid adaptation. STELLAR Innovation: Separates processing into: - Skill module: Fast, feedforward response to current situation - Context module: LSTM integrating observations over time The context modulates skill execution, enabling immediate adaptation to new task variations without retraining. Reference: STELLAR Figures 19-20, Section 1.3.5 """ def __init__(self, input_dim: int, skill_hidden: int = 64, skill_output: int = 32, context_hidden: int = 64, controller_hidden: int = 64, output_dim: int = None, ): super().__init__() self.input_dim = input_dim self.skill_hidden = skill_hidden self.skill_output = skill_output self.context_hidden = context_hidden self.output_dim = output_dim or input_dim # Skill module: Fast feedforward network self.skill = nn.Sequential( nn.Linear(input_dim, skill_hidden), nn.GELU(), nn.Linear(skill_hidden, skill_output), ) # Context module: LSTM for temporal integration self.context = nn.LSTMCell(input_dim, context_hidden) # Controller: Combines skill and context self.controller = nn.Sequential( nn.Linear(skill_output + context_hidden, controller_hidden), nn.GELU(), nn.Linear(controller_hidden, self.output_dim), ) # Hidden state for context LSTM self._h: Optional[torch.Tensor] = None self._c: Optional[torch.Tensor] = None def reset_context(self): """Reset LSTM hidden state (call at start of new task).""" self._h = None self._c = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with context-skill separation. Args: x: Input tensor [batch, dim] or [batch, seq, dim] Returns: Output tensor with context-modulated skill """ # Handle sequence input if x.dim() == 3: B, S, D = x.shape outputs = [] for t in range(S): out = self._forward_step(x[:, t, :]) outputs.append(out) return torch.stack(outputs, dim=1) else: return self._forward_step(x) def _forward_step(self, x: torch.Tensor) -> torch.Tensor: """Single step forward pass.""" B = x.shape[0] device = x.device # Initialize hidden states if needed if self._h is None or self._h.shape[0] != B: self._h = torch.zeros(B, self.context_hidden, device=device) self._c = torch.zeros(B, self.context_hidden, device=device) # Skill module: immediate response skill_out = self.skill(x) # [B, skill_output] # Context module: temporal integration (LSTM) self._h, self._c = self.context(x, (self._h, self._c)) context_out = self._h # [B, context_hidden] # Controller: combine skill with context modulation combined = torch.cat([skill_out, context_out], dim=-1) output = self.controller(combined) return output def get_context_state(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Return current context hidden state for analysis.""" return self._h, self._c # ============================================================================ # MULTIMODAL FUSION # ============================================================================ class SigLIPEmbeddings(nn.Module): """Patch + position embeddings for SigLIP vision encoder.""" def __init__(self, hidden_size: int = 1152, image_size: int = 448, patch_size: int = 14): super().__init__() self.num_patches = (image_size // patch_size) ** 2 # 1024 self.patch_embedding = nn.Conv2d( 3, hidden_size, kernel_size=patch_size, stride=patch_size, bias=True ) self.position_embedding = nn.Embedding(self.num_patches, hidden_size) self.register_buffer( "position_ids", torch.arange(self.num_patches, dtype=torch.long).unsqueeze(0), persistent=False, ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """pixel_values: [B, 3, H, W] -> [B, num_patches, hidden_size]""" patches = self.patch_embedding(pixel_values) # [B, D, H/P, W/P] patches = patches.flatten(2).transpose(1, 2) # [B, N, D] embeddings = patches + self.position_embedding(self.position_ids) return embeddings class SigLIPAttention(nn.Module): """Multi-head attention for SigLIP vision encoder.""" def __init__(self, hidden_size: int = 1152, num_heads: int = 16): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads # 72 self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(hidden_size, hidden_size) self.k_proj = nn.Linear(hidden_size, hidden_size) self.v_proj = nn.Linear(hidden_size, hidden_size) self.out_proj = nn.Linear(hidden_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, D = x.shape q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) out = out.transpose(1, 2).reshape(B, N, D) return self.out_proj(out) class SigLIPMLP(nn.Module): """Feed-forward MLP for SigLIP vision encoder.""" def __init__(self, hidden_size: int = 1152, intermediate_size: int = 4304): super().__init__() self.fc1 = nn.Linear(hidden_size, intermediate_size) self.fc2 = nn.Linear(intermediate_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(F.gelu(self.fc1(x), approximate="tanh")) class SigLIPEncoderLayer(nn.Module): """Single transformer layer for SigLIP (pre-norm residual).""" def __init__(self, hidden_size: int = 1152, intermediate_size: int = 4304, num_heads: int = 16): super().__init__() self.self_attn = SigLIPAttention(hidden_size, num_heads) self.mlp = SigLIPMLP(hidden_size, intermediate_size) self.layer_norm1 = nn.LayerNorm(hidden_size, eps=1e-6) self.layer_norm2 = nn.LayerNorm(hidden_size, eps=1e-6) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.self_attn(self.layer_norm1(x)) x = x + self.mlp(self.layer_norm2(x)) return x class SigLIPVisionEncoder(nn.Module): """Full SigLIP vision encoder matching Phi-4-multimodal-instruct. Architecture: 27-layer ViT with 1152-dim hidden, 448x448 input. Output pipeline: patches -> penultimate layer -> AvgPool2d -> HD transform -> MLP proj. Produces 545 vision tokens at text hidden_size (3072) dimension. """ def __init__(self, config: 'FireEchoConfig'): super().__init__() hs = config.vision_hidden_size # 1152 ims = config.vision_intermediate_size # 4304 nh = config.vision_num_heads # 16 nl = config.vision_num_layers # 27 img = config.vision_image_size # 448 ps = config.vision_patch_size # 14 text_dim = config.dim # 3072 self.grid_size = img // ps # 32 self.embeddings = SigLIPEmbeddings(hs, img, ps) self.encoder_layers = nn.ModuleList([ SigLIPEncoderLayer(hs, ims, nh) for _ in range(nl) ]) self.post_layernorm = nn.LayerNorm(hs, eps=1e-6) # Compress 32x32=1024 patches -> 16x16=256 tokens self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) # Project from vision dim to text dim self.projection = nn.Sequential( nn.Linear(hs, text_dim), nn.GELU(), nn.Linear(text_dim, text_dim), ) # Learnable separators for HD transform (loaded from weights) self.glb_GN = nn.Parameter(torch.zeros(1, 1, hs)) # [1, 1, 1152] self.sub_GN = nn.Parameter(torch.zeros(1, 1, 1, hs)) # [1, 1, 1, 1152] def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Args: pixel_values: [B, 3, 448, 448] Returns: Vision embeddings [B, 545, text_dim] with HD transform for masked fusion. 545 = sub(272) + glb_GN(1) + global(272) for single 448×448 image. """ x = self.embeddings(pixel_values) # [B, 1024, 1152] # Phi-4 uses penultimate layer features (layer_idx=-2), NOT the final layer. # With 27 layers (0-26), layer -2 = after layer 25 — skip layer 26. for i, layer in enumerate(self.encoder_layers): x = layer(x) if i == len(self.encoder_layers) - 2: # after layer 25 break # post_layernorm is NOT applied — Phi-4 extracts raw hidden states B, N, D = x.shape H = self.grid_size # 32 # AvgPool2d: 32×32 → 16×16 = 256 tokens x = x.reshape(B, H, H, D) # [B, 32, 32, 1152] x = x.permute(0, 3, 1, 2) # [B, 1152, 32, 32] x = self.avg_pool(x) # [B, 1152, 16, 16] x = x.permute(0, 2, 3, 1) # [B, 16, 16, 1152] pH = H // 2 # 16 — pooled spatial grid # --- HD transform (matching Phi-4's use_hd_transform + sub_glb order) --- # For single 448×448 image: global and sub see the same features. # Per-row separator (sub_GN) acts as line break in the spatial grid. sep = self.sub_GN.expand(B, pH, -1, -1) # [B, 16, 1, 1152] # Sub-image: add separator per row → 16×17 = 272 tokens sub_img = torch.cat([x, sep], dim=2) # [B, 16, 17, 1152] sub_img = sub_img.reshape(B, -1, D) # [B, 272, 1152] # Global image: same features + same separator pattern glb_img = torch.cat([x, sep], dim=2) # [B, 16, 17, 1152] glb_img = glb_img.reshape(B, -1, D) # [B, 272, 1152] # Combine in sub_glb order: [sub, glb_GN separator, global] glb_GN = self.glb_GN.expand(B, -1, -1) # [B, 1, 1152] x = torch.cat([sub_img, glb_GN, glb_img], dim=1) # [B, 545, 1152] x = self.projection(x) # [B, 545, 3072] return x # ============================================================================ # CONFORMER AUDIO ENCODER (Phase 4) # 24-layer Conformer matching Phi-4-multimodal pretrained weights. # Architecture: MeanVarNorm → ConvSubsampling(8x) → 24×ConformerBlock → Projection # Total: ~441M encoder params + ~12.6M projection = ~454M params (~908 MB BF16) # ============================================================================ class MeanVarianceNormLayer(nn.Module): """Global mean/invstd normalization for mel features.""" def __init__(self, input_size: int = 80): super().__init__() self.register_buffer('global_mean', torch.zeros(input_size)) self.register_buffer('global_invstd', torch.ones(input_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (x - self.global_mean) * self.global_invstd class ConvSubsampling(nn.Module): """NeMo-style dw_striding conv subsampling (8x time reduction). Input: [B, T, n_mels] (mel spectrogram) Output: [B, T//8, feat_out] (subsampled features) Conv layout matches safetensor keys: embed.conv.{0,2,3,5,6} Indices 1, 4 are SiLU activations (no weights). """ def __init__(self, feat_in: int = 80, feat_out: int = 1024, conv_channels: int = 1024): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, conv_channels, 3, stride=2, padding=1), # 0: [C, 1, 3, 3] nn.SiLU(), # 1 nn.Conv2d(conv_channels, conv_channels, 3, stride=2, # 2: [C, 1, 3, 3] depthwise padding=1, groups=conv_channels), nn.Conv2d(conv_channels, conv_channels, 1), # 3: [C, C, 1, 1] pointwise nn.SiLU(), # 4 nn.Conv2d(conv_channels, conv_channels, 3, stride=2, # 5: [C, 1, 3, 3] depthwise padding=1, groups=conv_channels), nn.Conv2d(conv_channels, conv_channels, 1), # 6: [C, C, 1, 1] pointwise ) # After 3× stride-2 on freq axis: 80 → 40 → 20 → 10 freq_out = feat_in for _ in range(3): freq_out = (freq_out + 1) // 2 self.out = nn.Linear(conv_channels * freq_out, feat_out) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(1) # [B, 1, T, n_mels] x = self.conv(x) # [B, C, T//8, F//8] b, c, t, f = x.shape x = x.permute(0, 2, 1, 3).reshape(b, t, c * f) # [B, T//8, C*F] return self.out(x) # [B, T//8, feat_out] class T5RelativeAttentionBias(nn.Module): """T5-style relative position bias for Conformer self-attention. Asymmetric: 1000 buckets ([-500, 499] shifted to [0, 999]), 16 heads. Weight: bias_values.weight [1000, 16] """ def __init__(self, num_heads: int = 16, max_distance: int = 500): super().__init__() self.num_heads = num_heads self.max_distance = max_distance self.num_buckets = max_distance * 2 self.bias_values = nn.Embedding(self.num_buckets, num_heads) def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: positions = torch.arange(seq_len, device=device) rel_pos = positions.unsqueeze(1) - positions.unsqueeze(0) # [T, T] rel_pos = rel_pos.clamp(-self.max_distance, self.max_distance - 1) rel_pos = rel_pos + self.max_distance # [0, 999] bias = self.bias_values(rel_pos) # [T, T, H] return bias.permute(2, 0, 1).unsqueeze(0) # [1, H, T, T] class ConformerGLULinear(nn.Module): """GLU linear: Linear → swish-gated split. Weight key: net.0.linear.weight [out_dim*2, in_dim] """ def __init__(self, in_dim: int, out_dim: int): super().__init__() self.linear = nn.Linear(in_dim, out_dim * 2) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear(x) half, gate = x.chunk(2, dim=-1) return half * F.silu(gate) class ConformerFeedForward(nn.Module): """Conformer FFN: LayerNorm → GLULinear → Linear. Keys: layer_norm.*, net.0.linear.*, net.2.* """ def __init__(self, d_model: int = 1024, d_inner: int = 1536): super().__init__() self.layer_norm = nn.LayerNorm(d_model) self.net = nn.Sequential( ConformerGLULinear(d_model, d_inner), # 0: net.0.linear.* nn.Identity(), # 1: dropout placeholder nn.Linear(d_inner, d_model), # 2: net.2.* nn.Identity(), # 3: dropout placeholder ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(self.layer_norm(x)) class ConformerSelfAttention(nn.Module): """Multi-head self-attention with T5 relative position bias. Keys: linear_q.*, linear_k.*, linear_v.*, linear_out.* """ def __init__(self, d_model: int = 1024, n_head: int = 16): super().__init__() self.n_head = n_head self.d_k = d_model // n_head self.scale = self.d_k ** -0.5 self.linear_q = nn.Linear(d_model, d_model) self.linear_k = nn.Linear(d_model, d_model) self.linear_v = nn.Linear(d_model, d_model) self.linear_out = nn.Linear(d_model, d_model) def forward(self, x: torch.Tensor, relative_attention_bias: Optional[torch.Tensor] = None) -> torch.Tensor: B, T, _ = x.shape q = self.linear_q(x).view(B, T, self.n_head, self.d_k).transpose(1, 2) k = self.linear_k(x).view(B, T, self.n_head, self.d_k).transpose(1, 2) v = self.linear_v(x).view(B, T, self.n_head, self.d_k).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) * self.scale if relative_attention_bias is not None: attn = attn + relative_attention_bias attn = F.softmax(attn, dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, T, -1) return self.linear_out(out) class ConformerGLUPointWiseConv(nn.Module): """GLU pointwise conv for ConvModule. Keys: ext_pw_conv_1d.*, b1, b2 """ def __init__(self, d_model: int = 1024): super().__init__() self.d_model = d_model self.ext_pw_conv_1d = nn.Conv1d(d_model, d_model * 2, 1) self.b1 = nn.Parameter(torch.zeros(1, d_model, 1)) self.b2 = nn.Parameter(torch.zeros(1, d_model, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, D, T] (already channel-first) x = self.ext_pw_conv_1d(x) # [B, 2D, T] half = x[:, :self.d_model, :] gate = x[:, self.d_model:, :] return (half + self.b1) * F.silu(gate + self.b2) # [B, D, T] class ConformerDepthWiseSepConv(nn.Module): """Depthwise separable Conv1d. Keys: dw_conv.*, pw_conv.* """ def __init__(self, d_model: int = 1024, kernel_size: int = 3): super().__init__() padding = (kernel_size - 1) // 2 self.dw_conv = nn.Conv1d(d_model, d_model, kernel_size, padding=padding, groups=d_model) self.pw_conv = nn.Conv1d(d_model, d_model, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.pw_conv(self.dw_conv(x)) class ConformerConvModule(nn.Module): """Conformer convolution module. Keys: layer_norm.*, glu.*, dw_sep_conv_1d.*, ext_pw_conv_1d.* Forward: LayerNorm → GLU → DWSepConv → SiLU → ext_pw_conv """ def __init__(self, d_model: int = 1024, kernel_size: int = 3): super().__init__() self.layer_norm = nn.LayerNorm(d_model) self.glu = ConformerGLUPointWiseConv(d_model) self.dw_sep_conv_1d = ConformerDepthWiseSepConv(d_model, kernel_size) self.ext_pw_conv_1d = nn.Conv1d(d_model, d_model, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm(x) x = x.permute(0, 2, 1) # [B, D, T] x = self.glu(x) # [B, D, T] x = self.dw_sep_conv_1d(x) # [B, D, T] x = F.silu(x) x = self.ext_pw_conv_1d(x) # [B, D, T] return x.permute(0, 2, 1) # [B, T, D] class ConformerEncoderLayer(nn.Module): """Single Conformer block: FFN_in → SelfAttn → ConvModule → FFN_out → LayerNorm. Keys per layer: feed_forward_in.*, self_attn.*, conv.*, feed_forward_out.*, layer_norm_att.*, layer_norm.* 36 parameter tensors per layer. """ def __init__(self, d_model: int = 1024, n_head: int = 16, d_ffn: int = 1536, kernel_size: int = 3): super().__init__() self.feed_forward_in = ConformerFeedForward(d_model, d_ffn) self.self_attn = ConformerSelfAttention(d_model, n_head) self.conv = ConformerConvModule(d_model, kernel_size) self.feed_forward_out = ConformerFeedForward(d_model, d_ffn) self.layer_norm_att = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor, relative_attention_bias: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + 0.5 * self.feed_forward_in(x) x = x + self.self_attn(self.layer_norm_att(x), relative_attention_bias) x = x + self.conv(x) x = x + 0.5 * self.feed_forward_out(x) return self.layer_norm(x) class ConformerEncoderCore(nn.Module): """Core encoder: normalization + conv subsampling + Conformer layers + rel bias. Keys: encoder_embedding.*, embed.*, encoders.{i}.*, relative_attention_bias_layer.* """ def __init__(self, config: 'FireEchoConfig'): super().__init__() d = config.audio_hidden_size self.encoder_embedding = MeanVarianceNormLayer(config.audio_n_mels) self.embed = ConvSubsampling(config.audio_n_mels, d, config.audio_conv_channels) self.relative_attention_bias_layer = T5RelativeAttentionBias( config.audio_num_heads, config.audio_max_rel_distance) self.encoders = nn.ModuleList([ ConformerEncoderLayer(d, config.audio_num_heads, config.audio_ffn_dim, config.audio_kernel_size) for _ in range(config.audio_num_layers) ]) def forward(self, mel: torch.Tensor) -> torch.Tensor: x = self.encoder_embedding(mel) # [B, T, 80] normalized x = self.embed(x) # [B, T//8, 1024] rel_bias = self.relative_attention_bias_layer( x.shape[1], x.device) # [1, H, T', T'] for layer in self.encoders: x = layer(x, rel_bias) return x # [B, T//8, 1024] class ConformerEncoder(nn.Module): """Full Conformer audio encoder matching Phi-4-multimodal pretrained weights. Processes mel spectrograms through 24 Conformer layers with T5 relative attention bias, then projects to text decoder dimension via MLP. Input: [B, T, 80] (log-mel spectrogram) Output: [B, T//8, 3072] (audio tokens for masked fusion) Weight inventory: 887 tensors (~454M params) - encoder.*: 879 (embed + encoders + rel_bias + norm buffers) - audio_projection.speech.*: 4 (MLP projection to text dim) - audio_projection.vision.*: 4 (not loaded, vision audio branch) """ def __init__(self, config: 'FireEchoConfig'): super().__init__() d = config.audio_hidden_size text_dim = config.dim self.encoder = ConformerEncoderCore(config) self.audio_projection = nn.ModuleDict({ 'speech': nn.Sequential( nn.Linear(d, text_dim), # 0: [3072, 1024] nn.GELU(), # 1: no weights nn.Linear(text_dim, text_dim), # 2: [3072, 3072] ), }) def forward(self, mel: torch.Tensor) -> torch.Tensor: x = self.encoder(mel) # [B, T//8, 1024] return self.audio_projection['speech'](x) # [B, T//8, 3072] class MultimodalFusion(nn.Module): """Phi-4 style multimodal fusion via masked token replacement. Instead of concatenating vision/audio tokens, Phi-4 replaces placeholder positions (marked by special token IDs) in the text embedding sequence. This preserves sequence length, RoPE positions, and attention patterns. """ def __init__(self, config: 'FireEchoConfig', use_vision: bool = False, use_audio: bool = False): super().__init__() self.image_token_id = config.image_token_id # 200010 self.audio_token_id = config.audio_token_id # 200011 if use_vision: self.vision_encoder = SigLIPVisionEncoder(config) if use_audio: self.audio_encoder = ConformerEncoder(config) def encode_and_fuse(self, input_ids: torch.Tensor, text_embeds: torch.Tensor, images: Optional[torch.Tensor] = None, audio: Optional[torch.Tensor] = None) -> torch.Tensor: """Replace placeholder token positions with encoded modality embeddings. Args: input_ids: [B, seq_len] — token IDs (contains special markers) text_embeds: [B, seq_len, hidden_size] — text embeddings from embed_tokens images: [B, 3, 448, 448] or None audio: [B, samples] or None Returns: Modified text_embeds with vision/audio tokens in-place (same shape) """ if images is not None and hasattr(self, 'vision_encoder'): # Guard: skip encoding if no image placeholder tokens in input has_img_tokens = (input_ids == self.image_token_id).any() if has_img_tokens: image_embeds = self.vision_encoder(images) # [B, 545, hidden_size] # Find image placeholder positions and replace for b in range(input_ids.shape[0]): img_mask = (input_ids[b] == self.image_token_id) img_positions = img_mask.nonzero(as_tuple=True)[0] n_replace = min(len(img_positions), image_embeds.shape[1]) if n_replace > 0: text_embeds[b, img_positions[:n_replace]] = \ image_embeds[b, :n_replace].to(text_embeds.dtype) if audio is not None and hasattr(self, 'audio_encoder'): has_aud_tokens = (input_ids == self.audio_token_id).any() if has_aud_tokens: audio_embeds = self.audio_encoder(audio) # [B, frames, hidden_size] for b in range(input_ids.shape[0]): aud_mask = (input_ids[b] == self.audio_token_id) aud_positions = aud_mask.nonzero(as_tuple=True)[0] n_replace = min(len(aud_positions), audio_embeds.shape[1]) if n_replace > 0: text_embeds[b, aud_positions[:n_replace]] = \ audio_embeds[b, :n_replace].to(text_embeds.dtype) return text_embeds # ============================================================================ # ROTARY POSITION EMBEDDINGS (RoPE) # ============================================================================ def _precompute_rope_freqs(dim: int, max_seq_len: int, theta: float = 10000.0, device='cpu', dtype=torch.float32): """Precompute cos/sin tables for RoPE.""" freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) t = torch.arange(max_seq_len, device=device, dtype=dtype) freqs = torch.outer(t, freqs) return torch.cos(freqs), torch.sin(freqs) # each [max_seq_len, dim//2] def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position: int = 0): """Apply RoPE with partial rotation support. x: [B, nH, S, head_dim], cos/sin: [max_len, rotary_dim//2]. When rotary_dim < head_dim (partial RoPE, e.g. Phi-4 with 75%), only the first rotary_dim dimensions are rotated; the rest pass through. """ seq_len = x.shape[2] end = position + seq_len if end > cos.shape[0]: raise ValueError( f"RoPE position {position}+{seq_len}={end} exceeds buffer size {cos.shape[0]}. " f"Increase max_seq_len in config." ) cos_s = cos[position:end].unsqueeze(0).unsqueeze(0) # [1,1,S,rotary_dim//2] sin_s = sin[position:end].unsqueeze(0).unsqueeze(0) rotary_dim = cos_s.shape[-1] * 2 if rotary_dim < x.shape[-1]: # Partial RoPE: rotate first rotary_dim dims, pass through rest x_rot = x[..., :rotary_dim] x_pass = x[..., rotary_dim:] half = rotary_dim // 2 x1, x2 = x_rot[..., :half], x_rot[..., half:] x_rot = torch.cat([x1 * cos_s - x2 * sin_s, x2 * cos_s + x1 * sin_s], dim=-1) return torch.cat([x_rot, x_pass], dim=-1) else: half = x.shape[-1] // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat([x1 * cos_s - x2 * sin_s, x2 * cos_s + x1 * sin_s], dim=-1) def _apply_rotary_emb_ids(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): """Apply RoPE with per-token position IDs (FE-XT tree verification). Unlike _apply_rotary_emb which takes a scalar position start, this takes a tensor of per-token positions. Used for tree speculation where branches share the same logical positions but are processed in one batch. x: [B, nH, M, head_dim], cos/sin: [max_len, rotary_dim//2]. position_ids: [M] int64 tensor of per-token positions. """ cos_s = cos[position_ids].unsqueeze(0).unsqueeze(0) # [1, 1, M, rotary_dim//2] sin_s = sin[position_ids].unsqueeze(0).unsqueeze(0) rotary_dim = cos_s.shape[-1] * 2 if rotary_dim < x.shape[-1]: x_rot = x[..., :rotary_dim] x_pass = x[..., rotary_dim:] half = rotary_dim // 2 x1, x2 = x_rot[..., :half], x_rot[..., half:] x_rot = torch.cat([x1 * cos_s - x2 * sin_s, x2 * cos_s + x1 * sin_s], dim=-1) return torch.cat([x_rot, x_pass], dim=-1) else: half = x.shape[-1] // 2 x1, x2 = x[..., :half], x[..., half:] return torch.cat([x1 * cos_s - x2 * sin_s, x2 * cos_s + x1 * sin_s], dim=-1) def build_tree_attention_mask(b: int, depth: int, prefix_len: int, device='cuda') -> torch.Tensor: """Build block-diagonal causal mask for FE-XT tree speculation. Token layout: b branches × depth steps, flattened contiguously. Token[branch * depth + step] = branch `branch` at step `step`. Step 0 is the shared root (same token duplicated across branches). The mask ensures: - All draft tokens can attend to ALL prefix positions (existing KV cache). - Each token can attend to its own branch's tokens causally (step <= own step). - NO cross-branch attention (branch i cannot see branch j's draft tokens). Args: b: Number of tree branches depth: Draft sequence depth per branch prefix_len: Number of tokens already in KV cache device: Torch device Returns: mask: [1, 1, M, prefix_len + M] float tensor. 0.0 where attention is allowed, -inf where masked. M = b * depth. """ M = b * depth total_len = prefix_len + M # Start with all masked (-inf) mask = torch.full((M, total_len), float('-inf'), device=device, dtype=torch.float32) # All draft tokens attend to prefix mask[:, :prefix_len] = 0.0 # Each branch: causal within own branch only for branch in range(b): start = branch * depth for step in range(depth): token_idx = start + step # Attend to own branch tokens [start : start + step + 1] mask[token_idx, prefix_len + start: prefix_len + start + step + 1] = 0.0 return mask.unsqueeze(0).unsqueeze(0) # [1, 1, M, prefix_len + M] def build_tree_position_ids(b: int, depth: int, start_pos: int, device='cuda') -> torch.Tensor: """Build per-token RoPE position IDs for FE-XT tree speculation. All branches share the same positions [start_pos, ..., start_pos + depth - 1]. Repeated b times for the b branches. Args: b: Number of tree branches depth: Draft sequence depth per branch start_pos: Position of the first draft token (= current_pos in KV cache) device: Torch device Returns: position_ids: [M] int64 tensor, M = b * depth """ branch_positions = torch.arange(depth, device=device, dtype=torch.long) + start_pos return branch_positions.repeat(b) # [b * depth] # ============================================================================ # FUSED ATTENTION WITH REAL KV CACHE # ============================================================================ class FusedAttention(nn.Module): """Fused multi-head attention with real paged KV cache support.""" def __init__(self, dim: int, num_heads: int, num_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, layer_idx: int = 0, use_fused_qkv: bool = True, use_native_cutlass: bool = True, use_quantum_matmul: bool = True, use_dsmem_cluster: bool = True, rope_theta: float = 10000.0, max_seq_len: int = 4096, attn_bias: bool = False, use_qk_norm: bool = False, qk_norm_per_head: bool = False, partial_rotary_factor: float = 1.0, use_fused_norm_qkv: bool = True, use_fused_residual_norm: bool = True, use_fused_rope: bool = True, use_gqa_native: bool = True, use_goliath: bool = False, goliath_bits: int = 4): super().__init__() self.dim = dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads or num_heads # GQA support self.head_dim = head_dim or dim // num_heads self.layer_idx = layer_idx self.scale = self.head_dim ** -0.5 self.use_fused_qkv = use_fused_qkv self.use_native_cutlass = use_native_cutlass self.use_quantum_matmul = use_quantum_matmul self.use_dsmem_cluster = use_dsmem_cluster self.use_fused_norm_qkv = use_fused_norm_qkv self.use_fused_residual_norm = use_fused_residual_norm self.use_fused_rope = use_fused_rope self.use_gqa_native = use_gqa_native # Projections — keep as nn.Linear (BF16 cuBLAS is faster than Goliath FP4 # for M=1 attention GEMV due to dequantization overhead) self.quantize_attn = False self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=attn_bias) self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=attn_bias) self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=attn_bias) self.out_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False) # QK Normalization — two styles: # flat (Molmo2/OLMo): RMSNorm(num_heads * head_dim) on concatenated heads # per_head (Qwen3): RMSNorm(head_dim) applied to each head independently self.use_qk_norm = use_qk_norm self.qk_norm_per_head = qk_norm_per_head if use_qk_norm: if qk_norm_per_head: self.q_norm = nn.RMSNorm(self.head_dim) self.k_norm = nn.RMSNorm(self.head_dim) else: self.q_norm = nn.RMSNorm(num_heads * self.head_dim) self.k_norm = nn.RMSNorm(self.num_kv_heads * self.head_dim) # RoPE buffers — partial RoPE: only rotate first N% of head_dim (Phi-4: 75%) self.rotary_ndims = int(self.head_dim * partial_rotary_factor) self.rope_theta = rope_theta cos, sin = _precompute_rope_freqs(self.rotary_ndims, max_seq_len, theta=rope_theta) self.register_buffer('rope_cos', cos.to(torch.bfloat16)) self.register_buffer('rope_sin', sin.to(torch.bfloat16)) def _ensure_rope_length(self, needed_len: int): """Dynamically resize RoPE cos/sin buffers if position exceeds current size.""" current_max = self.rope_cos.shape[0] if needed_len <= current_max: return new_max = int(needed_len * 1.5) + 256 print(f"[RoPE] Layer {self.layer_idx}: resizing buffers {current_max} -> {new_max}") cos, sin = _precompute_rope_freqs( self.rotary_ndims, new_max, theta=self.rope_theta, device=self.rope_cos.device, dtype=torch.float32, ) # Replace registered buffers in-place (store as bfloat16 to match model dtype) self.rope_cos = cos.to(torch.bfloat16) self.rope_sin = sin.to(torch.bfloat16) def forward(self, x: torch.Tensor, kv_cache: Optional[PagedKVCache] = None, seq_id: int = 0, position: int = 0, use_cache: bool = False, norm_weight: Optional[torch.Tensor] = None, norm_eps: float = 1e-6) -> torch.Tensor: B, S, D = x.shape # Tier 0: Fused RMSNorm + QKV (eliminates intermediate normed tensor) # Only when norm_weight is provided (pre-norm path passes it from FusedTransformerBlock) used_fused_norm_qkv = False if (norm_weight is not None and self.use_fused_norm_qkv and not self.quantize_attn # Fused path needs .weight.T (not quantized) and not self.training and S * B >= 64 and x.is_cuda and self.q_proj.bias is None): # No bias support in fused kernel x_flat = x.view(B * S, D).to(torch.bfloat16) q_flat, k_flat, v_flat = fused_rmsnorm_qkv_projection( x_flat, norm_weight.to(torch.bfloat16), self.q_proj.weight.T.contiguous().to(torch.bfloat16), self.k_proj.weight.T.contiguous().to(torch.bfloat16), self.v_proj.weight.T.contiguous().to(torch.bfloat16), eps=norm_eps, ) q = q_flat.view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = k_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) used_fused_norm_qkv = True else: # Apply norm externally if norm_weight was provided but fused path not taken if norm_weight is not None: variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + norm_eps).to(x.dtype) * norm_weight # Tier 1: CUTLASS / DSMEM / Quantum fused QKV use_fused_qkv_backend = ( (self.use_native_cutlass and _CUTLASS_AVAILABLE) or (self.use_dsmem_cluster and _DSMEM_AVAILABLE) or (self.use_quantum_matmul and _QUANTUM_AVAILABLE) ) and not self.quantize_attn and not self.training and S * B >= 64 and x.is_cuda if use_fused_qkv_backend: x_flat = x.view(B * S, D).to(torch.bfloat16) wq = self.q_proj.weight.T.contiguous().to(torch.bfloat16) wk = self.k_proj.weight.T.contiguous().to(torch.bfloat16) wv = self.v_proj.weight.T.contiguous().to(torch.bfloat16) q_flat = _fused_matmul(x_flat, wq, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul) k_flat = _fused_matmul(x_flat, wk, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul) v_flat = _fused_matmul(x_flat, wv, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul) if self.q_proj.bias is not None: q_flat = q_flat + self.q_proj.bias.to(q_flat.dtype) k_flat = k_flat + self.k_proj.bias.to(k_flat.dtype) v_flat = v_flat + self.v_proj.bias.to(v_flat.dtype) q = q_flat.view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = k_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) elif self.use_fused_qkv and not self.quantize_attn and not self.training and S * B >= 64: # Tier 2: Fused Triton QKV projection (3x fewer memory loads) x_flat = x.view(B * S, D) q_flat, k_flat, v_flat = fused_qkv_projection( x_flat.to(torch.bfloat16), self.q_proj.weight.T.contiguous(), self.k_proj.weight.T.contiguous(), self.v_proj.weight.T.contiguous() ) if self.q_proj.bias is not None: q_flat = q_flat + self.q_proj.bias.to(q_flat.dtype) k_flat = k_flat + self.k_proj.bias.to(k_flat.dtype) v_flat = v_flat + self.v_proj.bias.to(v_flat.dtype) q = q_flat.view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = k_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = v_flat.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) else: # Tier 3: Standard separate projections q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply QK norm before RoPE if self.use_qk_norm: if self.qk_norm_per_head: # Per-head norm (Qwen3): RMSNorm(head_dim) on last dim, q is [B, H, S, D] q = self.q_norm(q) k = self.k_norm(k) else: # Flat norm (Molmo2/OLMo): RMSNorm(H*D) on concatenated heads B_q, nH, S_q, hd = q.shape q = self.q_norm(q.transpose(1, 2).reshape(B_q, S_q, nH * hd)).reshape(B_q, S_q, nH, hd).transpose(1, 2) B_k, nKH, S_k, hd_k = k.shape k = self.k_norm(k.transpose(1, 2).reshape(B_k, S_k, nKH * hd_k)).reshape(B_k, S_k, nKH, hd_k).transpose(1, 2) # CUDA graph-safe decode detection _graph_decode = (use_cache and kv_cache is not None and getattr(kv_cache, '_graph_mode', False) and S == 1) if _graph_decode: # FireEcho Graph RoPE — pre-loaded cos/sin from side buffers # No position indexing, no _ensure_rope_length, graph-safe _half = kv_cache._graph_rope_cos.shape[-1] _cos = kv_cache._graph_rope_cos.view(1, 1, 1, _half) _sin = kv_cache._graph_rope_sin.view(1, 1, 1, _half) q_r1 = q[..., :_half] * _cos - q[..., _half:2*_half] * _sin q_r2 = q[..., _half:2*_half] * _cos + q[..., :_half] * _sin k_r1 = k[..., :_half] * _cos - k[..., _half:2*_half] * _sin k_r2 = k[..., _half:2*_half] * _cos + k[..., :_half] * _sin if self.head_dim > 2 * _half: q = torch.cat([q_r1, q_r2, q[..., 2*_half:]], dim=-1) k = torch.cat([k_r1, k_r2, k[..., 2*_half:]], dim=-1) else: q = torch.cat([q_r1, q_r2], dim=-1) k = torch.cat([k_r1, k_r2], dim=-1) else: # Dynamic RoPE buffer resize for vision (545+ tokens) or long sequences self._ensure_rope_length(position + q.shape[2]) # Apply RoPE to Q and K — fused (single launch) or separate if self.use_fused_rope and not self.training and q.is_cuda: q, k = fused_rope_qk(q, k, self.rope_cos, self.rope_sin, position) else: q = _apply_rotary_emb(q, self.rope_cos, self.rope_sin, position) k = _apply_rotary_emb(k, self.rope_cos, self.rope_sin, position) # RoPE cos/sin may be float32 — cast q, k back to match v's dtype if q.dtype != v.dtype: q = q.to(v.dtype) k = k.to(v.dtype) # KV Cache handling if use_cache and kv_cache is not None: if _graph_decode: # === FireEcho CUDA Graph path: scatter + full view + mask === kv_cache.store_flat_scatter(self.layer_idx, k[0], v[0]) k = kv_cache.flat_k[self.layer_idx].unsqueeze(0) v = kv_cache.flat_v[self.layer_idx].unsqueeze(0) elif getattr(kv_cache, '_flat_mode', False): # === Flat decode cache: zero-copy, no torch.cat === # Direct write + view slice — eliminates tensor allocation per layer kv_cache.store_flat(self.layer_idx, position, k[0], v[0]) end_pos = position + S k_view, v_view = kv_cache.get_flat_view(self.layer_idx, end_pos) k = k_view.unsqueeze(0) # [1, kv_heads, end_pos, head_dim] v = v_view.unsqueeze(0) else: # === Original paged cache path === for b in range(B): k_to_store = k[b] v_to_store = v[b] kv_cache.store(seq_id + b, self.layer_idx, position, k_to_store, v_to_store) if position > 0: k_hist, v_hist = kv_cache.get(seq_id, self.layer_idx, num_tokens=position) if k_hist.shape[1] > 0: k_hist_batched = k_hist.unsqueeze(0).expand(B, -1, -1, -1) v_hist_batched = v_hist.unsqueeze(0).expand(B, -1, -1, -1) k = torch.cat([k_hist_batched, k], dim=2) v = torch.cat([v_hist_batched, v], dim=2) # GQA: native SDPA support (no repeat_interleave allocation) or legacy expand use_gqa = self.num_kv_heads < self.num_heads if _graph_decode: # === FireEcho FlashDecode — Triton M=1 GQA with online softmax === # Single kernel launch per layer, reads only valid KV positions. # CUDA graph safe: valid_len read from GPU tensor inside kernel. if getattr(kv_cache, '_flat_kv_dtype', 'bf16') == 'fp8': # FP8 path: inline E4M3 dequant in Triton kernel out = fireecho_flash_decode_fp8( q, k, v, kv_cache.flat_k_scales[self.layer_idx].unsqueeze(0), kv_cache.flat_v_scales[self.layer_idx].unsqueeze(0), kv_cache._graph_valid_len, self.scale) else: # BF16 path: original FlashDecode out = fireecho_flash_decode( q, k, v, kv_cache._graph_valid_len, self.scale) else: # Attention — CUTLASS TMA or SDPA # CUTLASS TMA may not handle all KV lengths reliably; cap at 512 tokens kv_len = k.shape[2] use_cutlass_attn = ( self.use_native_cutlass and _CUTLASS_AVAILABLE and _cutlass_tma_attention is not None and not self.training and q.is_cuda and kv_len <= 512 ) is_causal = S > 1 and position == 0 if use_cutlass_attn: # CUTLASS doesn't support native GQA — must expand if use_gqa: repeat_factor = self.num_heads // self.num_kv_heads k_attn = k.repeat_interleave(repeat_factor, dim=1) v_attn = v.repeat_interleave(repeat_factor, dim=1) else: k_attn, v_attn = k, v try: out = _cutlass_tma_attention( q, k_attn, v_attn, scale=self.scale, is_causal=is_causal, ) except Exception: out = F.scaled_dot_product_attention(q, k_attn, v_attn, is_causal=is_causal) elif use_gqa and self.use_gqa_native: try: out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal, enable_gqa=True) except TypeError: repeat_factor = self.num_heads // self.num_kv_heads k_expanded = k.repeat_interleave(repeat_factor, dim=1) v_expanded = v.repeat_interleave(repeat_factor, dim=1) out = F.scaled_dot_product_attention(q, k_expanded, v_expanded, is_causal=is_causal) else: if use_gqa: repeat_factor = self.num_heads // self.num_kv_heads k = k.repeat_interleave(repeat_factor, dim=1) v = v.repeat_interleave(repeat_factor, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) out = out.transpose(1, 2).reshape(B, S, -1) # Output projection - fused matmul (CUTLASS / DSMEM / Quantum) or linear use_fused_out = ( ((self.use_native_cutlass and _CUTLASS_AVAILABLE) or (self.use_dsmem_cluster and _DSMEM_AVAILABLE) or (self.use_quantum_matmul and _QUANTUM_AVAILABLE)) and not self.quantize_attn # Can't access .weight.T when Goliath-quantized and not self.training and out.is_cuda ) if use_fused_out: out_flat = out.reshape(-1, self.num_heads * self.head_dim).to(torch.bfloat16) wo = self.out_proj.weight.T.contiguous().to(torch.bfloat16) out = _fused_matmul(out_flat, wo, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul).view(B, S, self.dim) else: out = self.out_proj(out) return out # ============================================================================ # FUSED FFN WITH NVFP4 # ============================================================================ class QuantizedLinear(nn.Module): """Linear layer with Goliath fused FP4/FP8 quantization. When Goliath is available, weights are quantized to GoliathFP4/FP8 on first inference forward. The fused kernel dequantizes in registers during matmul — no separate dequantize step, no global-memory materialisation of BF16 weights. Falls back to the legacy int8 dual-scaling path when Goliath is unavailable. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, use_nvfp4: bool = True, block_size: int = 32, goliath_bits: Union[int, str] = 4, use_goliath: bool = True, w4a4_mode: bool = False, decode_skip_act_quant: bool = True, decode_act_quant_threshold: int = 64, compute_residual: bool = False): super().__init__() self.in_features = in_features self.out_features = out_features self.use_nvfp4 = use_nvfp4 self.block_size = block_size self.goliath_bits = goliath_bits self.use_goliath = use_goliath and _GOLIATH_AVAILABLE self.compute_residual = compute_residual self.w4a4_mode = w4a4_mode self.decode_skip_act_quant = decode_skip_act_quant self.decode_act_quant_threshold = decode_act_quant_threshold # Full precision weights (will be quantized during forward) self.weight = nn.Parameter(torch.empty(out_features, in_features)) if bias: self.bias = nn.Parameter(torch.zeros(out_features)) else: self.register_parameter('bias', None) # Goliath quantization cache (populated on first forward) self._goliath_weights = None # GoliathFP4Weights or GoliathFP8Weights # Legacy quantization cache (fallback when Goliath unavailable) self.q_weight: Optional[torch.Tensor] = None self.block_scale: Optional[torch.Tensor] = None self.global_scale: Optional[torch.Tensor] = None self._quantized = False nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def _quantize_weights(self): """Quantize weights (Goliath fused or legacy int8).""" if self._quantized: return if not self.use_nvfp4: return if self.use_goliath: # Goliath: weight is [out, in], kernel expects [K=in, N=out] w_kn = self.weight.data.T.contiguous().float() # compute_residual only applies to FP4 (goliath_bits=4) use_res = self.compute_residual and self.goliath_bits in (4, 'auto') self._goliath_weights = _goliath_quantize( w_kn, bits=self.goliath_bits, training=self.training, sr_seed=self.weight._version if self.training else None, compute_residual=use_res, ) else: # Legacy int8 dual-scaling fallback self.q_weight, self.block_scale, self.global_scale = quantize_nvfp4( self.weight.data, self.block_size ) self._quantized = True def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_nvfp4 and not self.training: self._quantize_weights() if self.use_goliath and self._goliath_weights is not None: # Fused dequant-matmul: x [*, in] → out [*, out] orig_shape = x.shape[:-1] x_flat = x.reshape(-1, self.in_features) if self.w4a4_mode and not (self.decode_skip_act_quant and x_flat.shape[0] <= self.decode_act_quant_threshold): x_flat = _apply_act_quant(x_flat) out = _goliath_gemm(x_flat, self._goliath_weights, self.bias) return out.view(*orig_shape, self.out_features) # Legacy: dequantize → F.linear w = dequantize_nvfp4(self.q_weight, self.block_scale, self.global_scale, self.weight.shape) w = w.to(x.dtype) return F.linear(x, w, self.bias) # Full precision during training return F.linear(x, self.weight, self.bias) def invalidate_cache(self): """Invalidate quantization cache (call after weight update).""" self._quantized = False self._goliath_weights = None self.q_weight = None class FusedFFN(nn.Module): """Fused feed-forward network with Goliath FP4/FP8 fused dequant-matmul. Dispatch priority (inference): 1. Goliath fused kernel — dequant happens in Triton registers, zero extra traffic 2. CUTLASS TMA / DSMEM / Quantum BF16 matmul (after legacy dequant to BF16) 3. Fused SwiGLU Triton kernel (after legacy dequant) 4. nn.Module fallback (QuantizedLinear.forward / nn.Linear.forward) """ def __init__(self, dim: int, ffn_dim: int, use_nvfp4: bool = True, is_critical: bool = False, use_fused_swiglu: bool = True, use_native_cutlass: bool = True, use_quantum_matmul: bool = True, use_dsmem_cluster: bool = True, goliath_bits: Union[int, str] = 4, use_goliath: bool = True, w4a4_mode: bool = False, use_goliath_linear: bool = False, decode_skip_act_quant: bool = True, decode_act_quant_threshold: int = 64, compute_residual: bool = False): super().__init__() self.dim = dim self.ffn_dim = ffn_dim self.use_nvfp4 = use_nvfp4 and not is_critical # Critical layers stay FP16 self.use_fused_swiglu = use_fused_swiglu self.use_native_cutlass = use_native_cutlass self.use_quantum_matmul = use_quantum_matmul self.use_dsmem_cluster = use_dsmem_cluster self.use_goliath = use_goliath and _GOLIATH_AVAILABLE self.w4a4_mode = w4a4_mode self.decode_skip_act_quant = decode_skip_act_quant self.decode_act_quant_threshold = decode_act_quant_threshold self.use_goliath_linear = use_goliath_linear and _GoliathLinear is not None self.compute_residual = compute_residual if self.use_goliath_linear: # GoliathLinear: FP4/FP8 quantized forward + FP32 backward for training self.gate_proj = _GoliathLinear(dim, ffn_dim, bias=False, bits=goliath_bits) self.up_proj = _GoliathLinear(dim, ffn_dim, bias=False, bits=goliath_bits) self.down_proj = _GoliathLinear(ffn_dim, dim, bias=False, bits=goliath_bits) elif self.use_nvfp4: self.gate_proj = QuantizedLinear(dim, ffn_dim, bias=False, goliath_bits=goliath_bits, use_goliath=use_goliath, w4a4_mode=w4a4_mode, decode_skip_act_quant=decode_skip_act_quant, decode_act_quant_threshold=decode_act_quant_threshold, compute_residual=compute_residual) self.up_proj = QuantizedLinear(dim, ffn_dim, bias=False, goliath_bits=goliath_bits, use_goliath=use_goliath, w4a4_mode=w4a4_mode, decode_skip_act_quant=decode_skip_act_quant, decode_act_quant_threshold=decode_act_quant_threshold, compute_residual=compute_residual) self.down_proj = QuantizedLinear(ffn_dim, dim, bias=False, goliath_bits=goliath_bits, use_goliath=use_goliath, w4a4_mode=w4a4_mode, decode_skip_act_quant=decode_skip_act_quant, decode_act_quant_threshold=decode_act_quant_threshold, compute_residual=compute_residual) else: self.gate_proj = nn.Linear(dim, ffn_dim, bias=False) self.up_proj = nn.Linear(dim, ffn_dim, bias=False) self.down_proj = nn.Linear(ffn_dim, dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: B_orig, S_orig = x.shape[0], x.shape[1] # ------------------------------------------------------------------ # Tier 1: Goliath fused dequant-matmul (best path) # Each goliath_gemm call dequantises FP4/FP8 → BF16 in Triton registers # and fuses the matmul — no global-memory BF16 weight materialisation. # Skip when using GoliathLinear — it handles its own dispatch. # ------------------------------------------------------------------ if (self.use_goliath and self.use_nvfp4 and not self.use_goliath_linear and not self.training and x.is_cuda and x.numel() >= 4096): x_flat = x.view(-1, self.dim) # Ensure weights are Goliath-quantized (one-time) self.gate_proj._quantize_weights() self.up_proj._quantize_weights() self.down_proj._quantize_weights() # W4A4: quantize activations before each GEMM (skip during decode for speed) M = x_flat.shape[0] apply_aq = self.w4a4_mode and not (self.decode_skip_act_quant and M <= self.decode_act_quant_threshold) if apply_aq: x_flat = _apply_act_quant(x_flat) gate = _goliath_gemm(x_flat, self.gate_proj._goliath_weights) up = _goliath_gemm(x_flat, self.up_proj._goliath_weights) hidden = F.silu(gate) * up hidden_flat = hidden.view(-1, self.ffn_dim) if apply_aq: hidden_flat = _apply_act_quant(hidden_flat) out = _goliath_gemm(hidden_flat, self.down_proj._goliath_weights) return out.view(B_orig, S_orig, self.dim) # ------------------------------------------------------------------ # Tier 2: CUTLASS TMA / DSMEM / Quantum BF16 matmul (legacy dequant) # Skip when using GoliathLinear — it handles its own dispatch in Tier 4. # ------------------------------------------------------------------ use_fused_ffn = ( not self.use_goliath_linear and ((self.use_native_cutlass and _CUTLASS_AVAILABLE) or (self.use_dsmem_cluster and _DSMEM_AVAILABLE) or (self.use_quantum_matmul and _QUANTUM_AVAILABLE)) and not self.training and x.is_cuda and x.numel() >= 4096 ) if use_fused_ffn: x_flat = x.view(-1, self.dim).to(torch.bfloat16) if self.use_nvfp4: self.gate_proj._quantize_weights() self.up_proj._quantize_weights() gate_w = dequantize_nvfp4( self.gate_proj.q_weight, self.gate_proj.block_scale, self.gate_proj.global_scale, self.gate_proj.weight.shape ).T.contiguous().to(torch.bfloat16) up_w = dequantize_nvfp4( self.up_proj.q_weight, self.up_proj.block_scale, self.up_proj.global_scale, self.up_proj.weight.shape ).T.contiguous().to(torch.bfloat16) else: gate_w = self.gate_proj.weight.T.contiguous().to(torch.bfloat16) up_w = self.up_proj.weight.T.contiguous().to(torch.bfloat16) gate = _fused_matmul(x_flat, gate_w, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul) up = _fused_matmul(x_flat, up_w, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul) hidden = F.silu(gate) * up hidden = hidden.view(B_orig, S_orig, self.ffn_dim) hidden_flat = hidden.reshape(-1, self.ffn_dim).to(torch.bfloat16) w_down = self.down_proj.weight.T.contiguous().to(torch.bfloat16) out = _fused_matmul(hidden_flat, w_down, self.use_native_cutlass, self.use_dsmem_cluster, self.use_quantum_matmul).view(B_orig, S_orig, self.dim) return out # ------------------------------------------------------------------ # Tier 3: Fused SwiGLU Triton (legacy dequant) # Skip when using GoliathLinear — falls through to Tier 4. # ------------------------------------------------------------------ if (self.use_fused_swiglu and not self.use_goliath_linear and not self.training and x.numel() >= 4096): if self.use_nvfp4: self.gate_proj._quantize_weights() self.up_proj._quantize_weights() gate_w = dequantize_nvfp4( self.gate_proj.q_weight, self.gate_proj.block_scale, self.gate_proj.global_scale, self.gate_proj.weight.shape ).T.contiguous() up_w = dequantize_nvfp4( self.up_proj.q_weight, self.up_proj.block_scale, self.up_proj.global_scale, self.up_proj.weight.shape ).T.contiguous() else: gate_w = self.gate_proj.weight.T.contiguous() up_w = self.up_proj.weight.T.contiguous() x_flat = x.view(-1, self.dim).to(torch.bfloat16) hidden = fused_swiglu(x_flat, gate_w, up_w) hidden = hidden.view(B_orig, S_orig, self.ffn_dim) else: # Tier 4: nn.Module fallback gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) hidden = gate * up return self.down_proj(hidden) # ============================================================================ # MIXTURE-OF-EXPERTS FFN # ============================================================================ class MoERouter(nn.Module): """Top-k expert router with normalized probabilities.""" def __init__(self, dim: int, num_experts: int, num_experts_per_tok: int, norm_topk_prob: bool = True): super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts_per_tok = num_experts_per_tok self.norm_topk_prob = norm_topk_prob def forward(self, x: torch.Tensor): # x: [B*S, D] -> scores: [B*S, num_experts] # Compute gate scores in float32 for softmax stability scores = F.linear(x.float(), self.gate.weight.float(), self.gate.bias) topk_weights, topk_ids = torch.topk( scores, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: topk_weights = F.softmax(topk_weights.float(), dim=-1).to(x.dtype) else: topk_weights = topk_weights.softmax(dim=-1).to(x.dtype) return topk_weights, topk_ids class ExpertFFN(nn.Module): """Single MoE expert: SwiGLU FFN with fused gate+up projection. Fuses gate_proj and up_proj into a single matmul (gate_up_proj), cutting kernel launches from 3 to 2 per expert forward. With 8 active experts × 48 layers, this eliminates 384 kernel launches per token. """ def __init__(self, dim: int, intermediate_size: int, use_nvfp4: bool = True, goliath_bits: int = 4, use_goliath: bool = True, is_critical: bool = False, compute_residual: bool = False): super().__init__() self.intermediate_size = intermediate_size # Fused gate+up: single [dim → 2*intermediate] matmul self.gate_up_proj = QuantizedLinear( dim, 2 * intermediate_size, use_nvfp4=use_nvfp4, goliath_bits=goliath_bits, use_goliath=use_goliath, compute_residual=compute_residual) self.down_proj = QuantizedLinear( intermediate_size, dim, use_nvfp4=use_nvfp4, goliath_bits=goliath_bits, use_goliath=use_goliath, compute_residual=compute_residual) def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up = self.gate_up_proj(x) # [*, 2*intermediate] gate, up = gate_up.split(self.intermediate_size, dim=-1) return self.down_proj(F.silu(gate) * up) class MoEFFN(nn.Module): """Mixture-of-Experts FFN layer with top-k routing. Scatter-gather dispatch: each token is routed to top-k experts, weighted by normalized router probabilities. Simple loop implementation for correctness; can be optimized with grouped GEMM later. """ def __init__(self, dim: int, config: FireEchoConfig): super().__init__() self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.router = MoERouter( dim, config.num_experts, config.num_experts_per_tok, norm_topk_prob=config.norm_topk_prob) is_critical = False # Experts are never critical (too many) self.experts = nn.ModuleList([ ExpertFFN( dim, config.moe_intermediate_size, use_nvfp4=config.use_nvfp4 and config.quantize_weights, goliath_bits=config.goliath_bits, use_goliath=config.use_goliath, is_critical=is_critical, compute_residual=config.use_residual_correction, ) for _ in range(config.num_experts) ]) # Track expert usage for diagnostics + FE-MX tier assignment self.register_buffer('expert_usage', torch.zeros(config.num_experts)) self._forward_count = 0 self._experts_packed = False # True after pack_experts() called # FE-MX age-adaptive expert quantization self.use_femx_experts = config.use_femx_experts self.femx_cold_threshold = config.femx_cold_threshold self.femx_warm_threshold = config.femx_warm_threshold self.femx_tier_interval = config.femx_tier_interval if self.use_femx_experts: # Per-expert tier: 0=FEMX4 (cold), 1=FEMX6 (warm), 2=FEMX8 (hot) self.register_buffer('expert_tier', torch.full((config.num_experts,), 2, dtype=torch.uint8)) def update_expert_tiers(self): """Assign FE-MX precision tiers based on routing frequency. Cold experts (rarely routed) → FEMX4 (4.25 bits): save VRAM, acceptable quality loss Warm experts (moderate routing) → FEMX6 (6.25 bits): balanced Hot experts (frequently routed) → FEMX8 (8.25 bits): full quality preservation This is the novel optimization: experts that see fewer tokens contribute less to overall quality, so they can tolerate lower precision. Hot experts handle the critical paths and need maximum fidelity. If auto INT2 demotion is enabled, cold experts (tier 0) will be automatically converted to INT2 quantization for additional 50% bandwidth savings. """ if not self.use_femx_experts: return cold, warm, hot = 0, 0, 0 for e_idx in range(self.num_experts): usage = self.expert_usage[e_idx].item() if usage < self.femx_cold_threshold: self.expert_tier[e_idx] = 0 # FEMX4 cold += 1 elif usage < self.femx_warm_threshold: self.expert_tier[e_idx] = 1 # FEMX6 warm += 1 else: self.expert_tier[e_idx] = 2 # FEMX8 hot += 1 # Auto-demote cold experts: prefer FE-XC (higher quality), fall back to INT2 self._maybe_demote_to_fexc() self._maybe_demote_to_int2() return cold, warm, hot def get_expert_tier_summary(self) -> str: """Human-readable summary of expert tier distribution.""" if not self.use_femx_experts: return "FE-MX experts: disabled" tiers = self.expert_tier cold = (tiers == 0).sum().item() warm = (tiers == 1).sum().item() hot = (tiers == 2).sum().item() int2_count = self._expert_is_int2.sum().item() if getattr(self, '_int2_enabled', False) else 0 if int2_count > 0: return (f"FE-MX experts: {cold} cold({int2_count} INT2/{cold - int2_count} FP4) / " f"{warm} warm(FP6) / {hot} hot(FP8)") return f"FE-MX experts: {cold} cold(FP4) / {warm} warm(FP6) / {hot} hot(FP8)" # ================================================================ # FE-AGK: Atlas the Gatekeeper # Ban & Pick routing + MoDES adaptive expert skipping # ================================================================ def enable_atlas(self, ban_threshold: float = 0.01, modes_threshold: float = 2.0): """Enable FE-AGK (Atlas the Gatekeeper). Ban & Pick: Bans low-impact experts at runtime, reducing active experts from 8→~5 (1.25x throughput + easier draft routing). MoDES: Adaptive expert skipping for "easy" tokens where max router probability is below (modes_threshold × uniform_baseline). For 128 experts: uniform = 1/128 = 0.0078. modes_threshold=2.0 → skip when max_prob < 0.016 (~20-30% tokens) modes_threshold=3.0 → skip when max_prob < 0.023 (~10-20% tokens) Args: ban_threshold: experts with impact score below this are banned modes_threshold: multiplier on uniform baseline for MoE skip (2.0 = 2× uniform) """ self._atlas_enabled = True self._atlas_ban_threshold = ban_threshold self._atlas_modes_threshold = modes_threshold # Ban mask: True = expert is banned (low-impact) self.register_buffer('_atlas_banned', torch.zeros(self.num_experts, dtype=torch.bool, device='cuda')) # Impact scores: accumulated |expert_output| * router_weight self.register_buffer('_atlas_impact', torch.zeros(self.num_experts, dtype=torch.float32, device='cuda')) self._atlas_profile_count = 0 self._atlas_profiling = False # MoDES skip counter (diagnostics) self._atlas_modes_skips = 0 self._atlas_modes_total = 0 print(f" [FE-AGK] Atlas enabled: ban_thresh={ban_threshold}, " f"modes_thresh={modes_threshold}") def atlas_start_profiling(self): """Start profiling expert impact scores for Ban & Pick.""" if not getattr(self, '_atlas_enabled', False): return self._atlas_profiling = True self._atlas_impact.zero_() self._atlas_profile_count = 0 def atlas_finish_profiling(self, ban_ratio: float = 0.25): """Finish profiling and ban low-impact experts. Args: ban_ratio: fraction of experts to ban (0.25 = ban bottom 25%) """ if not getattr(self, '_atlas_enabled', False): return self._atlas_profiling = False if self._atlas_profile_count == 0: return # Normalize impact scores impact = self._atlas_impact / self._atlas_profile_count # Ban bottom N% by impact n_ban = max(1, int(self.num_experts * ban_ratio)) _, ban_idx = impact.topk(n_ban, largest=False) self._atlas_banned.zero_() self._atlas_banned[ban_idx] = True n_banned = self._atlas_banned.sum().item() effective_k = self.num_experts_per_tok # Will be reduced at runtime print(f" [FE-AGK] Ban & Pick: banned {n_banned}/{self.num_experts} experts " f"(bottom {ban_ratio:.0%} by impact)") print(f" [FE-AGK] Impact range: {impact.min():.4f} — {impact.max():.4f}") def atlas_get_stats(self) -> str: """Get Atlas diagnostics string.""" if not getattr(self, '_atlas_enabled', False): return "FE-AGK: disabled" n_banned = self._atlas_banned.sum().item() if hasattr(self, '_atlas_banned') else 0 skip_rate = (self._atlas_modes_skips / max(self._atlas_modes_total, 1)) return (f"FE-AGK: {n_banned} banned experts, " f"MoDES skip rate {skip_rate:.1%} " f"({self._atlas_modes_skips}/{self._atlas_modes_total})") def _atlas_filter_experts(self, topk_ids: torch.Tensor, topk_weights: torch.Tensor): """Ban & Pick: replace banned experts with next-best alternatives. When a top-k expert is banned, we don't just drop it — we replace it with the next unbanned expert from the router scores. This maintains the same number of active experts but routes around low-impact ones. For decode (T=1), if all alternatives are also banned, we just reduce k. """ if not self._atlas_banned.any(): return topk_ids, topk_weights # Check which selected experts are banned banned_mask = self._atlas_banned[topk_ids.long()] # [T, k] bool if not banned_mask.any(): return topk_ids, topk_weights # Simple approach: zero out banned expert weights, renormalize weights_filtered = topk_weights.clone() weights_filtered[banned_mask] = 0.0 # Renormalize remaining weights w_sum = weights_filtered.sum(dim=-1, keepdim=True).clamp(min=1e-8) weights_filtered = weights_filtered / w_sum return topk_ids, weights_filtered def _atlas_check_modes_skip(self, router_scores: torch.Tensor) -> bool: """MoDES: check if token is "easy" enough to skip MoE entirely. Compares max router probability against a RELATIVE threshold: uniform distribution gives max_prob ≈ 1/num_experts (0.0078 for 128 experts). Skip only when max_prob < uniform × modes_threshold_multiplier, meaning the routing is barely better than random. The modes_threshold parameter acts as a multiplier on the uniform baseline: - 2.0 = skip when max_prob < 2/128 = 0.016 (very uncertain, ~20-30% skip) - 3.0 = skip when max_prob < 3/128 = 0.023 (~10-20% skip) - 1.0 = skip when literally uniform (~5% skip) Args: router_scores: raw router logits [T, num_experts] Returns: True if MoE should be skipped for this token """ self._atlas_modes_total += 1 # Normalize to probability space probs = F.softmax(router_scores.float(), dim=-1) max_prob = probs.max().item() # Relative threshold: uniform baseline × multiplier uniform_baseline = 1.0 / self.num_experts # 0.0078 for 128 experts effective_threshold = uniform_baseline * self._atlas_modes_threshold if max_prob < effective_threshold: self._atlas_modes_skips += 1 return True return False def pack_experts(self): """Pack all experts' FP4 weights into contiguous per-layer buffers. After packing: - MoE decode uses goliath_packed_moe_gemm (GPU-resident expert IDs) - Zero .item() calls, zero Python loops for weight collection - CUDA-graph-capturable decode path Memory: no increase — same FP4 data, just reorganized contiguously. Per-expert GoliathFP4Weights freed after copy to save fragmentation. """ if _goliath_packed_moe_gemm is None: return # Goliath not available n = len(self.experts) # Ensure all experts are quantized for exp in self.experts: exp.gate_up_proj._quantize_weights() exp.down_proj._quantize_weights() # Get shapes from first expert's Goliath weights gw0 = self.experts[0].gate_up_proj._goliath_weights dw0 = self.experts[0].down_proj._goliath_weights if gw0 is None or dw0 is None: return # Not Goliath-quantized # Allocate contiguous packed buffers self.packed_gu_w = torch.empty(n, *gw0.packed.shape, dtype=torch.uint8, device='cuda') self.packed_gu_s = torch.empty(n, *gw0.block_scales.shape, dtype=torch.uint8, device='cuda') self.packed_gu_ts = torch.empty(n, dtype=torch.float32, device='cuda') self.packed_dn_w = torch.empty(n, *dw0.packed.shape, dtype=torch.uint8, device='cuda') self.packed_dn_s = torch.empty(n, *dw0.block_scales.shape, dtype=torch.uint8, device='cuda') self.packed_dn_ts = torch.empty(n, dtype=torch.float32, device='cuda') # Copy expert-by-expert into packed buffers, then re-point per-expert # weights at packed buffer slices (zero extra memory — views, not copies) for i, exp in enumerate(self.experts): gw = exp.gate_up_proj._goliath_weights self.packed_gu_w[i].copy_(gw.packed) self.packed_gu_s[i].copy_(gw.block_scales) self.packed_gu_ts[i] = gw.tensor_scale gu_shape = gw.shape dw = exp.down_proj._goliath_weights self.packed_dn_w[i].copy_(dw.packed) self.packed_dn_s[i].copy_(dw.block_scales) self.packed_dn_ts[i] = dw.tensor_scale dn_shape = dw.shape # Re-point per-expert Goliath weights at packed buffer slices # (views into contiguous buffer — zero extra memory) exp.gate_up_proj._goliath_weights = GoliathFP4Weights( packed=self.packed_gu_w[i], block_scales=self.packed_gu_s[i], tensor_scale=self.packed_gu_ts[i].item(), shape=gu_shape, ) exp.down_proj._goliath_weights = GoliathFP4Weights( packed=self.packed_dn_w[i], block_scales=self.packed_dn_s[i], tensor_scale=self.packed_dn_ts[i].item(), shape=dn_shape, ) self._experts_packed = True def enable_int2_cold_experts(self, cold_threshold_pct: float = 0.1): """Enable INT2 quantization for cold experts (rarely routed). Cold experts (bottom cold_threshold_pct by usage) are demoted to INT2, saving 50% bandwidth vs FP4 with minimal quality impact. Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) Requires pack_experts() to have been called first. """ if not self._experts_packed: return n = self.num_experts # Determine cold threshold from usage distribution usage_sorted = torch.sort(self.expert_usage)[0] cold_idx = int(n * cold_threshold_pct) if cold_idx == 0: cold_idx = 1 # At least 1 cold expert self._int2_cold_threshold = usage_sorted[cold_idx].item() # Mark experts as cold (INT2) or hot (FP4) self._expert_is_int2 = (self.expert_usage < self._int2_cold_threshold) # Convert cold experts to INT2 and pack into separate buffers cold_count = self._expert_is_int2.sum().item() if cold_count == 0: self._int2_enabled = False return # Check INT2 kernel availability (imported at module level) if GoliathINT2Weights is None or _goliath_packed_moe_int2_gemm is None: self._int2_enabled = False return # Get shapes from FP4 weights K_gu, N_gu = self.packed_gu_w.shape[1] * 2, self.packed_gu_w.shape[2] # FP4: K//2 K_dn, N_dn = self.packed_dn_w.shape[1] * 2, self.packed_dn_w.shape[2] # Allocate INT2 packed buffers (K//4 instead of K//2) self.packed_gu_w_int2 = torch.zeros(n, K_gu // 4, N_gu, dtype=torch.uint8, device='cuda') self.packed_gu_s_int2 = torch.zeros(n, K_gu // 32, N_gu, dtype=torch.float16, device='cuda') self.packed_dn_w_int2 = torch.zeros(n, K_dn // 4, N_dn, dtype=torch.uint8, device='cuda') self.packed_dn_s_int2 = torch.zeros(n, K_dn // 32, N_dn, dtype=torch.float16, device='cuda') # Convert cold experts: dequant FP4 → requant INT2 for i in range(n): if self._expert_is_int2[i]: # Dequantize FP4 weights back to float gu_fp4 = GoliathFP4Weights( packed=self.packed_gu_w[i], block_scales=self.packed_gu_s[i], tensor_scale=self.packed_gu_ts[i].item(), shape=(K_gu, N_gu)) gu_float = gu_fp4.to_float() dn_fp4 = GoliathFP4Weights( packed=self.packed_dn_w[i], block_scales=self.packed_dn_s[i], tensor_scale=self.packed_dn_ts[i].item(), shape=(K_dn, N_dn)) dn_float = dn_fp4.to_float() # Quantize to INT2 gu_int2 = GoliathINT2Weights.from_float(gu_float) dn_int2 = GoliathINT2Weights.from_float(dn_float) # Store in INT2 packed buffers self.packed_gu_w_int2[i].copy_(gu_int2.packed) self.packed_gu_s_int2[i].copy_(gu_int2.block_scales) self.packed_dn_w_int2[i].copy_(dn_int2.packed) self.packed_dn_s_int2[i].copy_(dn_int2.block_scales) self._int2_enabled = True self._int2_cold_count = cold_count def get_int2_status(self) -> str: """Return INT2 cold expert status string.""" if not getattr(self, '_int2_enabled', False): return "INT2 cold experts: disabled" cold = self._int2_cold_count hot = self.num_experts - cold return f"INT2 cold experts: {cold} INT2 / {hot} FP4" def enable_auto_int2_demotion(self, cold_threshold_pct: float = 0.1): """Enable automatic INT2 demotion for cold experts during tier updates. When enabled, update_expert_tiers() will automatically convert experts marked as cold (tier 0) to INT2 quantization, saving 50% bandwidth. Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) """ self._auto_int2_demotion = True self._auto_int2_threshold_pct = cold_threshold_pct def _maybe_demote_to_int2(self): """Convert cold-tier experts to INT2 if auto-demotion is enabled. Called by update_expert_tiers() when _auto_int2_demotion is True. Only converts experts that are: - Marked as cold tier (tier == 0) - Not already INT2 quantized - Have sufficient usage history - NOT protected (medical, legal, safety, financial domains) """ if not getattr(self, '_auto_int2_demotion', False): return if not self._experts_packed: return if GoliathINT2Weights is None or _goliath_packed_moe_int2_gemm is None: return # Find cold experts — use tier buffer if available, else usage-based bottom-N% if hasattr(self, 'expert_tier'): cold_mask = (self.expert_tier == 0) else: pct = getattr(self, '_auto_int2_threshold_pct', 0.10) n_cold = max(1, int(self.num_experts * pct)) _, coldest = self.expert_usage.topk(n_cold, largest=False) cold_mask = torch.zeros(self.num_experts, dtype=torch.bool, device='cuda') cold_mask[coldest] = True already_int2 = getattr(self, '_expert_is_int2', torch.zeros(self.num_experts, dtype=torch.bool, device='cuda')) # SAFETY: Respect domain protection — never demote protected experts protected = getattr(self, '_protected_experts', torch.zeros(self.num_experts, dtype=torch.bool, device='cuda')) needs_demotion = cold_mask & ~already_int2 & ~protected if not needs_demotion.any(): return # Initialize INT2 buffers if not already done if not getattr(self, '_int2_enabled', False): self._init_int2_buffers() # Convert newly cold experts to INT2 demotion_count = 0 for e_idx in range(self.num_experts): if needs_demotion[e_idx]: self._demote_single_expert_to_int2(e_idx) demotion_count += 1 if demotion_count > 0: self._int2_cold_count = self._expert_is_int2.sum().item() def _init_int2_buffers(self): """Initialize INT2 packed buffers for cold expert storage. Uses compact allocation (like FE-XC) — only allocates for expected cold count + margin, not all 128 experts. Saves ~160 MB VRAM. Uses slot mapping to translate expert IDs to buffer indices. """ n = self.num_experts # Release fragmented PyTorch cache before allocation torch.cuda.empty_cache() # FP4 packed shape: [E, K//2, N] (2 values per byte) # INT2 packed shape: [E, K//4, N] (4 values per byte) # INT2 scales shape: [E, K//32, N] (32-element groups, FP16) K_gu = self.packed_gu_w.shape[1] * 2 # Original K (FP4: K//2) N_gu = self.packed_gu_w.shape[2] K_dn = self.packed_dn_w.shape[1] * 2 N_dn = self.packed_dn_w.shape[2] # Only allocate slots for expected cold count + margin pct = getattr(self, '_auto_int2_threshold_pct', 0.10) max_int2 = min(n, int(n * pct) + 4) self._int2_max_slots = max_int2 self._int2_next_slot = 0 # Allocate compact INT2 buffers self.packed_gu_w_int2 = torch.zeros(max_int2, K_gu // 4, N_gu, dtype=torch.uint8, device='cuda') self.packed_gu_s_int2 = torch.zeros(max_int2, K_gu // 32, N_gu, dtype=torch.float16, device='cuda') self.packed_dn_w_int2 = torch.zeros(max_int2, K_dn // 4, N_dn, dtype=torch.uint8, device='cuda') self.packed_dn_s_int2 = torch.zeros(max_int2, K_dn // 32, N_dn, dtype=torch.float16, device='cuda') # Slot mapping: expert_id → buffer slot index (-1 = not INT2) self._int2_slot_map = torch.full((n,), -1, dtype=torch.int64, device='cuda') # Track which experts are INT2 self._expert_is_int2 = torch.zeros(n, dtype=torch.bool, device='cuda') self._int2_enabled = True self._int2_cold_count = 0 def _demote_single_expert_to_int2(self, e_idx: int): """Convert a single expert from FP4 to INT2.""" if self._expert_is_int2[e_idx]: return # Already INT2 # FP4 packed shape: [K//2, N] → original float shape: [K, N] K_gu = self.packed_gu_w.shape[1] * 2 N_gu = self.packed_gu_w.shape[2] K_dn = self.packed_dn_w.shape[1] * 2 N_dn = self.packed_dn_w.shape[2] # Dequantize FP4 → float gu_fp4 = GoliathFP4Weights( packed=self.packed_gu_w[e_idx], block_scales=self.packed_gu_s[e_idx], tensor_scale=self.packed_gu_ts[e_idx].item(), shape=(K_gu, N_gu), ) dn_fp4 = GoliathFP4Weights( packed=self.packed_dn_w[e_idx], block_scales=self.packed_dn_s[e_idx], tensor_scale=self.packed_dn_ts[e_idx].item(), shape=(K_dn, N_dn), ) # Requantize float → INT2 gu_int2 = GoliathINT2Weights.from_float(gu_fp4.to_float()) dn_int2 = GoliathINT2Weights.from_float(dn_fp4.to_float()) # Allocate a slot in the compact INT2 buffer slot = self._int2_next_slot if slot >= self._int2_max_slots: return # Buffer full, skip this expert self._int2_next_slot += 1 self._int2_slot_map[e_idx] = slot # Store in INT2 packed buffers (slot-indexed, not expert-indexed) self.packed_gu_w_int2[slot].copy_(gu_int2.packed) self.packed_gu_s_int2[slot].copy_(gu_int2.block_scales) self.packed_dn_w_int2[slot].copy_(dn_int2.packed) self.packed_dn_s_int2[slot].copy_(dn_int2.block_scales) self._expert_is_int2[e_idx] = True # Update per-expert module weights with INT2 reconstruction so that # prefill path (which uses per-expert modules, not packed INT2 buffers) # correctly reflects INT2-level quality degradation. # NOTE: After pack_experts(), per-expert weight.data may be empty (zero-size) # because weights were re-pointed to packed buffer slices. Only update if # the expert module still has valid weight tensors. expert = self.experts[e_idx] gu_weight = expert.gate_up_proj.weight.data if gu_weight.numel() > 0: gu_recon = gu_int2.to_float() # [K=in, N=out] Goliath convention dn_recon = dn_int2.to_float() # [K=in, N=out] gu_weight.copy_(gu_recon.T.to(gu_weight.dtype)) expert.gate_up_proj.invalidate_cache() expert.gate_up_proj._quantize_weights() expert.down_proj.weight.data.copy_(dn_recon.T.to(expert.down_proj.weight.dtype)) expert.down_proj.invalidate_cache() expert.down_proj._quantize_weights() # Also update packed FP4 buffers for grouped dispatch path self.packed_gu_w[e_idx].copy_(expert.gate_up_proj._goliath_weights.packed) self.packed_gu_s[e_idx].copy_(expert.gate_up_proj._goliath_weights.block_scales) self.packed_gu_ts[e_idx] = expert.gate_up_proj._goliath_weights.tensor_scale self.packed_dn_w[e_idx].copy_(expert.down_proj._goliath_weights.packed) self.packed_dn_s[e_idx].copy_(expert.down_proj._goliath_weights.block_scales) self.packed_dn_ts[e_idx] = expert.down_proj._goliath_weights.tensor_scale # --- FE-XC (FireEcho Xtreme Compress) methods --- def enable_auto_fexc_demotion(self, cold_threshold_pct: float = 0.10): """Enable automatic FE-XC demotion for cold experts. Like INT2 demotion but uses codebook-based 2-bit (FE-XC) which achieves much better quality through learned codebooks. """ self._auto_fexc_demotion = True self._auto_fexc_threshold_pct = cold_threshold_pct # ----- FE-XVQ (Hessian-weighted codebook 2-bit) ----- def enable_auto_fexvq_demotion(self, cold_threshold_pct: float = 0.10): """Enable FE-XVQ demotion for cold experts (Hessian-weighted codebooks). Like FE-XC but uses second-order information (input covariance) to produce better codebooks. Requires Hessian collection via collect_hessian_diag() before demotion. """ self._auto_fexvq_demotion = True self._auto_fexvq_threshold_pct = cold_threshold_pct # Initialize Hessian accumulators for this layer's MoE input dim = self.experts[0].gate_up_proj.in_features self._hessian_diag_gu = torch.zeros(dim, dtype=torch.float32, device='cuda') self._hessian_samples_gu = 0 inter = self.experts[0].intermediate_size self._hessian_diag_dn = torch.zeros(inter, dtype=torch.float32, device='cuda') self._hessian_samples_dn = 0 self._hessian_collecting = True def accumulate_hessian(self, x_flat: torch.Tensor, h_flat: Optional[torch.Tensor] = None): """Accumulate Hessian diagonal from MoE input activations. Called during calibration forward passes. Accumulates E[x^2] which is the diagonal of H = X^T X (input covariance / Fisher diagonal). Args: x_flat: [T, D] input activations to MoE layer h_flat: [T, inter] intermediate activations after SwiGLU (for down_proj) """ if not getattr(self, '_hessian_collecting', False): return # gate_up Hessian: diagonal of X^T X ≈ E[x^2] self._hessian_diag_gu += (x_flat.float() ** 2).sum(dim=0) self._hessian_samples_gu += x_flat.shape[0] # down Hessian (if provided) if h_flat is not None: self._hessian_diag_dn += (h_flat.float() ** 2).sum(dim=0) self._hessian_samples_dn += h_flat.shape[0] def get_hessian_diag(self): """Return normalized Hessian diagonals for gate_up and down projections.""" h_gu = None h_dn = None if self._hessian_samples_gu > 0: h_gu = self._hessian_diag_gu / self._hessian_samples_gu if self._hessian_samples_dn > 0: h_dn = self._hessian_diag_dn / self._hessian_samples_dn return h_gu, h_dn def _learn_layer_codebooks_fexvq(self): """Learn Hessian-weighted codebooks for this layer (FE-XVQ).""" if getattr(self, 'gu_codebooks_fexvq', None) is not None: return # Already learned goliath_K_gu = self.packed_gu_w.shape[1] * 2 goliath_N_gu = self.packed_gu_w.shape[2] goliath_K_dn = self.packed_dn_w.shape[1] * 2 goliath_N_dn = self.packed_dn_w.shape[2] perm = torch.randperm(self.num_experts)[:1] h_gu, h_dn = self.get_hessian_diag() # gate_up codebooks (Hessian-weighted) # NOTE: Goliath stores transposed [K=in, N=out]. FE-XVQ expects [K=out, N=in]. # After transpose, N_fexvq = goliath_K_gu (in_features). # Hessian is on in_features, which becomes N in FE-XVQ convention. gu_ref = GoliathFEXVQWeights.from_float( GoliathFP4Weights( packed=self.packed_gu_w[perm[0]], block_scales=self.packed_gu_s[perm[0]], tensor_scale=self.packed_gu_ts[perm[0]].item(), shape=(goliath_K_gu, goliath_N_gu), ).to_float().T.contiguous().cpu(), hessian_diag=h_gu.cpu() if h_gu is not None else None, n_iters=15) self.gu_codebooks_fexvq = gu_ref.codebooks.cuda() # down codebooks (Hessian-weighted) dn_ref = GoliathFEXVQWeights.from_float( GoliathFP4Weights( packed=self.packed_dn_w[perm[0]], block_scales=self.packed_dn_s[perm[0]], tensor_scale=self.packed_dn_ts[perm[0]].item(), shape=(goliath_K_dn, goliath_N_dn), ).to_float().T.contiguous().cpu(), hessian_diag=h_dn.cpu() if h_dn is not None else None, n_iters=15) self.dn_codebooks_fexvq = dn_ref.codebooks.cuda() def _demote_single_expert_to_fexvq(self, e_idx: int): """Convert a single expert from FP4 to FE-XVQ (Hessian-weighted codebook).""" if self._expert_is_fexc[e_idx]: return # Already FE-XC/FE-XVQ (shared buffer) goliath_K_gu = self.packed_gu_w.shape[1] * 2 goliath_N_gu = self.packed_gu_w.shape[2] goliath_K_dn = self.packed_dn_w.shape[1] * 2 goliath_N_dn = self.packed_dn_w.shape[2] self._learn_layer_codebooks_fexvq() gu_fp4 = GoliathFP4Weights( packed=self.packed_gu_w[e_idx], block_scales=self.packed_gu_s[e_idx], tensor_scale=self.packed_gu_ts[e_idx].item(), shape=(goliath_K_gu, goliath_N_gu), ) dn_fp4 = GoliathFP4Weights( packed=self.packed_dn_w[e_idx], block_scales=self.packed_dn_s[e_idx], tensor_scale=self.packed_dn_ts[e_idx].item(), shape=(goliath_K_dn, goliath_N_dn), ) h_gu, h_dn = self.get_hessian_diag() # Transpose + Hessian-weighted assignment on CPU gu_fexvq = GoliathFEXVQWeights.from_float( gu_fp4.to_float().T.contiguous().cpu(), hessian_diag=h_gu.cpu() if h_gu is not None else None, codebooks=self.gu_codebooks_fexvq.cpu()) dn_fexvq = GoliathFEXVQWeights.from_float( dn_fp4.to_float().T.contiguous().cpu(), hessian_diag=h_dn.cpu() if h_dn is not None else None, codebooks=self.dn_codebooks_fexvq.cpu()) # Store in FE-XC packed buffers (FE-XVQ shares same format) slot = self._fexc_next_slot if slot >= self._fexc_max_slots: return self._fexc_next_slot += 1 self._fexc_slot_map[e_idx] = slot self.packed_gu_codes[slot].copy_(gu_fexvq.codes) self.packed_gu_scales_fexc[slot].copy_(gu_fexvq.scales) self.packed_dn_codes[slot].copy_(dn_fexvq.codes) self.packed_dn_scales_fexc[slot].copy_(dn_fexvq.scales) self._expert_is_fexc[e_idx] = True # Shared with FE-XC (same dispatch) def _maybe_demote_to_fexvq(self): """Convert cold experts to FE-XVQ if auto-demotion is enabled.""" if not getattr(self, '_auto_fexvq_demotion', False): return if not self._experts_packed: return # Find cold experts by usage pct = getattr(self, '_auto_fexvq_threshold_pct', 0.10) n_cold = max(1, int(self.num_experts * pct)) _, coldest = self.expert_usage.topk(n_cold, largest=False) needs_demotion = torch.zeros(self.num_experts, dtype=torch.bool, device='cuda') needs_demotion[coldest] = True # Skip already demoted already = getattr(self, '_expert_is_fexc', torch.zeros(self.num_experts, dtype=torch.bool)) needs_demotion &= ~already if not needs_demotion.any(): return # Ensure FE-XC buffers exist (FE-XVQ shares them) if not getattr(self, '_fexc_enabled', False): self._init_fexc_buffers() # Use FE-XVQ codebooks if Hessian available, else fall back to FE-XC if self.gu_codebooks is None and getattr(self, 'gu_codebooks_fexvq', None) is None: self._learn_layer_codebooks_fexvq() # Also store FE-XC codebooks for dispatch (shared with FE-XVQ) if self.gu_codebooks is None: self.gu_codebooks = self.gu_codebooks_fexvq self.dn_codebooks = self.dn_codebooks_fexvq count = 0 for e_idx in range(self.num_experts): if needs_demotion[e_idx]: self._demote_single_expert_to_fexvq(e_idx) count += 1 if count > 0: self._fexc_cold_count = self._expert_is_fexc.sum().item() def _init_fexc_buffers(self): """Initialize FE-XC packed buffers for codebook-quantized expert storage. Only allocates for the expected cold expert count (not all 128) to save ~6 GB VRAM. Uses a slot mapping tensor to translate expert IDs to buffer indices during forward. NOTE: Goliath FP4 uses transposed convention: K=in_features, N=out_features. FE-XC uses standard convention: K=out_features, N=in_features. So FE-XC buffers swap the dimensions. """ n = self.num_experts # Goliath FP4: K=in_features, N=out_features (transposed) goliath_K_gu = self.packed_gu_w.shape[1] * 2 # in_features goliath_N_gu = self.packed_gu_w.shape[2] # out_features goliath_K_dn = self.packed_dn_w.shape[1] * 2 goliath_N_dn = self.packed_dn_w.shape[2] # FE-XC: K=out_features, N=in_features (standard) fexc_K_gu = goliath_N_gu # out_features for gate_up fexc_N_gu = goliath_K_gu # in_features for gate_up fexc_K_dn = goliath_N_dn fexc_N_dn = goliath_K_dn g = 8 # FE-XC group size # Only allocate slots for expected cold count + margin pct = getattr(self, '_auto_fexc_threshold_pct', 0.10) max_fexc = min(n, int(n * pct) + 4) # e.g., 16 for 10% of 128 self._fexc_max_slots = max_fexc self._fexc_next_slot = 0 # Codes: [max_fexc, K_out, N_in//g, 2] uint8 (FE-XC convention) self.packed_gu_codes = torch.zeros(max_fexc, fexc_K_gu, fexc_N_gu // g, 2, dtype=torch.uint8, device='cuda') self.packed_dn_codes = torch.zeros(max_fexc, fexc_K_dn, fexc_N_dn // g, 2, dtype=torch.uint8, device='cuda') # Scales: [max_fexc, K_out] float16 self.packed_gu_scales_fexc = torch.zeros(max_fexc, fexc_K_gu, dtype=torch.float16, device='cuda') self.packed_dn_scales_fexc = torch.zeros(max_fexc, fexc_K_dn, dtype=torch.float16, device='cuda') # Slot mapping: expert_id → buffer slot index (-1 = not FE-XC) self._fexc_slot_map = torch.full((n,), -1, dtype=torch.int64, device='cuda') # Codebooks: [2, 256, 8] float16 — shared per projection type self.gu_codebooks = None # Learned on first demotion self.dn_codebooks = None # Tracking self._expert_is_fexc = torch.zeros(n, dtype=torch.bool, device='cuda') self._fexc_enabled = True self._fexc_cold_count = 0 def _learn_layer_codebooks(self): """Learn shared codebooks from ALL experts in this layer (once).""" if self.gu_codebooks is not None: return # Already learned K_gu = self.packed_gu_w.shape[1] * 2 N_gu = self.packed_gu_w.shape[2] K_dn = self.packed_dn_w.shape[1] * 2 N_dn = self.packed_dn_w.shape[2] g = 8 # Sample weight groups from a few experts (not all 128 — too slow) sample_experts = min(8, self.num_experts) perm = torch.randperm(self.num_experts)[:sample_experts] gu_groups = [] dn_groups = [] for e_idx in perm: gu_fp4 = GoliathFP4Weights( packed=self.packed_gu_w[e_idx], block_scales=self.packed_gu_s[e_idx], tensor_scale=self.packed_gu_ts[e_idx].item(), shape=(K_gu, N_gu), ) dn_fp4 = GoliathFP4Weights( packed=self.packed_dn_w[e_idx], block_scales=self.packed_dn_s[e_idx], tensor_scale=self.packed_dn_ts[e_idx].item(), shape=(K_dn, N_dn), ) gu_groups.append(gu_fp4.to_float().view(-1, g)) dn_groups.append(dn_fp4.to_float().view(-1, g)) # Learn codebooks on CPU (torch.cdist in from_float needs ~400MB temp, # which may not fit in GPU VRAM alongside the model) # NOTE: Goliath FP4 stores weights transposed [K=in, N=out]. # FE-XC from_float expects [K=out, N=in], so we transpose (.T). gu_ref = GoliathFEXCWeights.from_float( GoliathFP4Weights( packed=self.packed_gu_w[perm[0]], block_scales=self.packed_gu_s[perm[0]], tensor_scale=self.packed_gu_ts[perm[0]].item(), shape=(K_gu, N_gu), ).to_float().T.contiguous().cpu(), n_iters=15) self.gu_codebooks = gu_ref.codebooks.cuda() dn_ref = GoliathFEXCWeights.from_float( GoliathFP4Weights( packed=self.packed_dn_w[perm[0]], block_scales=self.packed_dn_s[perm[0]], tensor_scale=self.packed_dn_ts[perm[0]].item(), shape=(K_dn, N_dn), ).to_float().T.contiguous().cpu(), n_iters=15) self.dn_codebooks = dn_ref.codebooks.cuda() def _demote_single_expert_to_fexc(self, e_idx: int): """Convert a single expert from FP4 to FE-XC codebook 2-bit.""" if self._expert_is_fexc[e_idx]: return K_gu = self.packed_gu_w.shape[1] * 2 N_gu = self.packed_gu_w.shape[2] K_dn = self.packed_dn_w.shape[1] * 2 N_dn = self.packed_dn_w.shape[2] # Ensure codebooks are learned self._learn_layer_codebooks() # Dequantize FP4 → float → FE-XC gu_fp4 = GoliathFP4Weights( packed=self.packed_gu_w[e_idx], block_scales=self.packed_gu_s[e_idx], tensor_scale=self.packed_gu_ts[e_idx].item(), shape=(K_gu, N_gu), ) dn_fp4 = GoliathFP4Weights( packed=self.packed_dn_w[e_idx], block_scales=self.packed_dn_s[e_idx], tensor_scale=self.packed_dn_ts[e_idx].item(), shape=(K_dn, N_dn), ) # Run k-means assignment on CPU (GPU VRAM too tight for torch.cdist) # NOTE: Transpose .T to convert from Goliath [K=in, N=out] to FE-XC [K=out, N=in] gu_fexc = GoliathFEXCWeights.from_float( gu_fp4.to_float().T.contiguous().cpu(), codebooks=self.gu_codebooks.cpu()) dn_fexc = GoliathFEXCWeights.from_float( dn_fp4.to_float().T.contiguous().cpu(), codebooks=self.dn_codebooks.cpu()) # Allocate a slot in the compact FE-XC buffer slot = self._fexc_next_slot if slot >= self._fexc_max_slots: return # Buffer full, skip this expert self._fexc_next_slot += 1 self._fexc_slot_map[e_idx] = slot # Store in FE-XC packed buffers (slot-indexed, not expert-indexed) self.packed_gu_codes[slot].copy_(gu_fexc.codes) self.packed_gu_scales_fexc[slot].copy_(gu_fexc.scales) self.packed_dn_codes[slot].copy_(dn_fexc.codes) self.packed_dn_scales_fexc[slot].copy_(dn_fexc.scales) self._expert_is_fexc[e_idx] = True # Update per-expert module weights with FE-XC reconstruction so that # prefill path (which uses per-expert modules, not packed FE-XC buffers) # correctly reflects FE-XC-level quality degradation. # NOTE: After pack_experts(), per-expert weight.data may be empty (zero-size) # because weights were re-pointed to packed buffer slices. Only update if # the expert module still has valid weight tensors. expert = self.experts[e_idx] gu_weight = expert.gate_up_proj.weight.data if gu_weight.numel() > 0: gu_recon = gu_fexc.to_float() # [K=out, N=in] dn_recon = dn_fexc.to_float() gu_weight.copy_(gu_recon.cuda().to(gu_weight.dtype)) expert.gate_up_proj.invalidate_cache() expert.gate_up_proj._quantize_weights() expert.down_proj.weight.data.copy_(dn_recon.cuda().to(expert.down_proj.weight.dtype)) expert.down_proj.invalidate_cache() expert.down_proj._quantize_weights() # Also update packed FP4 buffers for grouped dispatch path self.packed_gu_w[e_idx].copy_(expert.gate_up_proj._goliath_weights.packed) self.packed_gu_s[e_idx].copy_(expert.gate_up_proj._goliath_weights.block_scales) self.packed_gu_ts[e_idx] = expert.gate_up_proj._goliath_weights.tensor_scale self.packed_dn_w[e_idx].copy_(expert.down_proj._goliath_weights.packed) self.packed_dn_s[e_idx].copy_(expert.down_proj._goliath_weights.block_scales) self.packed_dn_ts[e_idx] = expert.down_proj._goliath_weights.tensor_scale def _maybe_demote_to_fexc(self): """Convert cold-tier experts to FE-XC if auto-demotion is enabled.""" if not getattr(self, '_auto_fexc_demotion', False): return if not self._experts_packed: return if GoliathFEXCWeights is None or _goliath_packed_moe_fexc_gemm is None: return # Determine cold experts: use tier buffer if available, else use # usage-based bottom-N% as cold (handles fresh model with no usage) if hasattr(self, 'expert_tier'): cold_mask = (self.expert_tier == 0) else: # No tier info — pick bottom N% by usage as cold pct = getattr(self, '_auto_fexc_threshold_pct', 0.10) n_cold = max(1, int(self.num_experts * pct)) _, coldest = self.expert_usage.topk(n_cold, largest=False) cold_mask = torch.zeros(self.num_experts, dtype=torch.bool, device='cuda') cold_mask[coldest] = True already_fexc = getattr(self, '_expert_is_fexc', torch.zeros(self.num_experts, dtype=torch.bool, device='cuda')) protected = getattr(self, '_protected_experts', torch.zeros(self.num_experts, dtype=torch.bool, device='cuda')) needs_demotion = cold_mask & ~already_fexc & ~protected if not needs_demotion.any(): return if not getattr(self, '_fexc_enabled', False): self._init_fexc_buffers() demotion_count = 0 for e_idx in range(self.num_experts): if needs_demotion[e_idx]: self._demote_single_expert_to_fexc(e_idx) demotion_count += 1 if demotion_count > 0: self._fexc_cold_count = self._expert_is_fexc.sum().item() def forward(self, x: torch.Tensor) -> torch.Tensor: orig_shape = x.shape x_flat = x.view(-1, x.shape[-1]) # [T, D] T, D = x_flat.shape k = self.num_experts_per_tok topk_weights, topk_ids = self.router(x_flat) # [T, k], [T, k] # FE-XVQ Hessian collection (gate_up input = x_flat) if getattr(self, '_hessian_collecting', False): self.accumulate_hessian(x_flat.detach()) # FE-AGK: Atlas the Gatekeeper (GPU-fast path) if getattr(self, '_atlas_enabled', False): # Ban & Pick: mask banned expert weights in-place, renormalize # No redundant matmul — uses pre-computed topk from router if getattr(self, '_atlas_has_bans', False): banned_mask = self._atlas_banned[topk_ids] # [T, k] topk_weights = topk_weights.masked_fill(banned_mask, 0.0) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True).clamp(min=1e-8) # MoDES: skip MoE for truly uncertain tokens (decode only) # Uses top-1 normalized weight — no redundant F.linear, no softmax # SKIP during CUDA graph capture (.item() not allowed) if (T == 1 and self._atlas_modes_threshold > 0 and not getattr(self, '_graph_capturing', False)): self._atlas_modes_total += 1 # top-1 weight from router. For k=8: uniform = 0.125 uniform_k = 1.0 / self.num_experts_per_tok if topk_weights[0, 0].item() < uniform_k * self._atlas_modes_threshold: self._atlas_modes_skips += 1 return torch.zeros_like(x).view(orig_shape) # Profile: accumulate impact scores during profiling if self._atlas_profiling: self._atlas_profile_count += 1 output = torch.zeros_like(x_flat) if T == 1 and self._experts_packed and _goliath_packed_moe_gemm is not None: # === Packed MoE decode: zero .item(), CUDA-graph-safe === # Expert IDs stay on GPU — no CPU-GPU sync, no Python loops. expert_ids = topk_ids[0] # [k] — GPU tensor, NOT .item()! expert_w = topk_weights[0] # [k] inter = self.experts[0].intermediate_size # Check for mixed FP4/INT2/FE-XC dispatch # SKIP during CUDA Graph capture — .any().item() is illegal _capturing = getattr(self, '_graph_capturing', False) _use_int2 = (getattr(self, '_int2_enabled', False) and _goliath_packed_moe_int2_gemm is not None and not _capturing) _use_fexc = (getattr(self, '_fexc_enabled', False) and _goliath_packed_moe_fexc_gemm is not None and not _capturing) if _use_fexc: is_fexc = self._expert_is_fexc[expert_ids.long()] # [k] bool has_fexc = is_fexc.any().item() else: has_fexc = False if _use_int2: is_int2 = self._expert_is_int2[expert_ids.long()] # [k] bool has_int2 = is_int2.any().item() else: has_int2 = False if not has_int2 and not has_fexc: # === All FP4 path (common case, CUDA-graph-safe) === gate_up_out = _goliath_packed_moe_gemm( x_flat, self.packed_gu_w, self.packed_gu_s, self.packed_gu_ts, expert_ids, num_active=k) # [k, 1, 2*inter] if _goliath_packed_moe_swiglu_down is not None: down_out = _goliath_packed_moe_swiglu_down( gate_up_out, self.packed_dn_w, self.packed_dn_s, self.packed_dn_ts, expert_ids, intermediate_size=inter, num_active=k) else: gate = gate_up_out[:, :, :inter] up = gate_up_out[:, :, inter:] hidden = F.silu(gate) * up hidden_stacked = hidden.reshape(k, inter) down_out = _goliath_packed_moe_gemm( hidden_stacked, self.packed_dn_w, self.packed_dn_s, self.packed_dn_ts, expert_ids, num_active=k, per_expert_input=True) output[0] = (expert_w[:, None, None] * down_out).sum(0).squeeze(0) # FE-AGK: profile expert impact (|output| * weight) if getattr(self, '_atlas_profiling', False): with torch.no_grad(): impact = (down_out.squeeze(1).norm(dim=-1) * expert_w).float() self._atlas_impact.scatter_add_(0, expert_ids.long(), impact) else: # === Mixed FP4/FE-XC/INT2 dispatch — split by precision === # Determine FP4 mask (not FE-XC and not INT2) _is_fexc = is_fexc if has_fexc else torch.zeros(k, dtype=torch.bool, device=x.device) _is_int2 = is_int2 if has_int2 else torch.zeros(k, dtype=torch.bool, device=x.device) fp4_mask = ~_is_fexc & ~_is_int2 has_fp4 = fp4_mask.any().item() if has_fp4: fp4_idx = fp4_mask.nonzero(as_tuple=True)[0] fp4_ids = expert_ids[fp4_idx] fp4_w = expert_w[fp4_idx] n_fp4 = fp4_idx.shape[0] gate_up_fp4 = _goliath_packed_moe_gemm( x_flat, self.packed_gu_w, self.packed_gu_s, self.packed_gu_ts, fp4_ids, num_active=n_fp4) if _goliath_packed_moe_swiglu_down is not None: down_fp4 = _goliath_packed_moe_swiglu_down( gate_up_fp4, self.packed_dn_w, self.packed_dn_s, self.packed_dn_ts, fp4_ids, intermediate_size=inter, num_active=n_fp4) else: g = gate_up_fp4[:, :, :inter] u = gate_up_fp4[:, :, inter:] h = F.silu(g) * u down_fp4 = _goliath_packed_moe_gemm( h.reshape(n_fp4, inter), self.packed_dn_w, self.packed_dn_s, self.packed_dn_ts, fp4_ids, num_active=n_fp4, per_expert_input=True) output[0] += (fp4_w[:, None, None] * down_fp4).sum(0).squeeze(0) # FE-XC codebook 2-bit experts (higher quality than INT2) if has_fexc: fexc_idx = _is_fexc.nonzero(as_tuple=True)[0] fexc_ids = expert_ids[fexc_idx] fexc_w = expert_w[fexc_idx] n_fexc = fexc_idx.shape[0] # Remap expert IDs to compact buffer slot indices fexc_slots = self._fexc_slot_map[fexc_ids] # Precompute psumbook once for both gate_up and down gu_psumbook = _fexc_precompute_psumbook(self.gu_codebooks, x_flat[0]) gate_up_fexc = _goliath_packed_moe_fexc_gemm( x_flat, self.packed_gu_codes, self.gu_codebooks, self.packed_gu_scales_fexc, fexc_slots, psumbook=gu_psumbook, num_active=n_fexc) gf = gate_up_fexc[:, :, :inter] uf = gate_up_fexc[:, :, inter:] hf = F.silu(gf) * uf dn_psumbook = _fexc_precompute_psumbook(self.dn_codebooks, hf.reshape(n_fexc, inter)[0]) down_fexc = _goliath_packed_moe_fexc_gemm( hf.reshape(n_fexc, inter), self.packed_dn_codes, self.dn_codebooks, self.packed_dn_scales_fexc, fexc_slots, psumbook=dn_psumbook, num_active=n_fexc) output[0] += (fexc_w[:, None, None] * down_fexc).sum(0).squeeze(0) # INT2 cold experts (legacy, lower quality) if has_int2: int2_idx = _is_int2.nonzero(as_tuple=True)[0] int2_ids = expert_ids[int2_idx] int2_w = expert_w[int2_idx] n_int2 = int2_idx.shape[0] # Remap expert IDs to compact buffer slot indices int2_slots = self._int2_slot_map[int2_ids] gate_up_int2 = _goliath_packed_moe_int2_gemm( x_flat, self.packed_gu_w_int2, self.packed_gu_s_int2, int2_slots, num_active=n_int2) g2 = gate_up_int2[:, :, :inter] u2 = gate_up_int2[:, :, inter:] h2 = F.silu(g2) * u2 down_int2 = _goliath_packed_moe_int2_gemm( h2.reshape(n_int2, inter), self.packed_dn_w_int2, self.packed_dn_s_int2, int2_slots, num_active=n_int2, per_expert_input=True) output[0] += (int2_w[:, None, None] * down_int2).sum(0).squeeze(0) # Track usage (scatter_add_ keeps it on GPU) if not self.training: self.expert_usage.scatter_add_( 0, expert_ids.long(), torch.ones(k, device=x.device, dtype=self.expert_usage.dtype)) elif 1 < T <= 128 and self._experts_packed and _goliath_packed_moe_gemm is not None: # === FE-XT: Batched packed MoE for tree verification (M=2-128) === # Groups tokens by expert, pads to common M, single kernel launch. # Eliminates .item() calls and per-expert Python module overhead. inter = self.experts[0].intermediate_size # 1. Flatten all T*k token-expert pairs flat_token_idx = torch.arange(T, device=x.device).unsqueeze(1).expand(-1, k).reshape(-1) flat_expert_idx = topk_ids.reshape(-1) # [T*k] flat_weights = topk_weights.reshape(-1) # [T*k] # 2. Sort by expert sort_order = flat_expert_idx.argsort(stable=True) sorted_expert = flat_expert_idx[sort_order] sorted_token = flat_token_idx[sort_order] sorted_weight = flat_weights[sort_order] # 3. Find unique experts + counts unique_experts, counts = torch.unique_consecutive( sorted_expert, return_counts=True) num_unique = unique_experts.shape[0] max_count = counts.max().item() offsets = torch.zeros(num_unique + 1, dtype=torch.long, device=x.device) offsets[1:] = counts.cumsum(0) # 4. Pad token inputs per expert to [num_unique, max_count, D] padded_input = torch.zeros(num_unique * max_count, D, dtype=x_flat.dtype, device=x.device) # Fill with gathered tokens per expert for i in range(num_unique): start = offsets[i].item() end = offsets[i + 1].item() n = end - start token_indices = sorted_token[start:end] base = i * max_count padded_input[base:base + n] = x_flat[token_indices] # 5. ONE packed MoE kernel: gate_up projection gate_up_out = _goliath_packed_moe_gemm( padded_input, self.packed_gu_w, self.packed_gu_s, self.packed_gu_ts, unique_experts, num_active=num_unique, per_expert_input=True) # [num_unique, max_count, 2*inter] # 6. SwiGLU activation gate = gate_up_out[:, :, :inter] up_val = gate_up_out[:, :, inter:] hidden_act = F.silu(gate) * up_val # [num_unique, max_count, inter] # 7. ONE packed MoE kernel: down projection down_out = _goliath_packed_moe_gemm( hidden_act.reshape(num_unique * max_count, inter), self.packed_dn_w, self.packed_dn_s, self.packed_dn_ts, unique_experts, num_active=num_unique, per_expert_input=True) # [num_unique, max_count, D] # 8. Scatter weighted results back to token positions for i in range(num_unique): start = offsets[i].item() end = offsets[i + 1].item() n = end - start token_indices = sorted_token[start:end] weights = sorted_weight[start:end] expert_result = down_out[i, :n, :] # [n, D] output.index_add_( 0, token_indices, weights.unsqueeze(-1) * expert_result) # Track usage if not self.training: self.expert_usage.scatter_add_( 0, flat_expert_idx.long(), torch.ones(T * k, device=x.device, dtype=self.expert_usage.dtype)) elif T == 1 and _goliath_multi_expert_gemm is not None: # === Fallback: unpacked multi-expert decode (8 separate pointers) === expert_ids = [topk_ids[0, i].item() for i in range(k)] expert_w = topk_weights[0] # [k] all_goliath = all( self.experts[eid].gate_up_proj._goliath_weights is not None and self.experts[eid].down_proj._goliath_weights is not None for eid in expert_ids) if all_goliath: for eid in expert_ids: self.experts[eid].gate_up_proj._quantize_weights() self.experts[eid].down_proj._quantize_weights() gate_up_weights = [ self.experts[eid].gate_up_proj._goliath_weights for eid in expert_ids] gate_up_out = _goliath_multi_expert_gemm( x_flat, gate_up_weights, num_experts=k) inter = self.experts[0].intermediate_size gate = gate_up_out[:, :, :inter] up = gate_up_out[:, :, inter:] hidden = F.silu(gate) * up down_weights = [ self.experts[eid].down_proj._goliath_weights for eid in expert_ids] hidden_stacked = hidden.reshape(k, inter) down_out = _goliath_multi_expert_gemm( hidden_stacked, down_weights, num_experts=k, per_expert_input=True) output[0] = (expert_w[:, None, None] * down_out).sum(0).squeeze(0) if not self.training: for eid in expert_ids: self.expert_usage[eid] += 1 else: for i in range(k): expert_idx = expert_ids[i] w = expert_w[i] expert_out = self.experts[expert_idx](x_flat) output[0] += w * expert_out.squeeze(0) if not self.training: self.expert_usage[expert_idx] += 1 elif T <= 2: # === Single-token decode fast path (no Goliath fusion) === for t in range(T): for i in range(k): expert_idx = topk_ids[t, i].item() w = topk_weights[t, i] expert_out = self.experts[expert_idx](x_flat[t:t+1]) output[t] += w * expert_out.squeeze(0) if not self.training: self.expert_usage[expert_idx] += 1 else: # === Grouped expert dispatch (sorted tokens, index_add) === # For batch/prefill: sort tokens by expert, loop active experts only. flat_token_idx = torch.arange(T, device=x.device).unsqueeze(1).expand(-1, k).reshape(-1) flat_expert_idx = topk_ids.reshape(-1) # [T*k] flat_weights = topk_weights.reshape(-1) # [T*k] sort_order = flat_expert_idx.argsort(stable=True) sorted_expert = flat_expert_idx[sort_order] sorted_token = flat_token_idx[sort_order] sorted_weight = flat_weights[sort_order] unique_experts, counts = torch.unique_consecutive( sorted_expert, return_counts=True) offsets = torch.zeros(len(counts) + 1, dtype=torch.long, device=x.device) offsets[1:] = counts.cumsum(0) for i in range(len(unique_experts)): expert_idx = unique_experts[i].item() start = offsets[i].item() end = offsets[i + 1].item() token_indices = sorted_token[start:end] expert_weights = sorted_weight[start:end] selected = x_flat[token_indices] expert_out = self.experts[expert_idx](selected) output.index_add_( 0, token_indices, expert_weights.unsqueeze(-1) * expert_out) if not self.training: self.expert_usage[expert_idx] += (end - start) # Periodically update FE-MX expert tiers self._forward_count += 1 if (self.use_femx_experts and self._forward_count % self.femx_tier_interval == 0): cold, warm, hot = self.update_expert_tiers() if (self._forward_count % (self.femx_tier_interval * 10) == 0 and not getattr(self, '_quiet', False)): print(f" [FE-MX] Expert tiers: {cold} cold(FP4) / " f"{warm} warm(FP6) / {hot} hot(FP8)") return output.view(orig_shape) # ============================================================================ # TRANSFORMER BLOCK # ============================================================================ class FusedTransformerBlock(nn.Module): """Fused transformer block with all optimizations.""" def __init__(self, config: FireEchoConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx # Check if this is a critical layer (keep in higher precision) is_critical = layer_idx in config.critical_layers self.attn = FusedAttention( config.dim, config.num_heads, config.num_kv_heads, head_dim=config.head_dim, layer_idx=layer_idx, use_fused_qkv=config.use_fused_qkv, use_native_cutlass=config.use_native_cutlass, use_quantum_matmul=config.use_quantum_matmul, use_dsmem_cluster=config.use_dsmem_cluster, rope_theta=config.rope_theta, max_seq_len=config.max_seq_len, attn_bias=config.attn_bias, use_qk_norm=config.use_qk_norm, qk_norm_per_head=config.qk_norm_per_head, partial_rotary_factor=config.partial_rotary_factor, use_fused_norm_qkv=config.use_fused_norm_qkv, use_fused_residual_norm=config.use_fused_residual_norm, use_fused_rope=config.use_fused_rope, use_gqa_native=config.use_gqa_native, use_goliath=config.use_goliath and config.use_nvfp4 and config.quantize_weights, goliath_bits=config.goliath_bits, ) if config.use_moe: self.ffn = MoEFFN(config.dim, config) else: self.ffn = FusedFFN( config.dim, config.intermediate_size, use_nvfp4=config.use_nvfp4 and config.quantize_weights, is_critical=is_critical, use_fused_swiglu=config.use_fused_swiglu, use_native_cutlass=config.use_native_cutlass, use_quantum_matmul=config.use_quantum_matmul, use_dsmem_cluster=config.use_dsmem_cluster, goliath_bits=config.goliath_bits, use_goliath=config.use_goliath, w4a4_mode=config.w4a4_mode, use_goliath_linear=config.use_goliath_linear, decode_skip_act_quant=config.decode_skip_act_quant, decode_act_quant_threshold=config.decode_act_quant_threshold, compute_residual=config.use_residual_correction, ) self.norm1 = nn.RMSNorm(config.dim) self.norm2 = nn.RMSNorm(config.dim) self.norm_after = config.norm_after self.use_fused_norm_qkv = config.use_fused_norm_qkv self.use_fused_residual_norm = config.use_fused_residual_norm def forward(self, x: torch.Tensor, kv_cache: Optional[PagedKVCache] = None, seq_id: int = 0, position: int = 0, use_cache: bool = False) -> torch.Tensor: if self.norm_after: # Post-norm (Molmo2/OLMo): norm applied AFTER attention/FFN # Structure: x = residual + norm(sublayer_out) # fused_residual_rmsnorm doesn't help here (norms the sum, not sublayer) x = x + self.norm1(self.attn(x, kv_cache, seq_id, position, use_cache)) x = x + self.norm2(self.ffn(x)) else: # Pre-norm (LLaMA/standard): norm applied BEFORE attention/FFN # Fusion 1: Pass norm_weight into attention for fused RMSNorm+QKV _norm1_eps = self.norm1.eps if self.norm1.eps is not None else 1e-6 if self.use_fused_norm_qkv and not self.training and x.is_cuda: attn_out = self.attn(x, kv_cache, seq_id, position, use_cache, norm_weight=self.norm1.weight, norm_eps=_norm1_eps) else: attn_out = self.attn(self.norm1(x), kv_cache, seq_id, position, use_cache) # Fusion 2: Fused residual + norm2 (x = x + attn_out; x_normed = RMSNorm(x)) _norm2_eps = self.norm2.eps if self.norm2.eps is not None else 1e-6 if self.use_fused_residual_norm and not self.training and x.is_cuda: x, x_normed = fused_residual_rmsnorm(x, attn_out, self.norm2.weight, _norm2_eps) x = x + self.ffn(x_normed) else: x = x + attn_out x = x + self.ffn(self.norm2(x)) return x # ============================================================================ # FIREECHO EAGLE-3 DRAFT HEAD — Lightweight speculative decoding # ============================================================================ # Multi-layer feature fusion from target model → 1 transformer layer → logits. # Generates K draft tokens autoregressively, verified by target model in one pass. # Acceptance via rejection sampling preserves exact output distribution. class _MPSLinear(nn.Module): """MPS-inspired low-rank factored linear layer. For 2D weight matrices, MPS tensor decomposition reduces to low-rank factorization: W[N,K] ≈ U[N,r] @ V[r,K]. Forward: y = x @ V^T @ U^T (two smaller matmuls instead of one large) Params: r*(N+K) instead of N*K Example: Linear(2048, 4096) with bond_dim=256 Original: 4096*2048 = 8.39M params MPS: 256*(4096+2048) = 1.57M params (5.3x compression) Speedup: 2*M*2048*256 + 2*M*256*4096 vs 2*M*2048*4096 = ~2.5x fewer FLOPs """ def __init__(self, in_features: int, out_features: int, bond_dim: int = 256, bias: bool = False): super().__init__() self.in_features = in_features self.out_features = out_features self.bond_dim = min(bond_dim, min(in_features, out_features)) # Low-rank factors: W ≈ U @ V self.V = nn.Parameter(torch.empty(self.bond_dim, in_features, dtype=torch.bfloat16)) self.U = nn.Parameter(torch.empty(out_features, self.bond_dim, dtype=torch.bfloat16)) if bias: self.bias = nn.Parameter(torch.zeros(out_features, dtype=torch.bfloat16)) else: self.register_parameter('bias', None) nn.init.kaiming_uniform_(self.V) nn.init.kaiming_uniform_(self.U) def init_from_weight(self, weight: torch.Tensor): """Initialize U, V from existing weight via truncated SVD. weight: [out_features, in_features] (standard nn.Linear convention) """ w = weight.float().cpu() try: U_full, S, Vh = torch.linalg.svd(w, full_matrices=False) except Exception: # Fallback: random init (SVD can fail on very large matrices) return r = self.bond_dim # U_full: [N, min(N,K)], S: [min(N,K)], Vh: [min(N,K), K] U_trunc = U_full[:, :r] * S[:r].unsqueeze(0).sqrt() # [N, r] V_trunc = S[:r].unsqueeze(1).sqrt() * Vh[:r, :] # [r, K] self.U.data.copy_(U_trunc.to(torch.bfloat16).to(self.U.device)) self.V.data.copy_(V_trunc.to(torch.bfloat16).to(self.V.device)) def forward(self, x: torch.Tensor) -> torch.Tensor: # y = x @ V^T @ U^T = (x @ V^T) @ U^T # x: [..., in_features], output: [..., out_features] h = F.linear(x, self.V, None) # [..., bond_dim] out = F.linear(h, self.U, self.bias) # [..., out_features] return out def extra_repr(self) -> str: return (f'in_features={self.in_features}, out_features={self.out_features}, ' f'bond_dim={self.bond_dim}, bias={self.bias is not None}') class ZeroCenteredRMSNorm(nn.Module): """Zero-Centered RMSNorm — prevents activation spikes in MoE routing. Standard RMSNorm: x / RMS(x) * weight Zero-Centered: (x - mean(x)) / RMS(x - mean(x)) * weight Subtracting the mean before normalization prevents large activation outliers from dominating the normalization, keeping expert routing stable across 128 experts. From Qwen3-Next (arxiv 2509.17765). """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: x_centered = x - x.mean(dim=-1, keepdim=True) rms = x_centered.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() return x_centered * rms * self.weight class FireEchoEagleHead(nn.Module): """EAGLE-3 / FE-XT draft head for speculative decoding. Architecture: FC compress: concat(h_low, h_mid, h_high, embed) → dim D× (self-attention + SwiGLU FFN) layers via ModuleList Shared lm_head (from main model, not duplicated) Supports D=1..20 layers. D=2 = ~115M params, D=8 = ~400M, D=16 = ~1.2B. All BF16 (no FP4 quantization — head runs during cheap draft phase). """ def __init__(self, dim: int, num_capture_layers: int = 3, num_heads: int = 16, ffn_mult: int = 2, max_draft_len: int = 16, num_layers: int = 2, num_medusa_heads: int = 0): super().__init__() self.dim = dim self.num_capture_layers = num_capture_layers self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.max_draft_len = max_draft_len self.num_layers = num_layers self.num_medusa_heads = num_medusa_heads # Feature compression: (num_capture_layers * dim + dim) → dim self.fc_compress = nn.Linear( (num_capture_layers + 1) * dim, dim, bias=False) # Hebbian memory injection: projects memory context into feature space # Zero-initialized so existing checkpoints are unaffected (no-op until trained) self.memory_proj = nn.Linear(dim, dim, bias=False) nn.init.zeros_(self.memory_proj.weight) # D transformer layers via ModuleList (was hardcoded layer 1/2) ffn_dim = dim * ffn_mult self.layers = nn.ModuleList() for _ in range(num_layers): self.layers.append(nn.ModuleDict({ 'norm_attn': ZeroCenteredRMSNorm(dim), 'q_proj': nn.Linear(dim, dim, bias=False), 'k_proj': nn.Linear(dim, dim, bias=False), 'v_proj': nn.Linear(dim, dim, bias=False), 'o_proj': nn.Linear(dim, dim, bias=False), 'norm_ffn': ZeroCenteredRMSNorm(dim), 'gate_proj': nn.Linear(dim, ffn_dim, bias=False), 'up_proj': nn.Linear(dim, ffn_dim, bias=False), 'down_proj': nn.Linear(ffn_dim, dim, bias=False), })) # Draft KV caches: one [k, v] pair per layer self._draft_kv = None # list of (k, v) tensors, allocated on first use self._draft_pos = 0 # Output projection: shared with main model (set externally) self.lm_head = None # Set to engine.lm_head during enable_eagle() # Medusa heads: K independent residual MLPs that predict positions t+2..t+K+1 # Each head transforms hidden state before shared lm_head (lightweight ~8M params each) # Zero-init last layer so heads start as identity (base lm_head predictions) if num_medusa_heads > 0: self.medusa_heads = nn.ModuleList() for _ in range(num_medusa_heads): head = nn.Sequential( nn.Linear(dim, dim, bias=False), nn.SiLU(), nn.Linear(dim, dim, bias=False), ) nn.init.zeros_(head[2].weight) # zero-init → starts as base lm_head self.medusa_heads.append(head) else: self.medusa_heads = None def reset_draft_cache(self): """Reset draft KV cache for a new speculation round.""" self._draft_pos = 0 if self._draft_kv is not None: for k, v in self._draft_kv: k.zero_() v.zero_() def _ensure_draft_cache(self, device, dtype, batch_size=1): """Allocate draft KV caches on first use (one per layer).""" if self._draft_kv is not None and self._draft_kv[0][0].shape[0] == batch_size: return self._draft_kv = [] for _ in range(self.num_layers): k = torch.zeros(batch_size, self.num_heads, self.max_draft_len, self.head_dim, dtype=dtype, device=device) v = torch.zeros_like(k) self._draft_kv.append((k, v)) def _draft_attn(self, x, pos, layer, draft_k, draft_v): """Self-attention block with draft KV cache (preserves autograd).""" B = x.shape[0] residual = x x_n = layer['norm_attn'](x) q = layer['q_proj'](x_n).view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) k = layer['k_proj'](x_n).view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) v = layer['v_proj'](x_n).view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) # Store detached for future steps (avoid in-place autograd issues) draft_k[:, :, pos:pos + 1, :] = k.detach() draft_v[:, :, pos:pos + 1, :] = v.detach() # Attend: prior positions (detached) + current (in graph) k_ctx = torch.cat([draft_k[:, :, :pos, :], k], dim=2) if pos > 0 else k v_ctx = torch.cat([draft_v[:, :, :pos, :], v], dim=2) if pos > 0 else v attn_out = F.scaled_dot_product_attention( q, k_ctx, v_ctx, is_causal=False) attn_out = attn_out.transpose(1, 2).reshape(B, 1, self.dim) return residual + layer['o_proj'](attn_out) def forward(self, features: torch.Tensor, token_embed: torch.Tensor, memory_context: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Single draft step. Args: features: [B, 1, num_capture_layers * dim] (first step, from target) or [B, 1, dim] (subsequent steps, from own output) token_embed: [B, 1, dim] embedding of previous token memory_context: Optional [B, 1, dim] Hebbian memory retrieval. Provides grounding context — the model can learn to use or ignore it via memory_proj (zero-init = no-op). Returns: hidden: [B, 1, dim] — hidden state for next draft step logits: [B, 1, vocab_size] — draft logits """ B = features.shape[0] # Compress features + token embedding if features.shape[-1] == self.num_capture_layers * self.dim: # First step: multi-layer features from target model x = self.fc_compress(torch.cat([features, token_embed], dim=-1)) else: # Subsequent steps: own hidden state + token embedding pad_size = (self.num_capture_layers - 1) * self.dim pad = torch.zeros(B, 1, pad_size, dtype=features.dtype, device=features.device) x = self.fc_compress( torch.cat([features, pad, token_embed], dim=-1)) # Inject Hebbian memory context (additive, gated by learned projection) if memory_context is not None: x = x + self.memory_proj(memory_context) self._ensure_draft_cache(x.device, x.dtype, batch_size=B) pos = self._draft_pos # Run through all D layers _hayabusa = getattr(self, '_hayabusa', None) for i, layer in enumerate(self.layers): # FE-H: ensure cold layers are on GPU before use if _hayabusa is not None: _hayabusa.ensure_layer_on_gpu(i) draft_k, draft_v = self._draft_kv[i] x = self._draft_attn(x, pos, layer, draft_k, draft_v) residual = x x_n = layer['norm_ffn'](x) x = residual + layer['down_proj']( F.silu(layer['gate_proj'](x_n)) * layer['up_proj'](x_n)) # Increment position once for all layers self._draft_pos += 1 # Logits via shared lm_head logits = self.lm_head(x) # Medusa heads: predict positions t+2..t+K+1 from same hidden state if self.medusa_heads is not None and self.training: medusa_logits = [] for head in self.medusa_heads: # Residual: h + head(h) → lm_head medusa_logits.append(self.lm_head(x + head(x))) return x, logits, medusa_logits return x, logits @torch.no_grad() def medusa_draft(self, capture_features: list, first_token: torch.Tensor, embed_fn: nn.Embedding, memory_context: Optional[torch.Tensor] = None, ) -> Tuple[list, list]: """Medusa draft: ONE forward pass → K+1 predictions (no sequential steps). Unlike EAGLE sequential drafting (K forward passes), Medusa runs the backbone once and uses K independent heads to predict K different future positions. Draft cost: 1 × backbone_time (vs K × backbone_time for sequential). Returns: draft_tokens: list of K+1 tensors [1, 1] — predicted token IDs draft_logits: list of K+1 tensors [1, 1, V] — prediction logits """ if self.medusa_heads is None: raise ValueError("Medusa heads not initialized (num_medusa_heads=0)") # Prepare features [1, 1, num_layers * dim] features = torch.cat([f[:, -1:, :] for f in capture_features], dim=-1) tok_emb = embed_fn(first_token) # [1, 1, dim] self.reset_draft_cache() self._ensure_draft_cache(features.device, features.dtype, batch_size=1) # Single backbone forward was_training = self.training self.eval() # Compress + run all layers B = features.shape[0] x = self.fc_compress(torch.cat([features, tok_emb], dim=-1)) if memory_context is not None: x = x + self.memory_proj(memory_context) self._ensure_draft_cache(x.device, x.dtype, batch_size=B) pos = self._draft_pos for i, layer in enumerate(self.layers): draft_k, draft_v = self._draft_kv[i] x = self._draft_attn(x, pos, layer, draft_k, draft_v) residual = x x_n = layer['norm_ffn'](x) x = residual + layer['down_proj']( F.silu(layer['gate_proj'](x_n)) * layer['up_proj'](x_n)) self._draft_pos += 1 # Head 0: base lm_head (position t+1) base_logits = self.lm_head(x) # [B, 1, V] base_token = base_logits[:, -1:, :].argmax(dim=-1) # [B, 1] draft_tokens = [base_token] draft_logits = [base_logits] # Medusa heads: positions t+2..t+K+1 for head in self.medusa_heads: m_logits = self.lm_head(x + head(x)) # [B, 1, V] m_token = m_logits[:, -1:, :].argmax(dim=-1) # [B, 1] draft_tokens.append(m_token) draft_logits.append(m_logits) if was_training: self.train() return draft_tokens, draft_logits def upgrade_to_mps(self, bond_dim: int = 256): """Replace FFN weight matrices with MPS-inspired low-rank factorization. For 2D weight matrices, MPS reduces to W[N,K] ≈ U[N,r] @ V[r,K]. This gives ~5x compression AND faster forward (two small matmuls). Only replaces gate_proj, up_proj, down_proj (FFN — biggest params). Keeps q/k/v/o_proj as nn.Linear (attention-critical, small). Initializes from existing weights via truncated SVD. """ replaced = 0 params_before = sum(p.numel() for p in self.parameters()) for i, layer in enumerate(self.layers): for proj_name in ('gate_proj', 'up_proj', 'down_proj'): old = layer[proj_name] device = old.weight.device mps_layer = _MPSLinear( old.in_features, old.out_features, bond_dim=bond_dim, bias=old.bias is not None) # Initialize from existing weights via SVD truncation mps_layer.init_from_weight(old.weight.data) if old.bias is not None and mps_layer.bias is not None: mps_layer.bias.data.copy_(old.bias.data) # Move to same device as original weights (CPU→GPU) mps_layer = mps_layer.to(device) layer[proj_name] = mps_layer replaced += 1 del old # Free memory immediately params_after = sum(p.numel() for p in self.parameters()) compression = params_before / params_after if params_after > 0 else 0 print(f" [MPS] Replaced {replaced} FFN layers with bond_dim={bond_dim}") print(f" [MPS] Params: {params_before/1e6:.1f}M → {params_after/1e6:.1f}M " f"({compression:.1f}x compression)") def upgrade_to_quantum_linear(self): """Swap nn.Linear FFN layers for GoliathQuantumLinear (FP8 fwd + Quantum bwd). Only replaces gate_proj, up_proj, down_proj (FFN — biggest params). Keeps q/k/v/o_proj as nn.Linear (attention-critical, small). Copies existing weights into the new BF16 master weights. """ from goliath_kernel import GoliathQuantumLinear replaced = 0 for i, layer in enumerate(self.layers): for proj_name in ('gate_proj', 'up_proj', 'down_proj'): old = layer[proj_name] new_layer = GoliathQuantumLinear( old.in_features, old.out_features, bias=old.bias is not None) # Copy existing weights new_layer.weight.data.copy_(old.weight.data.to(torch.bfloat16)) if old.bias is not None and new_layer.bias is not None: new_layer.bias.data.copy_(old.bias.data.to(torch.bfloat16)) layer[proj_name] = new_layer replaced += 1 total_params = sum(p.numel() for p in self.parameters()) quantum_params = sum( p.numel() for layer in self.layers for name in ('gate_proj', 'up_proj', 'down_proj') for p in layer[name].parameters()) print(f" [GoliathQuantum] Replaced {replaced} FFN layers " f"({quantum_params/1e6:.1f}M / {total_params/1e6:.1f}M params)") def load_legacy_checkpoint(self, state_dict: dict): """Load a D=2 legacy checkpoint (hardcoded layer1/layer2 naming). Maps old-style keys (norm1, q_proj, norm3, q_proj2, etc.) to new ModuleList keys (layers.0.norm_attn, layers.1.q_proj, etc.). Extra layers (2+) are left randomly initialized. """ legacy_map_layer0 = { 'norm1': 'norm_attn', 'q_proj': 'q_proj', 'k_proj': 'k_proj', 'v_proj': 'v_proj', 'o_proj': 'o_proj', 'norm2': 'norm_ffn', 'gate_proj': 'gate_proj', 'up_proj': 'up_proj', 'down_proj': 'down_proj', } legacy_map_layer1 = { 'norm3': 'norm_attn', 'q_proj2': 'q_proj', 'k_proj2': 'k_proj', 'v_proj2': 'v_proj', 'o_proj2': 'o_proj', 'norm4': 'norm_ffn', 'gate_proj2': 'gate_proj', 'up_proj2': 'up_proj', 'down_proj2': 'down_proj', } new_sd = {} for key, val in state_dict.items(): mapped = False # Layer 0 mappings for old_prefix, new_name in legacy_map_layer0.items(): if key.startswith(old_prefix + '.') or key == old_prefix: suffix = key[len(old_prefix):] new_key = f'layers.0.{new_name}{suffix}' new_sd[new_key] = val mapped = True break if mapped: continue # Layer 1 mappings for old_prefix, new_name in legacy_map_layer1.items(): if key.startswith(old_prefix + '.') or key == old_prefix: suffix = key[len(old_prefix):] new_key = f'layers.1.{new_name}{suffix}' new_sd[new_key] = val mapped = True break if mapped: continue # Pass through unchanged (fc_compress, memory_proj, lm_head, etc.) new_sd[key] = val missing, unexpected = self.load_state_dict(new_sd, strict=False) if missing: n_new = sum(1 for k in missing if k.startswith('layers.')) print(f" [EAGLE] Loaded legacy D=2 checkpoint. " f"{n_new} new layer params initialized randomly.") return missing, unexpected @torch.no_grad() def generate_draft(self, capture_features: list, first_token: torch.Tensor, embed_fn: nn.Embedding, depth: int = 5, memory_context: Optional[torch.Tensor] = None, ) -> Tuple[list, list]: """Generate K draft tokens greedily (single-path speculation). Args: capture_features: list of [B, 1, dim] hidden states from target layers first_token: [B, 1] token IDs (last token from target model) embed_fn: embedding layer for token → embedding depth: number of draft tokens to generate memory_context: Optional [B, 1, dim] Hebbian memory retrieval. Queried once from target model, reused for all K steps. Provides grounding so draft head can leverage stored facts. Returns: draft_tokens: list of [B, 1] token ID tensors draft_logits: list of [B, 1, V] logit tensors """ self.reset_draft_cache() # Concatenate multi-layer features: [B, 1, num_capture_layers * dim] features = torch.cat(capture_features, dim=-1) tok_emb = embed_fn(first_token) # [B, 1, dim] draft_tokens = [] draft_logits = [] hidden = None for d in range(depth): if d == 0: # First step: pass memory_context (from Hebbian retrieval) hidden, logits = self.forward(features, tok_emb, memory_context) else: tok_emb = embed_fn(tok) # Subsequent steps: same memory_context (queried once, reused) hidden, logits = self.forward(hidden, tok_emb, memory_context) draft_logits.append(logits) tok = logits[:, -1:, :].argmax(dim=-1) # greedy draft_tokens.append(tok) return draft_tokens, draft_logits @torch.no_grad() def generate_draft_tree(self, capture_features: list, first_token: torch.Tensor, embed_fn: nn.Embedding, depth: int = 5, num_branches: int = 2, memory_context: Optional[torch.Tensor] = None, ) -> list: """Generate a tree of draft tokens with multiple paths (EAGLE-2 style). Branches at step 0 using top-K tokens, then continues greedily. Returns multiple candidate paths — verify best one. Args: capture_features: list of [B, 1, dim] hidden states from target layers first_token: [B, 1] token IDs embed_fn: embedding layer depth: tokens per path num_branches: number of branches at step 0 (top-K) memory_context: Optional [B, 1, dim] Hebbian memory Returns: paths: list of (draft_tokens, draft_logits) tuples, one per branch. Each draft_tokens is list of [B, 1] tensors. """ features = torch.cat(capture_features, dim=-1) tok_emb = embed_fn(first_token) # Step 0: get top-K tokens and their hidden states self.reset_draft_cache() hidden_0, logits_0 = self.forward(features, tok_emb, memory_context) pos_after_0 = self._draft_pos # save position counter topk = logits_0[:, -1:, :].topk(num_branches, dim=-1) # [B, 1, K] topk_tokens = topk.indices # [B, 1, K] # Save draft KV cache state after step 0 (shared across branches) saved_kv = [(k.clone(), v.clone()) for k, v in self._draft_kv] paths = [] for b in range(num_branches): # Restore draft KV cache to post-step-0 state for i, (k, v) in enumerate(saved_kv): self._draft_kv[i][0].copy_(k) self._draft_kv[i][1].copy_(v) self._draft_pos = pos_after_0 branch_tok = topk_tokens[:, :, b:b+1].squeeze(-1) # [B, 1] branch_tokens = [branch_tok] branch_logits = [logits_0] hidden = hidden_0 tok = branch_tok # Continue greedily from branch token for d in range(1, depth): tok_emb = embed_fn(tok) hidden, logits = self.forward(hidden, tok_emb, memory_context) branch_logits.append(logits) tok = logits[:, -1:, :].argmax(dim=-1) branch_tokens.append(tok) paths.append((branch_tokens, branch_logits)) return paths @torch.no_grad() def generate_draft_tree_batched(self, capture_features: list, first_token: torch.Tensor, embed_fn: nn.Embedding, depth: int = 5, b: int = 8, memory_context: Optional[torch.Tensor] = None, ) -> list: """FE-XT: Tree drafting with b parallel branches (batched, not sequential). All b branches run through the draft head simultaneously using batch dim. Much faster than sequential generate_draft_tree for large b. Step 0: shared features → top_k(b) → b branch tokens Steps 1..depth-1: batch all b branches through head Args: capture_features: list of [1, 1, dim] hidden states from target layers first_token: [1, 1] token IDs embed_fn: embedding layer depth: tokens per branch b: number of branches (tree width) memory_context: Optional [1, 1, dim] Hebbian memory Returns: tree_tokens: list of b lists, each containing depth [1, 1] token tensors """ features = torch.cat(capture_features, dim=-1) # [1, 1, layers*dim] tok_emb = embed_fn(first_token) # [1, 1, dim] # Step 0: shared computation (B=1) self.reset_draft_cache() hidden_0, logits_0 = self.forward(features, tok_emb, memory_context) # hidden_0: [1, 1, dim], logits_0: [1, 1, V] # Branch: top-b tokens topk = logits_0[:, -1, :].topk(b, dim=-1) # values: [1, b], indices: [1, b] branch_first_tokens = topk.indices[0] # [b] # Expand hidden for b branches: [1, 1, dim] → [b, 1, dim] hidden = hidden_0.expand(b, -1, -1).contiguous() # Save step-0 position and expand draft KV cache for b branches pos_after_0 = self._draft_pos # Re-allocate draft cache with batch_size=b, copy step-0 KV into all branches saved_kv = [(k.clone(), v.clone()) for k, v in self._draft_kv] self._ensure_draft_cache(hidden.device, hidden.dtype, batch_size=b) for i, (sk, sv) in enumerate(saved_kv): # Broadcast step-0 KV (B=1) across all b branches self._draft_kv[i][0][:] = sk.expand(b, -1, -1, -1) self._draft_kv[i][1][:] = sv.expand(b, -1, -1, -1) self._draft_pos = pos_after_0 # Collect tokens per branch tree_tokens = [[branch_first_tokens[i:i+1].unsqueeze(0)] # [1, 1] for i in range(b)] # Expand memory_context for batched forward mem_b = None if memory_context is not None: mem_b = memory_context.expand(b, -1, -1).contiguous() # Steps 1..depth-1: batched through all b branches tok = branch_first_tokens.unsqueeze(1) # [b, 1] for d in range(1, depth): tok_emb = embed_fn(tok) # [b, 1, dim] hidden, logits = self.forward(hidden, tok_emb, mem_b) tok = logits[:, -1:, :].argmax(dim=-1) # [b, 1] greedy for i in range(b): tree_tokens[i].append(tok[i:i+1]) # [1, 1] # Reset cache back to B=1 for next round self._draft_kv = None return tree_tokens # ============================================================================ # FE-H: FIREECHO HAYABUSA — Async Prefetch Offload for Draft Head # ============================================================================ # Named after the peregrine falcon (hayabusa, 隼) — fastest animal alive. # Exploits idle PCIe bandwidth during verify phase to prefetch draft head # layers from CPU pinned memory to GPU. # # Based on SP-MoE (arxiv 2510.10302) adapted for draft head offload. # PCIe 5.0: 100MB layer / 50 GB/s = 2ms/layer. 35ms verify = 17 layers. # → Can prefetch ALL cold layers (D=16-20) during a single verify phase. class FireEchoHayabusa: """FE-H: Async prefetch offload manager for FE-XT draft head. Splits draft head layers into: - Hot layers (0..cutoff_L-1): always on GPU, used every draft step - Cold layers (cutoff_L..D-1): on CPU pinned memory, prefetched JIT during the verify phase when PCIe bandwidth is idle Uses a dedicated CUDA stream for non-blocking H2D transfers. Uses per-layer CUDA events for fine-grained synchronization. LRU eviction keeps GPU cold-layer buffer under budget. Architecture: Draft phase: sync_cold_layers() → ensure needed layers on GPU Verify phase: async_prefetch() → non-blocking H2D while target runs After draft: evict_cold_layers() → free GPU memory for next round """ def __init__(self, draft_head: FireEchoEagleHead, cutoff_L: int = 4, gpu_budget_mb: float = 400.0): """Initialize FE-H offload manager. Args: draft_head: FireEchoEagleHead with D layers cutoff_L: Layers 0..cutoff_L-1 stay on GPU (hot) gpu_budget_mb: Max GPU memory for cold layers (MB) """ self.draft_head = draft_head self.cutoff_L = cutoff_L self.gpu_budget_mb = gpu_budget_mb self.num_layers = len(draft_head.layers) # CUDA async transfer infrastructure self.prefetch_stream = torch.cuda.Stream() self.events = {} # layer_idx → CUDA event # Track which layers are currently on GPU self.gpu_resident = set(range(cutoff_L)) # hot layers always resident self.lru_order = [] # cold layers currently on GPU (oldest first) # Per-layer param size (for budget tracking) self._layer_size_bytes = {} for i in range(self.num_layers): size = sum(p.numel() * p.element_size() for p in draft_head.layers[i].values() if isinstance(p, nn.Module) for p in p.parameters()) self._layer_size_bytes[i] = size # Offload cold layers to CPU pinned memory offloaded = 0 for i in range(cutoff_L, self.num_layers): layer = draft_head.layers[i] for name, module in layer.items(): if isinstance(module, nn.Module): for p in module.parameters(): p.data = p.data.cpu().pin_memory() self.events[i] = torch.cuda.Event() offloaded += 1 total_offloaded_mb = sum( self._layer_size_bytes.get(i, 0) for i in range(cutoff_L, self.num_layers) ) / 1e6 print(f" [FE-H] Hayabusa: offloaded {offloaded} cold layers " f"({total_offloaded_mb:.0f} MB) to CPU pinned memory. " f"Hot layers: 0-{cutoff_L-1} (always GPU).") def predict_needed_layers(self) -> list: """Predict which cold layers will be needed in next draft round. Simple heuristic: prefetch the first 6 cold layers (most likely needed). A well-trained head primarily uses early layers; deeper layers refine. Returns: List of layer indices to prefetch """ return list(range(self.cutoff_L, min(self.cutoff_L + 6, self.num_layers))) def async_prefetch(self, layer_indices: Optional[list] = None): """Non-blocking H2D transfer of cold layers on prefetch stream. Call this at the START of the verify phase. Transfers happen asynchronously while the target model processes 48 MoE layers. Args: layer_indices: Layers to prefetch (None = predict automatically) """ if layer_indices is None: layer_indices = self.predict_needed_layers() with torch.cuda.stream(self.prefetch_stream): for idx in layer_indices: if idx in self.gpu_resident: continue # already on GPU, skip if idx < self.cutoff_L: continue # hot layer, always resident layer = self.draft_head.layers[idx] for name, module in layer.items(): if isinstance(module, nn.Module): for p in module.parameters(): p.data = p.data.to('cuda', non_blocking=True) self.events[idx].record(self.prefetch_stream) self.gpu_resident.add(idx) self.lru_order.append(idx) def sync_layer(self, idx: int): """Wait for a specific cold layer to finish prefetching. Call this BEFORE the draft head uses the layer. If the layer was prefetched during verify, this is usually a no-op (already done). """ if idx in self.events and idx >= self.cutoff_L: self.events[idx].synchronize() def sync_all_cold(self): """Wait for ALL prefetched layers to finish transferring. Call this before starting the draft phase to ensure all needed layers are on GPU. """ self.prefetch_stream.synchronize() def evict_cold_layers(self): """Evict LRU cold layers back to CPU pinned memory if over budget. Call after the draft phase completes to free GPU memory. Layers are evicted oldest-first until under gpu_budget_mb. """ budget_bytes = self.gpu_budget_mb * 1e6 while self._gpu_cold_size_bytes() > budget_bytes and self.lru_order: evict_idx = self.lru_order.pop(0) if evict_idx < self.cutoff_L: continue # never evict hot layers layer = self.draft_head.layers[evict_idx] for name, module in layer.items(): if isinstance(module, nn.Module): for p in module.parameters(): p.data = p.data.cpu().pin_memory() self.gpu_resident.discard(evict_idx) def _gpu_cold_size_bytes(self) -> int: """Total bytes of cold layers currently on GPU.""" return sum(self._layer_size_bytes.get(i, 0) for i in self.gpu_resident if i >= self.cutoff_L) def ensure_layer_on_gpu(self, idx: int): """Ensure a specific layer is on GPU (sync if prefetching, fetch if not). Called by the draft head's forward pass before using each layer. """ if idx in self.gpu_resident: # Already on GPU; just make sure prefetch is done self.sync_layer(idx) return # Not on GPU and not prefetching — do a blocking transfer layer = self.draft_head.layers[idx] for name, module in layer.items(): if isinstance(module, nn.Module): for p in module.parameters(): p.data = p.data.to('cuda', non_blocking=False) self.gpu_resident.add(idx) self.lru_order.append(idx) def stats(self) -> dict: """Return FE-H status.""" cold_on_gpu = sum(1 for i in self.gpu_resident if i >= self.cutoff_L) cold_total = max(self.num_layers - self.cutoff_L, 0) return { 'hot_layers': self.cutoff_L, 'cold_layers': cold_total, 'cold_on_gpu': cold_on_gpu, 'cold_on_cpu': cold_total - cold_on_gpu, 'gpu_cold_mb': self._gpu_cold_size_bytes() / 1e6, 'gpu_budget_mb': self.gpu_budget_mb, } # ============================================================================ # IMAGE PREPROCESSING — Single-crop for Phi-4 SigLIP vision encoder # ============================================================================ def preprocess_image_for_phi4( image_path: str, crop_size: int = 448, device: Optional[str] = None, dtype: torch.dtype = torch.bfloat16, use_cuda_preproc: bool = True, ) -> torch.Tensor: """Load and preprocess an image for Phi-4's SigLIP vision encoder. Single-crop pipeline: resize to crop_size × crop_size, normalize to [-1, 1]. Uses CUDA-accelerated fused resize+normalize when available. Args: image_path: Path to image file (PNG, JPG, BMP, etc.) crop_size: Target resolution (default 448 from Phi-4 config) device: Target device (None = CPU, will be moved later) dtype: Target dtype (default bfloat16 for SigLIP) use_cuda_preproc: Use CUDA-accelerated preprocessing if available Returns: pixel_values: Tensor of shape [1, 3, crop_size, crop_size] """ from PIL import Image import numpy as np # CUDA-accelerated path: fused bicubic resize + normalize in single kernel if (use_cuda_preproc and _PREPROC_CUDA_AVAILABLE and _fireecho_preproc is not None and torch.cuda.is_available()): image = Image.open(image_path).convert("RGB") image_np = np.array(image, dtype=np.uint8) image_t = torch.from_numpy(image_np).cuda() pixel_values = _fireecho_preproc.cuda_image_preprocess(image_t, crop_size) if device is not None: pixel_values = pixel_values.to(device=device) return pixel_values # CPU fallback: PIL + TorchVision import torchvision.transforms as T image = Image.open(image_path).convert("RGB") transform = T.Compose([ T.Resize( (crop_size, crop_size), interpolation=T.InterpolationMode.BICUBIC, antialias=True, ), T.ToTensor(), # [0, 1] range, CHW format T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # -> [-1, 1] ]) pixel_values = transform(image).unsqueeze(0) # [1, 3, crop_size, crop_size] pixel_values = pixel_values.to(dtype=dtype) if device is not None: pixel_values = pixel_values.to(device=device) return pixel_values def _speechlib_mel(sample_rate: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = 7690.0): """SpeechLib-compatible mel filterbank (matches Phi-4 processing_phi4mm.py). Uses natural-log mel scale (1127 * ln(1 + f/700)) with triangular filters constructed in mel space. Returns [n_mels, n_fft//2+1] matrix. """ import numpy as np bank_width = int(n_fft // 2 + 1) def mel(f): return 1127.0 * np.log(1.0 + f / 700.0) def bin2mel(fft_bin): return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) def f2bin(f): return int((f * n_fft / sample_rate) + 0.5) klo = f2bin(fmin) + 1 khi = f2bin(fmax) khi = max(khi, klo) mlo = mel(fmin) mhi = mel(fmax) m_centers = np.linspace(mlo, mhi, n_mels + 2) ms = (mhi - mlo) / (n_mels + 1) matrix = np.zeros((n_mels, bank_width), dtype=np.float32) for m in range(n_mels): left = m_centers[m] center = m_centers[m + 1] right = m_centers[m + 2] for fft_bin in range(klo, khi): mbin = bin2mel(fft_bin) if left < mbin < right: matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms return matrix # Pre-computed SpeechLib mel filterbank (transposed for dot product) _SPEECHLIB_MEL_MATRIX = None def _get_speechlib_mel(): """Lazy-init cached SpeechLib mel matrix [257, 80] for dot product.""" global _SPEECHLIB_MEL_MATRIX if _SPEECHLIB_MEL_MATRIX is None: import numpy as np _SPEECHLIB_MEL_MATRIX = _speechlib_mel(16000, 512, 80, fmin=0.0, fmax=7690.0).T return _SPEECHLIB_MEL_MATRIX def preprocess_audio_for_phi4( audio_path: str, sample_rate: int = 16000, n_mels: int = 80, n_fft: int = 512, win_length: int = 400, hop_length: int = 160, fmax: float = 7690.0, preemphasis: float = 0.97, device: Optional[str] = None, dtype: torch.dtype = torch.bfloat16, use_cuda_preproc: bool = True, ) -> Tuple[torch.Tensor, int]: """Load WAV and compute log-mel spectrogram for Phi-4 Conformer encoder. Exactly replicates SpeechLib feature extraction from Phi-4's processing_phi4mm.py: - Per-frame pre-emphasis with np.roll + scale by 32768 - Hamming window - FFT → magnitude → power - SpeechLib mel filterbank (natural-log mel, triangles in mel space) - clip(1.0) → ln() Args: audio_path: Path to WAV file sample_rate: Target sample rate (16000 Hz) n_mels: Number of mel filterbank bins (80) n_fft: FFT size (512) win_length: Window length in samples (400 = 25ms at 16kHz) hop_length: Hop length in samples (160 = 10ms at 16kHz) fmax: Max frequency for mel filterbank (7690 Hz) preemphasis: Pre-emphasis coefficient (0.97) device: Target device dtype: Target dtype Returns: mel_features: [1, T, 80] log-mel tensor num_frames: Number of time frames """ import soundfile as sf import numpy as np # Load audio audio, sr = sf.read(audio_path, dtype='float32') if audio.ndim > 1: audio = np.squeeze(audio) if len(audio.shape) == 2: audio = audio.mean(1) # Resample to 16kHz if needed (matching Phi-4's resampling logic) if sr > 16000: try: from scipy.signal import resample_poly audio = resample_poly(audio, 1, sr // 16000).astype(np.float32) except ImportError: ratio = sample_rate / sr new_len = int(len(audio) * ratio) audio = np.interp( np.linspace(0, len(audio) - 1, new_len), np.arange(len(audio)), audio ).astype(np.float32) sr = 16000 elif 8000 < sr < 16000: try: from scipy.signal import resample_poly audio = resample_poly(audio, 1, sr // 8000).astype(np.float32) except ImportError: pass sr = 8000 elif sr < 8000: raise ValueError(f"Unsupported sample rate {sr}") # 8kHz → 16kHz upsampling (SpeechLib fillzero method — handled by padding FFT) fs = sr if sr in (8000, 16000) else 16000 # Select window for sample rate if fs == 8000: n_fft_eff = 256 win_length_eff = 200 hop_length_eff = 80 fft_window = np.hamming(200).astype(np.float32) else: n_fft_eff = n_fft # 512 win_length_eff = win_length # 400 hop_length_eff = hop_length # 160 fft_window = np.hamming(win_length).astype(np.float32) # SpeechLib frame count: cut remaining samples insufficient for a hop num_frames = (len(audio) - win_length_eff) // hop_length_eff + 1 if num_frames <= 0: audio = np.pad(audio, (0, win_length_eff - len(audio))) num_frames = 1 # CUDA accelerated path (16kHz only — 8kHz fillzero stays on CPU) if (use_cuda_preproc and _PREPROC_CUDA_AVAILABLE and _fireecho_preproc is not None and torch.cuda.is_available() and fs == 16000): audio_t = torch.from_numpy(audio.copy()).float().cuda() window_t = torch.from_numpy(fft_window.copy()).float().cuda() mel_mat_np = _get_speechlib_mel() # [257, 80] mel_mat_t = torch.from_numpy(mel_mat_np.copy()).float().cuda() mel_features = _fireecho_preproc.cuda_audio_pipeline( audio_t, window_t, mel_mat_t, n_fft_eff, win_length_eff, hop_length_eff, preemphasis ) # mel_features: [T, 80] float32 on CUDA num_frames_cuda = mel_features.shape[0] mel_tensor = mel_features.unsqueeze(0).to(dtype=dtype) # [1, T, 80] if device is not None: mel_tensor = mel_tensor.to(device=device) return mel_tensor, num_frames_cuda # CPU fallback path (also handles 8kHz fillzero) # Frame extraction (matching SpeechLib exactly) y_frames = np.array( [audio[s: s + win_length_eff] for s in range(0, hop_length_eff * num_frames, hop_length_eff)], dtype=np.float32, ) # SpeechLib per-frame pre-emphasis: roll within each frame, scale by 32768 y_frames_prev = np.roll(y_frames, 1, axis=1) y_frames_prev[:, 0] = y_frames_prev[:, 1] y_frames = (y_frames - preemphasis * y_frames_prev) * 32768.0 # Windowed FFT → magnitude spectrum S = np.fft.rfft(fft_window * y_frames, n=n_fft_eff, axis=1).astype(np.complex64) # 8kHz fillzero: pad to look like 16kHz with zeros in 4-8kHz bins if fs == 8000: frames_n, bins_n = S.shape padarray = np.zeros((frames_n, bins_n), dtype=np.complex64) S = np.concatenate((S[:, 0:-1], padarray), axis=1) spec = np.abs(S).astype(np.float32) # Power spectrum + SpeechLib mel filterbank + log spec_power = spec ** 2 mel_matrix = _get_speechlib_mel() # [257, 80] fbank_power = np.clip(spec_power.dot(mel_matrix), 1.0, None) log_fbank = np.log(fbank_power).astype(np.float32) # Convert to tensor mel_tensor = torch.from_numpy(log_fbank).unsqueeze(0).to(dtype=dtype) # [1, T, 80] if device is not None: mel_tensor = mel_tensor.to(device=device) return mel_tensor, num_frames # ============================================================================ # AGI ENGINE - Main Interface # ============================================================================ class FireEchoEngine(nn.Module): """ Unified AGI Inference Engine v2. Complete production engine with: - Real paged KV cache (500k+ token capacity) - NVFP4 quantization on weights - Per-layer Hebbian memory - Multimodal fusion support - HuggingFace model loading """ def __init__(self, config: Optional[FireEchoConfig] = None): super().__init__() self.config = config or FireEchoConfig() # Auto-detect critical layers (first/last ~15% kept in higher precision) if not self.config.critical_layers and self.config.auto_critical_layers: self.config.critical_layers = self.config.compute_critical_layers() # Embedding self.embed = nn.Embedding(self.config.vocab_size, self.config.dim) # Transformer layers self.layers = nn.ModuleList([ FusedTransformerBlock(self.config, i) for i in range(self.config.num_layers) ]) # Output self.norm = nn.RMSNorm(self.config.dim) self.lm_head = nn.Linear(self.config.dim, self.config.vocab_size, bias=False) # Tied embeddings (share embed/lm_head weights) if self.config.tie_word_embeddings: self.lm_head.weight = self.embed.weight # Hebbian memory (per-layer or shared) — NHL + Floreano + STELLAR _hebb_kw = dict( tau_e=0.95, rare_correlation_pct=0.1, use_neuromodulation=True, temperature=self.config.hebbian_temperature, weight_radius=self.config.hebbian_weight_radius, use_soft_hebbian=self.config.hebbian_use_soft, use_learned_modulator=self.config.hebbian_use_learned_modulator, modulator_entropy_weight=self.config.hebbian_modulator_entropy_weight, tau_fast=self.config.hebbian_tau_fast, tau_slow=self.config.hebbian_tau_slow, use_three_factor=self.config.hebbian_use_three_factor, # Pattern separation (Surget & Belzung 2022) separation_strength=self.config.hebbian_separation_strength, separation_threshold=self.config.hebbian_separation_threshold, slot_activation_threshold=self.config.hebbian_slot_activation_threshold, # Synaptic competition (neurogenesis + HAG) competition_strength=self.config.hebbian_competition_strength, slot_recycle_after=self.config.hebbian_slot_recycle_after, # Phase 4: Research-backed tuning (Duan ICLR 2023, ENN 2025) max_update_norm=self.config.hebbian_max_update_norm, trace_clip=self.config.hebbian_trace_clip, weight_clip=self.config.hebbian_weight_clip, noise_scale=self.config.hebbian_noise_scale, # Phase 4b: BCPNN adaptive slot lr + homeostatic thresholds adaptive_slot_lr=self.config.hebbian_adaptive_slot_lr, tau_age=self.config.hebbian_tau_age, importance_scale=self.config.hebbian_importance_scale, homeostatic_threshold=self.config.hebbian_homeostatic_threshold, threshold_incr=self.config.hebbian_threshold_incr, threshold_decr=self.config.hebbian_threshold_decr, # Phase 4d: Cosine retrieval cosine_retrieval=self.config.hebbian_cosine_retrieval, retrieval_tau=self.config.hebbian_retrieval_tau, sparsity_lambda=self.config.hebbian_sparsity_lambda, # Phase 4e-g: Advanced multi_timescale=self.config.hebbian_multi_timescale, working_memory_ratio=self.config.hebbian_working_memory_ratio, structural_plasticity=self.config.hebbian_structural_plasticity, merge_threshold=self.config.hebbian_merge_threshold, use_trace_filter=self.config.hebbian_use_trace_filter, # Phase 5: MESU + Bayesian reward use_mesu=self.config.hebbian_mesu, mesu_sigma_prior=self.config.hebbian_mesu_sigma_prior, mesu_sigma_res=self.config.hebbian_mesu_sigma_res, use_bayesian_reward=self.config.hebbian_bayesian_reward, reward_prior_mean=self.config.hebbian_reward_prior_mean, reward_prior_var=self.config.hebbian_reward_prior_var, # Identity-init for projections identity_init=self.config.hebbian_identity_init, # Memory consolidation use_consolidation=self.config.hebbian_consolidation, consolidation_interval=self.config.hebbian_consolidation_interval, consolidation_threshold=self.config.hebbian_consolidation_threshold, consolidated_decay=self.config.hebbian_consolidated_decay, consolidation_ratio=self.config.hebbian_consolidation_ratio, # Adaptive transfer filtering adaptive_transfer=self.config.hebbian_adaptive_transfer, transfer_ema_decay=self.config.hebbian_transfer_ema_decay, transfer_demotion=self.config.hebbian_transfer_demotion, transfer_demotion_rate=self.config.hebbian_transfer_demotion_rate, # KG consolidation gate kg_consolidation_gate=self.config.hebbian_kg_consolidation_gate, # Layer 4: Intrinsic reward use_intrinsic_reward=self.config.hebbian_intrinsic_reward, curiosity_ema_decay=self.config.hebbian_curiosity_ema_decay, competence_ema_decay=self.config.hebbian_competence_ema_decay, # Layer 5: SPAR error detection threshold error_threshold=self.config.hebbian_spar_error_threshold, # FE-MX compression use_femx=self.config.hebbian_use_femx, ) if self.config.use_hebbian: if self.config.hebbian_per_layer: self.hebbian = PerLayerHebbian( self.config.dim, self.config.num_layers, self.config.hebbian_memory_size, self.config.hebbian_lr, self.config.hebbian_decay, use_layer_specialization=self.config.hebbian_use_layer_specialization, **_hebb_kw ) else: self.hebbian = HebbianMemory( self.config.dim, self.config.hebbian_memory_size, self.config.hebbian_lr, self.config.hebbian_decay, **_hebb_kw ) else: self.hebbian = None # torch.compile on Hebbian paths for 20-40% speedup (Phase 4) if self.hebbian is not None and self.config.hebbian_compile: try: if isinstance(self.hebbian, PerLayerHebbian): for mem in self.hebbian.memories: mem.forward = torch.compile(mem.forward, mode="reduce-overhead") else: self.hebbian.forward = torch.compile( self.hebbian.forward, mode="reduce-overhead" ) except Exception as e: import warnings warnings.warn(f"torch.compile on Hebbian failed (non-fatal): {e}") # Learnable per-layer Hebbian residual alpha (Phase 4c) # Controls how much Hebbian output contributes vs base model per layer. # sigmoid(alpha) bounds the mixing weight to (0, 1). Starts at sigmoid(0.1) ≈ 0.52. self.hebbian_alpha = None if self.hebbian is not None and self.config.hebbian_learnable_alpha: if self.config.hebbian_per_layer: # Per-layer: each layer gets its own alpha self.hebbian_alpha = nn.ParameterList([ nn.Parameter(torch.tensor(0.1)) for _ in range(self.config.num_layers) ]) else: # Shared: single alpha for the one Hebbian pass self.hebbian_alpha = nn.Parameter(torch.tensor(0.1)) # Persistent memory manager (enabled via enable_persistent_memory()) self._persistent_memory = None # Multimodal (optional) if self.config.use_vision or self.config.use_audio: _vram_before = (torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0) self.multimodal = MultimodalFusion( config=self.config, use_vision=self.config.use_vision, use_audio=self.config.use_audio, ) if torch.cuda.is_available(): _vram_after = torch.cuda.memory_allocated() / 1e9 _vram_delta = _vram_after - _vram_before if _vram_delta > 0.01: print(f" Multimodal VRAM: {_vram_delta:.2f} GB " f"(vision={'ON' if self.config.use_vision else 'OFF'}, " f"audio={'ON' if self.config.use_audio else 'OFF'})") else: self.multimodal = None # KV Cache (real paging) self.kv_cache: Optional[PagedKVCache] = None self._init_kv_cache() # L2 cache manager — prefer CUTLASS wrapper, fall back to built-in # hardware-backed L2CacheManager (uses cudaAccessPolicyWindow) self._l2_manager = None if self.config.use_l2_cache_control: if self.config.use_native_cutlass and _CutlassL2CacheManager is not None: try: self._l2_manager = _CutlassL2CacheManager() except Exception: self._l2_manager = None if self._l2_manager is None: try: self._l2_manager = L2CacheManager( l2_size_mb=self.config.l2_cache_mb, ) except Exception: self._l2_manager = None # Generation state self._current_seq_id = 0 self._current_position = 0 # Convert to compute dtype self.to(self.config.compute_dtype) def _init_kv_cache(self, device: str = 'cuda'): """Initialize KV cache.""" head_dim = self.config.head_dim or (self.config.dim // self.config.num_heads) num_layers = self.config.num_layers or len(self.layers) self.kv_cache = PagedKVCache( self.config.max_kv_blocks, self.config.kv_block_size, num_layers, self.config.num_kv_heads, head_dim, self.config.compute_dtype, device, ) def cuda(self, device=None): """Move to CUDA and reinitialize KV cache. Pin hot weights in L2.""" result = super().cuda(device) self._init_kv_cache('cuda') if self._l2_manager is not None and next(self.parameters()).is_cuda: try: if isinstance(self._l2_manager, L2CacheManager): # Built-in hardware-backed manager (named pin API) self._l2_manager.pin("embed", self.embed.weight, hit_ratio=0.9) else: # CUTLASS manager (positional pin API) self._l2_manager.pin(self.embed.weight, hit_ratio=0.9) except Exception: pass return result def forward(self, input_ids: torch.Tensor, use_cache: bool = False, images: Optional[torch.Tensor] = None, audio: Optional[torch.Tensor] = None, position: int = 0, reward: Optional[float] = None) -> torch.Tensor: """ Forward pass through the engine. Args: input_ids: Token IDs [batch, seq_len] use_cache: Whether to use KV cache images: Optional images [batch, C, H, W] audio: Optional audio [batch, samples] position: Starting position for KV cache reward: Optional reward signal for Hebbian memory (STELLAR Eq. 23) Returns: logits: Output logits [batch, seq_len, vocab_size] """ B, S = input_ids.shape # Embed tokens x = self.embed(input_ids) # Multimodal fusion (masked token replacement) if self.multimodal is not None and (images is not None or audio is not None): x = self.multimodal.encode_and_fuse(input_ids, x, images=images, audio=audio) # Transformer layers for i, layer in enumerate(self.layers): # L2 prefetch: start loading layer i+1 weights while layer i runs if getattr(self, '_l2_prefetch_enabled', False) and i + 1 < len(self.layers): self._prefetch_layer_weights(i + 1) x = layer(x, self.kv_cache, self._current_seq_id, position, use_cache) # EAGLE-3: capture hidden states at selected layers if getattr(self, '_eagle_enabled', False) and i in self._eagle_capture_set: self._eagle_hidden_states[i] = x # Per-layer Hebbian (if enabled) if self.config.use_hebbian and self.config.hebbian_per_layer and self.hebbian is not None: if self.hebbian_alpha is not None: # Learnable alpha: x_out = x + sigmoid(alpha) * (hebbian(x) - x) x_hebb = self.hebbian(x, i, update=self.training, reward=reward) alpha = torch.sigmoid(self.hebbian_alpha[i]) x = x + alpha * (x_hebb - x) else: x = self.hebbian(x, i, update=self.training, reward=reward) # Shared Hebbian (if not per-layer) if self.config.use_hebbian and not self.config.hebbian_per_layer and self.hebbian is not None: if self.hebbian_alpha is not None: x_hebb = self.hebbian(x, update=self.training, reward=reward) alpha = torch.sigmoid(self.hebbian_alpha) x = x + alpha * (x_hebb - x) else: x = self.hebbian(x, update=self.training, reward=reward) x = self.norm(x) logits = self.lm_head(x) return logits @torch.no_grad() def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.9, images: Optional[torch.Tensor] = None, audio: Optional[torch.Tensor] = None, use_cache: bool = True, callback: Optional[callable] = None, preserve_hebbian: bool = False, stop_tokens: Optional[list] = None) -> torch.Tensor: """ Autoregressive generation with KV caching. Args: input_ids: Prompt token IDs [batch, seq_len] max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k sampling (0 to disable) top_p: Nucleus sampling threshold images: Optional images for multimodal audio: Optional audio for multimodal use_cache: Use KV cache for efficiency callback: Optional callback(token_id, position) called after each token preserve_hebbian: If True, preserve Hebbian fast weights across generation Returns: Generated token IDs [batch, seq_len + max_new_tokens] """ self.eval() B = input_ids.shape[0] prompt_len = input_ids.shape[1] # Save Hebbian state before reset if preserving hebb_state = None if preserve_hebbian and self.hebbian is not None: hebb_state = self.hebbian.save_state() # Reset cache for new generation self.reset_cache() self._current_seq_id = 0 # Restore Hebbian state after cache reset if hebb_state is not None: self.hebbian.load_state(hebb_state) # Prefill: process entire prompt at once (graph mode OFF for prefill) if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(input_ids, use_cache=use_cache, images=images, audio=audio, position=0) current_pos = prompt_len # Default stop tokens: <|endoftext|> (199999) and <|end|> (200020) if stop_tokens is None: stop_tokens = [199999, 200020] stop_set = set(stop_tokens) if stop_tokens else set() # Check if CUDA graph decode is available _use_graph = (getattr(self, '_graph_enabled', False) and use_cache and B == 1 and prompt_len < self._graph_max_seq_len) if _use_graph: # Unmask prompt positions in graph attention mask self.kv_cache._graph_attn_mask[0, 0, 0, :prompt_len] = 0 # Capture graph if not already done if not self._graph_captured: self._capture_decode_graph(current_pos) else: self.kv_cache._graph_mode = True generated = input_ids for i in range(max_new_tokens): next_logits = logits[:, -1, :] / max(temperature, 1e-8) # Top-k filtering if top_k > 0: top_k_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) next_logits[next_logits < top_k_vals[:, [-1]]] = float('-inf') # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) # Stop on EOS / end-of-turn tokens if B == 1 and next_token.item() in stop_set: break # Callback if callback is not None: callback(next_token.item() if B == 1 else next_token, current_pos) # Decode: CUDA graph replay or standard forward if _use_graph and current_pos < self._graph_max_seq_len: logits = self._graph_decode_step(next_token, current_pos) else: logits = self.forward(next_token, use_cache=use_cache, position=current_pos) current_pos += 1 # Auto-log episode to persistent memory (if enabled) if self._persistent_memory is not None: gen_len = generated.shape[1] - prompt_len # Use last logits as a cheap embedding proxy (mean of last hidden state # is not available here, so we use the logits softmax entropy as a signal) embedding = logits[:, -1, :].detach().float().mean(dim=0) # [V] → scalar-ish # Better: use cached hidden states if EAGLE capture is active if hasattr(self, '_eagle_hidden_states') and self._eagle_hidden_states: last_key = max(self._eagle_hidden_states.keys()) embedding = self._eagle_hidden_states[last_key].detach().float().mean(dim=(0, 1)) try: self._persistent_memory.log_episode( prompt_ids=input_ids, gen_len=gen_len, embedding=embedding) except Exception: pass # Never let memory logging crash generation return generated # ================================================================= # FireEcho Safety Module — Uncertainty Preservation & Domain Protection # ================================================================= # Ensures model honesty by: # 1. Never forcing confident answers when model is uncertain # 2. Protecting critical domain experts from aggressive compression # 3. Memory provides CONTEXT, not forced output # 4. Verification is never skipped # Uncertainty markers that should NEVER be suppressed UNCERTAINTY_MARKERS = [ "I don't know", "I'm not sure", "I'm uncertain", "I cannot determine", "I don't have enough information", "I'm not confident", "It's unclear", "I cannot say for certain", "I may be wrong", "This is uncertain", "I lack the knowledge", "I cannot verify", "I'm not aware", "To my knowledge", "I believe, but", "I think, but I'm not certain", "I cannot confirm", "I don't have access to", ] # Protected domains — experts handling these topics stay FP4/FP8 PROTECTED_DOMAINS = { 'medical': ['health', 'disease', 'medication', 'symptom', 'diagnosis', 'treatment', 'surgery', 'drug', 'dose', 'allergy'], 'legal': ['law', 'legal', 'court', 'contract', 'liability', 'rights', 'criminal', 'civil', 'attorney', 'lawsuit'], 'safety': ['danger', 'hazard', 'emergency', 'poison', 'toxic', 'warning', 'risk', 'fatal', 'lethal', 'injury'], 'financial': ['investment', 'tax', 'fraud', 'securities', 'audit', 'fiduciary', 'bankruptcy', 'liability', 'insurance'], } def _contains_uncertainty(self, text: str) -> bool: """Check if text contains uncertainty markers that should be preserved.""" text_lower = text.lower() for marker in self.UNCERTAINTY_MARKERS: if marker.lower() in text_lower: return True return False def _get_uncertainty_score(self, logits: torch.Tensor) -> float: """Compute uncertainty score from logits entropy. High entropy = model is uncertain about next token Returns value in [0, 1] where 1 = maximum uncertainty """ probs = F.softmax(logits[:, -1, :], dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) # Normalize by max possible entropy (log vocab_size) max_entropy = torch.log(torch.tensor(logits.shape[-1], dtype=torch.float)) return (entropy / max_entropy).mean().item() def enable_honesty_mode(self, preserve_uncertainty: bool = True, protect_critical_domains: bool = True, uncertainty_threshold: float = 0.7): """Enable honesty safeguards for responsible AI generation. When enabled: - Model uncertainty is preserved, never suppressed - Critical domain experts are protected from compression - Memory provides context but doesn't force answers - High-uncertainty outputs trigger verification Args: preserve_uncertainty: Keep uncertainty markers in output protect_critical_domains: Prevent INT2 compression on medical/legal/etc uncertainty_threshold: Entropy threshold for uncertainty warning (0-1) """ self._honesty_mode = True self._preserve_uncertainty = preserve_uncertainty self._protect_critical_domains = protect_critical_domains self._uncertainty_threshold = uncertainty_threshold if protect_critical_domains: self._apply_domain_protection() print(f" [Honesty Mode] Enabled") print(f" - Uncertainty preservation: {preserve_uncertainty}") print(f" - Critical domain protection: {protect_critical_domains}") print(f" - Uncertainty threshold: {uncertainty_threshold}") def _apply_domain_protection(self): """Mark experts handling critical domains as protected from INT2 demotion.""" if not self.config.use_moe: return protected_count = 0 for i, layer in enumerate(self.layers): if hasattr(layer.ffn, '_protected_experts'): continue # Already protected # Initialize protection mask n_experts = layer.ffn.num_experts layer.ffn._protected_experts = torch.zeros(n_experts, dtype=torch.bool, device='cuda') # For now, protect first and last 10% of experts (typically handle # specialized knowledge). More sophisticated: track which experts # activate on domain-specific prompts during warmup. protect_count = max(1, n_experts // 10) layer.ffn._protected_experts[:protect_count] = True layer.ffn._protected_experts[-protect_count:] = True protected_count += protect_count * 2 if protected_count > 0: print(f" [Domain Protection] {protected_count} experts protected across {len(self.layers)} layers") @torch.no_grad() def generate_honest(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.9, memory_context: Optional[torch.Tensor] = None, use_cache: bool = True, stop_tokens: Optional[list] = None, tokenizer = None) -> dict: """Generate with honesty safeguards. Unlike standard generate(), this method: 1. Never suppresses uncertainty in model output 2. Tracks and reports confidence levels 3. Uses memory as CONTEXT, not forced answers 4. Returns metadata about generation confidence Args: input_ids: Input token IDs [1, seq_len] max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k sampling top_p: Nucleus sampling threshold memory_context: Optional Hebbian memory context (provides hints, not answers) use_cache: Use KV cache stop_tokens: Tokens that stop generation tokenizer: Optional tokenizer for uncertainty detection in text Returns: dict with: 'output': Generated tensor [1, seq_len + new_tokens] 'confidence': Average confidence score (0-1, higher = more confident) 'uncertainty_detected': Whether model expressed uncertainty 'high_uncertainty_positions': Token positions with high entropy """ if not getattr(self, '_honesty_mode', False): # Fall back to standard generation if honesty mode not enabled output = self.generate(input_ids, max_new_tokens, temperature, top_k, top_p, use_cache=use_cache, stop_tokens=stop_tokens) return {'output': output, 'confidence': None, 'uncertainty_detected': False, 'high_uncertainty_positions': []} self.eval() device = input_ids.device generated = input_ids.clone() # Tracking entropy_scores = [] high_uncertainty_positions = [] threshold = getattr(self, '_uncertainty_threshold', 0.7) # Prefill logits = self.forward(input_ids, use_cache=use_cache) current_pos = input_ids.shape[1] stop_set = set(stop_tokens or []) for i in range(max_new_tokens): # Get next token logits next_logits = logits[:, -1, :] / max(temperature, 1e-5) # Track uncertainty (entropy) uncertainty = self._get_uncertainty_score(logits) entropy_scores.append(uncertainty) if uncertainty > threshold: high_uncertainty_positions.append(current_pos + i) # Standard sampling (top-k, top-p) if top_k > 0: v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) next_logits[next_logits < v[:, [-1]]] = float('-inf') if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) # Stop check if next_token.item() in stop_set: break # Next forward logits = self.forward(next_token, use_cache=use_cache, position=current_pos) current_pos += 1 # Check for uncertainty markers in generated text uncertainty_detected = False if tokenizer is not None: generated_text = tokenizer.decode(generated[0, input_ids.shape[1]:], skip_special_tokens=True) uncertainty_detected = self._contains_uncertainty(generated_text) # Compute average confidence (inverse of entropy) avg_entropy = sum(entropy_scores) / len(entropy_scores) if entropy_scores else 0 confidence = 1.0 - avg_entropy return { 'output': generated, 'confidence': confidence, 'uncertainty_detected': uncertainty_detected, 'high_uncertainty_positions': high_uncertainty_positions, 'avg_entropy': avg_entropy, } # ================================================================= # Ensemble Generation — Multiple Predictions for Safer Output # ================================================================= # Based on Seni & Elder "Ensemble Methods in Data Mining" (2010) # Key insight: "In a multitude of counselors there is safety" (Proverbs 24:6b) # # Implements: # 1. Temperature Diversity Ensemble (Bagging-style) # 2. Majority Voting for Token Selection # 3. Consensus Detection (agreement = confidence) # 4. Disagreement Flagging (hallucination detector) @torch.no_grad() def generate_ensemble(self, input_ids: torch.Tensor, max_new_tokens: int = 100, num_samples: int = 3, temperatures: list = None, top_k: int = 50, top_p: float = 0.9, consensus_threshold: float = 0.5, use_cache: bool = True, stop_tokens: Optional[list] = None, tokenizer = None) -> dict: """Generate using ensemble of samples for safer, more reliable output. Generates multiple outputs with temperature diversity, then combines via majority voting. Disagreement between samples indicates potential hallucination or uncertainty — these positions are flagged. Based on Seni & Elder (2010): "Ensemble Methods in Data Mining" Key principle: Diverse models + voting = reduced variance + error detection Args: input_ids: Input token IDs [1, seq_len] max_new_tokens: Maximum tokens to generate num_samples: Number of diverse samples to generate (default 3) temperatures: List of temperatures for diversity (default [0.3, 0.7, 1.0]) top_k: Top-k sampling parameter top_p: Nucleus sampling threshold consensus_threshold: Minimum agreement ratio to accept token (default 0.5) use_cache: Use KV caching stop_tokens: Tokens that stop generation tokenizer: Optional tokenizer for text output Returns: dict with: 'output': Consensus output tensor 'samples': List of all generated samples 'consensus_scores': Per-token agreement ratios 'disagreement_positions': Positions where samples disagreed 'confidence': Overall consensus confidence 'hallucination_risk': Estimated risk based on disagreement """ if temperatures is None: # Temperature diversity: cold (precise), medium, warm (creative) temperatures = [0.3, 0.7, 1.0][:num_samples] while len(temperatures) < num_samples: temperatures.append(0.7 + 0.3 * len(temperatures) / num_samples) self.eval() device = input_ids.device # Generate diverse samples samples = [] for temp in temperatures[:num_samples]: self.reset_cache() # Fresh KV cache for each sample sample = self.generate( input_ids.clone(), max_new_tokens=max_new_tokens, temperature=temp, top_k=top_k, top_p=top_p, use_cache=use_cache, stop_tokens=stop_tokens ) samples.append(sample) # Find minimum length for alignment min_len = min(s.shape[1] for s in samples) prompt_len = input_ids.shape[1] # Majority voting for consensus output consensus_tokens = [] consensus_scores = [] disagreement_positions = [] for pos in range(prompt_len, min_len): # Collect tokens at this position from all samples tokens_at_pos = [s[0, pos].item() for s in samples] # Count votes from collections import Counter vote_counts = Counter(tokens_at_pos) winner_token, winner_count = vote_counts.most_common(1)[0] # Consensus score = agreement ratio agreement = winner_count / len(tokens_at_pos) consensus_scores.append(agreement) if agreement < consensus_threshold: disagreement_positions.append(pos - prompt_len) consensus_tokens.append(winner_token) # Build consensus output consensus_output = torch.cat([ input_ids, torch.tensor([consensus_tokens], device=device, dtype=torch.long) ], dim=1) # Calculate overall confidence and hallucination risk avg_consensus = sum(consensus_scores) / len(consensus_scores) if consensus_scores else 1.0 disagreement_ratio = len(disagreement_positions) / len(consensus_scores) if consensus_scores else 0.0 # Hallucination risk: high disagreement = high risk hallucination_risk = min(1.0, disagreement_ratio * 2) # Scale: 50% disagreement = 100% risk return { 'output': consensus_output, 'samples': samples, 'consensus_scores': consensus_scores, 'disagreement_positions': disagreement_positions, 'confidence': avg_consensus, 'hallucination_risk': hallucination_risk, 'num_samples': num_samples, 'temperatures_used': temperatures[:num_samples], } @torch.no_grad() def generate_safe(self, input_ids: torch.Tensor, max_new_tokens: int = 100, safety_level: str = 'medium', use_cache: bool = True, stop_tokens: Optional[list] = None, tokenizer = None) -> dict: """Generate with automatic safety level selection. Convenience method that selects appropriate ensemble/verification based on desired safety level. Args: input_ids: Input token IDs max_new_tokens: Maximum tokens to generate safety_level: 'low' (fast), 'medium' (balanced), 'high' (safest) use_cache: Use KV caching stop_tokens: Stop tokens tokenizer: Optional tokenizer Safety Levels: 'low': Single generation, basic uncertainty tracking 'medium': 3-sample ensemble, majority voting 'high': 5-sample ensemble, strict consensus, hallucination flagging Returns: dict with output, confidence, and safety metadata """ if safety_level == 'low': # Fast path: single generation with uncertainty tracking return self.generate_honest( input_ids, max_new_tokens=max_new_tokens, temperature=0.7, use_cache=use_cache, stop_tokens=stop_tokens, tokenizer=tokenizer ) elif safety_level == 'medium': # Balanced: 3-sample ensemble result = self.generate_ensemble( input_ids, max_new_tokens=max_new_tokens, num_samples=3, temperatures=[0.3, 0.7, 1.0], consensus_threshold=0.5, use_cache=use_cache, stop_tokens=stop_tokens, tokenizer=tokenizer ) result['safety_level'] = 'medium' return result elif safety_level == 'high': # Maximum safety: 5-sample ensemble with strict consensus result = self.generate_ensemble( input_ids, max_new_tokens=max_new_tokens, num_samples=5, temperatures=[0.1, 0.3, 0.5, 0.7, 1.0], consensus_threshold=0.6, # Stricter agreement required use_cache=use_cache, stop_tokens=stop_tokens, tokenizer=tokenizer ) result['safety_level'] = 'high' # Additional check: if hallucination risk > 30%, flag output if result['hallucination_risk'] > 0.3: result['safety_warning'] = ( f"HIGH HALLUCINATION RISK ({result['hallucination_risk']*100:.0f}%): " f"Samples disagreed at {len(result['disagreement_positions'])} positions. " "Verify this output independently." ) return result else: raise ValueError(f"Unknown safety_level: {safety_level}. Use 'low', 'medium', or 'high'.") def _token_level_voting(self, logits_list: List[torch.Tensor]) -> torch.Tensor: """Combine multiple logit distributions via averaging (soft voting). Seni & Elder: Averaging predictions reduces variance while preserving the signal. This is more robust than hard voting for continuous outputs. Args: logits_list: List of logit tensors [1, 1, vocab_size] Returns: Averaged logits [1, 1, vocab_size] """ # Stack and average (soft voting) stacked = torch.stack(logits_list, dim=0) # [N, 1, 1, V] averaged = stacked.mean(dim=0) # [1, 1, V] return averaged def _detect_hallucination_risk(self, samples: List[torch.Tensor], start_pos: int) -> Tuple[float, List[int]]: """Detect potential hallucination by measuring sample disagreement. High disagreement between diverse samples suggests the model is uncertain or confabulating — a hallucination risk indicator. Args: samples: List of generated token sequences start_pos: Position where generation started (prompt length) Returns: (risk_score, disagreement_positions) risk_score: 0.0 (all agree) to 1.0 (total disagreement) disagreement_positions: List of token positions with disagreement """ min_len = min(s.shape[1] for s in samples) disagreement_positions = [] total_positions = max(1, min_len - start_pos) for pos in range(start_pos, min_len): tokens = set(s[0, pos].item() for s in samples) if len(tokens) > 1: # Disagreement disagreement_positions.append(pos - start_pos) risk_score = len(disagreement_positions) / total_positions return risk_score, disagreement_positions # ================================================================= # Confidence-Based Routing # ================================================================= # Routes queries through optimal generation path based on Hebbian # memory confidence. High-confidence → speculative (fast). # Low-confidence → full verify (accurate). Never skips verification. @torch.no_grad() def generate_with_confidence_routing( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.0, confidence_threshold: float = 0.75, stop_tokens: Optional[list] = None, callback: Optional[callable] = None, ) -> dict: """Generate with automatic routing based on Hebbian memory confidence. Routes through the optimal generation path: - High confidence (Hebbian has strong match) → speculative_generate() for maximum speed. Memory grounds the draft head for higher acceptance. - Low confidence (novel query) → standard generate() with full forward. No speculation overhead, every token fully computed. SAFETY: Both paths produce identical results at temperature=0. Speculation NEVER skips verification — rejected drafts are discarded. Memory provides CONTEXT, not forced output. Args: input_ids: Input token IDs [1, seq_len] max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (0 = greedy) confidence_threshold: Memory confidence cutoff for speculation (0-1) stop_tokens: List of stop token IDs callback: Optional callback(token_id, position) Returns: dict with: 'output': Generated token tensor [1, seq_len + generated] 'route': 'speculative' or 'standard' 'confidence': Hebbian memory confidence score (0-1) 'tok_s': Tokens per second (approximate) """ self.eval() import time as _time # Step 1: Probe Hebbian memory for confidence on this query confidence = 0.0 route = 'standard' if self.hebbian is not None: # Run a quick prefill to get hidden states self.reset_cache() self._current_seq_id = 0 if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False with torch.no_grad(): logits = self.forward(input_ids, use_cache=True, position=0) # Query Hebbian memory with the final hidden state if isinstance(self.hebbian, PerLayerHebbian): last_layer = min(len(self.hebbian.memories) - 1, len(self.layers) - 1) mem = self.hebbian.memories[last_layer] else: mem = self.hebbian # Compute confidence: how strongly does the query match stored memory? # High attention entropy → low confidence (memory unsure) # Low attention entropy → high confidence (memory has strong match) last_hidden = logits # Use logits as proxy (we need the hidden state) # Actually use the layer output directly if hasattr(self, '_eagle_hidden_states') and self._eagle_hidden_states: last_key = max(self._eagle_hidden_states.keys()) query = self._eagle_hidden_states[last_key][:, -1:, :] else: # Fallback: use embedding of last token as query query = self.embed(input_ids[:, -1:]) q = mem.query_proj(query.squeeze(0)) # [1, dim] effective_w = mem.fast_weight if mem.use_consolidation: effective_w = mem.fast_weight + mem.consolidated_weight if mem.cosine_retrieval: q_norm = F.normalize(q, dim=-1) w_norm = F.normalize(effective_w, dim=-1) scores = torch.matmul(q_norm, w_norm.t()) else: scores = torch.matmul(q, effective_w.t()) / (mem.dim ** 0.5) attn = F.softmax(scores, dim=-1) # Confidence = max attention weight (how concentrated the retrieval is) confidence = attn.max().item() # Decide route if confidence >= confidence_threshold: route = 'speculative' # Step 2: Generate using chosen route self.reset_cache() # Fresh start for actual generation t0 = _time.perf_counter() if route == 'speculative' and getattr(self, '_eagle_enabled', False): output = self.speculative_generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, stop_tokens=stop_tokens, callback=callback) else: output = self.generate( input_ids, max_new_tokens=max_new_tokens, temperature=max(temperature, 1e-8), top_k=0, top_p=1.0, stop_tokens=stop_tokens, callback=callback) elapsed = _time.perf_counter() - t0 gen_len = output.shape[1] - input_ids.shape[1] tok_s = gen_len / elapsed if elapsed > 0 else 0 return { 'output': output, 'route': route, 'confidence': confidence, 'tok_s': tok_s, 'gen_len': gen_len, 'elapsed': elapsed, } def enable_l2_prefetch(self): """Enable L2 layer-ahead prefetch for MoE layers. Creates a prefetch stream and uses L2 cache policies to keep next layer's weights hot. During forward, layer N+1's packed MoE buffer gets L2-persistent policy while layer N executes. """ if not hasattr(self, '_prefetch_stream'): self._prefetch_stream = torch.cuda.Stream() self._l2_prefetch_enabled = True self._last_prefetch_idx = -1 print(" [L2 Prefetch] Enabled with async stream") def _prefetch_layer_weights(self, layer_idx: int): """Prefetch layer's MoE weights to L2 using async stream + L2 policy. Called before layer_idx's forward to bring weights closer to compute. Uses cudaAccessPolicyWindow to mark data as L2-persistent. """ if not getattr(self, '_l2_prefetch_enabled', False): return if layer_idx >= len(self.layers): return if layer_idx == getattr(self, '_last_prefetch_idx', -1): return # Already prefetched layer = self.layers[layer_idx] ffn = getattr(layer, 'ffn', None) if ffn is None: return # Get packed MoE buffers packed_gate_up = getattr(ffn, '_packed_gate_up_buffer', None) packed_down = getattr(ffn, '_packed_down_buffer', None) # Apply L2 persistent policy via L2CacheManager if self._l2_manager is not None and packed_gate_up is not None: try: # Pin gate_up buffer (usually larger) if isinstance(self._l2_manager, L2CacheManager): # Unpin previous layer's data first self._l2_manager.unpin(f"moe_gate_up_{layer_idx - 1}") self._l2_manager.unpin(f"moe_down_{layer_idx - 1}") # Pin current layer self._l2_manager.pin(f"moe_gate_up_{layer_idx}", packed_gate_up, hit_ratio=0.9, stream=self._prefetch_stream) if packed_down is not None: self._l2_manager.pin(f"moe_down_{layer_idx}", packed_down, hit_ratio=0.9, stream=self._prefetch_stream) except Exception: pass # L2 pinning is best-effort self._last_prefetch_idx = layer_idx def pack_all_experts(self): """Pack all MoE layers' expert weights into contiguous buffers. Call after loading to enable the packed MoE decode path: - Zero .item() calls (eliminates 384 CPU-GPU syncs per token) - Zero Python weight-collection loops - CUDA-graph-capturable decode path Typical speedup: 10-30% from eliminated Python overhead alone. """ packed_count = 0 for i, layer in enumerate(self.layers): if hasattr(layer.ffn, 'pack_experts'): layer.ffn.pack_experts() if layer.ffn._experts_packed: packed_count += 1 if packed_count > 0: print(f" [Packed MoE] {packed_count} layers packed " f"({packed_count * self.config.num_experts} experts → contiguous)") def enable_int2_cold_experts(self, cold_threshold_pct: float = 0.1, warmup_tokens: int = 100): """Enable INT2 quantization for cold (rarely-routed) experts. Cold experts are demoted from FP4 to INT2, saving 50% bandwidth with minimal quality impact (cold experts contribute <10% of token processing). This is the "coffee filter" optimization: compress rarely-used data more aggressively to reduce bandwidth, while keeping hot data at full precision. Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) warmup_tokens: Required tokens of usage tracking before enabling (default 100) Call after pack_all_experts() and after generating some warmup tokens to establish expert usage patterns. Benefits: - 50% bandwidth reduction for cold experts (2x smaller weights) - ~15-30% overall speedup for MoE decode - Minimal quality impact (<1% perplexity increase) Example: engine.pack_all_experts() engine.generate("warmup prompt", max_new_tokens=100) # establish usage engine.enable_int2_cold_experts(cold_threshold_pct=0.1) """ if not self.config.use_moe: return int2_count = 0 total_cold = 0 for i, layer in enumerate(self.layers): if hasattr(layer.ffn, 'enable_int2_cold_experts'): # Check if layer has enough usage data total_usage = layer.ffn.expert_usage.sum().item() if total_usage < warmup_tokens: continue layer.ffn.enable_int2_cold_experts(cold_threshold_pct) if getattr(layer.ffn, '_int2_enabled', False): int2_count += 1 total_cold += layer.ffn._int2_cold_count if int2_count > 0: avg_cold = total_cold / int2_count print(f" [INT2 Cold] {int2_count} layers enabled, " f"avg {avg_cold:.1f} cold experts/layer ({cold_threshold_pct*100:.0f}% threshold)") print(f" [INT2 Cold] Estimated bandwidth savings: ~{avg_cold/128*50:.1f}% per MoE layer") def get_int2_status(self) -> str: """Get INT2 cold expert status across all layers.""" if not self.config.use_moe: return "INT2: N/A (not MoE)" enabled = 0 total_cold = 0 for layer in self.layers: if hasattr(layer.ffn, '_int2_enabled') and layer.ffn._int2_enabled: enabled += 1 total_cold += layer.ffn._int2_cold_count if enabled == 0: return "INT2 cold experts: disabled" return f"INT2 cold experts: {enabled} layers, {total_cold} total cold experts" def enable_auto_int2_demotion(self, cold_threshold_pct: float = 0.1): """Enable automatic INT2 demotion for cold experts during tier updates. When enabled, the age-adaptive tier system will automatically convert experts marked as cold (tier 0) to INT2 quantization during periodic tier updates. This provides ongoing bandwidth optimization as expert usage patterns stabilize. Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) Benefits: - Automatic: no manual intervention needed - Dynamic: adapts to changing usage patterns - +20% bandwidth savings on cold experts (INT2 vs FP4) Usage: engine.pack_all_experts() engine.enable_auto_int2_demotion(cold_threshold_pct=0.1) # Cold experts auto-demote to INT2 during generation """ if not self.config.use_moe: return enabled_count = 0 for layer in self.layers: if hasattr(layer.ffn, 'enable_auto_int2_demotion'): layer.ffn.enable_auto_int2_demotion(cold_threshold_pct) enabled_count += 1 if enabled_count > 0: print(f" [Auto INT2] Enabled on {enabled_count} layers " f"(cold threshold: {cold_threshold_pct*100:.0f}%)") print(f" [Auto INT2] Cold experts will auto-demote to INT2 during tier updates") def enable_auto_fexc_demotion(self, cold_threshold_pct: float = 0.1): """Enable automatic FE-XC demotion for cold experts during tier updates. FE-XC (FireEcho Xtreme Compress) uses codebook-based 2-bit quantization with CodeGEMM psumbook acceleration. Much higher quality than INT2 at the same 2 bits/weight — near-FP16 quality through learned vector codebooks. Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) Benefits: - Same 2-bit compression as INT2 (50% bandwidth vs FP4) - Much higher quality (codebook quantization vs scalar) - CodeGEMM psumbook: precompute once, reuse across all experts - Dynamic: adapts to changing usage patterns Usage: engine.pack_all_experts() engine.enable_auto_fexc_demotion(cold_threshold_pct=0.1) """ if not self.config.use_moe: return enabled_count = 0 for layer in self.layers: if hasattr(layer.ffn, 'enable_auto_fexc_demotion'): layer.ffn.enable_auto_fexc_demotion(cold_threshold_pct) enabled_count += 1 if enabled_count > 0: print(f" [FE-XC] Enabled on {enabled_count} layers " f"(cold threshold: {cold_threshold_pct*100:.0f}%)") print(f" [FE-XC] Cold experts will auto-demote to codebook 2-bit") def enable_auto_fexvq_demotion(self, cold_threshold_pct: float = 0.1): """Enable FE-XVQ (Hessian-weighted codebook) demotion for cold experts. FE-XVQ uses second-order information from calibration data to produce better codebooks than FE-XC. Same 2-bit storage + inference kernel. Pipeline: 1. Call this method (enables Hessian collection) 2. Run calibration prompts via generate() (accumulates Hessian) 3. Call trigger_fexvq_demotion() (learns codebooks + demotes) Args: cold_threshold_pct: Fraction of experts to mark as cold (default 10%) """ if not self.config.use_moe: return enabled_count = 0 for layer in self.layers: if hasattr(layer.ffn, 'enable_auto_fexvq_demotion'): layer.ffn.enable_auto_fexvq_demotion(cold_threshold_pct) enabled_count += 1 if enabled_count > 0: print(f" [FE-XVQ] Enabled on {enabled_count} layers " f"(cold threshold: {cold_threshold_pct*100:.0f}%)") print(f" [FE-XVQ] Collecting Hessian — run calibration prompts, then call trigger_fexvq_demotion()") def trigger_fexvq_demotion(self): """Trigger FE-XVQ demotion using collected Hessian data. Call after running calibration prompts through generate(). Each MoE layer will learn Hessian-weighted codebooks and demote cold experts. """ if not self.config.use_moe: return total_fexvq = 0 for i, layer in enumerate(self.layers): ffn = layer.ffn if hasattr(ffn, '_maybe_demote_to_fexvq'): ffn._maybe_demote_to_fexvq() n = ffn._expert_is_fexc.sum().item() if hasattr(ffn, '_expert_is_fexc') else 0 if n > 0: total_fexvq += n if i < 3 or i == len(self.layers) - 1: print(f" Layer {i}: {n} experts → FE-XVQ") # Stop Hessian collection after demotion ffn._hessian_collecting = False if total_fexvq > 0: print(f" [FE-XVQ] Total: {total_fexvq} experts demoted to Hessian-weighted codebook 2-bit") # ================================================================= # FE-AGK: Atlas the Gatekeeper — Engine-level controls # ================================================================= def enable_atlas(self, ban_threshold: float = 0.01, modes_threshold: float = 2.0): """Enable FE-AGK (Atlas the Gatekeeper) on all MoE layers. Ban & Pick: Profiles expert impact, bans low-impact experts. Reduces active experts from 8→~5 (1.25x throughput). MoDES: Skips MoE for easy tokens where max router prob < modes_threshold × uniform. Usage: engine.enable_atlas() engine.atlas_profile(prompts) # Profile expert impact engine.atlas_ban() # Ban low-impact experts # Now generate() uses fewer experts + skips easy tokens """ if not self.config.use_moe: return count = 0 for layer in self.layers: if hasattr(layer.ffn, 'enable_atlas'): layer.ffn.enable_atlas(ban_threshold, modes_threshold) count += 1 print(f" [FE-AGK] Atlas the Gatekeeper enabled on {count} layers") def atlas_profile(self, tokenizer, prompts: list = None, num_prompts: int = 100): """Profile expert impact across diverse prompts for Ban & Pick. Runs prompts through the model, accumulating |output| * weight per expert. After profiling, call atlas_ban() to ban low-impact ones. """ if prompts is None: prompts = [ "Explain quantum computing in simple terms.", "Write a Python function to sort a list.", "What is the meaning of life?", "Translate 'hello world' to Japanese.", "Describe the architecture of a modern CPU.", "Write a haiku about programming.", "How does a neural network learn?", "Explain the difference between TCP and UDP.", "What causes gravity?", "Write a recursive Fibonacci function.", ] * (num_prompts // 10 + 1) prompts = prompts[:num_prompts] # Start profiling on all layers for layer in self.layers: if hasattr(layer.ffn, 'atlas_start_profiling'): layer.ffn.atlas_start_profiling() print(f" [FE-AGK] Profiling {len(prompts)} prompts...") for i, prompt in enumerate(prompts): ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") self.reset_cache() with torch.no_grad(): self.generate(ids, max_new_tokens=32, temperature=0.0, top_k=0, top_p=1.0) if (i + 1) % 20 == 0: print(f" {i+1}/{len(prompts)} profiled") print(f" [FE-AGK] Profiling complete.") def atlas_ban(self, ban_ratio: float = 0.25): """Ban low-impact experts based on profiling data. Args: ban_ratio: fraction of experts to ban (0.25 = bottom 25%) """ total_banned = 0 for i, layer in enumerate(self.layers): if hasattr(layer.ffn, 'atlas_finish_profiling'): layer.ffn.atlas_finish_profiling(ban_ratio) n = layer.ffn._atlas_banned.sum().item() layer.ffn._atlas_has_bans = (n > 0) total_banned += n avg = total_banned / max(len(self.layers), 1) print(f" [FE-AGK] Total: {total_banned} expert bans across " f"{len(self.layers)} layers (avg {avg:.1f}/layer)") effective_k = self.config.num_experts_per_tok * (1 - ban_ratio) print(f" [FE-AGK] Effective experts/token: ~{effective_k:.1f} " f"(was {self.config.num_experts_per_tok})") def atlas_stats(self): """Print Atlas diagnostics for all layers.""" for i, layer in enumerate(self.layers): if hasattr(layer.ffn, 'atlas_get_stats'): stats = layer.ffn.atlas_get_stats() if i < 3 or i == len(self.layers) - 1: print(f" Layer {i}: {stats}") # ================================================================= # FireEcho CUDA Graph Decode — graph-accelerated autoregressive engine # ================================================================= # Captures the full 48-layer decode forward as a single CUDA graph. # Eliminates ~5-7ms of Python interpreter overhead per token. # Position-dependent ops handled via GPU side buffers: # - RoPE: pre-loaded cos/sin from side buffers # - KV write: scatter_ with GPU position tensor # - KV read: full-length view + attention mask # All updates between replays are GPU-to-GPU — zero CPU-GPU sync. def enable_cuda_graph_decode(self, max_seq_len: int = 4096, batch_sizes: list = None): """Enable CUDA graph-accelerated single-token decode. Call after loading + packing experts. Graph is captured on first generate() call after prefill. Args: max_seq_len: Maximum sequence length to support batch_sizes: List of batch sizes to pre-capture graphs for. Default [1] for single-sequence decode. Use [1, 2, 4] for batched speculation support. """ if batch_sizes is None: batch_sizes = [1] # Ensure flat decode is enabled if not getattr(self.kv_cache, '_flat_mode', False): self.kv_cache.enable_flat_decode(max_seq_len) # Set up graph-mode side buffers on KV cache self.kv_cache.enable_cuda_graph(max_seq_len) # Pre-extend RoPE tables to max_seq_len (avoid resize during graph) for layer in self.layers: layer.attn._ensure_rope_length(max_seq_len + 1) # Static I/O buffers for each batch size self._graph_static_inputs = {} self._graph_static_outputs = {} self._decode_graphs = {} self._graph_captured_batches = set() for bs in batch_sizes: self._graph_static_inputs[bs] = torch.zeros( bs, 1, dtype=torch.long, device='cuda') self._graph_static_outputs[bs] = None # Legacy single-batch support self._graph_static_input = self._graph_static_inputs.get(1) self._graph_static_output = None self._graph_max_seq_len = max_seq_len self._graph_batch_sizes = batch_sizes self._decode_graph = None # Legacy, use _decode_graphs[1] instead self._graph_enabled = True self._graph_captured = False print(f" [CUDA Graph] Decode engine ready (max {max_seq_len} tokens, batches: {batch_sizes})") def _capture_decode_graph(self, capture_position: int, batch_size: int = 1): """Capture one decode forward as a CUDA graph. Runs 3 warmup iterations on a side stream (settles Triton autotune + CUDA memory allocator), then captures the forward pass. Args: capture_position: KV cache position to capture at batch_size: Batch size for this graph (default 1) """ rope_cos = self.layers[0].attn.rope_cos rope_sin = self.layers[0].attn.rope_sin # Get or create static buffers for this batch size if batch_size not in self._graph_static_inputs: self._graph_static_inputs[batch_size] = torch.zeros( batch_size, 1, dtype=torch.long, device='cuda') self._graph_static_outputs[batch_size] = None static_input = self._graph_static_inputs[batch_size] # Activate graph mode on KV cache self.kv_cache._graph_mode = True # Prepare side buffers for capture position self.kv_cache.prepare_graph_step(capture_position, rope_cos, rope_sin) static_input.fill_(1) # dummy token # Warmup on side stream s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): self.kv_cache.prepare_graph_step( capture_position, rope_cos, rope_sin) _ = self.forward( static_input, use_cache=True, position=0) torch.cuda.current_stream().wait_stream(s) # Capture — set flag to skip .item() calls in MoEFFN (not allowed during capture) self.kv_cache.prepare_graph_step(capture_position, rope_cos, rope_sin) for layer in self.layers: if hasattr(layer, 'ffn'): layer.ffn._graph_capturing = True graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): self._graph_static_outputs[batch_size] = self.forward( static_input, use_cache=True, position=0) for layer in self.layers: if hasattr(layer, 'ffn'): layer.ffn._graph_capturing = False # Store in batch-indexed dict self._decode_graphs[batch_size] = graph self._graph_captured_batches.add(batch_size) # Legacy single-batch support if batch_size == 1: self._decode_graph = graph self._graph_static_input = static_input self._graph_static_output = self._graph_static_outputs[1] self._graph_captured = True # Clean up KV cache entries written during warmup + capture num_cleanup = 4 # 3 warmup + 1 capture for pos in range(capture_position, min(capture_position + num_cleanup, self._graph_max_seq_len)): self.kv_cache.flat_k[:, :, pos, :] = 0 self.kv_cache.flat_v[:, :, pos, :] = 0 self.kv_cache._graph_attn_mask[0, 0, 0, pos] = float('-inf') print(f" [CUDA Graph] Captured batch_size={batch_size} at position {capture_position}") def _graph_decode_step(self, input_token: torch.Tensor, position: int, batch_size: int = None) -> torch.Tensor: """Replay CUDA graph for one decode step. Updates side buffers (GPU-to-GPU), then replays the captured graph. Returns logits from the static output buffer. Args: input_token: Token(s) to decode [B, 1] or [1, 1] position: KV cache position batch_size: Explicit batch size (auto-detected if None) """ # Auto-detect batch size from input if batch_size is None: batch_size = input_token.shape[0] if input_token.dim() > 1 else 1 # Check if we have a graph for this batch size if batch_size not in self._decode_graphs: # Capture on demand if batch size is in supported list if hasattr(self, '_graph_batch_sizes') and batch_size in self._graph_batch_sizes: self._capture_decode_graph(position, batch_size) else: # Fall back to eager mode for unsupported batch sizes return self.forward(input_token.view(batch_size, 1), use_cache=True, position=position) # Copy input to static buffer static_input = self._graph_static_inputs[batch_size] static_input.copy_(input_token.view(batch_size, 1)) rope_cos = self.layers[0].attn.rope_cos rope_sin = self.layers[0].attn.rope_sin self.kv_cache.prepare_graph_step(position, rope_cos, rope_sin) self._decode_graphs[batch_size].replay() return self._graph_static_outputs[batch_size] def capture_batched_graphs(self, capture_position: int = 100): """Pre-capture CUDA graphs for all configured batch sizes. Call after enable_cuda_graph_decode() and prefill to pre-capture all batch sizes, avoiding capture overhead during generation. Args: capture_position: KV cache position to capture at (default 100) """ if not hasattr(self, '_graph_batch_sizes'): print(" [CUDA Graph] No batch sizes configured. Call enable_cuda_graph_decode() first.") return for bs in self._graph_batch_sizes: if bs not in self._graph_captured_batches: self._capture_decode_graph(capture_position, bs) print(f" [CUDA Graph] Pre-captured {len(self._graph_batch_sizes)} batch sizes: {self._graph_batch_sizes}") def reset_cache(self): """Reset KV cache and generation state.""" if self.kv_cache is not None: self.kv_cache.clear() # Reset graph attention mask if CUDA graph is enabled if hasattr(self.kv_cache, '_graph_attn_mask'): self.kv_cache._graph_attn_mask.fill_(float('-inf')) self.kv_cache._graph_mode = False self._current_seq_id = 0 self._current_position = 0 if self.hebbian is not None: self.hebbian.reset() # ================================================================= # FireEcho EAGLE-3 — Speculative decoding with multi-layer fusion # ================================================================= def enable_eagle(self, capture_layers: Tuple[int, ...] = (8, 24, 47), num_heads: int = 16, ffn_mult: int = 2, draft_depth: int = 5, num_head_layers: int = 2, checkpoint_path: Optional[str] = None, num_medusa_heads: int = 0): """Enable EAGLE-3 / FE-XT speculative decoding. Creates a lightweight draft head that fuses hidden states from multiple target model layers to predict future tokens. Args: capture_layers: which target model layers to capture hidden states from num_heads: attention heads in draft head ffn_mult: FFN expansion factor in draft head draft_depth: default number of tokens to draft per round num_head_layers: number of self-attention + FFN layers in draft head (D) checkpoint_path: optional path to load draft head weights num_medusa_heads: number of Medusa parallel prediction heads (0=disabled) """ self._eagle_capture_layers = list(capture_layers) self._eagle_capture_set = set(capture_layers) self._eagle_hidden_states = {} self._eagle_enabled = True self._eagle_draft_depth = draft_depth # Create draft head self.eagle_head = FireEchoEagleHead( dim=self.config.dim, num_capture_layers=len(capture_layers), num_heads=num_heads, ffn_mult=ffn_mult, num_layers=num_head_layers, num_medusa_heads=num_medusa_heads, ).to(dtype=torch.bfloat16, device='cuda') # Share lm_head with main model (no 600MB duplication) self.eagle_head.lm_head = self.lm_head # Load checkpoint if provided if checkpoint_path is not None: ckpt = torch.load(checkpoint_path, map_location='cuda', weights_only=True) sd = ckpt.get('eagle_head', ckpt) # Detect legacy format (has 'norm1.weight' instead of 'layers.0...') is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) if is_legacy: self.eagle_head.load_legacy_checkpoint(sd) else: missing, _ = self.eagle_head.load_state_dict(sd, strict=False) if missing: n_new = sum(1 for k in missing if k.startswith('layers.')) if n_new: print(f" [EAGLE] Loaded checkpoint. " f"{n_new} new layer params (randomly initialized).") # Hebbian memory connection: draft head can query stored knowledge self._eagle_use_hebbian = (self.hebbian is not None) # Stats param_count = sum(p.numel() for p in self.eagle_head.parameters() if p is not self.lm_head.weight) vram_mb = param_count * 2 / 1e6 # BF16 hebb_str = " + Hebbian memory" if self._eagle_use_hebbian else "" head_name = "FE-XT" if num_head_layers > 2 else "EAGLE-3" print(f" [{head_name}] Draft head: D={num_head_layers}, " f"{param_count/1e6:.1f}M params, " f"{vram_mb:.0f} MB, capture layers {list(capture_layers)}" f"{hebb_str}") def enable_hayabusa(self, cutoff_L: int = 4, gpu_budget_mb: float = 400.0): """Enable FE-H (FireEcho Hayabusa) async prefetch offload for draft head. Offloads draft head layers >= cutoff_L to CPU pinned memory. During verify phase, these are prefetched asynchronously via PCIe. Must be called AFTER enable_eagle(). Args: cutoff_L: Layers 0..cutoff_L-1 stay on GPU (hot). Layers cutoff_L..D-1 are offloaded (cold). gpu_budget_mb: Max GPU memory for cold layer buffer (MB). """ if not getattr(self, '_eagle_enabled', False): raise RuntimeError("Call enable_eagle() before enable_hayabusa()") if len(self.eagle_head.layers) <= cutoff_L: print(f" [FE-H] Draft head has only {len(self.eagle_head.layers)} layers. " f"Nothing to offload (cutoff_L={cutoff_L}).") return self._hayabusa = FireEchoHayabusa( self.eagle_head, cutoff_L=cutoff_L, gpu_budget_mb=gpu_budget_mb) self._hayabusa_enabled = True # Set on eagle_head so forward() can access it for per-layer sync self.eagle_head._hayabusa = self._hayabusa def _get_eagle_memory_context(self, hidden_state: torch.Tensor) -> Optional[torch.Tensor]: """Query Hebbian memory for EAGLE draft head context. Retrieval only — no update. Speculative tokens may be rejected, so we must not contaminate the memory bank. Args: hidden_state: [B, 1, dim] from last capture layer of target model Returns: memory_context: [B, 1, dim] or None if Hebbian not available """ if not getattr(self, '_eagle_use_hebbian', False) or self.hebbian is None: return None if isinstance(self.hebbian, PerLayerHebbian): # Use memory at the last capture layer (deepest, most refined) last_layer = self._eagle_capture_layers[-1] if last_layer < len(self.hebbian.memories): return self.hebbian.retrieve_only(hidden_state, last_layer) else: # Fallback: use last available layer's memory return self.hebbian.retrieve_only( hidden_state, len(self.hebbian.memories) - 1) else: # Shared Hebbian: single memory bank return self.hebbian.retrieve_only(hidden_state) # ================================================================ # FE-XT: Tree Verification Forward Pass (read-only KV) # ================================================================ def _tree_verify_layer(self, x: torch.Tensor, layer_idx: int, layer: 'FusedTransformerBlock', position_ids: torch.Tensor, tree_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: """Single transformer layer forward for FE-XT tree verification. Read-only KV: reads prefix from flat cache, does NOT write draft tokens. Uses per-token position_ids for RoPE (branches share same positions). Applies tree attention mask (block-diagonal causal). Args: x: [1, M, dim] hidden states for M draft tokens layer_idx: Which transformer layer (0-47) layer: FusedTransformerBlock module position_ids: [M] per-token RoPE positions tree_mask: [1, 1, M, prefix_len + M] additive attention mask prefix_len: Number of tokens in KV cache Returns: x: [1, M, dim] updated hidden states """ attn = layer.attn ffn = layer.ffn B, M, D = x.shape # B=1 # Pre-norm (unfused — tree verify is infrequent, simplicity over speed) _norm1_eps = layer.norm1.eps if layer.norm1.eps is not None else 1e-6 variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x_normed = x * torch.rsqrt(variance + _norm1_eps).to(x.dtype) * layer.norm1.weight # QKV projections q = attn.q_proj(x_normed).view(B, M, attn.num_heads, attn.head_dim).transpose(1, 2) k = attn.k_proj(x_normed).view(B, M, attn.num_kv_heads, attn.head_dim).transpose(1, 2) v = attn.v_proj(x_normed).view(B, M, attn.num_kv_heads, attn.head_dim).transpose(1, 2) # QK norm (Qwen3 per-head) if attn.use_qk_norm: if attn.qk_norm_per_head: q = attn.q_norm(q) k = attn.k_norm(k) else: Bq, nH, Sq, hd = q.shape q = attn.q_norm(q.transpose(1, 2).reshape(Bq, Sq, nH * hd)).reshape(Bq, Sq, nH, hd).transpose(1, 2) Bk, nKH, Sk, hdk = k.shape k = attn.k_norm(k.transpose(1, 2).reshape(Bk, Sk, nKH * hdk)).reshape(Bk, Sk, nKH, hdk).transpose(1, 2) # RoPE with per-token position IDs (critical for tree branches) attn._ensure_rope_length(int(position_ids.max().item()) + 1) q = _apply_rotary_emb_ids(q, attn.rope_cos, attn.rope_sin, position_ids) k = _apply_rotary_emb_ids(k, attn.rope_cos, attn.rope_sin, position_ids) if q.dtype != v.dtype: q = q.to(v.dtype) k = k.to(v.dtype) # Get prefix KV from flat cache (READ-ONLY, no writes) k_prefix, v_prefix = self.kv_cache.get_flat_view(layer_idx, prefix_len) k_prefix = k_prefix.unsqueeze(0) # [1, kv_heads, prefix_len, head_dim] v_prefix = v_prefix.unsqueeze(0) # Concat: prefix KV + draft KV → [1, kv_heads, prefix_len + M, head_dim] k_full = torch.cat([k_prefix, k], dim=2) v_full = torch.cat([v_prefix, v], dim=2) # GQA expansion for attention if attn.num_kv_heads < attn.num_heads: repeat_factor = attn.num_heads // attn.num_kv_heads k_full = k_full.repeat_interleave(repeat_factor, dim=1) v_full = v_full.repeat_interleave(repeat_factor, dim=1) # Attention with tree mask (SDPA) out = F.scaled_dot_product_attention( q, k_full, v_full, attn_mask=tree_mask.to(q.dtype), scale=attn.scale) # Output projection + residual out = out.transpose(1, 2).contiguous().view(B, M, -1) out = attn.out_proj(out) x = x + out # FFN: norm2 + MoE (uses M>1 packed path from Task #181) _norm2_eps = layer.norm2.eps if layer.norm2.eps is not None else 1e-6 variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) x_normed2 = x * torch.rsqrt(variance + _norm2_eps).to(x.dtype) * layer.norm2.weight x = x + ffn(x_normed2) return x @torch.no_grad() def _tree_verify_forward(self, draft_token_ids: torch.Tensor, position_ids: torch.Tensor, tree_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: """FE-XT read-only tree verification forward pass. Processes M draft tokens through all transformer layers WITHOUT modifying the main KV cache. Uses temporary KV for draft tokens, reads existing prefix KV from flat cache. After verification, the winning branch's tokens are replayed into the KV cache by the caller (speculative_generate_xturbo). Args: draft_token_ids: [1, M] token IDs for all tree branches position_ids: [M] per-token RoPE positions tree_mask: [1, 1, M, prefix_len + M] additive mask (0 or -inf) prefix_len: Number of tokens in KV cache (the shared prefix) Returns: logits: [1, M, vocab_size] for all M draft tokens """ B, M = draft_token_ids.shape # B=1 # Embed x = self.embed(draft_token_ids) # [1, M, dim] # All 48 layers (read-only KV) for i, layer in enumerate(self.layers): # L2 prefetch during verify (same as regular forward) if getattr(self, '_l2_prefetch_enabled', False) and i + 1 < len(self.layers): self._prefetch_layer_weights(i + 1) x = self._tree_verify_layer(x, i, layer, position_ids, tree_mask, prefix_len) # EAGLE-3: capture hidden states at selected layers # (needed for next draft round after acceptance) if getattr(self, '_eagle_enabled', False) and i in self._eagle_capture_set: self._eagle_hidden_states[i] = x # Final norm + lm_head x = self.norm(x) logits = self.lm_head(x) return logits def _replay_accepted_branch(self, accepted_token_ids: torch.Tensor, start_pos: int): """Replay accepted branch tokens into the main KV cache. After tree verification identifies the winning branch, we must store the accepted tokens' KV entries into the flat cache so subsequent generation can attend to them. This runs a STANDARD forward pass (with KV writes) for the accepted tokens only. Costs one extra forward of `accepted_count` tokens, but with the M>1 packed MoE path this is fast. Args: accepted_token_ids: [1, A] token IDs for accepted tokens start_pos: KV cache position to start writing from """ if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False self.forward(accepted_token_ids, use_cache=True, position=start_pos) @torch.no_grad() def speculative_generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.0, draft_depth: int = None, stop_tokens: Optional[list] = None, callback: Optional[callable] = None, use_tree: bool = False, num_branches: int = 2, ) -> torch.Tensor: """Autoregressive generation with EAGLE-3 speculative decoding. Each round: decode 1 token → draft K tokens → verify K tokens → accept matches → rollback rejected → repeat. For temperature=0 (greedy), uses exact token matching. Preserves output distribution (identical to non-speculative decode). """ if not getattr(self, '_eagle_enabled', False): raise RuntimeError("Call enable_eagle() before speculative_generate()") self.eval() B = input_ids.shape[0] if B != 1: raise ValueError("Speculative decoding requires B=1") prompt_len = input_ids.shape[1] depth = draft_depth or self._eagle_draft_depth # Reset and prefill self.reset_cache() self._current_seq_id = 0 if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(input_ids, use_cache=True, position=0) current_pos = prompt_len if stop_tokens is None: stop_tokens = [199999, 200020] stop_set = set(stop_tokens) if stop_tokens else set() generated = input_ids total_drafted = 0 total_accepted = 0 total_rounds = 0 while generated.shape[1] - prompt_len < max_new_tokens: # ---- DECODE: sample next token from main model logits ---- if temperature <= 1e-8: next_token = logits[:, -1:, :].argmax(dim=-1) else: next_logits = logits[:, -1, :] / temperature probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() in stop_set: break # Process next_token through main model (stores KV, captures features) if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(next_token, use_cache=True, position=current_pos) current_pos += 1 if callback is not None: callback(next_token.item(), current_pos - 1) # Check remaining budget remaining = max_new_tokens - (generated.shape[1] - prompt_len) if remaining <= 0: break # ---- DRAFT: generate K candidate tokens ---- features = [self._eagle_hidden_states[l] for l in self._eagle_capture_layers] K = min(depth, remaining) # Query Hebbian memory for grounding context (retrieve only, no update) memory_ctx = self._get_eagle_memory_context( self._eagle_hidden_states[self._eagle_capture_layers[-1]]) if use_tree: # Tree drafting: generate multiple paths, pick best paths = self.eagle_head.generate_draft_tree( features, next_token, self.embed, depth=K, num_branches=num_branches, memory_context=memory_ctx) else: # Linear drafting: single greedy path dt, dl = self.eagle_head.generate_draft( features, next_token, self.embed, depth=K, memory_context=memory_ctx) paths = [(dt, dl)] # Save KV cache state before verification (for rollback between paths) kv_pos_before_verify = current_pos main_pred = logits[:, -1, :].argmax(dim=-1).item() best_accepted = 0 best_path_idx = 0 best_verify_logits = None best_draft_tokens = None for path_idx, (draft_tokens, draft_logits_list) in enumerate(paths): if not draft_tokens: continue if path_idx > 0: # Rollback KV to pre-verify state for secondary paths rollback_count = len(paths[path_idx - 1][0]) if rollback_count > 0: self.kv_cache.rollback_to(kv_pos_before_verify, rollback_count) # VERIFY: main model processes draft tokens at once draft_input = torch.cat(draft_tokens, dim=1) verify_logits = self.forward( draft_input, use_cache=True, position=current_pos) # ACCEPT/REJECT accepted = 0 if draft_tokens[0].item() == main_pred: accepted = 1 for i in range(1, len(draft_tokens)): target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item() if draft_tokens[i].item() == target_pred: accepted += 1 else: break if accepted > best_accepted: best_accepted = accepted best_path_idx = path_idx best_verify_logits = verify_logits best_draft_tokens = draft_tokens # If this path got all tokens accepted, no need to try more if accepted == len(draft_tokens): break # If not using tree, skip to commit if not use_tree: best_verify_logits = verify_logits best_draft_tokens = draft_tokens break # If we ended on a non-best path, rollback and re-verify best if use_tree and best_path_idx < len(paths) - 1: rollback_count = len(paths[-1][0]) if paths[-1][0] else 0 if rollback_count > 0: self.kv_cache.rollback_to(kv_pos_before_verify, rollback_count) draft_input = torch.cat(best_draft_tokens, dim=1) best_verify_logits = self.forward( draft_input, use_cache=True, position=current_pos) draft_tokens = best_draft_tokens or [] verify_logits = best_verify_logits accepted = best_accepted if not draft_tokens: continue total_drafted += len(draft_tokens) total_rounds += 1 total_accepted += accepted # ---- COMMIT accepted tokens ---- for i in range(accepted): generated = torch.cat([generated, draft_tokens[i]], dim=1) if callback is not None: callback(draft_tokens[i].item(), current_pos + i) if draft_tokens[i].item() in stop_set: return generated # ---- CORRECTION: use target model's choice at rejection point ---- if accepted < len(draft_tokens): if accepted == 0: correction = torch.tensor( [[main_pred]], device=generated.device) else: correction = verify_logits[:, accepted - 1, :].argmax( dim=-1, keepdim=True) generated = torch.cat([generated, correction], dim=1) if callback is not None: callback(correction.item(), current_pos + accepted) if correction.item() in stop_set: return generated # Rollback KV cache: remove rejected draft positions rollback_pos = current_pos + accepted + 1 rollback_count = len(draft_tokens) - accepted - 1 if rollback_count > 0: self.kv_cache.rollback_to(rollback_pos, rollback_count) current_pos = rollback_pos logits = self.forward( correction, use_cache=True, position=current_pos - 1) else: # All K accepted — bonus token from last verify logits bonus = verify_logits[:, -1, :].argmax(dim=-1, keepdim=True) generated = torch.cat([generated, bonus], dim=1) if callback is not None: callback(bonus.item(), current_pos + accepted) if bonus.item() in stop_set: return generated current_pos += accepted + 1 logits = self.forward( bonus, use_cache=True, position=current_pos - 1) if total_rounds > 0: avg_accept = total_accepted / total_rounds accept_rate = total_accepted / max(total_drafted, 1) print(f" [EAGLE-3] {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%}), avg {avg_accept:.1f}/round") return generated # ================================================================ # Medusa Speculative Decoding — Single-pass K+1 draft # ================================================================ @torch.no_grad() def speculative_generate_medusa(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.0, stop_tokens: Optional[list] = None, callback: Optional[callable] = None, ) -> torch.Tensor: """Autoregressive generation with Medusa speculative decoding. Unlike EAGLE (K sequential draft steps), Medusa runs the eagle backbone ONCE and uses K independent heads to predict K+1 future positions. Draft cost: 1 × backbone_time (vs K × backbone_time for EAGLE). Requires eagle head to have been created with num_medusa_heads > 0. """ if not getattr(self, '_eagle_enabled', False): raise RuntimeError("Call enable_eagle() before speculative_generate_medusa()") if self.eagle_head.medusa_heads is None: raise RuntimeError("Eagle head has no medusa heads (num_medusa_heads=0)") self.eval() B = input_ids.shape[0] if B != 1: raise ValueError("Medusa speculative decoding requires B=1") K = self.eagle_head.num_medusa_heads # Number of extra heads (K+1 total predictions) prompt_len = input_ids.shape[1] # Reset and prefill self.reset_cache() self._current_seq_id = 0 if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(input_ids, use_cache=True, position=0) current_pos = prompt_len if stop_tokens is None: stop_tokens = [199999, 200020] stop_set = set(stop_tokens) if stop_tokens else set() generated = input_ids total_drafted = 0 total_accepted = 0 total_rounds = 0 while generated.shape[1] - prompt_len < max_new_tokens: # ---- DECODE: sample next token from main model logits ---- if temperature <= 1e-8: next_token = logits[:, -1:, :].argmax(dim=-1) else: next_logits = logits[:, -1, :] / temperature probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() in stop_set: break # Process next_token through main model (stores KV, captures features) if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(next_token, use_cache=True, position=current_pos) current_pos += 1 if callback is not None: callback(next_token.item(), current_pos - 1) remaining = max_new_tokens - (generated.shape[1] - prompt_len) if remaining <= 0: break # ---- MEDUSA DRAFT: single forward → K+1 predictions ---- features = [self._eagle_hidden_states[l] for l in self._eagle_capture_layers] memory_ctx = self._get_eagle_memory_context( self._eagle_hidden_states[self._eagle_capture_layers[-1]]) draft_tokens, draft_logits = self.eagle_head.medusa_draft( features, next_token, self.embed, memory_context=memory_ctx) # Limit to remaining budget num_draft = min(len(draft_tokens), remaining) draft_tokens = draft_tokens[:num_draft] draft_logits = draft_logits[:num_draft] if not draft_tokens: continue total_drafted += len(draft_tokens) total_rounds += 1 # Save KV position for potential rollback kv_pos_before = current_pos # ---- VERIFY: run all draft tokens through main model at once ---- draft_input = torch.cat(draft_tokens, dim=1) # [1, num_draft] verify_logits = self.forward( draft_input, use_cache=True, position=current_pos) # ---- ACCEPT/REJECT: check sequential prefix match ---- main_pred = logits[:, -1, :].argmax(dim=-1).item() accepted = 0 if draft_tokens[0].item() == main_pred: accepted = 1 for i in range(1, len(draft_tokens)): target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item() if draft_tokens[i].item() == target_pred: accepted += 1 else: break total_accepted += accepted # ---- COMMIT accepted tokens ---- for i in range(accepted): generated = torch.cat([generated, draft_tokens[i]], dim=1) if callback is not None: callback(draft_tokens[i].item(), current_pos + i) if draft_tokens[i].item() in stop_set: if total_rounds > 0: accept_rate = total_accepted / max(total_drafted, 1) print(f" [Medusa] {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%})") return generated # ---- CORRECTION at rejection point ---- if accepted < len(draft_tokens): if accepted == 0: correction = torch.tensor( [[main_pred]], device=generated.device) else: correction = verify_logits[:, accepted - 1, :].argmax( dim=-1, keepdim=True) generated = torch.cat([generated, correction], dim=1) if callback is not None: callback(correction.item(), current_pos + accepted) if correction.item() in stop_set: if total_rounds > 0: accept_rate = total_accepted / max(total_drafted, 1) print(f" [Medusa] {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%})") return generated # Rollback KV: remove rejected draft positions rollback_pos = current_pos + accepted + 1 rollback_count = len(draft_tokens) - accepted - 1 if rollback_count > 0: self.kv_cache.rollback_to(rollback_pos, rollback_count) current_pos = rollback_pos logits = self.forward( correction, use_cache=True, position=current_pos - 1) else: # All K+1 accepted — bonus token from last verify logits bonus = verify_logits[:, -1, :].argmax(dim=-1, keepdim=True) generated = torch.cat([generated, bonus], dim=1) if callback is not None: callback(bonus.item(), current_pos + accepted) if bonus.item() in stop_set: if total_rounds > 0: accept_rate = total_accepted / max(total_drafted, 1) print(f" [Medusa] {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%})") return generated current_pos += accepted + 1 logits = self.forward( bonus, use_cache=True, position=current_pos - 1) if total_rounds > 0: avg_accept = total_accepted / total_rounds accept_rate = total_accepted / max(total_drafted, 1) print(f" [Medusa] {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%}), avg {avg_accept:.1f}/round") return generated # ================================================================ # FE-XT: FireEcho Xturbo — Tree Speculative Decoding # ================================================================ @torch.no_grad() def speculative_generate_xturbo( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.0, draft_depth: int = 5, num_branches: int = 8, stop_tokens: Optional[list] = None, callback: Optional[callable] = None, dynamic_b: bool = False, b_min: int = 2, b_max: int = 16, ) -> torch.Tensor: """FE-XT: Tree-structured speculative decoding with b parallel branches. Based on Scylla (arxiv 2505.07858). Instead of drafting 1 linear path, drafts b parallel paths of depth d using generate_draft_tree_batched, then verifies ALL b*d tokens in ONE read-only forward pass. Key innovations over linear speculative_generate(): - Tree parallelism: b branches explored in parallel (b is STRONGEST lever) - Read-only verify: KV cache untouched during verification, no rollback needed - Best-branch selection: picks branch with most accepted tokens - M>1 packed MoE: verification uses batched MoE kernel (not per-expert loops) - Dynamic b tuning: adapt branch count based on acceptance (Scylla Eq.4) Scylla Theorem 1.3: Throughput = 286.79 × log₂(b) + 7.54 At b=8, M=40: approaches J_crit for RTX 5090 → 35%+ bandwidth utilization. Args: input_ids: [1, seq_len] prompt token IDs max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (0 = greedy) draft_depth: Tokens per branch (d). Total verify = b * d tokens. num_branches: Number of tree branches (b). Initial b if dynamic_b=True. stop_tokens: Stop token IDs (default: Qwen3 EOS tokens) callback: Optional callback(token_id, position) per accepted token dynamic_b: If True, auto-tune b based on acceptance rate each round. b_min: Minimum branch count when dynamic_b=True. b_max: Maximum branch count when dynamic_b=True. Returns: generated: [1, prompt_len + generated_len] token IDs """ if not getattr(self, '_eagle_enabled', False): raise RuntimeError("Call enable_eagle() before speculative_generate_xturbo()") self.eval() B = input_ids.shape[0] if B != 1: raise ValueError("FE-XT requires B=1") prompt_len = input_ids.shape[1] depth = draft_depth b = num_branches # Dynamic b tuning state (Scylla Eq.4 inspired) # Track acceptance rate over sliding window; adjust b to maximize throughput. # High acceptance → increase b (more speculative parallelism pays off) # Low acceptance → decrease b (reduce wasted verify compute) _DYNAMIC_B_WINDOW = 8 # sliding window size _DYNAMIC_B_HIGH_THRESH = 0.50 # acceptance > 50% → try more branches _DYNAMIC_B_LOW_THRESH = 0.15 # acceptance < 15% → try fewer branches accept_history = [] # stores acceptance_rate per round # Reset and prefill self.reset_cache() self._current_seq_id = 0 if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(input_ids, use_cache=True, position=0) current_pos = prompt_len if stop_tokens is None: stop_tokens = [199999, 200020] stop_set = set(stop_tokens) if stop_tokens else set() generated = input_ids total_drafted = 0 total_accepted = 0 total_rounds = 0 while generated.shape[1] - prompt_len < max_new_tokens: # ---- DECODE: sample next token from target model logits ---- if temperature <= 1e-8: next_token = logits[:, -1:, :].argmax(dim=-1) else: next_logits = logits[:, -1, :] / temperature probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() in stop_set: break # Forward next_token through target model (stores KV, captures features) if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(next_token, use_cache=True, position=current_pos) current_pos += 1 if callback is not None: callback(next_token.item(), current_pos - 1) # Budget check remaining = max_new_tokens - (generated.shape[1] - prompt_len) if remaining <= 0: break # ---- FE-XT DRAFT: tree of b branches × d depth ---- features = [self._eagle_hidden_states[l] for l in self._eagle_capture_layers] K = min(depth, remaining) memory_ctx = self._get_eagle_memory_context( self._eagle_hidden_states[self._eagle_capture_layers[-1]]) tree_tokens = self.eagle_head.generate_draft_tree_batched( features, next_token, self.embed, depth=K, b=b, memory_context=memory_ctx) # tree_tokens: list of b lists, each containing K [1, 1] tensors # ---- FE-H: evict cold layers after draft (free GPU memory) ---- if getattr(self, '_hayabusa_enabled', False): self._hayabusa.evict_cold_layers() # ---- BUILD TREE INPUTS ---- M = b * K flat_ids = torch.zeros(1, M, dtype=torch.long, device=input_ids.device) for bi in range(b): for si in range(K): flat_ids[0, bi * K + si] = tree_tokens[bi][si].item() position_ids = build_tree_position_ids(b, K, current_pos, device=input_ids.device) tree_mask = build_tree_attention_mask(b, K, current_pos, device=input_ids.device) # ---- FE-H: start async prefetch during verify (free PCIe bandwidth) ---- if getattr(self, '_hayabusa_enabled', False): self._hayabusa.async_prefetch() # ---- FE-XT VERIFY: read-only forward for all M tokens ---- verify_logits = self._tree_verify_forward( flat_ids, position_ids, tree_mask, prefix_len=current_pos) # verify_logits: [1, M, vocab_size] # ---- FE-H: sync prefetched layers (should be done by now) ---- if getattr(self, '_hayabusa_enabled', False): self._hayabusa.sync_all_cold() # ---- ACCEPT: check each branch, pick best ---- main_pred = logits[:, -1, :].argmax(dim=-1).item() best_branch = 0 best_accepted = 0 for bi in range(b): offset = bi * K accepted = 0 # Step 0: first draft token must match target model prediction if flat_ids[0, offset].item() == main_pred: accepted = 1 # Steps 1..K-1: each draft must match verify prediction for si in range(1, K): target_pred = verify_logits[0, offset + si - 1, :].argmax( dim=-1).item() if flat_ids[0, offset + si].item() == target_pred: accepted += 1 else: break if accepted > best_accepted: best_accepted = accepted best_branch = bi if accepted == K: break # Perfect match total_drafted += K total_rounds += 1 total_accepted += best_accepted # ---- DYNAMIC b TUNING (Scylla Eq.4 inspired) ---- if dynamic_b: round_accept_rate = best_accepted / K if K > 0 else 0.0 accept_history.append(round_accept_rate) if len(accept_history) > _DYNAMIC_B_WINDOW: accept_history.pop(0) if len(accept_history) >= 3: avg_rate = sum(accept_history) / len(accept_history) old_b = b if avg_rate > _DYNAMIC_B_HIGH_THRESH and b < b_max: b = min(b * 2, b_max) elif avg_rate < _DYNAMIC_B_LOW_THRESH and b > b_min: b = max(b // 2, b_min) # No change otherwise — stay at current b best_offset = best_branch * K # ---- COMMIT accepted tokens to output ---- for i in range(best_accepted): tok_id = flat_ids[0, best_offset + i].item() tok = torch.tensor([[tok_id]], device=generated.device) generated = torch.cat([generated, tok], dim=1) if callback is not None: callback(tok_id, current_pos + i) if tok_id in stop_set: return generated # ---- CORRECTION + REPLAY ---- if best_accepted < K: # Correction: target model's choice at rejection point if best_accepted == 0: correction_id = main_pred else: correction_id = verify_logits[ 0, best_offset + best_accepted - 1, :].argmax(dim=-1).item() correction = torch.tensor([[correction_id]], device=generated.device) generated = torch.cat([generated, correction], dim=1) if callback is not None: callback(correction_id, current_pos + best_accepted) if correction_id in stop_set: return generated # Replay accepted + correction into KV cache (standard forward) replay_len = best_accepted + 1 replay_ids = torch.zeros(1, replay_len, dtype=torch.long, device=input_ids.device) for i in range(best_accepted): replay_ids[0, i] = flat_ids[0, best_offset + i].item() replay_ids[0, best_accepted] = correction_id if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(replay_ids, use_cache=True, position=current_pos) current_pos += replay_len else: # All K accepted — bonus token from verify logits bonus_id = verify_logits[ 0, best_offset + K - 1, :].argmax(dim=-1).item() bonus = torch.tensor([[bonus_id]], device=generated.device) generated = torch.cat([generated, bonus], dim=1) if callback is not None: callback(bonus_id, current_pos + K) if bonus_id in stop_set: return generated # Replay all accepted + bonus into KV cache replay_len = K + 1 replay_ids = torch.zeros(1, replay_len, dtype=torch.long, device=input_ids.device) for i in range(K): replay_ids[0, i] = flat_ids[0, best_offset + i].item() replay_ids[0, K] = bonus_id if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False logits = self.forward(replay_ids, use_cache=True, position=current_pos) current_pos += replay_len if total_rounds > 0: avg_accept = total_accepted / total_rounds accept_rate = total_accepted / max(total_drafted, 1) * 100 b_info = f"b={b}" if not dynamic_b else f"b={b} (dynamic {b_min}-{b_max})" print(f" [FE-XT] {total_rounds} rounds, {total_drafted} drafted, " f"{total_accepted} accepted ({accept_rate:.1f}%), " f"avg {avg_accept:.1f}/round, {b_info}, d={depth}") return generated @torch.no_grad() def speculative_generate_batched( self, input_ids_list: List[torch.Tensor], max_new_tokens: int = 100, temperature: float = 0.0, draft_depth: int = None, stop_tokens: Optional[list] = None, ) -> List[torch.Tensor]: """Batched speculative decoding for multi-sequence throughput. Processes multiple sequences in parallel, amortizing weight loading across all sequences. Key optimization for 1K tok/s throughput. Architecture: 1. Each sequence has independent KV cache (via seq_id) 2. Draft K tokens for ALL sequences in parallel (single eagle pass) 3. Verify ALL sequences in single batched forward (amortize MoE weights) 4. Accept/reject per sequence, rollback independently Args: input_ids_list: List of prompt tensors, each [1, seq_len] max_new_tokens: Max tokens per sequence temperature: Sampling temperature (0 = greedy) draft_depth: Override draft depth (default: eagle default) stop_tokens: List of stop token IDs Returns: List of generated tensors, each [1, prompt_len + generated_len] """ if not getattr(self, '_eagle_enabled', False): raise RuntimeError("Call enable_eagle() before speculative_generate_batched()") self.eval() B = len(input_ids_list) if B < 1: raise ValueError("Need at least 1 sequence") if B > 8: raise ValueError("Batched speculation supports up to 8 sequences") depth = draft_depth or self._eagle_draft_depth device = input_ids_list[0].device if stop_tokens is None: stop_tokens = [199999, 200020] stop_set = set(stop_tokens) if stop_tokens else set() # Reset cache for new generation self.reset_cache() # Per-sequence state seq_states = [] for b, input_ids in enumerate(input_ids_list): if input_ids.shape[0] != 1: raise ValueError("Each input must have batch=1") seq_states.append({ 'seq_id': b, 'generated': input_ids.clone(), 'prompt_len': input_ids.shape[1], 'current_pos': input_ids.shape[1], 'done': False, 'logits': None, # Last logits for this sequence }) # ---- PREFILL: Process all prompts ---- # For simplicity, prefill each sequence separately (can batch later) if hasattr(self.kv_cache, '_graph_mode'): self.kv_cache._graph_mode = False for state in seq_states: self._current_seq_id = state['seq_id'] logits = self.forward( state['generated'], use_cache=True, position=0) state['logits'] = logits # Capture features for EAGLE state['features'] = [self._eagle_hidden_states[l].clone() for l in self._eagle_capture_layers] # Capture Hebbian memory context for draft head state['memory_ctx'] = self._get_eagle_memory_context( self._eagle_hidden_states[self._eagle_capture_layers[-1]]) # Stats total_drafted = 0 total_accepted = 0 total_rounds = 0 # ---- MAIN LOOP: Continue until all sequences done ---- while True: # Check if all done active = [s for s in seq_states if not s['done']] if not active: break # Check token budget for state in active: gen_len = state['generated'].shape[1] - state['prompt_len'] if gen_len >= max_new_tokens: state['done'] = True active = [s for s in seq_states if not s['done']] if not active: break # ---- DECODE: Sample next token for each active sequence ---- for state in active: logits = state['logits'] if temperature <= 1e-8: next_token = logits[:, -1:, :].argmax(dim=-1) else: next_logits = logits[:, -1, :] / temperature probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) state['generated'] = torch.cat([state['generated'], next_token], dim=1) if next_token.item() in stop_set: state['done'] = True continue # Process next_token through main model self._current_seq_id = state['seq_id'] logits = self.forward( next_token, use_cache=True, position=state['current_pos']) state['logits'] = logits state['current_pos'] += 1 state['last_token'] = next_token # Capture features state['features'] = [self._eagle_hidden_states[l].clone() for l in self._eagle_capture_layers] # Update Hebbian memory context for this sequence state['memory_ctx'] = self._get_eagle_memory_context( self._eagle_hidden_states[self._eagle_capture_layers[-1]]) # Refresh active list active = [s for s in seq_states if not s['done']] if not active: break # ---- DRAFT: Generate K candidate tokens for each active sequence ---- all_drafts = [] # List of (state, draft_tokens, draft_logits_list) for state in active: remaining = max_new_tokens - (state['generated'].shape[1] - state['prompt_len']) K = min(depth, remaining) if K <= 0: state['done'] = True continue draft_tokens, draft_logits_list = self.eagle_head.generate_draft( state['features'], state['last_token'], self.embed, depth=K, memory_context=state.get('memory_ctx')) if draft_tokens: all_drafts.append((state, draft_tokens, draft_logits_list)) total_drafted += len(draft_tokens) if not all_drafts: continue total_rounds += 1 # ---- VERIFY: Batched verification across all sequences ---- # For efficiency, we verify each sequence's drafts # (Full batched verification across sequences requires padding) for state, draft_tokens, _ in all_drafts: draft_input = torch.cat(draft_tokens, dim=1) # [1, K] self._current_seq_id = state['seq_id'] verify_logits = self.forward( draft_input, use_cache=True, position=state['current_pos']) # ---- ACCEPT/REJECT (greedy for temperature=0) ---- main_pred = state['logits'][:, -1, :].argmax(dim=-1).item() accepted = 0 if draft_tokens[0].item() == main_pred: accepted = 1 for i in range(1, len(draft_tokens)): target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item() if draft_tokens[i].item() == target_pred: accepted += 1 else: break # ---- COMMIT accepted tokens ---- for i in range(accepted): state['generated'] = torch.cat([state['generated'], draft_tokens[i]], dim=1) if draft_tokens[i].item() in stop_set: state['done'] = True break total_accepted += accepted if state['done']: continue # ---- CORRECTION: use target model's choice at rejection point ---- if accepted < len(draft_tokens): if accepted == 0: correction = torch.tensor( [[main_pred]], device=device) else: correction = verify_logits[:, accepted - 1, :].argmax( dim=-1, keepdim=True) state['generated'] = torch.cat([state['generated'], correction], dim=1) if correction.item() in stop_set: state['done'] = True continue # Rollback KV cache rollback_pos = state['current_pos'] + accepted + 1 rollback_count = len(draft_tokens) - accepted - 1 if rollback_count > 0: self._current_seq_id = state['seq_id'] self.kv_cache.rollback_to(rollback_pos, rollback_count) state['current_pos'] = rollback_pos # Get logits for next round self._current_seq_id = state['seq_id'] state['logits'] = self.forward( correction, use_cache=True, position=state['current_pos'] - 1) state['features'] = [self._eagle_hidden_states[l].clone() for l in self._eagle_capture_layers] else: # All K accepted — bonus token from last verify logits bonus = verify_logits[:, -1, :].argmax(dim=-1, keepdim=True) state['generated'] = torch.cat([state['generated'], bonus], dim=1) if bonus.item() in stop_set: state['done'] = True continue state['current_pos'] += accepted + 1 # Get logits for next round self._current_seq_id = state['seq_id'] state['logits'] = self.forward( bonus, use_cache=True, position=state['current_pos'] - 1) state['features'] = [self._eagle_hidden_states[l].clone() for l in self._eagle_capture_layers] # Print stats if total_rounds > 0: avg_accept = total_accepted / total_rounds accept_rate = total_accepted / max(total_drafted, 1) print(f" [EAGLE-3 Batched] B={B}, {total_rounds} rounds, " f"{total_drafted} drafted, {total_accepted} accepted " f"({accept_rate:.0%}), avg {avg_accept:.1f}/round") return [s['generated'] for s in seq_states] @torch.no_grad() def benchmark_batched_throughput( self, prompts: List[str], tokenizer, max_new_tokens: int = 100, warmup_tokens: int = 20, ) -> Dict[str, float]: """Benchmark batched speculative decoding throughput. Args: prompts: List of prompt strings tokenizer: Tokenizer with encode() method max_new_tokens: Tokens to generate per sequence warmup_tokens: Warmup tokens before timing Returns: Dict with 'total_tokens', 'elapsed_s', 'throughput_tok_s', 'per_sequence_tok_s' """ import time # Tokenize prompts input_ids_list = [] for prompt in prompts: ids = tokenizer.encode(prompt, return_tensors='pt') if hasattr(ids, 'to'): ids = ids.to(next(self.parameters()).device) input_ids_list.append(ids) B = len(prompts) # Warmup run if warmup_tokens > 0: warmup_list = [ids.clone() for ids in input_ids_list] _ = self.speculative_generate_batched( warmup_list, max_new_tokens=warmup_tokens) torch.cuda.synchronize() # Timed run torch.cuda.synchronize() start = time.perf_counter() results = self.speculative_generate_batched( input_ids_list, max_new_tokens=max_new_tokens) torch.cuda.synchronize() elapsed = time.perf_counter() - start # Count generated tokens total_tokens = sum( r.shape[1] - ids.shape[1] for r, ids in zip(results, input_ids_list) ) throughput = total_tokens / elapsed per_seq = throughput / B print(f" [Batched Throughput] B={B}, {total_tokens} tokens in {elapsed:.2f}s") print(f" Total: {throughput:.1f} tok/s, Per-sequence: {per_seq:.1f} tok/s") return { 'batch_size': B, 'total_tokens': total_tokens, 'elapsed_s': elapsed, 'throughput_tok_s': throughput, 'per_sequence_tok_s': per_seq, } def get_hebbian_entropy_loss(self) -> torch.Tensor: """Get neuro-modulator entropy loss for training loop integration (NHL Eq. 10). Returns a differentiable scalar that, when added to the main loss and backpropagated, trains the learned neuro-modulator to make decisive gating decisions (outputs near 0 or 1). Usage in training loop: logits = engine(input_ids, reward=r) loss = F.cross_entropy(logits, targets) loss = loss + engine.get_hebbian_entropy_loss() loss.backward() """ if self.hebbian is not None and hasattr(self.hebbian, 'get_modulator_entropy_loss'): return self.hebbian.get_modulator_entropy_loss() return torch.tensor(0.0) def get_hebbian_separation_loss(self) -> torch.Tensor: """Get pattern separation loss for training loop integration. Penalizes high cosine similarity between active memory slots, encouraging diverse, non-overlapping representations (Surget & Belzung 2022 — hippocampal neurogenesis pattern separation). Usage in training loop: loss = ce_loss + engine.get_hebbian_entropy_loss() loss = loss + engine.get_hebbian_separation_loss() loss.backward() """ if self.hebbian is not None and hasattr(self.hebbian, 'get_pattern_separation_loss'): return self.hebbian.get_pattern_separation_loss() return torch.tensor(0.0) def inject_kg_signal(self, score: float) -> None: """Inject KG consistency score into Hebbian memory for consolidation gating.""" if self.hebbian is not None: self.hebbian.inject_kg_signal(score) def compute_intrinsic_reward(self) -> Dict[str, float]: """Delegate intrinsic motivation computation to Hebbian memory.""" if self.hebbian is not None and hasattr(self.hebbian, 'compute_intrinsic_reward'): return self.hebbian.compute_intrinsic_reward() return {'curiosity': 0.0, 'competence': 0.0, 'combined': 0.0} @torch.no_grad() def perceive(self, input_ids: torch.Tensor) -> Dict[str, float]: """SPAR perceive: embed input and assess familiarity via Hebbian memory.""" if self.hebbian is None: return {'familiarity': 0.0, 'complexity': 0.0} x = self.embed(input_ids) x_mean = x.float().mean(dim=(0, 1)) # [D] return self.hebbian.perceive(x_mean) def reflect(self, loss: float, reward: float) -> Dict[str, Any]: """SPAR reflect: delegate self-monitoring to Hebbian memory.""" if self.hebbian is not None and hasattr(self.hebbian, 'reflect'): return self.hebbian.reflect(loss, reward) return {'prediction_error': 0.0, 'error_detected': False, 'consecutive_errors': 0, 'recovery_action': 'none'} def save_hebbian_state(self, path: Optional[str] = None) -> Dict[str, Any]: """Save Hebbian memory state. Optionally write to disk for cross-session persistence.""" if self.hebbian is None: return {} state = self.hebbian.save_state() if path is not None: torch.save(state, path) print(f" [FireEcho] Hebbian state saved to {path}") return state def load_hebbian_state(self, state_or_path) -> None: """Restore Hebbian memory from dict or disk path.""" if self.hebbian is None: return if isinstance(state_or_path, str): state = torch.load(state_or_path, map_location='cpu', weights_only=False) # Move tensors to current device device = next(self.parameters()).device for k, v in state.items(): if isinstance(v, dict): for kk, vv in v.items(): if isinstance(vv, torch.Tensor): v[kk] = vv.to(device) elif isinstance(v, torch.Tensor): state[k] = v.to(device) print(f" [FireEcho] Hebbian state loaded from {state_or_path}") else: state = state_or_path self.hebbian.load_state(state) # ================================================================= # Persistent Memory — AGI-like memory that never forgets # ================================================================= def enable_persistent_memory(self, save_dir: Optional[str] = None, auto_save_interval: int = 100, reflection_interval: int = 50): """Enable disk-backed persistent memory wrapping HebbianMemory. Creates a PersistentMemoryManager that adds: - Episodic log: timestamped experience records - Semantic journal: human-readable insights (JSON) - Reflection engine: periodic self-review + autonomous notes - Auto-save: Hebbian state persists across sessions Args: save_dir: Directory for memory files (default: engine dir / persistent_memory) auto_save_interval: Auto-save every N episodes reflection_interval: Run reflection every N episodes """ from persistent_memory import PersistentMemoryManager if save_dir is None: save_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'persistent_memory') self._persistent_memory = PersistentMemoryManager( hebbian=self.hebbian, save_dir=save_dir, auto_save_interval=auto_save_interval, reflection_interval=reflection_interval, ) # Load existing state from disk self._persistent_memory.load() print(f"[FireEcho] Persistent memory enabled: {save_dir}") def save_persistent_memory(self): """Manually save all persistent memory state to disk.""" if self._persistent_memory is not None: self._persistent_memory.save() def load_persistent_memory(self, save_dir: str): """Load persistent memory from a specific directory.""" if self._persistent_memory is None: self.enable_persistent_memory(save_dir=save_dir) else: self._persistent_memory.save_dir = save_dir self._persistent_memory.load() def log_episode(self, prompt_ids: torch.Tensor, gen_len: int, embedding: torch.Tensor, reward: float = 0.0, acceptance_rate: float = -1.0, tags: Optional[List[str]] = None) -> Optional[Dict]: """Log a completed generation as an episode in persistent memory.""" if self._persistent_memory is None: return None return self._persistent_memory.log_episode( prompt_ids=prompt_ids, gen_len=gen_len, embedding=embedding, reward=reward, acceptance_rate=acceptance_rate, tags=tags) def get_memory_journal(self, category: Optional[str] = None, min_confidence: float = 0.0) -> List[Dict]: """Return human-readable semantic insights from persistent memory.""" if self._persistent_memory is None: return [] return self._persistent_memory.get_journal(category, min_confidence) def recall_similar(self, query_embedding: torch.Tensor, top_k: int = 5) -> List[Dict]: """Find most similar past episodes in persistent memory.""" if self._persistent_memory is None: return [] return self._persistent_memory.recall(query_embedding, top_k) def get_persistent_memory_stats(self) -> Dict: """Get persistent memory statistics.""" if self._persistent_memory is None: return {'enabled': False} stats = self._persistent_memory.get_stats() stats['enabled'] = True return stats def memory_usage(self) -> Dict[str, Any]: """Return memory usage statistics.""" model_bytes = sum(p.numel() * p.element_size() for p in self.parameters()) return { 'model_gb': model_bytes / 1e9, 'model_params': sum(p.numel() for p in self.parameters()) / 1e6, 'kv_cache': self.kv_cache.stats() if self.kv_cache else {}, 'total_allocated_gb': torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0, 'total_reserved_gb': torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0, } @staticmethod def _extract_rope_theta(hf_config) -> float: """Extract rope_theta from HF config, checking rope_parameters dict. Some models store rope_theta inside a rope_parameters dict rather than as a top-level attribute. """ # 1. Top-level attribute theta = getattr(hf_config, 'rope_theta', None) if theta is not None: return float(theta) # 2. rope_parameters dict rope_params = getattr(hf_config, 'rope_parameters', None) if isinstance(rope_params, dict) and 'rope_theta' in rope_params: return float(rope_params['rope_theta']) # 3. Fallback return 10000.0 @classmethod def _load_qwen3_streaming(cls, config: 'FireEchoConfig', st_files: List[str], dtype: torch.dtype, device: str) -> 'FireEchoEngine': """Stream-load Qwen3-Omni thinker weights with minimal RAM/VRAM usage. Layer-by-layer pipeline: 1. Build engine skeleton (embed, norm, lm_head) in target dtype (~1.3GB) 2. For each of 48 layers: construct in BF16 → load weights from shards → quantize to FP4 on GPU → free CPU copy 3. Peak CPU RAM: ~7GB (skeleton + 1 layer + 1 shard) Peak VRAM: ~18GB (48 FP4 layers + BF16 embed/lm_head) """ from safetensors.torch import load_file import gc, re, os, json as _json # --- Phase 1: Build shard index (key → shard file mapping) --- index_path = os.path.join(os.path.dirname(st_files[0]), 'model.safetensors.index.json') key_to_shard: Dict[str, str] = {} shard_dir = os.path.dirname(st_files[0]) if os.path.exists(index_path): with open(index_path) as f: idx = _json.load(f) key_to_shard = {k: os.path.join(shard_dir, v) for k, v in idx.get('weight_map', {}).items()} print(f" [Qwen3 Streaming] Loaded shard index: " f"{len(key_to_shard)} keys across {len(st_files)} shards") else: print(f" [Qwen3 Streaming] No index file, will scan all shards") tp = 'thinker.model' # HF thinker prefix # --- Phase 2: Build layer → HF keys mapping --- # Group keys by layer index for targeted shard loading layer_hf_keys: Dict[int, List[str]] = {i: [] for i in range(config.num_layers)} global_keys = [] # embed, norm, lm_head _direct_suffix_map = { 'input_layernorm.weight': 'norm1.weight', 'post_attention_layernorm.weight': 'norm2.weight', 'self_attn.q_proj.weight': 'attn.q_proj.weight', 'self_attn.k_proj.weight': 'attn.k_proj.weight', 'self_attn.v_proj.weight': 'attn.v_proj.weight', 'self_attn.o_proj.weight': 'attn.out_proj.weight', 'self_attn.q_norm.weight': 'attn.q_norm.weight', 'self_attn.k_norm.weight': 'attn.k_norm.weight', 'mlp.gate.weight': 'ffn.router.gate.weight', } def _map_hf_key(hf_key): """Map HF thinker key to FireEcho key. Returns None if unmapped.""" if hf_key == f'{tp}.embed_tokens.weight': return 'embed.weight' if hf_key == f'{tp}.norm.weight': return 'norm.weight' if hf_key == 'thinker.lm_head.weight': return 'lm_head.weight' m = re.match(r'thinker\.model\.layers\.(\d+)\.(.*)', hf_key) if not m: return None li, suffix = m.group(1), m.group(2) fe = _direct_suffix_map.get(suffix) if fe: return f'layers.{li}.{fe}' em = re.match( r'mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight', suffix) if em: e_idx, proj = em.group(1), em.group(2) if proj == 'gate_proj': # Fused gate+up: gate goes to top half (:0) return f'layers.{li}.ffn.experts.{e_idx}.gate_up_proj.weight:0' elif proj == 'up_proj': # Fused gate+up: up goes to bottom half (:1) return f'layers.{li}.ffn.experts.{e_idx}.gate_up_proj.weight:1' else: return f'layers.{li}.ffn.experts.{e_idx}.down_proj.weight' return None # Categorize all indexed keys for hf_key in (key_to_shard.keys() if key_to_shard else []): if not hf_key.startswith('thinker.'): continue fe_key = _map_hf_key(hf_key) if fe_key is None: continue m = re.match(r'layers\.(\d+)\.', fe_key) if m: layer_hf_keys[int(m.group(1))].append(hf_key) else: global_keys.append(hf_key) # --- Phase 3: Build engine skeleton (no transformer layers yet) --- # Enable Goliath FP4 quantization for the streaming path — layers are # built with QuantizedLinear (use_nvfp4+use_goliath=True), loaded in # BF16, then quantized to packed FP4 on GPU one at a time. Without # this, full BF16 layers go to GPU and OOM at ~layer 20. # Goliath FP4 uses packed uint8 (2 values/byte) = half the VRAM of # legacy NV FP4 (int8). config.use_nvfp4 = True config.quantize_weights = True config.use_goliath = True # Hebbian is created AFTER loading (skeleton builds with num_layers=0, # PerLayerHebbian would get 0 memories). Set flag for post-load init. _want_hebbian = True config.use_hebbian = False print(f" [Qwen3 Streaming] Building engine skeleton...") old_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) try: # Temporarily set num_layers=0 to build skeleton only # (avoids allocating 30B+ params in one shot) orig_num_layers = config.num_layers config.num_layers = 0 config.critical_layers = [] engine = cls(config) config.num_layers = orig_num_layers if config.auto_critical_layers: config.critical_layers = config.compute_critical_layers() finally: torch.set_default_dtype(old_dtype) # Load global weights (embed, norm, lm_head) _loaded_shards_cache: Dict[str, Dict[str, torch.Tensor]] = {} n_assigned = 0 def _load_tensor(hf_key): """Load a single tensor from the correct shard (with caching).""" shard_path = key_to_shard.get(hf_key) if shard_path is None: # Fallback: scan all shards for sp in st_files: shard = load_file(sp, device='cpu') if hf_key in shard: t = shard[hf_key] del shard return t del shard return None if shard_path not in _loaded_shards_cache: _loaded_shards_cache[shard_path] = load_file(shard_path, device='cpu') return _loaded_shards_cache[shard_path].get(hf_key) def _set_param(obj, dotted_path, tensor): """Navigate dotted path and assign tensor to parameter.""" parts = dotted_path.split('.') for part in parts[:-1]: if part.isdigit(): obj = obj[int(part)] else: obj = getattr(obj, part) leaf = parts[-1] target = getattr(obj, leaf) if isinstance(target, nn.Parameter): target.data.copy_(tensor.to(dtype)) else: setattr(obj, leaf, tensor.to(dtype)) for hf_key in global_keys: fe_key = _map_hf_key(hf_key) t = _load_tensor(hf_key) if t is not None and fe_key is not None: _set_param(engine, fe_key, t) n_assigned += 1 _loaded_shards_cache.clear() gc.collect() # Move global params to GPU use_gpu = 'cuda' in str(device) and torch.cuda.is_available() if use_gpu: engine.embed = engine.embed.cuda() engine.norm = engine.norm.cuda() if hasattr(engine, 'lm_head') and engine.lm_head is not None: engine.lm_head = engine.lm_head.cuda() vram = torch.cuda.memory_allocated() / 1e9 print(f" [Qwen3 Streaming] Global params on GPU: {vram:.1f} GB") # --- Phase 4: Build + load + quantize each layer one at a time --- engine.layers = nn.ModuleList() # Reset to empty for li in range(config.num_layers): # Construct ONE layer in BF16 on CPU torch.set_default_dtype(dtype) try: layer = FusedTransformerBlock(config, li) finally: torch.set_default_dtype(old_dtype) # Load weights for this layer from shards layer_keys = layer_hf_keys.get(li, []) if not layer_keys and key_to_shard: # Fallback: scan index for this layer prefix = f'{tp}.layers.{li}.' layer_keys = [k for k in key_to_shard if k.startswith(prefix)] # Group by shard to minimize shard loads shard_groups: Dict[str, List[str]] = {} for hf_key in layer_keys: sp = key_to_shard.get(hf_key, '') shard_groups.setdefault(sp, []).append(hf_key) layer_assigned = 0 for shard_path, keys in shard_groups.items(): if shard_path and os.path.exists(shard_path): shard = load_file(shard_path, device='cpu') for hf_key in keys: fe_key = _map_hf_key(hf_key) if fe_key is None: continue t = shard.get(hf_key) if t is None: continue # Strip 'layers.{li}.' prefix for layer-relative path layer_rel = fe_key[len(f'layers.{li}.'):] try: if ':' in layer_rel: # Fused gate_up_proj half-write (:0=gate, :1=up) path, half_str = layer_rel.rsplit(':', 1) half_idx = int(half_str) # Navigate to fused weight parameter obj = layer parts = path.split('.') for part in parts[:-1]: obj = obj[int(part)] if part.isdigit() else getattr(obj, part) param = getattr(obj, parts[-1]) mid = param.shape[0] // 2 if half_idx == 0: param.data[:mid].copy_(t.to(dtype)) else: param.data[mid:].copy_(t.to(dtype)) else: _set_param(layer, layer_rel, t) layer_assigned += 1 n_assigned += 1 except (AttributeError, IndexError, RuntimeError): pass del shard gc.collect() # Quantize and move to GPU if use_gpu and config.use_nvfp4: # Disable residual correction — saves ~3x VRAM per expert for module in layer.modules(): if hasattr(module, 'compute_residual'): module.compute_residual = False # Module-by-module: move weight to GPU, quantize, free BF16 for module in layer.modules(): if (hasattr(module, '_quantize_weights') and hasattr(module, '_quantized') and not module._quantized): module.weight.data = module.weight.data.cuda() module._quantize_weights() if module._quantized and module._goliath_weights is not None: module.weight = nn.Parameter( torch.empty(0, device='cuda', dtype=dtype), requires_grad=False) # Move remaining (norms, router, biases, buffers) to GPU layer = layer.cuda() elif use_gpu: layer = layer.cuda() engine.layers.append(layer) if (li + 1) % 4 == 0 or li == config.num_layers - 1: gc.collect() if use_gpu: torch.cuda.empty_cache() vram = torch.cuda.memory_allocated() / 1e9 else: vram = 0 ram_gb = sum(p.numel() * p.element_size() for p in engine.parameters()) / 1e9 print(f" Layer {li+1}/{config.num_layers}: " f"{layer_assigned} weights, " f"VRAM {vram:.1f} GB, CPU {ram_gb:.1f} GB") # Move remaining named children (hebbian memory, etc.) if use_gpu: for name, module in engine.named_children(): if name not in ('embed', 'norm', 'lm_head', 'layers'): try: module.cuda() except RuntimeError: pass gc.collect() torch.cuda.empty_cache() vram = torch.cuda.memory_allocated() / 1e9 print(f" [Qwen3 Streaming] Final VRAM: {vram:.1f} GB (FP4 quantized)") elif device != 'cpu': engine = engine.to(device) # Reinitialize KV cache with correct num_layers (was 0 during skeleton) cache_device = 'cuda' if use_gpu else 'cpu' engine._init_kv_cache(cache_device) # Create Hebbian memory now that num_layers is correct if _want_hebbian: config.use_hebbian = True _dim = config.dim _mem_size = getattr(config, 'hebbian_memory_size', 512) _lr = getattr(config, 'hebbian_lr', 0.01) _decay = getattr(config, 'hebbian_decay', 0.999) if getattr(config, 'hebbian_per_layer', True): engine.hebbian = PerLayerHebbian( _dim, config.num_layers, _mem_size, _lr, _decay, use_layer_specialization=getattr( config, 'hebbian_use_layer_specialization', True), use_femx=getattr(config, 'hebbian_use_femx', False), ).to(device=cache_device, dtype=dtype) else: engine.hebbian = HebbianMemory( _dim, _mem_size, _lr, _decay, use_femx=getattr(config, 'hebbian_use_femx', False), ).to(device=cache_device, dtype=dtype) total_params = sum(p.numel() for p in engine.parameters()) / 1e6 print(f" [Qwen3 Streaming] Done: {total_params:.1f}M params, " f"{n_assigned} weights loaded") return engine @classmethod def from_pretrained(cls, model_name: str, config: Optional[FireEchoConfig] = None, dtype: torch.dtype = torch.bfloat16, device: str = 'cuda') -> 'FireEchoEngine': """ Load weights from a local model directory or HuggingFace model. Supports Molmo2-O-7B (AI2 / OLMo-style) and compatible architectures. Args: model_name: Local model path or HuggingFace model name config: Optional FireEchoConfig (auto-detected from model config if None) dtype: Model dtype device: Device to load to Returns: FireEchoEngine with loaded weights """ try: from transformers import AutoConfig except ImportError: raise ImportError("transformers required: pip install transformers") print(f"[FireEcho] Loading {model_name}...") # Load HF config — fall back to raw JSON if AutoConfig fails # (e.g., Qwen3-Omni has a bug in transformers config class) try: hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) except (AttributeError, TypeError, OSError, ValueError) as e: print(f" [FireEcho] AutoConfig failed ({e}), loading config.json directly") import types, json as _json, os as _os _cfg_path = _os.path.join(model_name, 'config.json') if _os.path.isdir(model_name) else None if _cfg_path and _os.path.exists(_cfg_path): with open(_cfg_path) as _f: _raw = _json.load(_f) hf_config = types.SimpleNamespace(**_raw) else: raise # --- Detect architecture from HF config (before loading weights) --- is_qwen3_omni = (getattr(hf_config, 'model_type', '') == 'qwen3_omni_moe' or 'Qwen3Omni' in str(getattr(hf_config, 'architectures', []))) # --- Load weights from safetensors / bin files --- import os import glob as _glob # Support local paths and HuggingFace Hub if os.path.isdir(model_name): model_dir = model_name else: from huggingface_hub import snapshot_download model_dir = snapshot_download(model_name) st_files = sorted(_glob.glob(os.path.join(model_dir, '*.safetensors'))) hf_state: Dict[str, torch.Tensor] = {} _qwen3_streaming = False if st_files: from safetensors.torch import load_file if is_qwen3_omni: # Qwen3: defer weight loading — will stream from shards later # This avoids holding 45GB+ in CPU RAM alongside the engine _qwen3_streaming = True print(f" Qwen3-Omni: will stream-load from {len(st_files)} shards") else: for f in st_files: hf_state.update(load_file(f)) else: # Fallback to .bin bin_files = sorted(_glob.glob(os.path.join(model_dir, '*.bin'))) for f in bin_files: hf_state.update(torch.load(f, map_location='cpu', weights_only=True)) if not hf_state and not _qwen3_streaming: raise RuntimeError(f"No weight files found in {model_dir}") # Detect architecture: Phi-4 (LoRA .base_layer.) vs OLMo/Molmo2 vs Llama-style is_phi4 = not is_qwen3_omni and any('.base_layer.' in k for k in hf_state) is_olmo = not is_phi4 and not is_qwen3_omni and any(k.startswith('model.transformer.blocks.') for k in hf_state) # Detect attention biases from LLM transformer layers (not vision/audio towers) if is_qwen3_omni: # Qwen3 thinker has audio_tower with biases — only check model.layers has_attn_bias = any( 'model.layers.' in k and (k.endswith('.self_attn.q_proj.bias') or k.endswith('.self_attn.att_proj.bias')) for k in hf_state ) else: has_attn_bias = any(k.endswith('.self_attn.q_proj.bias') or k.endswith('.self_attn.att_proj.bias') for k in hf_state) # Build FireEchoConfig from model config hf_rope_theta = cls._extract_rope_theta(hf_config) if config is None: if is_qwen3_omni: # Qwen3-Omni MoE thinker: extract text_config from thinker_config # Handle both AutoConfig objects and SimpleNamespace (raw JSON fallback) tc = getattr(hf_config, 'thinker_config', hf_config) if isinstance(tc, dict): import types as _types tc = _types.SimpleNamespace(**tc) text_cfg = getattr(tc, 'text_config', tc) if isinstance(text_cfg, dict): import types as _types text_cfg = _types.SimpleNamespace(**text_cfg) config = FireEchoConfig( dim=text_cfg.hidden_size, num_heads=text_cfg.num_attention_heads, num_kv_heads=getattr(text_cfg, 'num_key_value_heads', text_cfg.num_attention_heads), head_dim=getattr(text_cfg, 'head_dim', None), num_layers=text_cfg.num_hidden_layers, vocab_size=text_cfg.vocab_size, intermediate_size=text_cfg.intermediate_size, max_seq_len=min(getattr(text_cfg, 'max_position_embeddings', 65536), 131072), rope_theta=getattr(text_cfg, 'rope_theta', 1000000.0), partial_rotary_factor=1.0, tie_word_embeddings=getattr(text_cfg, 'tie_word_embeddings', False), norm_after=False, use_qk_norm=getattr(text_cfg, 'use_qk_norm', True), qk_norm_per_head=True, # Qwen3 always uses per-head QK norm use_fused_qkv=False, # Qwen3 has separate Q/K/V use_moe=True, num_experts=getattr(text_cfg, 'num_experts', 128), num_experts_per_tok=getattr(text_cfg, 'num_experts_per_tok', 8), moe_intermediate_size=getattr(text_cfg, 'moe_intermediate_size', 768), shared_expert_intermediate_size=getattr(text_cfg, 'shared_expert_intermediate_size', 0), norm_topk_prob=getattr(text_cfg, 'norm_topk_prob', True), # KV cache: 4k tokens (256 blocks × 16) — 30B MoE leaves # ~10GB for cache on 32GB GPU. 4k × 48L × 4H × 128D × 2 = 3.1GB. # Default 8192 blocks = 6GB → OOM after 20.3GB model load. max_kv_blocks=256, kv_block_size=16, use_nvfp4=False, use_goliath=False, use_hebbian=False, ) elif is_phi4: # Phi-4-multimodal-instruct: pre-norm, GQA, partial RoPE, tied embeddings config = FireEchoConfig( dim=hf_config.hidden_size, num_heads=hf_config.num_attention_heads, num_kv_heads=getattr(hf_config, 'num_key_value_heads', hf_config.num_attention_heads), num_layers=hf_config.num_hidden_layers, vocab_size=hf_config.vocab_size, intermediate_size=hf_config.intermediate_size, max_seq_len=min(getattr(hf_config, 'max_position_embeddings', 131072), 131072), rope_theta=getattr(hf_config, 'rope_theta', 10000.0), partial_rotary_factor=getattr(hf_config, 'partial_rotary_factor', 1.0), attn_bias=getattr(hf_config, 'attention_bias', False), tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', True), norm_after=False, use_nvfp4=False, use_goliath=False, use_hebbian=False, ) elif is_olmo: # Molmo2 / OLMo: config may have nested text_config text_cfg = getattr(hf_config, 'text_config', hf_config) # Determine actual vocab size from loaded embeddings (config's # additional_vocab_size=128 is often wrong — actual new_embedding # may only have 10 rows) _base_emb = hf_state.get('model.transformer.wte.embedding') _new_emb = hf_state.get('model.transformer.wte.new_embedding') if _base_emb is not None: _actual_vocab = _base_emb.shape[0] + (_new_emb.shape[0] if _new_emb is not None else 0) else: _actual_vocab = text_cfg.vocab_size + (getattr(text_cfg, 'additional_vocab_size', 0) or 0) config = FireEchoConfig( dim=text_cfg.hidden_size, num_heads=text_cfg.num_attention_heads, num_kv_heads=getattr(text_cfg, 'num_key_value_heads', text_cfg.num_attention_heads), num_layers=text_cfg.num_hidden_layers, vocab_size=_actual_vocab, intermediate_size=text_cfg.intermediate_size, max_seq_len=min(getattr(text_cfg, 'max_position_embeddings', 65536), 131072), rope_theta=getattr(text_cfg, 'rope_theta', hf_rope_theta), attn_bias=getattr(text_cfg, 'qkv_bias', has_attn_bias), tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), use_qk_norm=getattr(text_cfg, 'use_qk_norm', False), norm_after=getattr(text_cfg, 'norm_after', False), use_nvfp4=False, use_goliath=False, use_hebbian=False, ) else: config = FireEchoConfig( dim=hf_config.hidden_size, num_heads=hf_config.num_attention_heads, num_kv_heads=getattr(hf_config, 'num_key_value_heads', hf_config.num_attention_heads), num_layers=hf_config.num_hidden_layers, vocab_size=hf_config.vocab_size, intermediate_size=hf_config.intermediate_size, max_seq_len=min(getattr(hf_config, 'max_position_embeddings', 4096), 131072), rope_theta=hf_rope_theta, attn_bias=has_attn_bias, tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False), use_nvfp4=False, use_goliath=False, use_hebbian=False, ) else: # Caller provided config — sync rope_theta if caller used default if config.rope_theta == 10000.0 and hf_rope_theta != 10000.0: print(f" [FireEcho] Overriding rope_theta: {config.rope_theta} -> {hf_rope_theta} (from HF config)") config.rope_theta = hf_rope_theta if has_attn_bias and not config.attn_bias: config.attn_bias = True # Sync partial_rotary_factor for Phi-4 if is_phi4: phi_prf = getattr(hf_config, 'partial_rotary_factor', None) if phi_prf is not None and config.partial_rotary_factor == 1.0: print(f" [FireEcho] Overriding partial_rotary_factor: 1.0 -> {phi_prf} (from HF config)") config.partial_rotary_factor = phi_prf if getattr(hf_config, 'tie_word_embeddings', False) and not config.tie_word_embeddings: config.tie_word_embeddings = True # Sync QK norm and norm_after for OLMo models if is_olmo: text_cfg = getattr(hf_config, 'text_config', hf_config) if getattr(text_cfg, 'use_qk_norm', False) and not config.use_qk_norm: config.use_qk_norm = True if getattr(text_cfg, 'norm_after', False) and not config.norm_after: print(" [FireEcho] Enabling norm_after=True (post-norm, from HF config)") config.norm_after = True # Sync head_dim for Qwen3 (head_dim=128, not dim//num_heads=64) if is_qwen3_omni: tc = getattr(hf_config, 'thinker_config', hf_config) if isinstance(tc, dict): import types as _types tc = _types.SimpleNamespace(**tc) text_cfg = getattr(tc, 'text_config', tc) if isinstance(text_cfg, dict): import types as _types text_cfg = _types.SimpleNamespace(**text_cfg) hf_head_dim = getattr(text_cfg, 'head_dim', None) if hf_head_dim is not None and config.head_dim != hf_head_dim: print(f" [FireEcho] Overriding head_dim: {config.head_dim} -> {hf_head_dim} (from HF config)") config.head_dim = hf_head_dim # Auto-correct vocab_size from actual loaded embedding weights # (Molmo2 config claims additional_vocab_size=128 but actual new_embedding # may only have 10 rows — mismatch causes uninitialized tokens and garbage output) if is_olmo: _base_emb = hf_state.get('model.transformer.wte.embedding') _new_emb = hf_state.get('model.transformer.wte.new_embedding') if _base_emb is not None: _actual_vocab = _base_emb.shape[0] + (_new_emb.shape[0] if _new_emb is not None else 0) if _actual_vocab != config.vocab_size: print(f" [FireEcho] Correcting vocab_size: {config.vocab_size} -> {_actual_vocab} (from loaded embeddings)") config.vocab_size = _actual_vocab # Auto-detect critical layers for quantization if not config.critical_layers and config.auto_critical_layers: config.critical_layers = config.compute_critical_layers() # --- Qwen3-Omni streaming load (memory-efficient) --- # Constructs engine in bf16, streams shards one at a time, layer-by-layer GPU quantize. # Peak RAM: ~66GB (engine 61GB + 1 shard ~5GB) instead of ~167GB (float32 engine + all weights). if _qwen3_streaming: return cls._load_qwen3_streaming(config, st_files, dtype, device) engine = cls(config) # --- Build weight mapping and load --- mapped_state: Dict[str, torch.Tensor] = {} if is_qwen3_omni: # === Qwen3-Omni MoE thinker weight mapping === # Keys prefixed with 'thinker.' — separate Q/K/V, MoE experts, QK norm print(f" Architecture: Qwen3-Omni MoE thinker " f"({config.num_experts} experts/layer, top-{config.num_experts_per_tok})") tp = 'thinker.model' # thinker prefix # Embeddings emb_w = hf_state.get(f'{tp}.embed_tokens.weight') if emb_w is not None: mapped_state['embed.weight'] = emb_w # Final norm norm_w = hf_state.get(f'{tp}.norm.weight') if norm_w is not None: mapped_state['norm.weight'] = norm_w # LM head (not tied for Qwen3) lm_w = hf_state.get('thinker.lm_head.weight') if lm_w is not None: mapped_state['lm_head.weight'] = lm_w # Per-layer weights for i in range(config.num_layers): hp = f'{tp}.layers.{i}' # HF prefix fp = f'layers.{i}' # FireEcho prefix # Layer norms in_ln = hf_state.get(f'{hp}.input_layernorm.weight') if in_ln is not None: mapped_state[f'{fp}.norm1.weight'] = in_ln post_ln = hf_state.get(f'{hp}.post_attention_layernorm.weight') if post_ln is not None: mapped_state[f'{fp}.norm2.weight'] = post_ln # Attention: separate Q/K/V projections mapped_state[f'{fp}.attn.q_proj.weight'] = hf_state[f'{hp}.self_attn.q_proj.weight'] mapped_state[f'{fp}.attn.k_proj.weight'] = hf_state[f'{hp}.self_attn.k_proj.weight'] mapped_state[f'{fp}.attn.v_proj.weight'] = hf_state[f'{hp}.self_attn.v_proj.weight'] mapped_state[f'{fp}.attn.out_proj.weight'] = hf_state[f'{hp}.self_attn.o_proj.weight'] # QK norm (per-head: weight shape [head_dim]) q_norm_w = hf_state.get(f'{hp}.self_attn.q_norm.weight') if q_norm_w is not None: mapped_state[f'{fp}.attn.q_norm.weight'] = q_norm_w k_norm_w = hf_state.get(f'{hp}.self_attn.k_norm.weight') if k_norm_w is not None: mapped_state[f'{fp}.attn.k_norm.weight'] = k_norm_w # MoE router gate_w = hf_state.get(f'{hp}.mlp.gate.weight') if gate_w is not None: mapped_state[f'{fp}.ffn.router.gate.weight'] = gate_w # 128 experts per layer for e in range(config.num_experts): for proj in ('gate_proj', 'up_proj', 'down_proj'): src_key = f'{hp}.mlp.experts.{e}.{proj}.weight' dst_key = f'{fp}.ffn.experts.{e}.{proj}.weight' w = hf_state.get(src_key) if w is not None: mapped_state[dst_key] = w # Count mapped n_thinker_total = len(hf_state) n_mapped = len(mapped_state) print(f" Mapped {n_mapped}/{n_thinker_total} thinker weights") # Skipped non-thinker components (already filtered during loading) print(f" Skipped talker, code2wav, vision, audio encoders (text-only mode)") elif is_phi4: # === Phi-4-multimodal-instruct weight mapping === # Keys have `.base_layer.` due to built-in LoRA; merge vision LoRA when use_vision=True. modes = [] if config.use_vision: modes.append('vision') if config.use_audio: modes.append('audio') modes.append('text') if config.use_vision or config.use_audio: modes.append('LoRA') mode_str = '+'.join(modes) print(f" Architecture: Phi-4-multimodal (fused QKV + fused gate_up, {mode_str} mode)") # Embeddings emb_w = hf_state.get('model.embed_tokens.weight') if emb_w is not None: mapped_state['embed.weight'] = emb_w # Final norm norm_w = hf_state.get('model.norm.weight') if norm_w is not None: mapped_state['norm.weight'] = norm_w # LM head: tie_word_embeddings=True → no separate lm_head weight if not config.tie_word_embeddings and 'lm_head.weight' in hf_state: mapped_state['lm_head.weight'] = hf_state['lm_head.weight'] # Per-layer weights with fused QKV and gate_up splitting head_dim = config.dim // config.num_heads q_dim = config.num_heads * head_dim k_dim = config.num_kv_heads * head_dim v_dim = config.num_kv_heads * head_dim # LoRA merge: W_merged = W_base + (alpha/r) * B @ A # Vision and speech LoRAs are additive on the same base weights vision_lora_scale = config.vision_lora_alpha / config.vision_lora_r # 2.0 speech_lora_scale = config.speech_lora_alpha / config.speech_lora_r # 2.0 lora_merged = 0 speech_lora_merged = 0 for i in range(config.num_layers): ph = f'model.layers.{i}' pf = f'layers.{i}' # Norms in_ln = hf_state.get(f'{ph}.input_layernorm.weight') if in_ln is not None: mapped_state[f'{pf}.norm1.weight'] = in_ln post_ln = hf_state.get(f'{ph}.post_attention_layernorm.weight') if post_ln is not None: mapped_state[f'{pf}.norm2.weight'] = post_ln # Fused QKV → merge vision+speech LoRA → split into q_proj, k_proj, v_proj fused_qkv = hf_state.get(f'{ph}.self_attn.qkv_proj.base_layer.weight') if fused_qkv is not None: if config.use_vision: la = hf_state.get(f'{ph}.self_attn.qkv_proj.lora_A.vision.weight') lb = hf_state.get(f'{ph}.self_attn.qkv_proj.lora_B.vision.weight') if la is not None and lb is not None: fused_qkv = fused_qkv.float() + vision_lora_scale * (lb.float() @ la.float()) fused_qkv = fused_qkv.to(dtype) lora_merged += 2 if config.use_audio: la = hf_state.get(f'{ph}.self_attn.qkv_proj.lora_A.speech.weight') lb = hf_state.get(f'{ph}.self_attn.qkv_proj.lora_B.speech.weight') if la is not None and lb is not None: fused_qkv = fused_qkv.float() + speech_lora_scale * (lb.float() @ la.float()) fused_qkv = fused_qkv.to(dtype) speech_lora_merged += 2 assert fused_qkv.shape[0] == q_dim + k_dim + v_dim, \ f"QKV shape mismatch layer {i}: expected {q_dim+k_dim+v_dim}, got {fused_qkv.shape[0]}" q_w, k_w, v_w = fused_qkv.split([q_dim, k_dim, v_dim], dim=0) mapped_state[f'{pf}.attn.q_proj.weight'] = q_w mapped_state[f'{pf}.attn.k_proj.weight'] = k_w mapped_state[f'{pf}.attn.v_proj.weight'] = v_w # Output projection → merge vision+speech LoRA o_w = hf_state.get(f'{ph}.self_attn.o_proj.base_layer.weight') if o_w is not None: if config.use_vision: la = hf_state.get(f'{ph}.self_attn.o_proj.lora_A.vision.weight') lb = hf_state.get(f'{ph}.self_attn.o_proj.lora_B.vision.weight') if la is not None and lb is not None: o_w = o_w.float() + vision_lora_scale * (lb.float() @ la.float()) o_w = o_w.to(dtype) lora_merged += 2 if config.use_audio: la = hf_state.get(f'{ph}.self_attn.o_proj.lora_A.speech.weight') lb = hf_state.get(f'{ph}.self_attn.o_proj.lora_B.speech.weight') if la is not None and lb is not None: o_w = o_w.float() + speech_lora_scale * (lb.float() @ la.float()) o_w = o_w.to(dtype) speech_lora_merged += 2 mapped_state[f'{pf}.attn.out_proj.weight'] = o_w # Fused gate_up_proj → merge vision+speech LoRA → split fused_gu = hf_state.get(f'{ph}.mlp.gate_up_proj.base_layer.weight') if fused_gu is not None: if config.use_vision: la = hf_state.get(f'{ph}.mlp.gate_up_proj.lora_A.vision.weight') lb = hf_state.get(f'{ph}.mlp.gate_up_proj.lora_B.vision.weight') if la is not None and lb is not None: fused_gu = fused_gu.float() + vision_lora_scale * (lb.float() @ la.float()) fused_gu = fused_gu.to(dtype) lora_merged += 2 if config.use_audio: la = hf_state.get(f'{ph}.mlp.gate_up_proj.lora_A.speech.weight') lb = hf_state.get(f'{ph}.mlp.gate_up_proj.lora_B.speech.weight') if la is not None and lb is not None: fused_gu = fused_gu.float() + speech_lora_scale * (lb.float() @ la.float()) fused_gu = fused_gu.to(dtype) speech_lora_merged += 2 gate_w, up_w = fused_gu.chunk(2, dim=0) mapped_state[f'{pf}.ffn.gate_proj.weight'] = gate_w mapped_state[f'{pf}.ffn.up_proj.weight'] = up_w # Down projection → merge vision+speech LoRA down_w = hf_state.get(f'{ph}.mlp.down_proj.base_layer.weight') if down_w is not None: if config.use_vision: la = hf_state.get(f'{ph}.mlp.down_proj.lora_A.vision.weight') lb = hf_state.get(f'{ph}.mlp.down_proj.lora_B.vision.weight') if la is not None and lb is not None: down_w = down_w.float() + vision_lora_scale * (lb.float() @ la.float()) down_w = down_w.to(dtype) lora_merged += 2 if config.use_audio: la = hf_state.get(f'{ph}.mlp.down_proj.lora_A.speech.weight') lb = hf_state.get(f'{ph}.mlp.down_proj.lora_B.speech.weight') if la is not None and lb is not None: down_w = down_w.float() + speech_lora_scale * (lb.float() @ la.float()) down_w = down_w.to(dtype) speech_lora_merged += 2 mapped_state[f'{pf}.ffn.down_proj.weight'] = down_w if lora_merged > 0: print(f" Merged {lora_merged} vision LoRA weights " f"(r={config.vision_lora_r}, alpha={config.vision_lora_alpha}, " f"scale={vision_lora_scale:.1f})") if speech_lora_merged > 0: print(f" Merged {speech_lora_merged} speech LoRA weights " f"(r={config.speech_lora_r}, alpha={config.speech_lora_alpha}, " f"scale={speech_lora_scale:.1f})") # Vision encoder weights (when use_vision=True) vision_mapped = 0 if config.use_vision: vision_prefix = 'model.embed_tokens_extend.image_embed.' vision_map = { # SigLIP embeddings 'img_processor.embeddings.patch_embedding.weight': 'multimodal.vision_encoder.embeddings.patch_embedding.weight', 'img_processor.embeddings.patch_embedding.bias': 'multimodal.vision_encoder.embeddings.patch_embedding.bias', 'img_processor.embeddings.position_embedding.weight': 'multimodal.vision_encoder.embeddings.position_embedding.weight', # Post-layernorm 'img_processor.post_layernorm.weight': 'multimodal.vision_encoder.post_layernorm.weight', 'img_processor.post_layernorm.bias': 'multimodal.vision_encoder.post_layernorm.bias', # Learnable separators 'glb_GN': 'multimodal.vision_encoder.glb_GN', 'sub_GN': 'multimodal.vision_encoder.sub_GN', } # Static mappings for src_suffix, dst_key in vision_map.items(): src_key = vision_prefix + src_suffix if src_key in hf_state: mapped_state[dst_key] = hf_state[src_key] vision_mapped += 1 # Per-layer SigLIP encoder weights (27 layers) for i in range(config.vision_num_layers): src_layer = f'{vision_prefix}img_processor.encoder.layers.{i}' dst_layer = f'multimodal.vision_encoder.encoder_layers.{i}' layer_parts = { 'self_attn.q_proj.weight': 'self_attn.q_proj.weight', 'self_attn.q_proj.bias': 'self_attn.q_proj.bias', 'self_attn.k_proj.weight': 'self_attn.k_proj.weight', 'self_attn.k_proj.bias': 'self_attn.k_proj.bias', 'self_attn.v_proj.weight': 'self_attn.v_proj.weight', 'self_attn.v_proj.bias': 'self_attn.v_proj.bias', 'self_attn.out_proj.weight': 'self_attn.out_proj.weight', 'self_attn.out_proj.bias': 'self_attn.out_proj.bias', 'layer_norm1.weight': 'layer_norm1.weight', 'layer_norm1.bias': 'layer_norm1.bias', 'layer_norm2.weight': 'layer_norm2.weight', 'layer_norm2.bias': 'layer_norm2.bias', 'mlp.fc1.weight': 'mlp.fc1.weight', 'mlp.fc1.bias': 'mlp.fc1.bias', 'mlp.fc2.weight': 'mlp.fc2.weight', 'mlp.fc2.bias': 'mlp.fc2.bias', } for src_part, dst_part in layer_parts.items(): src_key = f'{src_layer}.{src_part}' if src_key in hf_state: mapped_state[f'{dst_layer}.{dst_part}'] = hf_state[src_key] vision_mapped += 1 # MLP projection (img_projection) proj_keys = [ ('img_projection.0.weight', 'multimodal.vision_encoder.projection.0.weight'), ('img_projection.0.bias', 'multimodal.vision_encoder.projection.0.bias'), ('img_projection.2.weight', 'multimodal.vision_encoder.projection.2.weight'), ('img_projection.2.bias', 'multimodal.vision_encoder.projection.2.bias'), ] for src_suffix, dst_key in proj_keys: src_key = vision_prefix + src_suffix if src_key in hf_state: mapped_state[dst_key] = hf_state[src_key] vision_mapped += 1 print(f" Mapped {vision_mapped} vision encoder weights (SigLIP)") # Strict verification: check for missing/unexpected keys if hasattr(engine, 'multimodal') and hasattr(engine.multimodal, 'vision_encoder'): vision_keys_in_mapped = {k for k in mapped_state if k.startswith('multimodal.vision_encoder.')} expected_keys = {f'multimodal.vision_encoder.{k}' for k, _ in engine.multimodal.vision_encoder.named_parameters()} missing = expected_keys - vision_keys_in_mapped unexpected = vision_keys_in_mapped - expected_keys if missing: print(f" WARNING: {len(missing)} missing vision weights: " f"{list(missing)[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f" WARNING: {len(unexpected)} unexpected vision weights: " f"{list(unexpected)[:5]}{'...' if len(unexpected) > 5 else ''}") if not missing and not unexpected: print(f" Vision weight verification: OK (all {len(expected_keys)} params matched)") # Audio encoder weights (when use_audio=True) audio_mapped = 0 if config.use_audio: audio_prefix = 'model.embed_tokens_extend.audio_embed.' # Static mappings (conv embedding, norm buffers, rel bias, projection) audio_static_map = { 'encoder.encoder_embedding.global_mean': 'multimodal.audio_encoder.encoder.encoder_embedding.global_mean', 'encoder.encoder_embedding.global_invstd': 'multimodal.audio_encoder.encoder.encoder_embedding.global_invstd', 'encoder.relative_attention_bias_layer.bias_values.weight': 'multimodal.audio_encoder.encoder.relative_attention_bias_layer.bias_values.weight', 'audio_projection.speech.0.weight': 'multimodal.audio_encoder.audio_projection.speech.0.weight', 'audio_projection.speech.0.bias': 'multimodal.audio_encoder.audio_projection.speech.0.bias', 'audio_projection.speech.2.weight': 'multimodal.audio_encoder.audio_projection.speech.2.weight', 'audio_projection.speech.2.bias': 'multimodal.audio_encoder.audio_projection.speech.2.bias', } # Conv subsampling: embed.conv.{0,2,3,5,6} + embed.out for idx in [0, 2, 3, 5, 6]: for suffix in ['weight', 'bias']: k = f'encoder.embed.conv.{idx}.{suffix}' audio_static_map[k] = f'multimodal.audio_encoder.{k}' for suffix in ['weight', 'bias']: k = f'encoder.embed.out.{suffix}' audio_static_map[k] = f'multimodal.audio_encoder.{k}' for src_suffix, dst_key in audio_static_map.items(): src_key = audio_prefix + src_suffix if src_key in hf_state: mapped_state[dst_key] = hf_state[src_key] audio_mapped += 1 # Per-layer Conformer encoder weights (24 layers, 36 tensors each) audio_layer_parts = [ 'self_attn.linear_q.weight', 'self_attn.linear_q.bias', 'self_attn.linear_k.weight', 'self_attn.linear_k.bias', 'self_attn.linear_v.weight', 'self_attn.linear_v.bias', 'self_attn.linear_out.weight', 'self_attn.linear_out.bias', 'feed_forward_in.layer_norm.weight', 'feed_forward_in.layer_norm.bias', 'feed_forward_in.net.0.linear.weight', 'feed_forward_in.net.0.linear.bias', 'feed_forward_in.net.2.weight', 'feed_forward_in.net.2.bias', 'feed_forward_out.layer_norm.weight', 'feed_forward_out.layer_norm.bias', 'feed_forward_out.net.0.linear.weight', 'feed_forward_out.net.0.linear.bias', 'feed_forward_out.net.2.weight', 'feed_forward_out.net.2.bias', 'conv.layer_norm.weight', 'conv.layer_norm.bias', 'conv.glu.ext_pw_conv_1d.weight', 'conv.glu.ext_pw_conv_1d.bias', 'conv.glu.b1', 'conv.glu.b2', 'conv.dw_sep_conv_1d.dw_conv.weight', 'conv.dw_sep_conv_1d.dw_conv.bias', 'conv.dw_sep_conv_1d.pw_conv.weight', 'conv.dw_sep_conv_1d.pw_conv.bias', 'conv.ext_pw_conv_1d.weight', 'conv.ext_pw_conv_1d.bias', 'layer_norm_att.weight', 'layer_norm_att.bias', 'layer_norm.weight', 'layer_norm.bias', ] for i in range(config.audio_num_layers): src_layer = f'{audio_prefix}encoder.encoders.{i}' dst_layer = f'multimodal.audio_encoder.encoder.encoders.{i}' for part in audio_layer_parts: src_key = f'{src_layer}.{part}' if src_key in hf_state: mapped_state[f'{dst_layer}.{part}'] = hf_state[src_key] audio_mapped += 1 print(f" Mapped {audio_mapped} audio encoder weights (Conformer)") # Verification if hasattr(engine, 'multimodal') and hasattr(engine.multimodal, 'audio_encoder'): audio_keys_in_mapped = {k for k in mapped_state if k.startswith('multimodal.audio_encoder.')} expected_keys = set() for k, _ in engine.multimodal.audio_encoder.named_parameters(): expected_keys.add(f'multimodal.audio_encoder.{k}') for k, _ in engine.multimodal.audio_encoder.named_buffers(): expected_keys.add(f'multimodal.audio_encoder.{k}') missing = expected_keys - audio_keys_in_mapped unexpected = audio_keys_in_mapped - expected_keys if missing: print(f" WARNING: {len(missing)} missing audio weights: " f"{sorted(list(missing))[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f" WARNING: {len(unexpected)} unexpected audio weights: " f"{sorted(list(unexpected))[:5]}") if not missing and not unexpected: print(f" Audio weight verification: OK (all {len(expected_keys)} params+buffers matched)") # Report skipped weights all_lora = [k for k in hf_state if 'lora_A.' in k or 'lora_B.' in k] merged_modalities = set() if config.use_vision and lora_merged > 0: merged_modalities.add('vision') if config.use_audio and speech_lora_merged > 0: merged_modalities.add('speech') skipped_lora = sum(1 for k in all_lora if not any(f'.{m}.' in k for m in merged_modalities)) if skipped_lora > 0: print(f" Skipped {skipped_lora} LoRA adapter weights" f" ({'unmerged modalities' if merged_modalities else 'base mode'})") # Count skipped multimodal embedding weights skipped_mm = sum(1 for k in hf_state if 'embed_tokens_extend' in k and not (config.use_vision and '.image_embed.' in k and 'lora_' not in k) and not (config.use_audio and '.audio_embed.' in k and 'lora_' not in k and 'audio_projection.vision.' not in k)) if skipped_mm > 0: print(f" Skipped {skipped_mm} multimodal embedding weights") elif is_olmo: # === OLMo / Molmo2 weight mapping (fused QKV/MLP splitting) === print(" Architecture: OLMo/Molmo2 (fused QKV + fused MLP)") # Embedding: concatenate wte.embedding + wte.new_embedding base_emb = hf_state.get('model.transformer.wte.embedding') new_emb = hf_state.get('model.transformer.wte.new_embedding') if base_emb is not None: if new_emb is not None: mapped_state['embed.weight'] = torch.cat([base_emb, new_emb], dim=0) else: mapped_state['embed.weight'] = base_emb # Final norm ln_f = hf_state.get('model.transformer.ln_f.weight') if ln_f is not None: mapped_state['norm.weight'] = ln_f # LM head — pad to full vocab_size if additional embeddings exist # Use new_embedding weights as initial lm_head rows for added tokens # (better than zeros — mirrors the input embedding for those tokens) if not config.tie_word_embeddings and 'lm_head.weight' in hf_state: lm_w = hf_state['lm_head.weight'] if lm_w.shape[0] < config.vocab_size: n_pad = config.vocab_size - lm_w.shape[0] if new_emb is not None and new_emb.shape[0] == n_pad: pad = new_emb.to(dtype=lm_w.dtype, device=lm_w.device) else: pad = torch.zeros(n_pad, lm_w.shape[1], dtype=lm_w.dtype, device=lm_w.device) lm_w = torch.cat([lm_w, pad], dim=0) mapped_state['lm_head.weight'] = lm_w # Per-layer weights with fused weight splitting head_dim = config.dim // config.num_heads q_dim = config.num_heads * head_dim k_dim = config.num_kv_heads * head_dim v_dim = config.num_kv_heads * head_dim for i in range(config.num_layers): oh = f'model.transformer.blocks.{i}' pf = f'layers.{i}' # Norms attn_norm = hf_state.get(f'{oh}.attn_norm.weight') if attn_norm is not None: mapped_state[f'{pf}.norm1.weight'] = attn_norm ff_norm = hf_state.get(f'{oh}.ff_norm.weight') if ff_norm is not None: mapped_state[f'{pf}.norm2.weight'] = ff_norm # Fused QKV -> split into q_proj, k_proj, v_proj fused_qkv = hf_state.get(f'{oh}.self_attn.att_proj.weight') if fused_qkv is not None: q_w, k_w, v_w = fused_qkv.split([q_dim, k_dim, v_dim], dim=0) mapped_state[f'{pf}.attn.q_proj.weight'] = q_w mapped_state[f'{pf}.attn.k_proj.weight'] = k_w mapped_state[f'{pf}.attn.v_proj.weight'] = v_w # Output projection attn_out = hf_state.get(f'{oh}.self_attn.attn_out.weight') if attn_out is not None: mapped_state[f'{pf}.attn.out_proj.weight'] = attn_out # QK norms q_norm_w = hf_state.get(f'{oh}.self_attn.q_norm.weight') if q_norm_w is not None: mapped_state[f'{pf}.attn.q_norm.weight'] = q_norm_w k_norm_w = hf_state.get(f'{oh}.self_attn.k_norm.weight') if k_norm_w is not None: mapped_state[f'{pf}.attn.k_norm.weight'] = k_norm_w # Fused MLP -> split ff_proj into gate_proj and up_proj # Molmo2: x, gate = ff_proj(x).chunk(2, dim=-1); act(gate) * x # First half of weight -> up (multiplied), second half -> gate (activated) fused_ff = hf_state.get(f'{oh}.mlp.ff_proj.weight') if fused_ff is not None: up_w, gate_w = fused_ff.chunk(2, dim=0) mapped_state[f'{pf}.ffn.gate_proj.weight'] = gate_w mapped_state[f'{pf}.ffn.up_proj.weight'] = up_w # Down projection ff_out = hf_state.get(f'{oh}.mlp.ff_out.weight') if ff_out is not None: mapped_state[f'{pf}.ffn.down_proj.weight'] = ff_out # Report skipped vision weights skipped_vision = sum(1 for k in hf_state if 'vision_backbone' in k) if skipped_vision > 0: print(f" Skipped {skipped_vision} vision backbone weights (text-only mode)") else: # === Standard Llama-style key mapping === key_map: Dict[str, str] = { 'model.embed_tokens.weight': 'embed.weight', 'model.norm.weight': 'norm.weight', } if not config.tie_word_embeddings: key_map['lm_head.weight'] = 'lm_head.weight' for i in range(config.num_layers): ph = f'model.layers.{i}' pf = f'layers.{i}' key_map.update({ f'{ph}.input_layernorm.weight': f'{pf}.norm1.weight', f'{ph}.post_attention_layernorm.weight': f'{pf}.norm2.weight', f'{ph}.self_attn.q_proj.weight': f'{pf}.attn.q_proj.weight', f'{ph}.self_attn.k_proj.weight': f'{pf}.attn.k_proj.weight', f'{ph}.self_attn.v_proj.weight': f'{pf}.attn.v_proj.weight', f'{ph}.self_attn.o_proj.weight': f'{pf}.attn.out_proj.weight', f'{ph}.mlp.gate_proj.weight': f'{pf}.ffn.gate_proj.weight', f'{ph}.mlp.up_proj.weight': f'{pf}.ffn.up_proj.weight', f'{ph}.mlp.down_proj.weight': f'{pf}.ffn.down_proj.weight', }) if config.attn_bias: key_map.update({ f'{ph}.self_attn.q_proj.bias': f'{pf}.attn.q_proj.bias', f'{ph}.self_attn.k_proj.bias': f'{pf}.attn.k_proj.bias', f'{ph}.self_attn.v_proj.bias': f'{pf}.attn.v_proj.bias', }) for hf_key, fe_key in key_map.items(): if hf_key in hf_state: mapped_state[fe_key] = hf_state[hf_key] # --- Load mapped weights --- missing, unexpected = engine.load_state_dict(mapped_state, strict=False) loaded = len(mapped_state) - len([k for k in mapped_state if k in missing]) print(f" Config: {config.dim}d, {config.num_layers}L, " f"{config.num_heads}H, kv_heads={config.num_kv_heads}") print(f" Mapped {loaded}/{len(mapped_state)} weight tensors") if missing: # Filter out expected missing keys (RoPE buffers, Hebbian, etc.) real_missing = [k for k in missing if 'rope_' not in k and 'kv_cache' not in k and 'hebbian' not in k and 'expert_usage' not in k # MoE tracking buffer and 'expert_tier' not in k # FE-MX tier buffer and not (config.tie_word_embeddings and k == 'lm_head.weight')] if real_missing: print(f" Missing (non-buffer): {real_missing[:10]}") # Move to target dtype and device engine = engine.to(dtype) # Estimate BF16 model size to decide loading strategy param_bytes = sum(p.numel() * p.element_size() for p in engine.parameters()) param_gb = param_bytes / 1e9 if 'cuda' in str(device) and torch.cuda.is_available(): gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9 if param_gb > gpu_mem * 0.85 and config.use_nvfp4: # Layer-by-layer quantized loading for large models # Move each layer to GPU, quantize QuantizedLinear weights, free originals import gc print(f" Large model ({param_gb:.1f}GB BF16 > {gpu_mem:.0f}GB VRAM), " f"using layer-by-layer quantized loading") # Non-layer params first (small: embed, norm, lm_head) engine.embed = engine.embed.cuda() engine.norm = engine.norm.cuda() if hasattr(engine, 'lm_head'): engine.lm_head = engine.lm_head.cuda() for i, layer in enumerate(engine.layers): layer = layer.cuda() # Quantize all QuantizedLinear in this layer immediately for module in layer.modules(): if hasattr(module, '_quantize_weights') and hasattr(module, '_quantized'): if not module._quantized: module._quantize_weights() # Free the original FP32/BF16 weight after quantization if module._quantized and hasattr(module, '_goliath_weights'): module.weight.data = torch.empty(0, device='cuda') engine.layers[i] = layer if (i + 1) % 8 == 0: gc.collect() torch.cuda.empty_cache() vram = torch.cuda.memory_allocated() / 1e9 print(f" Layer {i+1}/{config.num_layers}: {vram:.1f} GB VRAM") # Move remaining modules for name, module in engine.named_children(): if name not in ('embed', 'norm', 'lm_head', 'layers'): module.cuda() gc.collect() torch.cuda.empty_cache() vram = torch.cuda.memory_allocated() / 1e9 print(f" Final VRAM: {vram:.1f} GB (FP4 quantized)") else: engine = engine.cuda() else: engine = engine.to(device) total_params = sum(p.numel() for p in engine.parameters()) / 1e6 print(f" Total params: {total_params:.1f}M on {device} ({dtype})") return engine def _forward_hf(self, input_ids: torch.Tensor) -> torch.Tensor: """Forward using HF model (when loaded from pretrained).""" if hasattr(self, '_use_hf') and self._use_hf: return self._hf_model(input_ids).logits return self.forward(input_ids) # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def create_small_engine(dim: int = 768, num_layers: int = 12, num_heads: int = 12, use_nvfp4: bool = True) -> FireEchoEngine: """Create a small engine for testing.""" config = FireEchoConfig( dim=dim, num_heads=num_heads, num_kv_heads=num_heads, num_layers=num_layers, vocab_size=32000, intermediate_size=dim * 4, max_seq_len=4096, max_kv_blocks=256, # Supports 4k tokens use_nvfp4=use_nvfp4, quantize_weights=use_nvfp4, use_hebbian=True, hebbian_per_layer=False, ) return FireEchoEngine(config) def create_7b_engine(quantize: bool = True, long_context: bool = False) -> FireEchoEngine: """Create 7B-scale engine (Molmo2-O-7B config).""" config = FireEchoConfig.molmo2_7b() if not long_context: config.max_seq_len = 4096 # Reduce for quick testing config.use_nvfp4 = quantize config.quantize_weights = quantize config.use_hebbian = True return FireEchoEngine(config) def create_multimodal_engine() -> FireEchoEngine: """Create engine with vision and audio support.""" config = FireEchoConfig( dim=4096, num_heads=32, num_kv_heads=8, num_layers=32, vocab_size=32000, intermediate_size=14336, max_seq_len=32768, use_vision=True, use_audio=True, use_hebbian=True, ) return FireEchoEngine(config) # ============================================================================ # BENCHMARK # ============================================================================ def benchmark_engine(): """Comprehensive benchmark of the fused engine.""" print("\n" + "=" * 70) print("FIREECHO KERNEL v2 - FUSED ENGINE BENCHMARK") print("=" * 70) if not torch.cuda.is_available(): print("CUDA not available!") return props = torch.cuda.get_device_properties(0) print(f"\nGPU: {props.name}") print(f"VRAM: {props.total_memory / 1e9:.1f} GB") # Test 1: Small engine print("\n" + "-" * 50) print("TEST 1: Small Engine (768d, 12L, 12H)") print("-" * 50) engine = create_small_engine().cuda() mem = engine.memory_usage() print(f"Model: {mem['model_params']:.1f}M params, {mem['model_gb']:.3f} GB") configs = [(1, 128), (1, 512), (1, 1024), (1, 2048), (4, 512)] print(f"\n{'Config':>15} | {'Time':>10} | {'Tok/s':>12} | {'Cache':>10}") print("-" * 55) for batch, seq in configs: try: input_ids = torch.randint(0, 32000, (batch, seq), device='cuda') # Warmup (no KV cache for speed benchmark) for _ in range(3): _ = engine(input_ids, use_cache=False) torch.cuda.synchronize() # Benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(10): _ = engine(input_ids, use_cache=False) end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) / 10 tokens = batch * seq tok_per_sec = tokens / (ms / 1000) print(f"B={batch}, S={seq:>4} | {ms:>8.2f}ms | {tok_per_sec:>10,.0f} | -") except Exception as e: print(f"B={batch}, S={seq:>4} | ERROR: {str(e)[:35]}") # Test 2: Memory usage print("\n" + "-" * 50) print("TEST 2: Memory Usage") print("-" * 50) mem = engine.memory_usage() print(f"Model: {mem['model_gb']:.3f} GB") print(f"GPU Allocated: {mem['total_allocated_gb']:.3f} GB") print(f"GPU Reserved: {mem['total_reserved_gb']:.3f} GB") # Test 3: Generation print("\n" + "-" * 50) print("TEST 3: Generation Speed") print("-" * 50) # Clear memory first gc.collect() torch.cuda.empty_cache() prompt = torch.randint(0, 32000, (1, 64), device='cuda') # Warmup (no cache for stability) _ = engine.generate(prompt, max_new_tokens=10, use_cache=False) torch.cuda.synchronize() # Benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() output = engine.generate(prompt, max_new_tokens=50, use_cache=False) end.record() torch.cuda.synchronize() gen_time = start.elapsed_time(end) gen_tokens = output.shape[1] - prompt.shape[1] gen_tok_per_sec = gen_tokens / (gen_time / 1000) print(f"Generated {gen_tokens} tokens in {gen_time:.1f}ms") print(f"Generation speed: {gen_tok_per_sec:.1f} tok/s") # Test 4: Hebbian stats if engine.hebbian is not None: print("\n" + "-" * 50) print("TEST 4: Hebbian Memory") print("-" * 50) stats = engine.hebbian.get_stats() print(f"Updates: {stats.get('update_count', 'N/A')}") print(f"Memory norm: {stats.get('memory_norm', 'N/A'):.4f}") # Test 5: 7B config (smaller version for quick test) print("\n" + "-" * 50) print("TEST 5: 7B-Scale Engine (4096d, 8L, 32H)") print("-" * 50) # Clean up small engine del engine gc.collect() torch.cuda.empty_cache() try: # Use reduced layers for faster test (8 instead of 32) config_7b = FireEchoConfig( dim=4096, num_heads=32, num_kv_heads=32, # MHA like Molmo2-O-7B num_layers=8, # Reduced for testing vocab_size=100406, intermediate_size=11008, max_seq_len=4096, max_kv_blocks=256, use_nvfp4=True, # Now memory-optimized quantize_weights=True, use_hebbian=False, ) engine_7b = FireEchoEngine(config_7b).cuda() mem_7b = engine_7b.memory_usage() print(f"Model: {mem_7b['model_params']:.1f}M params, {mem_7b['model_gb']:.3f} GB") print(f"NVFP4: Enabled (memory-optimized)") print(f"Goliath dispatch: {'available' if _GOLIATH_AVAILABLE else 'not loaded'}") if _GOLIATH_AVAILABLE and _can_use_goliath_dot_scaled is not None: print(f"Goliath dot_scaled (native FP4 TCs): {'ACTIVE' if _can_use_goliath_dot_scaled() else 'False (BF16 dequant path)'}") # Test 7B forward configs_7b = [(1, 512), (1, 2048)] print(f"\n{'Config':>15} | {'Time':>10} | {'Tok/s':>12}") print("-" * 45) for batch, seq in configs_7b: try: input_ids = torch.randint(0, 32000, (batch, seq), device='cuda') # Warmup for _ in range(2): _ = engine_7b(input_ids, use_cache=False) torch.cuda.synchronize() # Benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(5): _ = engine_7b(input_ids, use_cache=False) end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) / 5 tok_per_sec = (batch * seq) / (ms / 1000) print(f"B={batch}, S={seq:>4} | {ms:>8.2f}ms | {tok_per_sec:>10,.0f}") except Exception as e: print(f"B={batch}, S={seq:>4} | ERROR: {str(e)[:30]}") # Clean up del engine_7b gc.collect() torch.cuda.empty_cache() except Exception as e: print(f"7B test failed: {str(e)[:50]}") # Final summary print("\n" + "=" * 70) print("SUMMARY") print("=" * 70) print("Features production-ready:") print(" - Hybrid MatMul (Triton/cuBLAS auto-dispatch)") print(" - NVFP4 Quantization (4-bit, memory-optimized)") print(" - Paged KV Cache (vLLM-style, batch ops)") print(" - Flash Attention (SDPA)") print(" - Hebbian Memory (decay + gated)") print(" - GQA Support (grouped-query attention)") print(" - Multimodal (Vision + Audio encoders)") print("=" * 70 + "\n") def benchmark_long_context(): """Benchmark long context generation.""" print("\n" + "=" * 70) print("LONG CONTEXT BENCHMARK (32k tokens)") print("=" * 70) if not torch.cuda.is_available(): return config = FireEchoConfig( dim=768, num_heads=12, num_kv_heads=12, num_layers=6, vocab_size=32000, intermediate_size=3072, max_seq_len=32768, max_kv_blocks=4096, use_nvfp4=False, use_hebbian=False, ) engine = FireEchoEngine(config).cuda() print(f"KV Cache capacity: {engine.kv_cache.capacity_tokens():,} tokens") # Test increasing context lengths for ctx_len in [1024, 4096, 8192, 16384, 32768]: try: gc.collect() torch.cuda.empty_cache() engine.reset_cache() input_ids = torch.randint(0, 32000, (1, ctx_len), device='cuda') start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() _ = engine(input_ids, use_cache=True) end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) tok_per_sec = ctx_len / (ms / 1000) mem_gb = torch.cuda.memory_allocated() / 1e9 print(f" {ctx_len:>6} tokens: {ms:>8.1f}ms, {tok_per_sec:>8,.0f} tok/s, {mem_gb:.2f} GB") except Exception as e: print(f" {ctx_len:>6} tokens: ERROR - {str(e)[:40]}") break print("=" * 70 + "\n") def benchmark_fused_kernels(): """Benchmark the fused kernels. Reports actual runtime state of optional backends.""" if not torch.cuda.is_available(): print("CUDA not available") return # Report actual GPU and runtime state (no hardcoded claims) cap = torch.cuda.get_device_capability(0) gpu_name = torch.cuda.get_device_name(0) or "Unknown" print(f"\nGPU: {gpu_name}") print(f"Compute capability: {cap[0]}.{cap[1]}") try: print(f"Triton: loaded (version {getattr(triton, '__version__', 'unknown')})") except Exception: print("Triton: loaded") device = torch.device('cuda') print("\n" + "-" * 50) print("1. MATMUL PERFORMANCE (cuBLAS)") print("-" * 50) print(f"{'Config (M,K,N)':>20} | {'Time':>10} | {'TFLOPS':>10}") print("-" * 45) matmul_configs = [ (2048, 4096, 4096), (4096, 4096, 4096), (4096, 4096, 11008), (8192, 4096, 4096), ] for M, K, N in matmul_configs: a = torch.randn(M, K, device=device, dtype=torch.bfloat16) b = torch.randn(K, N, device=device, dtype=torch.bfloat16) # Warmup for _ in range(5): _ = torch.matmul(a, b) torch.cuda.synchronize() # Benchmark start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(20): _ = torch.matmul(a, b) end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) / 20 flops = 2 * M * N * K tflops = flops / (ms * 1e9) print(f"({M},{K},{N}){' ':>6} | {ms:>8.2f}ms | {tflops:>8.1f}") print("\n" + "-" * 50) print("2. L2 CACHE MANAGER") print("-" * 50) cache_mgr = L2CacheManager(l2_size_mb=128.0, reserve_fraction=0.5) # Pin some test tensors test_weights = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16) pinned = cache_mgr.pin("test_weight", test_weights) print(f"Pinned 4096x4096 BF16 tensor: {pinned}") stats = cache_mgr.stats() print(f"HW pinning (cudaAccessPolicyWindow): {stats['hw_pinning']}") print(f"L2 Reserved: {stats['reserved_mb']:.1f} MB") print(f"L2 Pinned: {stats['pinned_mb']:.1f} MB") print(f"Utilization: {stats['utilization']*100:.1f}%") cache_mgr.clear() print("\n" + "=" * 70) print("OPTIONAL BACKENDS (runtime)") print("=" * 70) print(f" CUTLASS (TMA MatMul / TMA Attention / L2): {'available' if _CUTLASS_AVAILABLE else 'not loaded'}") print(f" Quantum Gold (quantum_optimized_matmul): {'available' if _QUANTUM_AVAILABLE else 'not loaded'}") print(f" DSMEM cluster (cluster_matmul_dsmem): {'available' if _DSMEM_AVAILABLE else 'not loaded'}") print(f" Goliath FP4/FP8 fused dequant-matmul: {'available' if _GOLIATH_AVAILABLE else 'not loaded'}") if _GOLIATH_AVAILABLE and _can_use_goliath_dot_scaled is not None: dot_scaled_ok = _can_use_goliath_dot_scaled() print(f" Goliath native dot_scaled (MXFP4 TCs): {'ACTIVE' if dot_scaled_ok else 'probe False (BF16 fallback)'}") else: print(f" Goliath native dot_scaled (MXFP4 TCs): not available") print("\nBuilt-in kernels (always present):") print(" - Hybrid MatMul (Triton / cuBLAS dispatch)") print(" - Fused QKV, Fused SwiGLU, Persistent GEMM, Split-K") print(" - NVFP4 Quantization, Paged KV Cache, Flash Attention (SDPA)") print(" - GQA, Multimodal, Hebbian Memory, L2 Cache Manager") print("=" * 70 + "\n") # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": benchmark_fused_kernels() # Test new Phase 1 kernels benchmark_engine() # benchmark_long_context() # Uncomment for long context test