File size: 5,976 Bytes
0769ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
ENGRAM Protocol β€” Model Architecture Registry


Contains ModelCacheSpec definitions for known models and utilities
to look up specs by model_id or infer model family from string.

D3: extraction_layers set to middle-to-deep (8-31 for 32-layer models)
per ShadowKV validation. Early layers (0-7) and final layer preserved.
"""

from __future__ import annotations

from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec

# ── Pre-registered Model Specs ────────────────────────────────────────────────

# Llama 3.1 8B β€” Primary Phase 1 target (D1, D6)
# GQA: 32 query heads, 8 KV heads, head_dim 128
LLAMA_3_1_8B = ModelCacheSpec(
    model_id="meta-llama/Llama-3.1-8B-Instruct",
    model_family="llama",
    n_layers=32,
    n_heads=32,
    n_kv_heads=8,
    head_dim=128,
    rope_enabled=True,
    extraction_layers=tuple(range(8, 32)),  # layers 8-31 (D3)
)

# Llama 3.1 8B base (non-instruct)
LLAMA_3_1_8B_BASE = ModelCacheSpec(
    model_id="meta-llama/Llama-3.1-8B",
    model_family="llama",
    n_layers=32,
    n_heads=32,
    n_kv_heads=8,
    head_dim=128,
    rope_enabled=True,
    extraction_layers=tuple(range(8, 32)),
)

# Phi-3-Mini-128K β€” Secondary Phase 1 target
# ShadowKV validated SVD on this model (D3)
# MHA: 32 query heads, 32 KV heads (no GQA), head_dim 96
PHI_3_MINI = ModelCacheSpec(
    model_id="microsoft/Phi-3-mini-128k-instruct",
    model_family="phi",
    n_layers=32,
    n_heads=32,
    n_kv_heads=32,  # Phi-3-Mini uses MHA, not GQA
    head_dim=96,
    rope_enabled=True,
    extraction_layers=tuple(range(8, 32)),
)

# Gemma 2 2B β€” NOTE: QK-Norm model, SVD behavior may differ (T3 caveat)
GEMMA_2_2B = ModelCacheSpec(
    model_id="google/gemma-2-2b-it",
    model_family="gemma",
    n_layers=26,
    n_heads=8,
    n_kv_heads=4,
    head_dim=256,
    rope_enabled=True,
    extraction_layers=tuple(range(6, 26)),
)

# Qwen 2.5 7B
QWEN_2_5_7B = ModelCacheSpec(
    model_id="Qwen/Qwen2.5-7B-Instruct",
    model_family="qwen",
    n_layers=28,
    n_heads=28,
    n_kv_heads=4,
    head_dim=128,
    rope_enabled=True,
    extraction_layers=tuple(range(7, 28)),
)

# Mistral 7B v0.3
MISTRAL_7B = ModelCacheSpec(
    model_id="mistralai/Mistral-7B-Instruct-v0.3",
    model_family="mistral",
    n_layers=32,
    n_heads=32,
    n_kv_heads=8,
    head_dim=128,
    rope_enabled=True,
    extraction_layers=tuple(range(8, 32)),
)


# Gemma 4 26B-A4B β€” ISWA model (Interleaved Sliding Window Attention)
# Dual KV cache: Global (full context) + SWA (sliding window 1024 tokens)
# MoE: 128 experts, 8 active β€” does NOT affect KV cache (FFN-only)
# Reverse-engineered from llama.cpp b5200+ state blob format.
GEMMA_4_26B_A4B = ModelCacheSpec(
    model_id="google/gemma-4-26b-a4b-it",
    model_family="gemma",
    n_layers=30,   # total: 5 global + 25 SWA
    n_heads=32,
    n_kv_heads=8,  # dominant section (SWA)
    head_dim=256,  # dominant section (SWA)
    rope_enabled=True,
    extraction_layers=tuple(range(8, 30)),
    cache_sections=(
        CacheSection(
            attention_type=AttentionType.FULL,
            n_layers=5,
            n_kv_heads=2,
            head_dim=512,
        ),
        CacheSection(
            attention_type=AttentionType.SLIDING,
            n_layers=25,
            n_kv_heads=8,
            head_dim=256,
            window_size=1024,
        ),
    ),
)


# ── Registry ──────────────────────────────────────────────────────────────────

_REGISTRY: dict[str, ModelCacheSpec] = {
    spec["model_id"]: spec
    for spec in [
        LLAMA_3_1_8B,
        LLAMA_3_1_8B_BASE,
        PHI_3_MINI,
        GEMMA_2_2B,
        GEMMA_4_26B_A4B,
        QWEN_2_5_7B,
        MISTRAL_7B,
    ]
}

_FAMILY_MAP: dict[str, str] = {
    "llama": "llama",
    "meta-llama": "llama",
    "phi": "phi",
    "microsoft/phi": "phi",
    "gemma": "gemma",
    "google/gemma": "gemma",
    "qwen": "qwen",
    "mistral": "mistral",
    "deepseek": "deepseek",
}


def get_model_spec(model_id: str) -> ModelCacheSpec | None:
    """Look up a ModelCacheSpec by exact model_id."""
    return _REGISTRY.get(model_id)


def register_model_spec(spec: ModelCacheSpec) -> None:
    """Register a new model spec in the runtime registry."""
    _REGISTRY[spec["model_id"]] = spec


def infer_model_family(model_id: str) -> str:
    """Infer model family from a model_id string."""
    model_id_lower = model_id.lower()
    for prefix, family in _FAMILY_MAP.items():
        if prefix in model_id_lower:
            return family
    return "unknown"


def make_spec_from_metadata(
    model_id: str,
    n_layers: int,
    n_heads: int,
    n_kv_heads: int,
    head_dim: int,
    rope_enabled: bool = True,
) -> ModelCacheSpec:
    """Create a ModelCacheSpec from raw parameters.

    Automatically sets extraction_layers to middle-to-deep range (D3).
    """
    skip_layers = max(1, n_layers // 4)
    extraction_layers = tuple(range(skip_layers, n_layers))

    return ModelCacheSpec(
        model_id=model_id,
        model_family=infer_model_family(model_id),
        n_layers=n_layers,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        head_dim=head_dim,
        rope_enabled=rope_enabled,
        extraction_layers=extraction_layers,
    )


def is_iswa_spec(spec: ModelCacheSpec) -> bool:
    """Check if a model spec describes an ISWA (multi-section) cache."""
    return "cache_sections" in spec


def validate_kv_shape(
    spec: ModelCacheSpec,
    n_layers: int,
    n_kv_heads: int,
    head_dim: int,
) -> bool:
    """Validate that KV tensor dimensions match the model spec."""
    return (
        spec["n_layers"] == n_layers
        and spec["n_kv_heads"] == n_kv_heads
        and spec["head_dim"] == head_dim
    )