| """ |
| import os |
| import sys |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| FlashVQ Correctness Tests — CPU path, GPU path, and CPU vs GPU equivalence. |
| |
| Test structure follows testing/test_tscale.py pattern: |
| - Each test is a standalone function |
| - Manual runner at bottom for direct execution |
| - CUDA/Triton tests skip gracefully when unavailable |
| |
| Tests 1-7: CPU path correctness (Task 1) |
| Tests 8-11: GPU path correctness + CPU vs GPU equivalence (Task 2) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import sys |
| import os |
|
|
|
|
| import flash_vq |
| from arbitor.kernel.flash_vq import FlashVQCodebook, _HAS_TRITON |
|
|
| try: |
| from arbitor.main import VQAdapter, MultimodalVQBridge, HIDDEN_DIM, CODEBOOK_DIM |
| from arbitor.kernel.ternary_scale import TScaleType |
| _HAS_TRIGRAM = True |
| except ImportError: |
| _HAS_TRIGRAM = False |
|
|
|
|
| |
|
|
| def _make_cpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True): |
| """Create a deterministic FlashVQCodebook on CPU.""" |
| torch.manual_seed(seed) |
| vq = FlashVQCodebook( |
| codebook_size=codebook_size, |
| codebook_dim=codebook_dim, |
| decay=0.99, |
| commitment_weight=1.0, |
| threshold_ema_dead_code=2, |
| kmeans_init=False, |
| kmeans_iters=10, |
| rotation_trick=rotation_trick, |
| ) |
| return vq |
|
|
|
|
| |
|
|
| def test_flash_vq_cpu_forward_shapes(): |
| """ |
| Test 1: FlashVQCodebook CPU forward with random input returns |
| (quantized, indices, commitment_loss) with correct shapes. |
| """ |
| vq = _make_cpu_vq() |
| x = torch.randn(4, 16, 32) |
| quantized, indices, loss = vq._cpu_forward(x.reshape(-1, 32)) |
|
|
| |
| assert quantized.shape == (64, 32), f"quantized shape: {quantized.shape}" |
| |
| assert indices.shape == (64,), f"indices shape: {indices.shape}" |
| |
| assert loss.numel() == 1, f"loss shape: {loss.shape}" |
| assert loss.dim() == 0, f"loss dim: {loss.dim()}" |
| |
| assert indices.min() >= 0, f"negative index: {indices.min()}" |
| assert indices.max() < vq.codebook_size, f"index too large: {indices.max()}" |
| |
| assert quantized.shape[-1] == 32, f"quantized last dim: {quantized.shape[-1]}" |
|
|
| print(" PASS test_flash_vq_cpu_forward_shapes") |
|
|
|
|
| def test_flash_vq_cpu_quantized_matches_codebook(): |
| """ |
| Test 2: FlashVQCodebook CPU quantized output matches codebook[indices] |
| (straight-through estimator). |
| """ |
| vq = _make_cpu_vq() |
| x = torch.randn(4, 16, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| |
| embed_snapshot = vq.embed.clone() |
|
|
| quantized, indices, loss = vq._cpu_forward(x_flat) |
|
|
| |
| |
| expected_quantized = embed_snapshot[indices] |
| diff_vq = quantized - x_flat |
| diff_raw = expected_quantized - x_flat |
| |
| assert torch.allclose(diff_vq, diff_raw.detach(), atol=1e-6), \ |
| "STE: quantized - x should equal (embed[indices] - x).detach()" |
|
|
| print(" PASS test_flash_vq_cpu_quantized_matches_codebook") |
|
|
|
|
| def test_flash_vq_cpu_cosine_sim(): |
| """ |
| Test 3: FlashVQCodebook CPU cosine similarity matches |
| F.normalize(x) @ F.normalize(codebook).T argmax. |
| """ |
| vq = _make_cpu_vq() |
| x = torch.randn(4, 16, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| |
| embed_snapshot = vq.embed.clone() |
|
|
| quantized, indices, loss = vq._cpu_forward(x_flat) |
|
|
| |
| x_norm = F.normalize(x_flat, dim=-1) |
| embed_norm = F.normalize(embed_snapshot, dim=-1) |
| manual_sim = x_norm @ embed_norm.T |
| manual_indices = manual_sim.argmax(dim=-1) |
|
|
| |
| assert torch.equal(indices, manual_indices), \ |
| f"Indices differ! First 10 indices: {indices[:10]} vs {manual_indices[:10]}" |
|
|
| print(" PASS test_flash_vq_cpu_cosine_sim") |
|
|
|
|
| def test_flash_vq_cpu_ema_update(): |
| """ |
| Test 4: FlashVQCodebook CPU EMA update changes embed and cluster_size |
| after forward pass (with rotation_trick=False for deterministic EMA). |
| |
| Tests EMA in isolation by calling _ema_update directly, then verifies |
| embed and cluster_size changed for assigned codebook entries. |
| """ |
| vq = _make_cpu_vq(rotation_trick=False) |
| embed_before = vq.embed.clone() |
| cluster_size_before = vq.cluster_size.clone() |
|
|
| |
| x = torch.randn(2, 8, 32) |
| x_flat = x.reshape(-1, 32) |
| |
| indices = torch.zeros(16, dtype=torch.long) |
| |
| for i in range(16): |
| indices[i] = i % 4 |
|
|
| |
| vq._ema_update(x_flat, indices) |
|
|
| |
| assert not torch.equal(embed_before, vq.embed), \ |
| "Embed did not change after EMA update" |
| |
| assert not torch.equal(cluster_size_before, vq.cluster_size), \ |
| "cluster_size did not change after EMA update" |
| |
| |
| assert (vq.cluster_size[:4] > 0).all(), \ |
| "Assigned entries should have non-zero cluster_size" |
| assert (vq.cluster_size[4:] == 0).all(), \ |
| "Unassigned entries should have zero cluster_size" |
|
|
| |
| |
| vq2 = _make_cpu_vq(rotation_trick=False) |
| embed_before2 = vq2.embed.clone() |
| q, idx, loss = vq2._cpu_forward(torch.randn(4, 16, 32).reshape(-1, 32)) |
| assert not torch.equal(embed_before2, vq2.embed), \ |
| "Embed did not change after full forward pass" |
|
|
| print(" PASS test_flash_vq_cpu_ema_update") |
|
|
|
|
| def test_flash_vq_cpu_dead_code_reset(): |
| """ |
| Test 5: FlashVQCodebook CPU dead code reset replaces inactive codebook entries. |
| """ |
| vq = _make_cpu_vq() |
| |
| vq.cluster_size[:] = 0.0 |
| |
| vq.cluster_size[:10] = 5.0 |
|
|
| x = torch.randn(2, 8, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| |
| embed_before = vq.embed.clone() |
| n_dead_before = vq.get_dead_code_count() |
| assert n_dead_before == vq.codebook_size - 10, \ |
| f"Expected {vq.codebook_size - 10} dead entries, got {n_dead_before}" |
|
|
| |
| vq._dead_code_reset(x_flat) |
|
|
| |
| |
| n_dead_after = vq.get_dead_code_count() |
| |
| dead_indices_before_10 = torch.where(vq.cluster_size == 0)[0] |
| |
| if len(dead_indices_before_10) > 0: |
| idx = dead_indices_before_10[0] |
| assert not torch.equal(embed_before[idx], vq.embed[idx]), \ |
| f"Dead entry {idx} embed was not replaced" |
|
|
| print(" PASS test_flash_vq_cpu_dead_code_reset") |
|
|
|
|
| def test_flash_vq_cpu_rotation_trick_grad(): |
| """ |
| Test 6: FlashVQCodebook CPU rotation trick gradient flows correctly. |
| Gradient should not be zero, and should differ from STE gradient. |
| """ |
| torch.manual_seed(42) |
|
|
| |
| vq_rot = _make_cpu_vq(rotation_trick=True, seed=42) |
| x = torch.randn(2, 4, 32, requires_grad=True) |
| x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True) |
|
|
| |
| quantized_rot, indices_rot, loss_rot = vq_rot._cpu_forward(x_flat) |
|
|
| |
| loss_val = quantized_rot.sum() |
| loss_val.backward() |
| rot_grad = x_flat.grad.clone() |
|
|
| assert rot_grad is not None, "Rotation trick gradient is None" |
| assert rot_grad.abs().sum().item() > 0, "Rotation trick gradient is all zeros" |
|
|
| |
| torch.manual_seed(42) |
| vq_ste = _make_cpu_vq(rotation_trick=False, seed=42) |
| x_flat2 = x.reshape(-1, 32).detach().clone().requires_grad_(True) |
|
|
| quantized_ste, indices_ste, loss_ste = vq_ste._cpu_forward(x_flat2) |
| loss_val_ste = quantized_ste.sum() |
| loss_val_ste.backward() |
| ste_grad = x_flat2.grad.clone() |
|
|
| |
| |
| if torch.equal(indices_rot, indices_ste): |
| grad_diff = (rot_grad - ste_grad).abs().max().item() |
| assert grad_diff > 1e-8, \ |
| f"Rotation trick gradient equals STE gradient (diff={grad_diff})" |
|
|
| print(" PASS test_flash_vq_cpu_rotation_trick_grad") |
|
|
|
|
| def test_flash_vq_cpu_commitment_loss(): |
| """ |
| Test 7: FlashVQCodebook CPU commitment loss is non-negative scalar. |
| """ |
| vq = _make_cpu_vq(rotation_trick=False) |
| x = torch.randn(4, 16, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| quantized, indices, loss = vq._cpu_forward(x_flat) |
|
|
| assert loss.item() >= 0.0, f"Commitment loss is negative: {loss.item()}" |
| assert loss.dim() == 0, f"Loss is not scalar: {loss.shape}" |
|
|
| |
| expected_loss = F.mse_loss(x_flat, quantized.detach()) |
| assert torch.allclose(loss, expected_loss, atol=1e-6), \ |
| f"Loss mismatch: {loss.item()} vs {expected_loss.item()}" |
|
|
| print(" PASS test_flash_vq_cpu_commitment_loss") |
|
|
|
|
| |
|
|
| def _make_gpu_vq(codebook_size=8192, codebook_dim=32, seed=42, rotation_trick=True): |
| """Create a deterministic FlashVQCodebook on GPU.""" |
| vq = _make_cpu_vq(codebook_size, codebook_dim, seed, rotation_trick) |
| vq = vq.cuda() |
| return vq |
|
|
|
|
| def test_flash_vq_gpu_vs_cpu_forward(): |
| """ |
| Test 8: FlashVQCodebook GPU forward output matches CPU forward output |
| within atol=1e-3. |
| """ |
| if not torch.cuda.is_available() or not _HAS_TRITON: |
| print(" SKIP test_flash_vq_gpu_vs_cpu_forward (CUDA/Triton unavailable)") |
| return |
|
|
| torch.manual_seed(42) |
| vq_cpu = _make_cpu_vq(rotation_trick=False) |
| vq_gpu = _make_gpu_vq(rotation_trick=False) |
|
|
| x = torch.randn(2, 8, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| quantized_cpu, indices_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) |
| x_gpu = x_flat.detach().clone().cuda() |
| quantized_gpu, indices_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) |
|
|
| quantized_gpu_cpu = quantized_gpu.cpu() |
| loss_gpu_cpu = loss_gpu.cpu() |
|
|
| |
| fwd_diff = (quantized_cpu - quantized_gpu_cpu).abs().max().item() |
| assert fwd_diff < 1e-3, \ |
| f"CPU vs GPU quantized max diff: {fwd_diff} (exceeds 1e-3)" |
|
|
| |
| assert torch.equal(indices_cpu, indices_gpu.cpu()), \ |
| "CPU vs GPU indices differ" |
|
|
| |
| loss_diff = abs(loss_cpu.item() - loss_gpu_cpu.item()) |
| assert loss_diff < 1e-3, \ |
| f"CPU vs GPU loss diff: {loss_diff}" |
|
|
| print(f" PASS test_flash_vq_gpu_vs_cpu_forward (fwd_diff={fwd_diff:.6f})") |
|
|
|
|
| def test_flash_vq_gpu_vs_cpu_gradients(): |
| """ |
| Test 9: FlashVQCodebook GPU gradient (rotation trick backward) matches |
| CPU gradient within atol=1e-3. |
| """ |
| if not torch.cuda.is_available() or not _HAS_TRITON: |
| print(" SKIP test_flash_vq_gpu_vs_cpu_gradients (CUDA/Triton unavailable)") |
| return |
|
|
| torch.manual_seed(42) |
| vq_cpu = _make_cpu_vq(rotation_trick=True, seed=42) |
| vq_gpu = _make_gpu_vq(rotation_trick=True, seed=42) |
|
|
| x = torch.randn(2, 4, 32) |
| x_flat = x.reshape(-1, 32).detach().clone().requires_grad_(True) |
|
|
| |
| q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) |
| q_cpu.sum().backward() |
| cpu_grad = x_flat.grad.clone() |
|
|
| |
| x_gpu = x_flat.detach().clone().cuda().requires_grad_(True) |
| q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) |
| q_gpu.sum().backward() |
| gpu_grad = x_gpu.grad.clone() |
|
|
| bwd_diff = (cpu_grad - gpu_grad.cpu()).abs().max().item() |
| assert bwd_diff < 1e-3, \ |
| f"CPU vs GPU gradient max diff: {bwd_diff} (exceeds 1e-3)" |
|
|
| print(f" PASS test_flash_vq_gpu_vs_cpu_gradients (bwd_diff={bwd_diff:.6f})") |
|
|
|
|
| def test_flash_vq_gpu_small_codebook(): |
| """ |
| Test 10: FlashVQCodebook GPU path with codebook_size=4096 also matches |
| CPU path (multi-codebook support per D-102). |
| """ |
| if not torch.cuda.is_available() or not _HAS_TRITON: |
| print(" SKIP test_flash_vq_gpu_small_codebook (CUDA/Triton unavailable)") |
| return |
|
|
| torch.manual_seed(42) |
| vq_cpu = _make_cpu_vq(codebook_size=4096, rotation_trick=False) |
| vq_gpu = _make_gpu_vq(codebook_size=4096, rotation_trick=False) |
|
|
| x = torch.randn(2, 8, 32) |
| x_flat = x.reshape(-1, 32) |
|
|
| q_cpu, idx_cpu, loss_cpu = vq_cpu._cpu_forward(x_flat) |
| x_gpu = x_flat.detach().clone().cuda() |
| q_gpu, idx_gpu, loss_gpu = vq_gpu._triton_forward(x_gpu) |
|
|
| fwd_diff = (q_cpu - q_gpu.cpu()).abs().max().item() |
| assert fwd_diff < 1e-3, \ |
| f"CPU vs GPU (4096) quantized max diff: {fwd_diff}" |
| assert torch.equal(idx_cpu, idx_gpu.cpu()), \ |
| "CPU vs GPU (4096) indices differ" |
|
|
| print(f" PASS test_flash_vq_gpu_small_codebook (fwd_diff={fwd_diff:.6f})") |
|
|
|
|
| |
|
|
| def test_flash_vq_in_vqadapter(): |
| """ |
| Test 11: VQAdapter with FlashVQCodebook forward produces correct shapes |
| and all VQAdapter methods work (get_codebook_utilization, get_dead_code_count, |
| l2_distance_matching). |
| """ |
| if not _HAS_TRIGRAM: |
| print(" SKIP test_flash_vq_in_vqadapter (trigram.py not importable)") |
| return |
|
|
| vq = VQAdapter(codebook_size=128, codebook_dim=32, tscale_type=TScaleType.T4) |
| |
| vq.vq.embed.data = torch.randn(128, 32) * 0.02 |
| vq.vq.cluster_size.data.zero_() |
| vq.eval() |
|
|
| x = torch.randn(2, 8, 512) |
|
|
| with torch.no_grad(): |
| output, vq_loss, indices = vq(x) |
|
|
| |
| assert output.shape == (2, 8, 512), f"output shape: {output.shape}" |
| |
| assert vq_loss.numel() == 1, f"vq_loss shape: {vq_loss.shape}" |
| |
| assert indices.shape == (2, 8), f"indices shape: {indices.shape}" |
| |
| assert indices.min() >= 0, f"negative index: {indices.min()}" |
| assert indices.max() < 128, f"index too large: {indices.max()}" |
|
|
| |
| util = vq.get_codebook_utilization() |
| assert isinstance(util, float), f"util type: {type(util)}" |
| assert 0.0 <= util <= 1.0, f"util out of range: {util}" |
|
|
| |
| dead = vq.get_dead_code_count() |
| assert isinstance(dead, (int, type(torch.tensor(0).item()))), f"dead type: {type(dead)}" |
| dead_val = int(dead) |
| assert dead_val >= 0, f"dead count negative: {dead_val}" |
|
|
| |
| x_codebook_dim = x[..., :32] |
| with torch.no_grad(): |
| l2_idx, l2_dist = vq.l2_distance_matching(x_codebook_dim) |
| assert l2_idx.shape == (2, 8), f"l2 indices shape: {l2_idx.shape}" |
| assert l2_dist.shape == (2, 8), f"l2 distances shape: {l2_dist.shape}" |
| assert l2_dist.min() >= 0.0, "l2 distance should be non-negative" |
|
|
| print(" PASS test_flash_vq_in_vqadapter") |
|
|
|
|
| def test_flash_vq_multimodal_bridge(): |
| """ |
| Test 12: MultimodalVQBridge with FlashVQCodebook — all three VQAdapters |
| (text, image, audio) produce correct outputs. |
| """ |
| if not _HAS_TRIGRAM: |
| print(" SKIP test_flash_vq_multimodal_bridge (trigram.py not importable)") |
| return |
|
|
| bridge = MultimodalVQBridge( |
| text_codebook_size=256, |
| image_codebook_size=128, |
| audio_codebook_size=128, |
| codebook_dim=32, |
| enable_image=True, |
| enable_audio=True, |
| ) |
| bridge.eval() |
|
|
| x = torch.randn(2, 8, 512) |
| with torch.no_grad(): |
| text_out, text_loss, text_idx = bridge.text_vq(x) |
| image_out, image_loss, image_idx = bridge.image_vq(x) |
| audio_out, audio_loss, audio_idx = bridge.audio_vq(x) |
|
|
| assert text_out.shape == (2, 8, 512), f"text output shape: {text_out.shape}" |
| assert image_out.shape == (2, 8, 512), f"image output shape: {image_out.shape}" |
| assert audio_out.shape == (2, 8, 512), f"audio output shape: {audio_out.shape}" |
|
|
| assert text_idx.max() < 256, f"text index too large: {text_idx.max()}" |
| assert image_idx.max() < 128, f"image index too large: {image_idx.max()}" |
| assert audio_idx.max() < 128, f"audio index too large: {audio_idx.max()}" |
|
|
| |
| all_util = bridge.get_codebook_utilization() |
| assert 'text' in all_util |
| assert 'image' in all_util |
| assert 'audio' in all_util |
| for mod, u in all_util.items(): |
| assert 0.0 <= u <= 1.0, f"{mod} utilization out of range: {u}" |
|
|
| print(" PASS test_flash_vq_multimodal_bridge") |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| cpu_tests = [ |
| test_flash_vq_cpu_forward_shapes, |
| test_flash_vq_cpu_quantized_matches_codebook, |
| test_flash_vq_cpu_cosine_sim, |
| test_flash_vq_cpu_ema_update, |
| test_flash_vq_cpu_dead_code_reset, |
| test_flash_vq_cpu_rotation_trick_grad, |
| test_flash_vq_cpu_commitment_loss, |
| ] |
| gpu_tests = [ |
| test_flash_vq_gpu_vs_cpu_forward, |
| test_flash_vq_gpu_vs_cpu_gradients, |
| test_flash_vq_gpu_small_codebook, |
| ] |
| integration_tests = [ |
| test_flash_vq_in_vqadapter, |
| test_flash_vq_multimodal_bridge, |
| ] |
| all_tests = cpu_tests + gpu_tests + integration_tests |
|
|
| print("Running FlashVQ tests...\n") |
| passed = 0 |
| failed = 0 |
| skipped = 0 |
| for test in all_tests: |
| try: |
| test() |
| passed += 1 |
| except Exception as e: |
| msg = str(e) |
| if msg.startswith(" SKIP"): |
| print(msg) |
| skipped += 1 |
| else: |
| print(f" FAIL {test.__name__}: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| total_run = passed + failed |
| print(f"\n{passed} passed, {failed} failed, {skipped} skipped out of {len(all_tests)} tests (attempted {total_run})") |
| sys.exit(1 if failed > 0 else 0) |
|
|