mxguru1 commited on
Commit
fa1f4fa
·
verified ·
1 Parent(s): 55f5f5e

Add KV interception hooks + generalised allocator + smoke tests (3/3: smoke_test_v3.py)

Browse files
Files changed (1) hide show
  1. smoke_test_v3.py +289 -0
smoke_test_v3.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smoke test for assignment_v2 + kv_intercept.
3
+
4
+ Coverage:
5
+ 1. Back-compat: old assign_bit_widths API still works.
6
+ 2. Generic core: assign_greedy with arbitrary (cost, unit) pairs.
7
+ 3. KV allocator: assign_kv_bits respects KV-cache budget at max_seq_len.
8
+ 4. Two-budget combined: assign_combined runs both independently.
9
+ 5. KV interception hook: forward hooks on k_proj/v_proj actually
10
+ modify attention output, with drift ordered by bit width.
11
+ 6. KV interception multi-layer: kv_quant_active_multi installs and
12
+ tears down cleanly.
13
+
14
+ Run: place assignment_v2.py and kv_intercept.py in the same directory,
15
+ then `python smoke_test_v3.py`.
16
+ """
17
+
18
+ import sys
19
+ import logging
20
+ from collections import Counter
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ # Make sibling modules importable regardless of where the script is run from.
27
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
28
+
29
+ import assignment_v2 as asgn
30
+ import kv_intercept as kvi
31
+
32
+
33
+ def hr(title):
34
+ print(f"\n{'=' * 6} {title} {'=' * 6}")
35
+
36
+
37
+ # ===========================================================================
38
+ # 1. Back-compat: old API
39
+ # ===========================================================================
40
+ hr("1. Back-compat: assign_bit_widths still works")
41
+
42
+ def opts(d2, d3, d4):
43
+ return [
44
+ asgn.LayerOption(bits=2, quantizer="hqq", drift=d2, bytes_per_param=(2/8)*1.07),
45
+ asgn.LayerOption(bits=3, quantizer="hqq", drift=d3, bytes_per_param=(3/8)*1.07),
46
+ asgn.LayerOption(bits=4, quantizer="hqq", drift=d4, bytes_per_param=(4/8)*1.07),
47
+ ]
48
+
49
+ candidates = [
50
+ asgn.LayerCandidate(0, "attn", 100_000_000, opts(1.20, 0.40, 0.05)),
51
+ asgn.LayerCandidate(1, "mlp", 200_000_000, opts(0.50, 0.15, 0.08)),
52
+ asgn.LayerCandidate(2, "attn", 100_000_000, opts(0.12, 0.10, 0.09)),
53
+ ]
54
+
55
+ result = asgn.assign_bit_widths(candidates, weight_budget_gb=0.15)
56
+ print(f" total_weights_gb: {result.total_weights_gb:.4f}")
57
+ print(f" total_drift: {result.total_drift:.3f}")
58
+ print(f" saturated: {result.saturated}")
59
+ for a in result.assignments:
60
+ print(f" L{a.layer_idx} {a.component:<5} -> {a.chosen.bits}-bit ({a.chosen.quantizer}), drift={a.chosen.drift:.3f}")
61
+
62
+ # Sensitive layer (0) should be at least as high-bit as tolerant layer (2)
63
+ by_layer = result.by_layer
64
+ assert by_layer[(0, "attn")].chosen.bits >= by_layer[(2, "attn")].chosen.bits
65
+ print(" v1 API back-compat verified ✓")
66
+
67
+
68
+ # ===========================================================================
69
+ # 2. Generic core
70
+ # ===========================================================================
71
+ hr("2. Generic assign_greedy with arbitrary cost/unit pairs")
72
+
73
+ # Simulate something completely unlike weights: 5 candidates each with
74
+ # unit_count = 1 (so cost_per_unit IS the total cost) and different drifts.
75
+ gcands = [
76
+ asgn.GenericCandidate(
77
+ candidate_id=("act", i),
78
+ unit_count=1,
79
+ options=[
80
+ asgn.GenericOption(cost_per_unit=1.0e8, drift=1.0, label=("a",)),
81
+ asgn.GenericOption(cost_per_unit=2.0e8, drift=0.5, label=("b",)),
82
+ asgn.GenericOption(cost_per_unit=4.0e8, drift=0.1, label=("c",)),
83
+ ],
84
+ )
85
+ for i in range(5)
86
+ ]
87
+ # Budget for 5 × cheap (0.5 GB) + room for 2 upgrades to 'b' (extra 0.2 GB)
88
+ gen_result = asgn.assign_greedy(gcands, budget_bytes=0.7e9)
89
+ print(f" total_bytes: {gen_result.total_bytes / 1e9:.3f} GB / {gen_result.budget_gb:.3f} GB budget")
90
+ print(f" drift: {gen_result.total_drift:.3f}")
91
+ print(f" saturated: {gen_result.saturated}")
92
+ labels = Counter(a.chosen.label for a in gen_result.assignments)
93
+ print(f" label distribution: {dict(labels)}")
94
+ assert gen_result.total_bytes <= gen_result.budget_bytes
95
+
96
+
97
+ # ===========================================================================
98
+ # 3. KV allocator
99
+ # ===========================================================================
100
+ hr("3. assign_kv_bits respects KV-cache budget at max_seq_len")
101
+
102
+ # OLMo-like shape: 40 layers, 40 KV heads, 128 head_dim → 800 KB/token fp16
103
+ NUM_KV_HEADS = 40
104
+ HEAD_DIM = 128
105
+ NUM_LAYERS = 40
106
+ MAX_SEQ = 4096
107
+
108
+ def kv_opts(num_kv_heads, head_dim):
109
+ """Generate 4 KV options per layer: fp16, 8-bit, 4-bit, 2-bit hqq_g64."""
110
+ elems = num_kv_heads * head_dim
111
+ group_size = 64
112
+ groups = max(1, elems // group_size)
113
+ hqq_overhead = groups * 2 * 2 # 2 (zero+scale) × 2 bytes per group
114
+
115
+ def bpt(k_bits, v_bits, overhead):
116
+ k = elems * k_bits / 8 + overhead
117
+ v = elems * v_bits / 8 + overhead
118
+ return k + v
119
+
120
+ return [
121
+ # Drift values: arbitrary but ordered. Reality would have these measured.
122
+ asgn.KVOption(16, 16, "fp16_passthrough", drift=0.000, bytes_per_kv_token=elems*4),
123
+ asgn.KVOption(8, 8, "hqq_g64", drift=0.005, bytes_per_kv_token=bpt(8, 8, hqq_overhead)),
124
+ asgn.KVOption(4, 4, "hqq_g64", drift=0.030, bytes_per_kv_token=bpt(4, 4, hqq_overhead)),
125
+ asgn.KVOption(2, 4, "hqq_g64", drift=0.080, bytes_per_kv_token=bpt(2, 4, hqq_overhead)),
126
+ ]
127
+
128
+ kv_cands = [
129
+ asgn.KVCandidate(layer_idx=i, num_kv_heads=NUM_KV_HEADS, head_dim=HEAD_DIM,
130
+ options=kv_opts(NUM_KV_HEADS, HEAD_DIM))
131
+ for i in range(NUM_LAYERS)
132
+ ]
133
+
134
+ # Budget: 2.0 GB (between all-2/4-bit (~0.8 GB) and all-fp16 (~3.3 GB))
135
+ kv_result = asgn.assign_kv_bits(kv_cands, kv_budget_gb=2.0, max_seq_len=MAX_SEQ)
136
+ print(f" Layers: {NUM_LAYERS}, max_seq_len: {MAX_SEQ}")
137
+ print(f" KV used: {kv_result.total_kv_gb:.3f} / {kv_result.budget_gb:.3f} GB")
138
+ print(f" drift: {kv_result.total_drift:.4f}")
139
+ print(f" saturated: {kv_result.saturated}")
140
+ bits_hist = Counter((a.chosen.k_bits, a.chosen.v_bits) for a in kv_result.assignments)
141
+ print(f" (k_bits, v_bits) distribution: {dict(bits_hist)}")
142
+ assert kv_result.total_kv_gb <= kv_result.budget_gb
143
+ assert len(bits_hist) >= 1
144
+
145
+
146
+ # ===========================================================================
147
+ # 4. Two-budget combined
148
+ # ===========================================================================
149
+ hr("4. assign_combined: weights and KV under independent budgets")
150
+
151
+ # Reuse candidates from earlier
152
+ combined = asgn.assign_combined(
153
+ weight_candidates=[c.to_generic() for c in candidates],
154
+ kv_candidates=[c.to_generic(MAX_SEQ) for c in kv_cands],
155
+ weight_budget_bytes=0.15e9,
156
+ kv_budget_bytes=2.0e9,
157
+ )
158
+ print(f" weight total: {combined.weights.total_gb:.3f} GB / 0.15 GB budget")
159
+ print(f" KV total: {combined.kv.total_gb:.3f} GB / 2.00 GB budget")
160
+ print(f" combined drift: {combined.total_drift:.4f}")
161
+ # The two are independent — verify by checking the totals match the sums
162
+ assert abs(combined.total_drift - (combined.weights.total_drift + combined.kv.total_drift)) < 1e-9
163
+ assert combined.weights.total_bytes <= 0.15e9
164
+ assert combined.kv.total_bytes <= 2.0e9
165
+ print(" weight and KV pools independent, both within budget ✓")
166
+
167
+
168
+ # ===========================================================================
169
+ # 5. KV interception hook actually modifies attention output
170
+ # ===========================================================================
171
+ hr("5. K/V interception hook changes attention output")
172
+
173
+ # Build a minimal Llama-family attention module: just k_proj and v_proj as
174
+ # nn.Linear, then a fake attention computation that uses them.
175
+ class TinyAttn(nn.Module):
176
+ """Mimics Llama-family self_attn surface: q_proj, k_proj, v_proj, o_proj."""
177
+ def __init__(self, hidden=128, num_heads=4, num_kv_heads=4):
178
+ super().__init__()
179
+ self.hidden = hidden
180
+ self.num_heads = num_heads
181
+ self.num_kv_heads = num_kv_heads
182
+ self.head_dim = hidden // num_heads
183
+ self.q_proj = nn.Linear(hidden, num_heads * self.head_dim, bias=False)
184
+ self.k_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
185
+ self.v_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
186
+ self.o_proj = nn.Linear(num_heads * self.head_dim, hidden, bias=False)
187
+
188
+ def forward(self, x):
189
+ b, s, _ = x.shape
190
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
191
+ k = self.k_proj(x).view(b, s, self.num_kv_heads, self.head_dim).transpose(1, 2)
192
+ v = self.v_proj(x).view(b, s, self.num_kv_heads, self.head_dim).transpose(1, 2)
193
+ # Simple scaled-dot-product attention (no GQA repeat needed since
194
+ # num_heads == num_kv_heads here for the test)
195
+ attn = torch.softmax(q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1)
196
+ out = (attn @ v).transpose(1, 2).reshape(b, s, -1)
197
+ return self.o_proj(out)
198
+
199
+
200
+ torch.manual_seed(0)
201
+ attn = TinyAttn()
202
+ attn.eval()
203
+
204
+ x = torch.randn(1, 32, 128)
205
+
206
+ with torch.no_grad():
207
+ baseline = attn(x).clone()
208
+
209
+ # 4-bit KV
210
+ spec_4 = kvi.KVQuantSpec(k_bits=4, v_bits=4, quantizer="hqq_g64", group_size=64)
211
+ with kvi.kv_quant_active(attn, spec_4), torch.no_grad():
212
+ out_4bit = attn(x).clone()
213
+
214
+ # 2-bit KV
215
+ spec_2 = kvi.KVQuantSpec(k_bits=2, v_bits=2, quantizer="hqq_g64", group_size=64)
216
+ with kvi.kv_quant_active(attn, spec_2), torch.no_grad():
217
+ out_2bit = attn(x).clone()
218
+
219
+ # After context exits, hooks should be removed — verify by re-running
220
+ with torch.no_grad():
221
+ after = attn(x).clone()
222
+
223
+ drift_4 = ((out_4bit - baseline) ** 2).mean().item()
224
+ drift_2 = ((out_2bit - baseline) ** 2).mean().item()
225
+ drift_after = ((after - baseline) ** 2).mean().item()
226
+
227
+ print(f" attention output drift at 4-bit KV: {drift_4:.6e}")
228
+ print(f" attention output drift at 2-bit KV: {drift_2:.6e}")
229
+ print(f" drift after context exit (should be 0): {drift_after:.6e}")
230
+
231
+ assert drift_4 > 0, "4-bit hook had no effect on attention output"
232
+ assert drift_2 > drift_4, f"2-bit drift ({drift_2}) should exceed 4-bit drift ({drift_4})"
233
+ assert drift_after == 0.0, f"Hook leaked past context manager: drift {drift_after}"
234
+ print(" hook activates, drift ordered, cleans up on exit ✓")
235
+
236
+
237
+ # ===========================================================================
238
+ # 6. Multi-layer hook installation
239
+ # ===========================================================================
240
+ hr("6. kv_quant_active_multi installs and tears down cleanly")
241
+
242
+ # Build a tiny "model" with 3 attention modules
243
+ class TinyModelShim:
244
+ """Stand-in for a HF model with model.layers[i].self_attn structure."""
245
+ def __init__(self, n=3):
246
+ # Match the discovery path: model.model.layers[i].self_attn
247
+ layers = []
248
+ for _ in range(n):
249
+ class Layer:
250
+ pass
251
+ layer = Layer()
252
+ layer.self_attn = TinyAttn()
253
+ layers.append(layer)
254
+ class Inner:
255
+ pass
256
+ self.model = Inner()
257
+ self.model.layers = layers
258
+
259
+
260
+ m = TinyModelShim(n=3)
261
+ attns = kvi.find_attention_modules(m)
262
+ print(f" Discovered {len(attns)} attention modules: layer indices {sorted(attns.keys())}")
263
+ assert len(attns) == 3
264
+
265
+ # Activate on layers 0 and 2 only, then verify only those have hooks during
266
+ # the context and ALL are hookless afterward.
267
+ specs = {
268
+ 0: kvi.KVQuantSpec(k_bits=4, v_bits=4, quantizer="hqq_g64"),
269
+ 2: kvi.KVQuantSpec(k_bits=2, v_bits=2, quantizer="hqq_g64"),
270
+ }
271
+
272
+ # Capture hook counts during context
273
+ def count_hooks(attn):
274
+ return len(attn.k_proj._forward_hooks) + len(attn.v_proj._forward_hooks)
275
+
276
+ before = {i: count_hooks(a) for i, a in attns.items()}
277
+ with kvi.kv_quant_active_multi(attns, specs):
278
+ during = {i: count_hooks(a) for i, a in attns.items()}
279
+ after = {i: count_hooks(a) for i, a in attns.items()}
280
+
281
+ print(f" hooks before: {before}")
282
+ print(f" hooks during: {during}")
283
+ print(f" hooks after: {after}")
284
+ assert during[0] == 2 and during[1] == 0 and during[2] == 2
285
+ assert before == after == {0: 0, 1: 0, 2: 0}
286
+ print(" multi-layer hook lifecycle clean ✓")
287
+
288
+
289
+ print("\nAll assertions passed.")