"""Tests for HyperNetwork and ConditionEncoder.""" import pytest import torch from model.hypernetwork import HyperNetwork from model.condition_encoder import ConditionEncoder from model.lora import LoRAConfig, compute_total_lora_params class TestConditionEncoder: """Test conditioning vector generation.""" def test_forward_shape(self): enc = ConditionEncoder(n_cameras=100, scene_input_dim=2048, query_input_dim=2048, out_dim=256) cam_ids = torch.tensor([0, 1, 2]) scene = torch.randn(3, 2048) query = torch.randn(3, 2048) out = enc(cam_ids, scene, query) assert out.shape == (3, 256) def test_forward_no_camera(self): enc = ConditionEncoder(n_cameras=100, out_dim=256) scene = torch.randn(2, 2048) query = torch.randn(2, 2048) out = enc.forward_no_camera(scene, query) assert out.shape == (2, 256) def test_different_cameras_different_outputs(self): enc = ConditionEncoder(n_cameras=100, out_dim=256) scene = torch.randn(1, 2048) query = torch.randn(1, 2048) out_cam0 = enc(torch.tensor([0]), scene, query) out_cam1 = enc(torch.tensor([1]), scene, query) # Different cameras should produce different conditioning assert not torch.allclose(out_cam0, out_cam1, atol=1e-3) class TestHyperNetwork: """Test LoRA parameter generation.""" def test_output_shapes(self): lora_config = LoRAConfig(rank=8, targets=("q", "v")) hn = HyperNetwork( cond_dim=256, hidden_dim=128, lora_config=lora_config, num_decoder_blocks=6, decoder_embed_dim=512, ) cond = torch.randn(2, 256) params, sigma = hn(cond) expected_count = compute_total_lora_params(6, 512, 8, ("q", "v")) assert params.shape == (2, expected_count) assert sigma.shape == (2, 1) def test_sigma_positive(self): hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256) cond = torch.randn(4, 256) _, sigma = hn(cond) assert (sigma > 0).all() def test_confidence_range(self): hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256) cond = torch.randn(4, 256) _, sigma = hn(cond) conf = hn.compute_confidence(sigma) assert (conf >= 0).all() and (conf <= 1).all() def test_summary(self): lora_config = LoRAConfig(rank=16, targets=("q", "v")) hn = HyperNetwork(lora_config=lora_config, num_decoder_blocks=12, decoder_embed_dim=1024) summary = hn.summary() assert summary["generated_params"] == compute_total_lora_params(12, 1024, 16, ("q", "v")) assert summary["own_params"] > 0 assert summary["generation_ratio"] > 0 def test_different_conditions_different_params(self): hn = HyperNetwork(cond_dim=256, hidden_dim=128, num_decoder_blocks=3, decoder_embed_dim=256) cond1 = torch.randn(1, 256) cond2 = torch.randn(1, 256) params1, _ = hn(cond1) params2, _ = hn(cond2) assert not torch.allclose(params1, params2, atol=1e-3)