File size: 5,976 Bytes
0769ff3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """
ENGRAM Protocol β Model Architecture Registry
Contains ModelCacheSpec definitions for known models and utilities
to look up specs by model_id or infer model family from string.
D3: extraction_layers set to middle-to-deep (8-31 for 32-layer models)
per ShadowKV validation. Early layers (0-7) and final layer preserved.
"""
from __future__ import annotations
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
# ββ Pre-registered Model Specs ββββββββββββββββββββββββββββββββββββββββββββββββ
# Llama 3.1 8B β Primary Phase 1 target (D1, D6)
# GQA: 32 query heads, 8 KV heads, head_dim 128
LLAMA_3_1_8B = ModelCacheSpec(
model_id="meta-llama/Llama-3.1-8B-Instruct",
model_family="llama",
n_layers=32,
n_heads=32,
n_kv_heads=8,
head_dim=128,
rope_enabled=True,
extraction_layers=tuple(range(8, 32)), # layers 8-31 (D3)
)
# Llama 3.1 8B base (non-instruct)
LLAMA_3_1_8B_BASE = ModelCacheSpec(
model_id="meta-llama/Llama-3.1-8B",
model_family="llama",
n_layers=32,
n_heads=32,
n_kv_heads=8,
head_dim=128,
rope_enabled=True,
extraction_layers=tuple(range(8, 32)),
)
# Phi-3-Mini-128K β Secondary Phase 1 target
# ShadowKV validated SVD on this model (D3)
# MHA: 32 query heads, 32 KV heads (no GQA), head_dim 96
PHI_3_MINI = ModelCacheSpec(
model_id="microsoft/Phi-3-mini-128k-instruct",
model_family="phi",
n_layers=32,
n_heads=32,
n_kv_heads=32, # Phi-3-Mini uses MHA, not GQA
head_dim=96,
rope_enabled=True,
extraction_layers=tuple(range(8, 32)),
)
# Gemma 2 2B β NOTE: QK-Norm model, SVD behavior may differ (T3 caveat)
GEMMA_2_2B = ModelCacheSpec(
model_id="google/gemma-2-2b-it",
model_family="gemma",
n_layers=26,
n_heads=8,
n_kv_heads=4,
head_dim=256,
rope_enabled=True,
extraction_layers=tuple(range(6, 26)),
)
# Qwen 2.5 7B
QWEN_2_5_7B = ModelCacheSpec(
model_id="Qwen/Qwen2.5-7B-Instruct",
model_family="qwen",
n_layers=28,
n_heads=28,
n_kv_heads=4,
head_dim=128,
rope_enabled=True,
extraction_layers=tuple(range(7, 28)),
)
# Mistral 7B v0.3
MISTRAL_7B = ModelCacheSpec(
model_id="mistralai/Mistral-7B-Instruct-v0.3",
model_family="mistral",
n_layers=32,
n_heads=32,
n_kv_heads=8,
head_dim=128,
rope_enabled=True,
extraction_layers=tuple(range(8, 32)),
)
# Gemma 4 26B-A4B β ISWA model (Interleaved Sliding Window Attention)
# Dual KV cache: Global (full context) + SWA (sliding window 1024 tokens)
# MoE: 128 experts, 8 active β does NOT affect KV cache (FFN-only)
# Reverse-engineered from llama.cpp b5200+ state blob format.
GEMMA_4_26B_A4B = ModelCacheSpec(
model_id="google/gemma-4-26b-a4b-it",
model_family="gemma",
n_layers=30, # total: 5 global + 25 SWA
n_heads=32,
n_kv_heads=8, # dominant section (SWA)
head_dim=256, # dominant section (SWA)
rope_enabled=True,
extraction_layers=tuple(range(8, 30)),
cache_sections=(
CacheSection(
attention_type=AttentionType.FULL,
n_layers=5,
n_kv_heads=2,
head_dim=512,
),
CacheSection(
attention_type=AttentionType.SLIDING,
n_layers=25,
n_kv_heads=8,
head_dim=256,
window_size=1024,
),
),
)
# ββ Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_REGISTRY: dict[str, ModelCacheSpec] = {
spec["model_id"]: spec
for spec in [
LLAMA_3_1_8B,
LLAMA_3_1_8B_BASE,
PHI_3_MINI,
GEMMA_2_2B,
GEMMA_4_26B_A4B,
QWEN_2_5_7B,
MISTRAL_7B,
]
}
_FAMILY_MAP: dict[str, str] = {
"llama": "llama",
"meta-llama": "llama",
"phi": "phi",
"microsoft/phi": "phi",
"gemma": "gemma",
"google/gemma": "gemma",
"qwen": "qwen",
"mistral": "mistral",
"deepseek": "deepseek",
}
def get_model_spec(model_id: str) -> ModelCacheSpec | None:
"""Look up a ModelCacheSpec by exact model_id."""
return _REGISTRY.get(model_id)
def register_model_spec(spec: ModelCacheSpec) -> None:
"""Register a new model spec in the runtime registry."""
_REGISTRY[spec["model_id"]] = spec
def infer_model_family(model_id: str) -> str:
"""Infer model family from a model_id string."""
model_id_lower = model_id.lower()
for prefix, family in _FAMILY_MAP.items():
if prefix in model_id_lower:
return family
return "unknown"
def make_spec_from_metadata(
model_id: str,
n_layers: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
rope_enabled: bool = True,
) -> ModelCacheSpec:
"""Create a ModelCacheSpec from raw parameters.
Automatically sets extraction_layers to middle-to-deep range (D3).
"""
skip_layers = max(1, n_layers // 4)
extraction_layers = tuple(range(skip_layers, n_layers))
return ModelCacheSpec(
model_id=model_id,
model_family=infer_model_family(model_id),
n_layers=n_layers,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
rope_enabled=rope_enabled,
extraction_layers=extraction_layers,
)
def is_iswa_spec(spec: ModelCacheSpec) -> bool:
"""Check if a model spec describes an ISWA (multi-section) cache."""
return "cache_sections" in spec
def validate_kv_shape(
spec: ModelCacheSpec,
n_layers: int,
n_kv_heads: int,
head_dim: int,
) -> bool:
"""Validate that KV tensor dimensions match the model spec."""
return (
spec["n_layers"] == n_layers
and spec["n_kv_heads"] == n_kv_heads
and spec["head_dim"] == head_dim
)
|