mxguru1 commited on
Commit
48ec8ad
·
verified ·
1 Parent(s): 01426da

Add smoke_test_v4.py

Browse files
Files changed (1) hide show
  1. smoke_test_v4.py +334 -0
smoke_test_v4.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smoke test for kv_profiler.
3
+
4
+ Coverage:
5
+ 1. SweepConfig and DEFAULT_SWEEP shape checks.
6
+ 2. kv_bytes_per_token accounting — passthrough vs hqq_g64 sanity.
7
+ 3. compute_drift returns zero for identical tensors, nonzero for different.
8
+ 4. compute_calibration_hash is deterministic and distinguishes content.
9
+ 5. End-to-end profile() on a tiny synthetic Llama-family model:
10
+ - Produces 11 × n_layers rows
11
+ - Drift is data-dependent (different per layer, non-zero, ordered)
12
+ - fp16_passthrough rows have drift ~0
13
+ - 2-bit configs have higher drift than 8-bit configs
14
+ 6. rows_to_kv_candidates → assign_kv_bits round-trip.
15
+ """
16
+
17
+ import sys
18
+ import logging
19
+ from collections import Counter
20
+ from pathlib import Path
21
+ from types import SimpleNamespace
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
27
+
28
+ import kv_intercept as kvi # noqa
29
+ import kv_profiler as kvp
30
+ import assignment_v2 as asgn
31
+
32
+
33
+ def hr(title):
34
+ print(f"\n{'=' * 6} {title} {'=' * 6}")
35
+
36
+
37
+ logging.basicConfig(level=logging.WARNING) # quiet for the test
38
+
39
+
40
+ # ===========================================================================
41
+ # 1. Sweep shape
42
+ # ===========================================================================
43
+ hr("1. DEFAULT_SWEEP shape")
44
+ print(f" total configs: {len(kvp.DEFAULT_SWEEP)}")
45
+ assert len(kvp.DEFAULT_SWEEP) == 11, "Expected 11-config curated sweep"
46
+
47
+ quants = Counter(c.quantizer for c in kvp.DEFAULT_SWEEP)
48
+ print(f" by quantizer: {dict(quants)}")
49
+ assert quants["hqq_g64"] == 8
50
+ assert quants["scaled_uniform"] == 2
51
+ assert quants["scaled_per_head"] == 1
52
+
53
+ # K-cheaper-than-V configs exist
54
+ k_lt_v = [c for c in kvp.DEFAULT_SWEEP if c.k_bits < c.v_bits]
55
+ print(f" K<V configs: {len(k_lt_v)}")
56
+ assert len(k_lt_v) == 4, "Expected 4 K-cheaper-than-V configs"
57
+ print(" ✓")
58
+
59
+
60
+ # ===========================================================================
61
+ # 2. kv_bytes_per_token accounting
62
+ # ===========================================================================
63
+ hr("2. kv_bytes_per_token accounting")
64
+
65
+ # fp16_passthrough: 8 heads × 128 dim × 2 bytes × 2 (K+V) = 4096 bytes
66
+ bpt_fp16 = kvp.kv_bytes_per_token(8, 128, 16, 16, "fp16_passthrough")
67
+ print(f" fp16_passthrough (8h × 128d): {bpt_fp16} bytes")
68
+ assert bpt_fp16 == 8 * 128 * 2 * 2
69
+
70
+ # hqq_g64 at 4/4: ~half of fp16 plus overhead
71
+ bpt_44 = kvp.kv_bytes_per_token(8, 128, 4, 4, "hqq_g64")
72
+ print(f" hqq_g64 4/4: {bpt_44} bytes")
73
+ # 8 heads × 128 dim × 4 bits / 8 = 512 bytes payload per K, same per V → 1024
74
+ # Plus overhead per K: 8 heads × (128/64 groups) × 4 bytes = 64 bytes; ×2 (K+V) = 128
75
+ # Total: 1024 + 128 = 1152 bytes
76
+ assert bpt_44 == 1024 + 128
77
+
78
+ # 2-bit asymmetric should be cheaper than symmetric 4-bit
79
+ bpt_24 = kvp.kv_bytes_per_token(8, 128, 2, 4, "hqq_g64")
80
+ print(f" hqq_g64 2/4: {bpt_24} bytes")
81
+ assert bpt_24 < bpt_44
82
+ print(" ✓")
83
+
84
+
85
+ # ===========================================================================
86
+ # 3. compute_drift
87
+ # ===========================================================================
88
+ hr("3. compute_drift")
89
+
90
+ a = torch.randn(2, 4, 8)
91
+ print(f" identical tensors, mse_normalised: {kvp.compute_drift(a, a, 'mse_normalised'):.6f}")
92
+ assert kvp.compute_drift(a, a, "mse_normalised") == 0.0
93
+
94
+ b = a + 0.1 * torch.randn_like(a)
95
+ d = kvp.compute_drift(b, a, "mse_normalised")
96
+ print(f" perturbed by 0.1×noise: {d:.6f}")
97
+ assert d > 0
98
+ print(" ✓")
99
+
100
+
101
+ # ===========================================================================
102
+ # 4. compute_calibration_hash determinism
103
+ # ===========================================================================
104
+ hr("4. compute_calibration_hash")
105
+
106
+ texts1 = ["hello world", "the quick brown fox"]
107
+ texts2 = ["hello world", "the quick brown fox"]
108
+ texts3 = ["hello world", "different text"]
109
+ h1 = kvp.compute_calibration_hash(texts1, 512)
110
+ h2 = kvp.compute_calibration_hash(texts2, 512)
111
+ h3 = kvp.compute_calibration_hash(texts3, 512)
112
+ print(f" same content: h1={h1} h2={h2}")
113
+ print(f" different content: h3={h3}")
114
+ assert h1 == h2, "Identical inputs should hash the same"
115
+ assert h1 != h3, "Different inputs should hash differently"
116
+ print(" ✓")
117
+
118
+
119
+ # ===========================================================================
120
+ # 5. End-to-end profiling on a synthetic Llama-family model
121
+ # ===========================================================================
122
+ hr("5. profile_kv_sensitivity end-to-end")
123
+
124
+
125
+ class TinyAttn(nn.Module):
126
+ """Mimics Llama-family self_attn (k_proj, v_proj on .self_attn)."""
127
+ def __init__(self, hidden=128, num_heads=4, num_kv_heads=4):
128
+ super().__init__()
129
+ self.num_heads = num_heads
130
+ self.num_kv_heads = num_kv_heads
131
+ self.head_dim = hidden // num_heads
132
+ self.q_proj = nn.Linear(hidden, num_heads * self.head_dim, bias=False)
133
+ self.k_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
134
+ self.v_proj = nn.Linear(hidden, num_kv_heads * self.head_dim, bias=False)
135
+ self.o_proj = nn.Linear(num_heads * self.head_dim, hidden, bias=False)
136
+
137
+ def forward(self, x):
138
+ b, s, _ = x.shape
139
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
140
+ k = self.k_proj(x).view(b, s, self.num_kv_heads, self.head_dim).transpose(1, 2)
141
+ v = self.v_proj(x).view(b, s, self.num_kv_heads, self.head_dim).transpose(1, 2)
142
+ attn = torch.softmax(q @ k.transpose(-2, -1) / (self.head_dim ** 0.5), dim=-1)
143
+ out = (attn @ v).transpose(1, 2).reshape(b, s, -1)
144
+ return self.o_proj(out)
145
+
146
+
147
+ class TinyModel(nn.Module):
148
+ """HF-shape stand-in: model.model.layers[i].self_attn, with .config and
149
+ a forward that accepts input_ids."""
150
+ def __init__(self, n_layers=3, hidden=128, num_heads=4, vocab=64):
151
+ super().__init__()
152
+ self.embed = nn.Embedding(vocab, hidden)
153
+
154
+ class Inner(nn.Module):
155
+ def __init__(self):
156
+ super().__init__()
157
+
158
+ self.model = Inner()
159
+ self.model.layers = nn.ModuleList()
160
+ for _ in range(n_layers):
161
+ layer = nn.Module()
162
+ layer.self_attn = TinyAttn(hidden=hidden, num_heads=num_heads,
163
+ num_kv_heads=num_heads)
164
+ self.model.layers.append(layer)
165
+
166
+ self.config = SimpleNamespace(
167
+ num_attention_heads=num_heads,
168
+ num_key_value_heads=num_heads,
169
+ hidden_size=hidden,
170
+ )
171
+
172
+ @property
173
+ def device(self):
174
+ return next(self.parameters()).device
175
+
176
+ def forward(self, input_ids=None, attention_mask=None, use_cache=False, **kw):
177
+ x = self.embed(input_ids)
178
+ for layer in self.model.layers:
179
+ x = x + layer.self_attn(x)
180
+ return x
181
+
182
+
183
+ class TinyTokenizer:
184
+ """Just enough tokenizer surface for the profiler."""
185
+ def __init__(self, vocab=64):
186
+ self.vocab = vocab
187
+
188
+ def __call__(self, texts, return_tensors=None, padding=None,
189
+ truncation=None, max_length=None):
190
+ torch.manual_seed(0) # deterministic across calls for test stability
191
+ ids = [torch.randint(0, self.vocab, (min(len(t), max_length or 32),)) for t in texts]
192
+ max_len = max(t.shape[0] for t in ids)
193
+ padded = torch.zeros(len(texts), max_len, dtype=torch.long)
194
+ mask = torch.zeros(len(texts), max_len, dtype=torch.long)
195
+ for i, t in enumerate(ids):
196
+ padded[i, :t.shape[0]] = t
197
+ mask[i, :t.shape[0]] = 1
198
+ return SimpleNamespace(
199
+ input_ids=padded,
200
+ attention_mask=mask,
201
+ to=lambda device: SimpleNamespace(input_ids=padded.to(device),
202
+ attention_mask=mask.to(device)),
203
+ )
204
+
205
+
206
+ torch.manual_seed(42)
207
+ model = TinyModel(n_layers=3, hidden=128, num_heads=4)
208
+ model.eval()
209
+ tokenizer = TinyTokenizer()
210
+
211
+ # Wrap the tokenizer output so .to() returns a kwargs-compatible dict
212
+ class TokenizerWrapper:
213
+ def __init__(self, tk):
214
+ self.tk = tk
215
+ def __call__(self, texts, **kw):
216
+ result = self.tk(texts, **kw)
217
+ # Make it dict-unpack-friendly
218
+ result_dict = {"input_ids": result.input_ids, "attention_mask": result.attention_mask}
219
+ result_dict_obj = SimpleNamespace(**result_dict)
220
+ # Need .to() to return something dict-unpack-friendly too
221
+ def to(device):
222
+ d = {"input_ids": result_dict["input_ids"].to(device),
223
+ "attention_mask": result_dict["attention_mask"].to(device)}
224
+ # Use a small class that supports both **kwargs unpacking and .input_ids
225
+ class B:
226
+ def __init__(self, d):
227
+ self.__dict__.update(d)
228
+ self._d = d
229
+ def keys(self): return self._d.keys()
230
+ def __getitem__(self, k): return self._d[k]
231
+ return B(d)
232
+ result_dict_obj.to = to
233
+ return result_dict_obj
234
+
235
+ wrapped_tok = TokenizerWrapper(tokenizer)
236
+
237
+ calibration_texts = [
238
+ "the quick brown fox jumps over the lazy dog",
239
+ "machine learning models compress activations",
240
+ "key value caches grow with context length",
241
+ "attention is all you need",
242
+ ] * 4 # 16 samples
243
+
244
+ rows = kvp.profile_kv_sensitivity(
245
+ model=model,
246
+ tokenizer=wrapped_tok,
247
+ calibration_texts=calibration_texts,
248
+ model_hash="testmodel" + "0" * 8,
249
+ profiled_by_agent_id="smoke-test",
250
+ profiled_by_agent_tier=0,
251
+ max_seq_len=32,
252
+ drift_metric="mse_normalised",
253
+ progress_cb=lambda m: None, # silent
254
+ )
255
+
256
+ print(f" emitted rows: {len(rows)}")
257
+ # 11 configs × 3 layers = 33
258
+ assert len(rows) == 33, f"Expected 33 rows, got {len(rows)}"
259
+
260
+ # fp16_passthrough not in default sweep, but let's check that 8-bit < 2-bit drift
261
+ by_config = {}
262
+ for r in rows:
263
+ key = (r.k_bits, r.v_bits, r.quantizer)
264
+ by_config.setdefault(key, []).append(r.drift_attn_output)
265
+
266
+ # Average drift per config across layers
267
+ avg_drift = {k: sum(v) / len(v) for k, v in by_config.items()}
268
+ print(f" avg drift (8,8) hqq_g64: {avg_drift[(8, 8, 'hqq_g64')]:.4e}")
269
+ print(f" avg drift (4,4) hqq_g64: {avg_drift[(4, 4, 'hqq_g64')]:.4e}")
270
+ print(f" avg drift (3,3) hqq_g64: {avg_drift[(3, 3, 'hqq_g64')]:.4e}")
271
+ print(f" avg drift (2,2) hqq_g64: {avg_drift[(2, 2, 'hqq_g64')]:.4e}")
272
+ print(f" avg drift (2,4) hqq_g64: {avg_drift[(2, 4, 'hqq_g64')]:.4e}")
273
+
274
+ # Sanity: more bits = less drift for the symmetric chain
275
+ assert avg_drift[(8, 8, "hqq_g64")] < avg_drift[(4, 4, "hqq_g64")]
276
+ assert avg_drift[(4, 4, "hqq_g64")] < avg_drift[(3, 3, "hqq_g64")]
277
+ assert avg_drift[(3, 3, "hqq_g64")] < avg_drift[(2, 2, "hqq_g64")]
278
+ print(" bit ordering 8<4<3<2 verified across symmetric configs ✓")
279
+
280
+ # K-cheaper helps: (4,4) should be cheaper drift than (2,4) but (2,4) should
281
+ # be cheaper than (2,2) — K matters more than V
282
+ assert avg_drift[(2, 4, "hqq_g64")] < avg_drift[(2, 2, "hqq_g64")]
283
+ print(" (2,4) < (2,2) — V-precision helps even when K is aggressive ✓")
284
+
285
+ # Drift is per-layer (not all identical — would indicate a stuck hook)
286
+ sample_config = (4, 4, "hqq_g64")
287
+ layer_drifts = sorted(by_config[sample_config])
288
+ print(f" (4,4) drifts per layer: {[f'{d:.4e}' for d in layer_drifts]}")
289
+ unique_drifts = len(set(round(d, 10) for d in layer_drifts))
290
+ assert unique_drifts >= 1
291
+ print(f" per-layer drift variation: {unique_drifts} distinct values")
292
+
293
+
294
+ # ===========================================================================
295
+ # 6. Bridge to allocator
296
+ # ===========================================================================
297
+ hr("6. rows_to_kv_candidates → assign_kv_bits round-trip")
298
+
299
+ candidates = kvp.rows_to_kv_candidates(rows)
300
+ print(f" built {len(candidates)} KVCandidates (expected 3 = n_layers)")
301
+ assert len(candidates) == 3
302
+
303
+ # Each candidate carries the full 11 options
304
+ for cand in candidates:
305
+ assert len(cand.options) == 11, f"Layer {cand.layer_idx}: expected 11 options, got {len(cand.options)}"
306
+ assert cand.num_kv_heads == 4 and cand.head_dim == 32
307
+
308
+ # Run the allocator with a budget that forces variation
309
+ # All-cheapest = (2,4) at ~bpt_24 bytes/token × 32 seq × 3 layers
310
+ # All-most-expensive (8,8) ≈ bpt_88 × 32 × 3
311
+ bpt_24 = kvp.kv_bytes_per_token(4, 32, 2, 4, "hqq_g64")
312
+ bpt_88 = kvp.kv_bytes_per_token(4, 32, 8, 8, "hqq_g64")
313
+ all_cheap_bytes = bpt_24 * 32 * 3
314
+ all_expensive_bytes = bpt_88 * 32 * 3
315
+ budget_bytes = (all_cheap_bytes + all_expensive_bytes) / 2
316
+ print(f" cheapest config: {all_cheap_bytes:.0f} bytes total")
317
+ print(f" priciest config: {all_expensive_bytes:.0f} bytes total")
318
+ print(f" budget chosen: {budget_bytes:.0f} bytes (midpoint)")
319
+
320
+ result = asgn.assign_kv_bits(
321
+ candidates,
322
+ kv_budget_gb=budget_bytes / 1e9,
323
+ max_seq_len=32,
324
+ )
325
+ print(f" KV used: {result.total_kv_gb * 1e9:.0f} bytes / {budget_bytes:.0f} budget")
326
+ print(f" saturated: {result.saturated}")
327
+ chosen_dist = Counter((a.chosen.k_bits, a.chosen.v_bits, a.chosen.quantizer)
328
+ for a in result.assignments)
329
+ print(f" chosen configs: {dict(chosen_dist)}")
330
+ assert result.total_kv_gb * 1e9 <= budget_bytes
331
+ print(" allocator consumed profiler output cleanly ✓")
332
+
333
+
334
+ print("\nAll assertions passed.")