arcisvlm / tests /test_hypernetwork.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
3.2 kB
"""Tests for HyperNetwork and ConditionEncoder."""
import pytest
import torch
from model.hypernetwork import HyperNetwork
from model.condition_encoder import ConditionEncoder
from model.lora import LoRAConfig, compute_total_lora_params
class TestConditionEncoder:
"""Test conditioning vector generation."""
def test_forward_shape(self):
enc = ConditionEncoder(n_cameras=100, scene_input_dim=2048, query_input_dim=2048, out_dim=256)
cam_ids = torch.tensor([0, 1, 2])
scene = torch.randn(3, 2048)
query = torch.randn(3, 2048)
out = enc(cam_ids, scene, query)
assert out.shape == (3, 256)
def test_forward_no_camera(self):
enc = ConditionEncoder(n_cameras=100, out_dim=256)
scene = torch.randn(2, 2048)
query = torch.randn(2, 2048)
out = enc.forward_no_camera(scene, query)
assert out.shape == (2, 256)
def test_different_cameras_different_outputs(self):
enc = ConditionEncoder(n_cameras=100, out_dim=256)
scene = torch.randn(1, 2048)
query = torch.randn(1, 2048)
out_cam0 = enc(torch.tensor([0]), scene, query)
out_cam1 = enc(torch.tensor([1]), scene, query)
# Different cameras should produce different conditioning
assert not torch.allclose(out_cam0, out_cam1, atol=1e-3)
class TestHyperNetwork:
"""Test LoRA parameter generation."""
def test_output_shapes(self):
lora_config = LoRAConfig(rank=8, targets=("q", "v"))
hn = HyperNetwork(
cond_dim=256, hidden_dim=128,
lora_config=lora_config,
num_decoder_blocks=6, decoder_embed_dim=512,
)
cond = torch.randn(2, 256)
params, sigma = hn(cond)
expected_count = compute_total_lora_params(6, 512, 8, ("q", "v"))
assert params.shape == (2, expected_count)
assert sigma.shape == (2, 1)
def test_sigma_positive(self):
hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256)
cond = torch.randn(4, 256)
_, sigma = hn(cond)
assert (sigma > 0).all()
def test_confidence_range(self):
hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256)
cond = torch.randn(4, 256)
_, sigma = hn(cond)
conf = hn.compute_confidence(sigma)
assert (conf >= 0).all() and (conf <= 1).all()
def test_summary(self):
lora_config = LoRAConfig(rank=16, targets=("q", "v"))
hn = HyperNetwork(lora_config=lora_config, num_decoder_blocks=12, decoder_embed_dim=1024)
summary = hn.summary()
assert summary["generated_params"] == compute_total_lora_params(12, 1024, 16, ("q", "v"))
assert summary["own_params"] > 0
assert summary["generation_ratio"] > 0
def test_different_conditions_different_params(self):
hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256)
cond1 = torch.randn(1, 256)
cond2 = torch.randn(1, 256)
params1, _ = hn(cond1)
params2, _ = hn(cond2)
assert not torch.allclose(params1, params2, atol=1e-3)