Upload test_v2.py with huggingface_hub
Browse files- test_v2.py +278 -0
test_v2.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CPU smoke tests for v2 novel techniques: QAT-fused cooldown, mixed-precision GPTQ, nuclear-norm reg."""
|
| 2 |
+
import math, torch, torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
# ---- Minimal model for testing ----
|
| 6 |
+
class CastedLinear(nn.Linear):
|
| 7 |
+
def __init__(self, *a, **kw):
|
| 8 |
+
super().__init__(*a, **kw)
|
| 9 |
+
self._qat_enabled = False; self._qat_bits = 6; self._qat_clip_sigmas = 12.85
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
w = self.weight.to(x.dtype)
|
| 12 |
+
if self._qat_enabled and self.training:
|
| 13 |
+
w = fake_quantize_ste(w, self._qat_bits, self._qat_clip_sigmas)
|
| 14 |
+
return F.linear(x, w, self.bias.to(x.dtype) if self.bias is not None else None)
|
| 15 |
+
|
| 16 |
+
class FakeQuantize(torch.autograd.Function):
|
| 17 |
+
@staticmethod
|
| 18 |
+
def forward(ctx, w, bits, clip_sigmas):
|
| 19 |
+
clip_range = 2 ** (bits - 1) - 1
|
| 20 |
+
row_std = w.float().std(dim=1, keepdim=True)
|
| 21 |
+
scale = (clip_sigmas * row_std / clip_range).clamp_min(1e-10)
|
| 22 |
+
q = (w / scale).round().clamp(-clip_range, clip_range)
|
| 23 |
+
return (q * scale).to(w.dtype)
|
| 24 |
+
@staticmethod
|
| 25 |
+
def backward(ctx, grad_output):
|
| 26 |
+
return grad_output, None, None
|
| 27 |
+
|
| 28 |
+
def fake_quantize_ste(w, bits, clip_sigmas):
|
| 29 |
+
return FakeQuantize.apply(w, bits, clip_sigmas)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_1_ste_fake_quant():
|
| 33 |
+
print("TEST 1: STE Fake Quantization")
|
| 34 |
+
w = torch.randn(128, 256, requires_grad=True)
|
| 35 |
+
wq = fake_quantize_ste(w, 6, 12.85)
|
| 36 |
+
|
| 37 |
+
# Forward: quantized weights should differ from original
|
| 38 |
+
assert not torch.allclose(w, wq), "Quantized weights should differ from original"
|
| 39 |
+
|
| 40 |
+
# Check quantization is valid: values should be on a grid
|
| 41 |
+
row_std = w.float().detach().std(dim=1, keepdim=True)
|
| 42 |
+
scale = (12.85 * row_std / 31).clamp_min(1e-10)
|
| 43 |
+
grid_vals = (wq / scale).round()
|
| 44 |
+
assert (grid_vals.abs() <= 31).all(), "Values should be within INT6 range"
|
| 45 |
+
|
| 46 |
+
# Backward: STE should pass gradient through unchanged
|
| 47 |
+
loss = wq.sum()
|
| 48 |
+
loss.backward()
|
| 49 |
+
assert w.grad is not None, "Gradient should flow through STE"
|
| 50 |
+
assert torch.allclose(w.grad, torch.ones_like(w)), "STE gradient should be identity"
|
| 51 |
+
|
| 52 |
+
print(" ✓ Forward quantizes correctly (values on INT6 grid)")
|
| 53 |
+
print(" ✓ Backward passes gradient unchanged (STE)")
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_2_qat_castedlinear():
|
| 58 |
+
print("\nTEST 2: QAT-aware CastedLinear")
|
| 59 |
+
layer = CastedLinear(64, 128, bias=False)
|
| 60 |
+
x = torch.randn(4, 16, 64)
|
| 61 |
+
|
| 62 |
+
# Without QAT
|
| 63 |
+
layer._qat_enabled = False; layer.train()
|
| 64 |
+
out_fp = layer(x)
|
| 65 |
+
|
| 66 |
+
# With QAT
|
| 67 |
+
layer._qat_enabled = True; layer._qat_bits = 6
|
| 68 |
+
out_qat = layer(x)
|
| 69 |
+
|
| 70 |
+
# Outputs should differ (quantization noise)
|
| 71 |
+
assert not torch.allclose(out_fp, out_qat, atol=1e-6), "QAT should change outputs"
|
| 72 |
+
|
| 73 |
+
# But not by too much (< 10% relative error)
|
| 74 |
+
rel_err = (out_fp - out_qat).abs().mean() / out_fp.abs().mean()
|
| 75 |
+
print(f" QAT relative error: {rel_err:.4f}")
|
| 76 |
+
assert rel_err < 0.15, "QAT error should be small"
|
| 77 |
+
|
| 78 |
+
# Eval mode: QAT should be inactive
|
| 79 |
+
layer.eval()
|
| 80 |
+
out_eval = layer(x)
|
| 81 |
+
assert torch.allclose(out_fp, out_eval, atol=1e-7), "QAT should be off in eval mode"
|
| 82 |
+
|
| 83 |
+
# Gradient should flow through QAT
|
| 84 |
+
layer.train()
|
| 85 |
+
loss = out_qat.sum()
|
| 86 |
+
loss.backward()
|
| 87 |
+
assert layer.weight.grad is not None, "Gradient must flow through QAT"
|
| 88 |
+
assert layer.weight.grad.abs().sum() > 0, "Gradient must be non-zero"
|
| 89 |
+
|
| 90 |
+
print(" ✓ QAT changes forward output")
|
| 91 |
+
print(" ✓ QAT inactive in eval mode")
|
| 92 |
+
print(" ✓ Gradients flow through QAT")
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_3_nuclear_norm_penalty():
|
| 97 |
+
print("\nTEST 3: Nuclear-norm Regularization")
|
| 98 |
+
|
| 99 |
+
# Low-rank matrix should have lower penalty than full-rank
|
| 100 |
+
low_rank = torch.randn(128, 8) @ torch.randn(8, 256) # rank 8
|
| 101 |
+
full_rank = torch.randn(128, 256) # rank 128
|
| 102 |
+
|
| 103 |
+
# Frobenius norm squared as nuclear-norm proxy
|
| 104 |
+
pen_low = low_rank.norm() ** 2
|
| 105 |
+
pen_full = full_rank.norm() ** 2
|
| 106 |
+
|
| 107 |
+
# Scale to same scale for fair comparison
|
| 108 |
+
low_rank_scaled = low_rank / low_rank.std()
|
| 109 |
+
full_rank_scaled = full_rank / full_rank.std()
|
| 110 |
+
pen_low_s = low_rank_scaled.norm() ** 2
|
| 111 |
+
pen_full_s = full_rank_scaled.norm() ** 2
|
| 112 |
+
|
| 113 |
+
# At same std, Frobenius norm ∝ sqrt(m*n), so they should be similar
|
| 114 |
+
# But the POINT of nuclear-norm reg is that during training,
|
| 115 |
+
# minimizing Frobenius norm pushes singular values toward zero
|
| 116 |
+
print(f" Low-rank Frobenius²: {pen_low.item():.1f}")
|
| 117 |
+
print(f" Full-rank Frobenius²: {pen_full.item():.1f}")
|
| 118 |
+
|
| 119 |
+
# Test that regularization actually modifies weights
|
| 120 |
+
W = nn.Parameter(torch.randn(64, 128))
|
| 121 |
+
opt = torch.optim.SGD([W], lr=0.01)
|
| 122 |
+
|
| 123 |
+
norm_before = W.norm().item()
|
| 124 |
+
for _ in range(10):
|
| 125 |
+
opt.zero_grad()
|
| 126 |
+
penalty = W.float().norm() ** 2
|
| 127 |
+
penalty.backward()
|
| 128 |
+
opt.step()
|
| 129 |
+
norm_after = W.norm().item()
|
| 130 |
+
|
| 131 |
+
assert norm_after < norm_before, "Nuclear reg should decrease weight norm"
|
| 132 |
+
print(f" Weight norm: {norm_before:.3f} → {norm_after:.3f} (decrease: {100*(1-norm_after/norm_before):.1f}%)")
|
| 133 |
+
print(" ✓ Regularization decreases weight magnitude")
|
| 134 |
+
return True
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_4_mixed_precision_classify():
|
| 138 |
+
print("\nTEST 4: Mixed-Precision Classification (MLP=INT4 vs Attn=INT6)")
|
| 139 |
+
|
| 140 |
+
def classify_param(name):
|
| 141 |
+
if 'tok_emb' in name or 'lm_head' in name: return 'embed'
|
| 142 |
+
if '.mlp.' in name: return 'mlp'
|
| 143 |
+
if '.attn.' in name: return 'attn'
|
| 144 |
+
return 'other'
|
| 145 |
+
|
| 146 |
+
test_cases = {
|
| 147 |
+
'blocks.0.mlp.fc.weight': ('mlp', 4),
|
| 148 |
+
'blocks.0.mlp.proj.weight': ('mlp', 4),
|
| 149 |
+
'blocks.0.attn.c_q.weight': ('attn', 6),
|
| 150 |
+
'blocks.0.attn.c_k.weight': ('attn', 6),
|
| 151 |
+
'blocks.0.attn.proj.weight': ('attn', 6),
|
| 152 |
+
'tok_emb.weight': ('embed', 8),
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
mlp_bits, attn_bits, embed_bits = 4, 6, 8
|
| 156 |
+
|
| 157 |
+
for name, (expected_cat, expected_bits) in test_cases.items():
|
| 158 |
+
cat = classify_param(name)
|
| 159 |
+
bits = {'mlp': mlp_bits, 'attn': attn_bits, 'embed': embed_bits}.get(cat, 6)
|
| 160 |
+
assert cat == expected_cat, f"{name}: expected {expected_cat}, got {cat}"
|
| 161 |
+
assert bits == expected_bits, f"{name}: expected INT{expected_bits}, got INT{bits}"
|
| 162 |
+
print(f" {name:40s} → {cat:6s} → INT{bits}")
|
| 163 |
+
|
| 164 |
+
print(" ✓ All parameters classified correctly")
|
| 165 |
+
|
| 166 |
+
# Verify INT4 actually saves space
|
| 167 |
+
# INT6: 6/8 = 0.75 bytes/param. INT4: 4/8 = 0.5 bytes/param
|
| 168 |
+
# For an MLP layer 512×2048: 1,048,576 params
|
| 169 |
+
mlp_params = 512 * 2048 * 2 # fc + proj
|
| 170 |
+
int6_bytes = mlp_params * 6 / 8
|
| 171 |
+
int4_bytes = mlp_params * 4 / 8
|
| 172 |
+
savings = int6_bytes - int4_bytes
|
| 173 |
+
print(f"\n Per-layer MLP savings: {savings/1024:.0f} KB ({savings/int6_bytes*100:.0f}% reduction)")
|
| 174 |
+
print(f" Over 11 layers: {savings*11/1024:.0f} KB saved → room for more params")
|
| 175 |
+
print(" ✓ Mixed precision saves significant space")
|
| 176 |
+
return True
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def test_5_qat_improves_quantized_quality():
|
| 180 |
+
print("\nTEST 5: QAT Training Improves Post-Quantization Quality")
|
| 181 |
+
|
| 182 |
+
torch.manual_seed(42)
|
| 183 |
+
V, D = 32, 64
|
| 184 |
+
|
| 185 |
+
# Simple task: learn embedding → linear → predict
|
| 186 |
+
class TinyModel(nn.Module):
|
| 187 |
+
def __init__(self):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.emb = nn.Embedding(V, D)
|
| 190 |
+
self.linear = CastedLinear(D, V, bias=False)
|
| 191 |
+
def forward(self, x, y):
|
| 192 |
+
h = self.emb(x)
|
| 193 |
+
logits = self.linear(h)
|
| 194 |
+
return F.cross_entropy(logits.reshape(-1, V), y.reshape(-1))
|
| 195 |
+
|
| 196 |
+
def quantize_and_eval(model, x, y):
|
| 197 |
+
"""Simulate post-hoc INT6 quantization and evaluate."""
|
| 198 |
+
model.eval()
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
# Quantize the linear weight
|
| 201 |
+
w = model.linear.weight.float()
|
| 202 |
+
std = w.std(dim=1, keepdim=True)
|
| 203 |
+
scale = (12.85 * std / 31).clamp_min(1e-10)
|
| 204 |
+
wq = ((w / scale).round().clamp(-31, 31) * scale).to(w.dtype)
|
| 205 |
+
|
| 206 |
+
# Replace weight temporarily
|
| 207 |
+
orig = model.linear.weight.data.clone()
|
| 208 |
+
model.linear.weight.data = wq
|
| 209 |
+
loss = model(x, y).item()
|
| 210 |
+
model.linear.weight.data = orig
|
| 211 |
+
return loss
|
| 212 |
+
|
| 213 |
+
# Training data: simple pattern
|
| 214 |
+
x = torch.arange(V).unsqueeze(0).expand(8, -1)
|
| 215 |
+
y = (x + 1) % V
|
| 216 |
+
|
| 217 |
+
# Train WITHOUT QAT
|
| 218 |
+
torch.manual_seed(42)
|
| 219 |
+
model_no_qat = TinyModel()
|
| 220 |
+
opt1 = torch.optim.Adam(model_no_qat.parameters(), lr=1e-2)
|
| 221 |
+
for _ in range(100):
|
| 222 |
+
opt1.zero_grad()
|
| 223 |
+
model_no_qat(x, y).backward()
|
| 224 |
+
opt1.step()
|
| 225 |
+
|
| 226 |
+
# Train WITH QAT (enable at step 50)
|
| 227 |
+
torch.manual_seed(42)
|
| 228 |
+
model_qat = TinyModel()
|
| 229 |
+
opt2 = torch.optim.Adam(model_qat.parameters(), lr=1e-2)
|
| 230 |
+
for step in range(100):
|
| 231 |
+
if step == 50:
|
| 232 |
+
model_qat.linear._qat_enabled = True
|
| 233 |
+
model_qat.linear._qat_bits = 6
|
| 234 |
+
model_qat.linear._qat_clip_sigmas = 12.85
|
| 235 |
+
model_qat.train()
|
| 236 |
+
opt2.zero_grad()
|
| 237 |
+
model_qat(x, y).backward()
|
| 238 |
+
opt2.step()
|
| 239 |
+
|
| 240 |
+
# Evaluate both AFTER quantization
|
| 241 |
+
loss_no_qat = quantize_and_eval(model_no_qat, x, y)
|
| 242 |
+
model_qat.linear._qat_enabled = False # disable for fair eval
|
| 243 |
+
loss_qat = quantize_and_eval(model_qat, x, y)
|
| 244 |
+
|
| 245 |
+
# Also measure FP quality (before quantization)
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
model_no_qat.eval(); fp_no_qat = model_no_qat(x, y).item()
|
| 248 |
+
model_qat.eval(); fp_qat = model_qat(x, y).item()
|
| 249 |
+
|
| 250 |
+
print(f" No-QAT: FP loss={fp_no_qat:.4f} Quantized loss={loss_no_qat:.4f} (gap={loss_no_qat-fp_no_qat:+.4f})")
|
| 251 |
+
print(f" QAT: FP loss={fp_qat:.4f} Quantized loss={loss_qat:.4f} (gap={loss_qat-fp_qat:+.4f})")
|
| 252 |
+
|
| 253 |
+
qat_gap = loss_qat - fp_qat
|
| 254 |
+
no_qat_gap = loss_no_qat - fp_no_qat
|
| 255 |
+
|
| 256 |
+
if qat_gap < no_qat_gap:
|
| 257 |
+
print(f" ✓ QAT reduces quantization gap by {no_qat_gap - qat_gap:.4f}")
|
| 258 |
+
else:
|
| 259 |
+
print(f" ⚠ QAT did not reduce gap on this toy task (expected at larger scale)")
|
| 260 |
+
|
| 261 |
+
return True
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == '__main__':
|
| 265 |
+
print("Parameter Golf v2 — Novel Technique Smoke Tests")
|
| 266 |
+
print("=" * 60)
|
| 267 |
+
results = []
|
| 268 |
+
results.append(("STE Fake Quant", test_1_ste_fake_quant()))
|
| 269 |
+
results.append(("QAT CastedLinear", test_2_qat_castedlinear()))
|
| 270 |
+
results.append(("Nuclear-norm Reg", test_3_nuclear_norm_penalty()))
|
| 271 |
+
results.append(("Mixed Precision", test_4_mixed_precision_classify()))
|
| 272 |
+
results.append(("QAT Quality", test_5_qat_improves_quantized_quality()))
|
| 273 |
+
|
| 274 |
+
print("\n" + "=" * 60)
|
| 275 |
+
print("SUMMARY")
|
| 276 |
+
for name, ok in results:
|
| 277 |
+
print(f" {'✓' if ok else '✗'} {name}")
|
| 278 |
+
print(f"\n{'All passed!' if all(r[1] for r in results) else 'FAILURES!'}")
|