File size: 24,226 Bytes
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
 
 
 
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
01426da
 
 
002163f
 
 
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
 
 
 
 
01426da
 
002163f
 
 
 
 
 
 
01426da
 
 
002163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
002163f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
"""
Sovereign Hive HSAQ β€” KV-cache sensitivity profiler
====================================================

Sweeps each attention layer across a curated 11-config grid of
(k_bits, v_bits, quantizer) combinations and measures the resulting
attention-output drift on calibration data. Emits ProfileRow objects
ready for insertion into the `kv_sensitivity_profile` Vault table
(migration 003).

Why this file exists:
  The KV cache is the bigger lever than weight quantization for MHA
  models on 12 GB cards (e.g. OLMo: ~3.3 GB KV at 4K ctx fp16, more than
  the weights). The allocator needs measured drift values to pick which
  layers can tolerate aggressive KV quant. This profiler is what feeds it.

Pure-logic conventions (matching candidate_record.py):
  - No Vault writes. Caller persists the emitted ProfileRows.
  - No PermissionGate routing here. Caller (the hunter agent) handles that.
  - All disk/network I/O is the caller's job. We accept a loaded model
    and tokenizer; we return data structures.

What this DOES touch:
  - GPU forward passes (calls `model(**batch)` under torch.no_grad).
  - kv_intercept.kv_quant_active context manager to install hooks.
  - Hook lifecycle (always cleaned up via context manager β€” verified in
    smoke_test_v3.py test #6).

What this does NOT do:
  - It does not handle inference-queue gating. If two profilers run
    concurrently on the same GPU, you'll OOM. Caller must serialize.
  - It does not write to disk. Output is in-memory ProfileRow objects.
  - It does not bump pipeline_version. The constant lives here as a
    reference value; callers wanting a different lineage pass their own.

Config sweep rationale (11 per layer):
  Full Cartesian (5 k_bits Γ— 5 v_bits Γ— 4 quantizers) = 100/layer is way
  overkill β€” most combinations carry no signal the allocator can use.
  This curated 11 covers the decision boundaries that actually matter:

    Symmetric K=V (the common case):       (8,8), (4,4), (3,3), (2,2)  hqq_g64
    K-cheaper-than-V (K more tolerant):    (4,8), (3,8), (2,8), (2,4)  hqq_g64
    Quantizer comparison:                  (4,4) scaled_uniform
                                           (4,4) scaled_per_head
                                           (8,8) scaled_uniform

Strategy: "joint" (default) vs "isolated"

  "joint" hooks every layer at the same config in ONE forward pass via
    kv_quant_active_multi. Total cost: 11 forwards. Drift values reflect
    CUMULATIVE CASCADE β€” each layer's attention sees the quanted output of
    earlier layers' attention. This is realistic to deployment.

  "isolated" hooks ONE layer at a time, runs a forward per (layer, config)
    pair. Total cost: 11 Γ— n_layers forwards. Drift values reflect MARGINAL
    per-layer sensitivity β€” what happens when only that layer is quanted.

  Empirically (A100, OLMo-2-13B, 32 prompts, 1 GB KV budget) β€” see
  mxguru1/hsaq-strategy-comparison/report.json:
    - joint took 64.3 s, isolated took 665.6 s (10.4Γ— slower)
    - isolated drift values are uniformly LOWER (ratios 0.12-0.61 by config)
      because the joint pass compounds drift through the residual stream
    - BUT both strategies produced IDENTICAL per-layer allocations (40/40
      agreement) when piped through assign_kv_bits β€” the greedy allocator
      only cares about pairwise drift-per-byte ratios, and the ordering of
      configs by drift is preserved across strategies

  Default is "joint" because: same allocation, 10Γ— faster, drift values
  are cascade-realistic. Switch to "isolated" only when you want the
  marginal per-layer signal for analysis (e.g. ranking pruning candidates,
  understanding which layers cause the most ripple downstream).
"""

from __future__ import annotations

import hashlib
import json
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Callable, Iterable, Literal, Optional

logger = logging.getLogger("HSAQ.KVProfiler")

# Bump when changing the config sweep, drift metric implementation, or
# hook semantics. Cached rows under earlier versions become lookup-misses
# so they're safely ignored rather than silently reused.
PIPELINE_VERSION = "1.0.0"


# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------


DriftMetric = Literal["mse_normalised", "kl_softmax"]
KVQuantizer = Literal["hqq_g64", "scaled_uniform", "scaled_per_head", "fp16_passthrough"]


@dataclass(frozen=True)
class SweepConfig:
    """One (k_bits, v_bits, quantizer) probe to measure on every layer."""
    k_bits: int
    v_bits: int
    quantizer: KVQuantizer
    group_size: int = 64


@dataclass
class ProfileRow:
    """One row of measured KV sensitivity. Shape matches the Vault table
    kv_sensitivity_profile (migration 003) exactly β€” caller can insert
    this dict directly or pass through a Vault adapter."""
    model_hash: str
    calibration_hash: str
    pipeline_version: str
    layer_idx: int
    k_bits: int
    v_bits: int
    quantizer: str
    drift_attn_output: float
    drift_metric: str
    bytes_per_kv_token: float
    max_seq_len_observed: int
    num_kv_heads: int
    head_dim: int
    profiled_at: str
    profiled_by_agent_id: str
    profiled_by_agent_tier: int

    def to_vault_payload(self) -> dict:
        """Row-shaped dict ready for INSERT into kv_sensitivity_profile."""
        return {
            "model_hash": self.model_hash,
            "calibration_hash": self.calibration_hash,
            "pipeline_version": self.pipeline_version,
            "layer_idx": self.layer_idx,
            "k_bits": self.k_bits,
            "v_bits": self.v_bits,
            "quantizer": self.quantizer,
            "drift_attn_output": self.drift_attn_output,
            "drift_metric": self.drift_metric,
            "bytes_per_kv_token": self.bytes_per_kv_token,
            "max_seq_len_observed": self.max_seq_len_observed,
            "num_kv_heads": self.num_kv_heads,
            "head_dim": self.head_dim,
            "profiled_at": self.profiled_at,
            "profiled_by_agent_id": self.profiled_by_agent_id,
            "profiled_by_agent_tier": self.profiled_by_agent_tier,
        }


# ---------------------------------------------------------------------------
# Default sweep β€” the curated 11 configs
# ---------------------------------------------------------------------------


DEFAULT_SWEEP: tuple[SweepConfig, ...] = (
    # ── Symmetric K=V (most common case, 4 configs) ──────────────────────
    SweepConfig(k_bits=8, v_bits=8, quantizer="hqq_g64"),
    SweepConfig(k_bits=4, v_bits=4, quantizer="hqq_g64"),
    SweepConfig(k_bits=3, v_bits=3, quantizer="hqq_g64"),
    SweepConfig(k_bits=2, v_bits=2, quantizer="hqq_g64"),
    # ── K-cheaper-than-V (K is more error-tolerant, 4 configs) ───────────
    SweepConfig(k_bits=4, v_bits=8, quantizer="hqq_g64"),
    SweepConfig(k_bits=3, v_bits=8, quantizer="hqq_g64"),
    SweepConfig(k_bits=2, v_bits=8, quantizer="hqq_g64"),
    SweepConfig(k_bits=2, v_bits=4, quantizer="hqq_g64"),
    # ── Quantizer comparison (3 configs) ──────────────────────────────────
    SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_uniform"),
    SweepConfig(k_bits=4, v_bits=4, quantizer="scaled_per_head"),
    SweepConfig(k_bits=8, v_bits=8, quantizer="scaled_uniform"),
)


# ---------------------------------------------------------------------------
# Bytes-per-KV-token accounting
# ---------------------------------------------------------------------------


def kv_bytes_per_token(
    num_kv_heads: int,
    head_dim: int,
    k_bits: int,
    v_bits: int,
    quantizer: KVQuantizer,
    group_size: int = 64,
) -> float:
    """Bytes per token of KV cache for one layer at the given config.

    Includes quantizer-specific overhead (scales / zeros). The allocator
    reads this value cached in the Vault rather than recomputing, so the
    formula here is the single source of truth β€” change it, bump
    PIPELINE_VERSION.
    """
    elems = num_kv_heads * head_dim

    if quantizer == "fp16_passthrough":
        # 2 bytes/elem Γ— K and V both
        return 2 * elems * 2  # K + V at fp16 (2 bytes/elem)

    if quantizer == "scaled_uniform":
        # One fp16 scale per row of K and V each
        k_payload = elems * k_bits / 8
        v_payload = elems * v_bits / 8
        return (k_payload + 2) + (v_payload + 2)

    if quantizer == "scaled_per_head":
        # fp16 scale per head, per K/V
        k_payload = elems * k_bits / 8
        v_payload = elems * v_bits / 8
        return (k_payload + num_kv_heads * 2) + (v_payload + num_kv_heads * 2)

    if quantizer == "hqq_g64":
        # Groups along the head_dim axis (per row of the KV cache).
        # NOTE: this is the corrected accounting β€” earlier stub grouped
        # along (num_kv_heads Γ— head_dim) which underestimated overhead.
        groups_per_row = max(1, head_dim // group_size)
        # zero + scale per group, fp16 each = 4 bytes per group
        overhead_per_row = num_kv_heads * groups_per_row * 4
        k_payload = elems * k_bits / 8
        v_payload = elems * v_bits / 8
        return (k_payload + overhead_per_row) + (v_payload + overhead_per_row)

    raise ValueError(f"unknown quantizer: {quantizer}")


# ---------------------------------------------------------------------------
# Drift metrics
# ---------------------------------------------------------------------------


def compute_drift(quanted, baseline, metric: DriftMetric) -> float:
    """Compute attention-output drift between quanted and fp16 baseline.

    mse_normalised: ((quanted - baseline)^2).mean() / baseline.abs().mean()^2
        Scale-invariant. Recommended for allocator inputs. KIVI-style.

    kl_softmax: KL divergence treating attention outputs as logits over the
        last dim. PROVISIONAL β€” attention output isn't a distribution, so
        this is a proxy metric only. Kept for compatibility with the
        Vault schema; do NOT use as allocator input without validation.
    """
    import torch
    import torch.nn.functional as F

    if metric == "mse_normalised":
        mse = ((quanted - baseline) ** 2).mean().item()
        denom = (baseline ** 2).mean().item()
        return mse / max(denom, 1e-8)

    if metric == "kl_softmax":
        a = F.log_softmax(quanted.float().flatten(0, -2), dim=-1)
        b = F.softmax(baseline.float().flatten(0, -2), dim=-1)
        return float(F.kl_div(a, b, reduction="batchmean"))

    raise ValueError(f"unknown drift metric: {metric}")


# ---------------------------------------------------------------------------
# Calibration hash
# ---------------------------------------------------------------------------


def compute_calibration_hash(
    calibration_texts: list[str],
    max_seq_len: int,
) -> str:
    """Deterministic hash of calibration inputs for cache invalidation."""
    payload = json.dumps({
        "n_texts": len(calibration_texts),
        "max_seq_len": max_seq_len,
        # First/last text content matters for distinguishing calibration sets;
        # hash a sample rather than every text to keep this cheap.
        "first_text": calibration_texts[0][:200] if calibration_texts else "",
        "last_text": calibration_texts[-1][:200] if calibration_texts else "",
        "total_chars": sum(len(t) for t in calibration_texts),
    }, sort_keys=True)
    return hashlib.sha256(payload.encode()).hexdigest()[:16]


# ---------------------------------------------------------------------------
# Per-layer attention-output capture
# ---------------------------------------------------------------------------


def _capture_attn_outputs(
    model,
    attn_modules_by_layer: dict[int, object],
    batch,
) -> dict[int, "torch.Tensor"]:
    """Run one forward pass with hooks that capture attention output per layer.

    HF attention modules return either `attn_output` directly or a tuple
    starting with `attn_output`. We grab the first tensor element in either
    case.
    """
    import torch

    captured: dict[int, torch.Tensor] = {}
    handles: list = []

    def make_hook(li: int):
        def hook(_mod, _inp, out):
            t = out[0] if isinstance(out, tuple) else out
            if isinstance(t, torch.Tensor):
                # Move to CPU to bound GPU memory across all captured layers
                captured[li] = t.detach().to("cpu")
        return hook

    try:
        for li, attn in attn_modules_by_layer.items():
            handles.append(attn.register_forward_hook(make_hook(li)))
        with torch.no_grad():
            model(**batch, use_cache=False)
    finally:
        for h in handles:
            h.remove()

    return captured


# ---------------------------------------------------------------------------
# Profiler
# ---------------------------------------------------------------------------


def profile_kv_sensitivity(
    model,
    tokenizer,
    calibration_texts: list[str],
    *,
    model_hash: str,
    profiled_by_agent_id: str,
    profiled_by_agent_tier: int,
    sweep: Iterable[SweepConfig] = DEFAULT_SWEEP,
    drift_metric: DriftMetric = "mse_normalised",
    max_seq_len: int = 1024,
    pipeline_version: str = PIPELINE_VERSION,
    strategy: Literal["joint", "isolated"] = "joint",
    progress_cb: Optional[Callable[[str], None]] = None,
) -> list[ProfileRow]:
    """Measure attention-output drift per (layer Γ— sweep config).

    Args:
        model: HF model in eval mode on its target device.
        tokenizer: Matching HF tokenizer.
        calibration_texts: List of natural-language strings to feed through
            the model. Length controls statistical robustness; quality of
            the distribution match controls allocator decisions on out-of-
            distribution data. 32-256 samples is the sweet spot.
        model_hash: Deterministic identifier of the model (caller-computed).
            Used as the first part of the cache invalidation key.
        profiled_by_agent_id, profiled_by_agent_tier: Audit chain. Caller
            is whatever agent invoked the profile (e.g. hunter at tier 2).
        sweep: Iterable of SweepConfigs to measure. Defaults to the curated
            11-config sweep documented in the module header.
        drift_metric: "mse_normalised" (recommended) or "kl_softmax"
            (provisional β€” see compute_drift docstring).
        max_seq_len: Tokenizer truncation length and the max_seq_len_observed
            column for emitted rows.
        pipeline_version: Override only when running a deliberately
            distinct lineage (e.g. comparing two metric implementations).
        strategy: "joint" (default) hooks every layer at the same config in
            one forward pass β€” fast, cumulative cascade drift. "isolated"
            loops per (layer, config) β€” slower, marginal per-layer drift.
            See module docstring for the empirical equivalence finding.
        progress_cb: Optional callback for progress messages.

    Returns:
        List of ProfileRow objects, one per (layer, sweep_config) pair.
        Caller persists these into kv_sensitivity_profile.

    Raises:
        RuntimeError: If the model doesn't expose Llama-family attention.
        Any tokenizer / forward-pass errors propagate.
    """
    # Local imports β€” keeps module importable without torch present.
    import torch  # noqa: F401
    from kv_intercept import (
        KVQuantSpec,
        find_attention_modules,
        kv_quant_active,
        kv_quant_active_multi,
    )

    if strategy not in ("joint", "isolated"):
        raise ValueError(f"unknown strategy: {strategy!r}; want 'joint' or 'isolated'")

    def log(msg: str) -> None:
        if progress_cb:
            progress_cb(msg)
        else:
            logger.info(msg)

    sweep_list = list(sweep)
    if not sweep_list:
        raise ValueError("Empty sweep β€” at least one SweepConfig required")
    if not calibration_texts:
        raise ValueError("No calibration_texts provided")

    # ── 1. Locate attention modules ─────────────────────────────────────
    attn_modules = find_attention_modules(model)
    n_layers = len(attn_modules)
    log(f"[kv-prof] discovered {n_layers} attention layers")

    # ── 2. Tokenize calibration set ─────────────────────────────────────
    batch = tokenizer(
        calibration_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_len,
    ).to(model.device)
    actual_seq_len = batch.input_ids.shape[1]
    log(f"[kv-prof] tokenised {len(calibration_texts)} prompts, "
        f"seq_len={actual_seq_len}")

    # ── 3. Compute calibration hash ─────────────────────────────────────
    calibration_hash = compute_calibration_hash(calibration_texts, max_seq_len)
    log(f"[kv-prof] calibration_hash={calibration_hash}")

    # ── 4. Capture full-precision baseline attention outputs ────────────
    log("[kv-prof] capturing fp baseline attention outputs (1 forward pass)")
    t_start = time.time()
    baseline_outputs = _capture_attn_outputs(model, attn_modules, batch)
    log(f"[kv-prof] baseline captured in {time.time() - t_start:.1f}s")

    if len(baseline_outputs) != n_layers:
        logger.warning(
            "Captured baselines for %d/%d layers β€” some hooks didn't fire",
            len(baseline_outputs), n_layers,
        )

    # ── 5. Per-layer dimensions ─────────────────────────────────────────
    # Pulled from model.config β€” Llama-family models report these directly.
    # If a model has per-layer variation (very rare), this is the first
    # place to patch.
    num_kv_heads = getattr(
        model.config, "num_key_value_heads",
        getattr(model.config, "num_attention_heads", 1),
    )
    head_dim = getattr(
        model.config, "head_dim", None,
    ) or (model.config.hidden_size // model.config.num_attention_heads)
    log(f"[kv-prof] num_kv_heads={num_kv_heads}, head_dim={head_dim}")

    # ── 6. Sweep loop ───────────────────────────────────────────────────
    # joint:    outer = configs, inner = (one multi-layer hook + 1 forward).
    #           Total len(sweep) forwards. Drift is cumulative cascade.
    # isolated: outer = configs, middle = layers, inner = (one single-layer
    #           hook + 1 forward). Total len(sweep) * n_layers forwards.
    #           Drift is marginal per-layer.
    rows: list[ProfileRow] = []
    n_configs = len(sweep_list)
    if strategy == "joint":
        log(f"[kv-prof] strategy=joint: {n_configs} configs Γ— {n_layers} layers "
            f"= {n_configs * n_layers} measurements ({n_configs} forward passes)")
    else:
        log(f"[kv-prof] strategy=isolated: {n_configs} configs Γ— {n_layers} layers "
            f"= {n_configs * n_layers} measurements "
            f"({n_configs * n_layers} forward passes)")

    profiled_at = datetime.now(UTC).isoformat()

    def _record(li: int, cfg: SweepConfig, bpt: float, quanted_output) -> None:
        drift = compute_drift(
            quanted_output,
            baseline_outputs[li],
            drift_metric,
        )
        rows.append(ProfileRow(
            model_hash=model_hash,
            calibration_hash=calibration_hash,
            pipeline_version=pipeline_version,
            layer_idx=li,
            k_bits=cfg.k_bits,
            v_bits=cfg.v_bits,
            quantizer=cfg.quantizer,
            drift_attn_output=float(drift),
            drift_metric=drift_metric,
            bytes_per_kv_token=float(bpt),
            max_seq_len_observed=actual_seq_len,
            num_kv_heads=num_kv_heads,
            head_dim=head_dim,
            profiled_at=profiled_at,
            profiled_by_agent_id=profiled_by_agent_id,
            profiled_by_agent_tier=profiled_by_agent_tier,
        ))

    for cfg_idx, cfg in enumerate(sweep_list):
        log(f"[kv-prof] config {cfg_idx + 1}/{n_configs}: "
            f"k={cfg.k_bits}b v={cfg.v_bits}b {cfg.quantizer}")
        t_cfg = time.time()

        spec = KVQuantSpec(
            k_bits=cfg.k_bits,
            v_bits=cfg.v_bits,
            quantizer=cfg.quantizer,
            group_size=cfg.group_size,
        )
        bpt = kv_bytes_per_token(
            num_kv_heads=num_kv_heads,
            head_dim=head_dim,
            k_bits=cfg.k_bits,
            v_bits=cfg.v_bits,
            quantizer=cfg.quantizer,
            group_size=cfg.group_size,
        )

        # Special case: fp16_passthrough drift is zero by definition for both strategies.
        if cfg.quantizer == "fp16_passthrough":
            for li in attn_modules:
                if li in baseline_outputs:
                    _record(li, cfg, bpt, baseline_outputs[li])
            log(f"[kv-prof]   config done in {time.time() - t_cfg:.1f}s (passthrough)")
            continue

        if strategy == "joint":
            specs_by_layer = {li: spec for li in attn_modules}
            with kv_quant_active_multi(attn_modules, specs_by_layer):
                quanted_outputs = _capture_attn_outputs(
                    model, attn_modules, batch,
                )
            for li in attn_modules:
                if li not in quanted_outputs or li not in baseline_outputs:
                    logger.debug("layer %d: missing capture, skipping", li)
                    continue
                _record(li, cfg, bpt, quanted_outputs[li])
        else:  # isolated
            for li, attn in attn_modules.items():
                with kv_quant_active(attn, spec):
                    captured = _capture_attn_outputs(
                        model, attn_modules, batch,
                    )
                if li not in captured or li not in baseline_outputs:
                    logger.debug("layer %d: missing capture, skipping", li)
                    continue
                _record(li, cfg, bpt, captured[li])

        log(f"[kv-prof]   config done in {time.time() - t_cfg:.1f}s, "
            f"sample drift = {rows[-1].drift_attn_output:.4e}")

    log(f"[kv-prof] complete β€” {len(rows)} rows total")
    return rows


# ---------------------------------------------------------------------------
# Bridge to allocator
# ---------------------------------------------------------------------------


def rows_to_kv_candidates(rows: list[ProfileRow]):
    """Group ProfileRows by layer into KVCandidate objects that
    assign_kv_bits (in assignment_v2.py) consumes directly.

    Carries num_kv_heads and head_dim through from the rows so the
    allocator gets the real per-layer dimensions, not placeholders.
    """
    from assignment_v2 import KVCandidate, KVOption
    from collections import defaultdict

    by_layer: dict[int, list[ProfileRow]] = defaultdict(list)
    for r in rows:
        by_layer[r.layer_idx].append(r)

    candidates = []
    for layer_idx, layer_rows in sorted(by_layer.items()):
        # All rows for one layer share num_kv_heads/head_dim by construction
        first = layer_rows[0]
        options = [
            KVOption(
                k_bits=r.k_bits,
                v_bits=r.v_bits,
                quantizer=r.quantizer,
                drift=r.drift_attn_output,
                bytes_per_kv_token=r.bytes_per_kv_token,
            )
            for r in layer_rows
        ]
        candidates.append(KVCandidate(
            layer_idx=layer_idx,
            num_kv_heads=first.num_kv_heads,
            head_dim=first.head_dim,
            options=options,
        ))
    return candidates