"""Video/latent comprehension tests — verify VideoHead on CPU. Tests: VideoHead forward, cross-attention conditioning, ACT halting, latent shape compatibility with pig-vae. Runs on CPU — quick smoke tests only (no full video decode). """ import os, sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) import torch from arbitor.kernel.ternary_scale import TScaleType device = "cpu" FAILED = 0 def check(name, condition, detail=""): global FAILED if condition: print(f" ✓ {name}") else: print(f" ✗ {name} — {detail}") FAILED += 1 print("\n=== Video / Latent Comprehension ===\n") from arbitor import VideoHead, HIDDEN_DIM # 1. VideoHead forward head = VideoHead() relational = torch.randn(2, 10, HIDDEN_DIM) latents = head(relational) check("VideoHead forward", latents is not None, "got None") check("Latent shape", latents.shape == (2, 16, 1, 32, 32), f"got {latents.shape}") check("No NaN in latents", not torch.isnan(latents).any()) check("Latents finite", torch.isfinite(latents).all()) # 2. ACT halting (should stop early for clear conditioning) head2 = VideoHead(max_steps=6) relational_clear = torch.randn(2, 10, HIDDEN_DIM) * 10 # strong signal latents2 = head2(relational_clear) check("ACT halting with strong signal", latents2 is not None) # 3. Latents with different batch sizes latents3 = head(torch.randn(1, 5, HIDDEN_DIM)) check("Batch=1 works", latents3.shape == (1, 16, 1, 32, 32)) latents4 = head(torch.randn(4, 8, HIDDEN_DIM)) check("Batch=4 works", latents4.shape == (4, 16, 1, 32, 32)) # 4. pig-vae latent compatibility check # The pig-vae expects [B, 16, T, H, W] latents check("Latent channels = 16", latents.shape[1] == 16) check("Latent spatial = 32", latents.shape[3] == 32 and latents.shape[4] == 32) check("Latent temporal = 1 (single frame)", latents.shape[2] == 1) # 5. Model with VideoHead from arbitor import ARBModel model = ARBModel(enable_image=False, enable_audio=False, enable_vq=False, enable_graph=False, enable_memory_modules=False, enable_moe=False) model.eval() x = torch.randint(0, 256, (2, 10)) with torch.no_grad(): video_latents = model.video_head(relational) check("VideoHead in model pipeline", video_latents.shape == (2, 16, 1, 32, 32)) # 6. Quantization effects (VideoHead uses TernaryScaleTensor internally) params = sum(p.numel() for p in head.parameters() if not hasattr(p, 'T_packed')) ternary_buffers = sum(b.numel() for n, b in head.named_buffers() if 'T_packed' in n) check("VideoHead has ternary weights", ternary_buffers > 0, f"{ternary_buffers} packed ternary entries") check("VideoHead minimal float params", params < 5000, f"{params} float params") print(f"\n{'='*50}") if FAILED == 0: print("✓ All video comprehension tests passed!") else: print(f"✗ {FAILED} test(s) failed") sys.exit(FAILED)