| """ |
| MR-JEPA Architecture Validation Test. |
| |
| Tests the complete forward pass with synthetic data to verify: |
| 1. All modules instantiate correctly |
| 2. Tensor shapes are consistent throughout |
| 3. JEPA loss computes correctly |
| 4. Target encoder EMA updates work |
| 5. Both MC and open-ended heads produce valid output |
| 6. Ablation controls work (no-JEPA, no-rollout, no-evidence-gate) |
| 7. Loss function variants (smooth_l1, mse, cosine) |
| 8. Anti-collapse regularizations (SIGReg, VICReg) |
| 9. Parameter counting is correct |
| |
| Run from repo root: python test_architecture.py |
| """ |
|
|
| import os |
| import sys |
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from mr_jepa.configs.model_config import ( |
| MRJEPAConfig, VisualBackboneConfig, TextEncoderConfig, |
| EvidenceMemoryConfig, LatentRolloutConfig, JEPAObjectiveConfig, |
| AnswerHeadConfig, TrainingPhaseConfig, |
| ) |
| from mr_jepa.models.evidence_memory import EvidenceMemory |
| from mr_jepa.models.latent_rollout import LatentRolloutModule |
| from mr_jepa.models.target_encoder import TargetEncoder, JEPALoss, SIGRegLoss, VICRegLoss |
| from mr_jepa.models.answer_heads import DiscriminativeHead, GenerativeHead |
|
|
|
|
| def test_evidence_memory(): |
| print("\n=== Test: Evidence Memory ===") |
| config = EvidenceMemoryConfig(hidden_dim=256, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4, dropout=0.1) |
| visual_dim, text_dim, B, N_v, N_t = 512, 384, 4, 49, 32 |
| model = EvidenceMemory(config, visual_dim=visual_dim, text_dim=text_dim) |
| visual_tokens = torch.randn(B, N_v, visual_dim) |
| text_tokens = torch.randn(B, N_t, text_dim) |
| text_mask = torch.ones(B, N_t); text_mask[:, -5:] = 0 |
| output = model(visual_tokens, text_tokens, text_mask) |
| evidence = output['evidence_tokens'] |
| assert evidence.shape == (B, config.num_evidence_tokens, config.hidden_dim) |
| print(f" Evidence shape: {evidence.shape}"); print(" ✓ passed!") |
|
|
|
|
| def test_latent_rollout(): |
| print("\n=== Test: Latent Rollout ===") |
| config = LatentRolloutConfig(hidden_dim=256, num_state_tokens=8, K=3, num_predictor_layers=2, num_heads=4, ffn_dim=512, dropout=0.1, use_evidence_gate=True, gate_type="sigmoid", use_step_embedding=True) |
| B, N_e = 4, 16 |
| model = LatentRolloutModule(config) |
| output = model(torch.randn(B, N_e, config.hidden_dim)) |
| assert output['trajectory'].shape == (B, config.K + 1, config.num_state_tokens, config.hidden_dim) |
| assert output['z_final'].shape == (B, config.num_state_tokens, config.hidden_dim) |
| assert output['z_projected'].shape == output['trajectory'].shape |
| print(f" Trajectory: {output['trajectory'].shape}"); print(" ✓ passed!") |
|
|
|
|
| def test_target_encoder_and_jepa_loss(): |
| print("\n=== Test: Target Encoder + JEPA Loss ===") |
| D, N_e, N_s, K, B = 256, 16, 8, 3, 4 |
| visual_dim, text_dim = 512, 384 |
| ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4) |
| ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512) |
| j_cfg = JEPAObjectiveConfig(ema_momentum_base=0.996, ema_momentum_end=1.0, use_sigreg=True, sigreg_weight=0.1) |
| evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim) |
| rollout = LatentRolloutModule(ro_cfg) |
| target_enc = TargetEncoder(evidence_mem, rollout, j_cfg) |
| orig = list(target_enc.target_rollout.parameters())[0].clone() |
| with torch.no_grad(): |
| for p in rollout.parameters(): p.add_(torch.randn_like(p) * 0.1) |
| target_enc.update_ema(evidence_mem, rollout, step=100, total_steps=1000) |
| assert not torch.allclose(orig, list(target_enc.target_rollout.parameters())[0]), "EMA did not update!" |
| print(f" EMA momentum: {target_enc._current_momentum:.6f}") |
| target_output = target_enc(torch.randn(B, 49, visual_dim), torch.randn(B, 32, text_dim), torch.ones(B, 32)) |
| assert target_output['target_trajectory'].shape == (B, K + 1, N_s, D) |
| jepa_loss_fn = JEPALoss(j_cfg, D) |
| pred_traj = torch.randn(B, K + 1, N_s, D, requires_grad=True) |
| loss_dict = jepa_loss_fn(pred_traj, target_output['target_trajectory'], torch.tensor(1.5)) |
| loss_dict['total_loss'].backward() |
| assert pred_traj.grad is not None, "No gradients!" |
| print(f" Total loss: {loss_dict['total_loss'].item():.4f}, grad norm: {pred_traj.grad.norm().item():.4f}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_answer_heads(): |
| print("\n=== Test: Answer Heads ===") |
| D, text_dim, B, N_s, max_opts, vocab_size = 256, 384, 4, 8, 4, 1000 |
| head_config = AnswerHeadConfig(disc_hidden_dim=256, disc_num_layers=2, max_num_options=max_opts, gen_hidden_dim=256, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=32) |
| disc_head = DiscriminativeHead(head_config, hidden_dim=D, text_dim=text_dim) |
| z_final = torch.randn(B, N_s, D) |
| option_mask = torch.tensor([[True,True,True,True],[True,True,True,False],[True,True,False,False],[True,True,True,True]]) |
| disc_output = disc_head(z_final, torch.randn(B, max_opts, text_dim), option_mask) |
| assert disc_output['logits'][2, 2] == float('-inf'), "Masked option should be -inf!" |
| gen_head = GenerativeHead(head_config, hidden_dim=D, vocab_size=vocab_size) |
| gen_output = gen_head(z_final, torch.randint(0, vocab_size, (B, 16))) |
| generated = gen_head.generate(z_final, start_token_id=1, max_length=10) |
| print(f" Disc logits: {disc_output['logits'].shape}, Gen loss: {gen_output['loss'].item():.4f}, Generated: {generated.shape}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_sigreg_and_vicreg(): |
| print("\n=== Test: SIGReg + VICReg ===") |
| D, B, N = 256, 32, 8 |
| sigreg = SIGRegLoss(D, num_projections=64) |
| z_rand = torch.randn(B, N, D) |
| z_coll = torch.ones(B, N, D) |
| loss_rand = sigreg(z_rand) |
| loss_coll = sigreg(z_coll) |
| assert loss_coll > loss_rand, "SIGReg should penalize collapsed representations more!" |
| vicreg = VICRegLoss(var_weight=1.0, cov_weight=0.04) |
| loss_vic = vicreg(z_rand) |
| print(f" SIGReg random={loss_rand.item():.4f}, collapsed={loss_coll.item():.4f}; VICReg={loss_vic.item():.4f}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_parameter_counting(): |
| print("\n=== Test: Parameter Counting ===") |
| D = 256 |
| ev = EvidenceMemory(EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4), visual_dim=512, text_dim=384) |
| ro = LatentRolloutModule(LatentRolloutConfig(hidden_dim=D, num_state_tokens=8, K=3, num_predictor_layers=3, num_heads=4, ffn_dim=512)) |
| print(f" Evidence: {sum(p.numel() for p in ev.parameters()):,}, Rollout: {sum(p.numel() for p in ro.parameters()):,}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_trajectory_metrics(): |
| print("\n=== Test: Trajectory Metrics ===") |
| from mr_jepa.utils.visualization import compute_trajectory_metrics, visualize_trajectory |
| B, K, N_s, D = 4, 3, 8, 256 |
| trajectory = torch.randn(B, K + 1, N_s, D) |
| for k in range(1, K + 1): |
| trajectory[:, k] = trajectory[:, k-1] + torch.randn(B, N_s, D) * (0.5 ** k) |
| metrics = compute_trajectory_metrics(trajectory) |
| viz = visualize_trajectory(trajectory[0], method='pca') |
| assert metrics['convergence_rate'] < 1.0 |
| print(f" Convergence rate: {metrics['convergence_rate']:.4f}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_evaluation_metrics(): |
| print("\n=== Test: Evaluation Metrics ===") |
| from mr_jepa.evaluation.metrics import compute_accuracy, compute_anls, compute_vqa_accuracy, compute_relaxed_accuracy |
| assert compute_accuracy([0,1,2,0], [0,1,1,0])['accuracy'] == 75.0 |
| compute_anls(["hello world", "test"], [["hello world"], ["testing"]]) |
| compute_vqa_accuracy(["cat"], [["cat"]*10]) |
| compute_relaxed_accuracy(["100","hello"], ["100","hello"], types=["human_test","human_test"]) |
| print(" All metrics compute correctly") |
| print(" ✓ passed!") |
|
|
|
|
| def test_end_to_end_forward(): |
| print("\n=== Test: End-to-End Forward Pass ===") |
| D, B, N_v, N_t, N_e, N_s, K = 256, 2, 49, 32, 16, 8, 3 |
| max_opts, vocab_size, visual_dim, text_dim = 4, 100, 512, 384 |
| ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4) |
| ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512) |
| j_cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1) |
| h_cfg = AnswerHeadConfig(disc_hidden_dim=D, gen_hidden_dim=D, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=16) |
| evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim) |
| rollout = LatentRolloutModule(ro_cfg) |
| target_enc = TargetEncoder(evidence_mem, rollout, j_cfg) |
| disc_head = DiscriminativeHead(h_cfg, D, text_dim) |
| gen_head = GenerativeHead(h_cfg, D, vocab_size) |
| jepa_loss_fn = JEPALoss(j_cfg, D) |
| vis = torch.randn(B, N_v, visual_dim); txt = torch.randn(B, N_t, text_dim); mask = torch.ones(B, N_t) |
| evidence = evidence_mem(vis, txt, mask)['evidence_tokens'] |
| rollout_out = rollout(evidence) |
| target_out = target_enc(vis, txt, mask) |
| disc_out = disc_head(rollout_out['z_final'], torch.randn(B, max_opts, text_dim), torch.ones(B, max_opts, dtype=torch.bool)) |
| task_loss = nn.functional.cross_entropy(disc_out['logits'], torch.tensor([1, 3])) |
| gen_out = gen_head(rollout_out['z_final'], torch.randint(0, vocab_size, (B, 16)), evidence) |
| loss_dict = jepa_loss_fn(rollout_out['z_projected'], target_out['target_trajectory'], task_loss, gen_out['loss']) |
| loss_dict['total_loss'].backward() |
| target_enc.update_ema(evidence_mem, rollout, step=1, total_steps=100) |
| ev_grads = sum(1 for p in evidence_mem.parameters() if p.grad is not None) |
| ro_grads = sum(1 for p in rollout.parameters() if p.grad is not None) |
| print(f" Total loss: {loss_dict['total_loss'].item():.4f}, EV grads: {ev_grads}, RO grads: {ro_grads}") |
| print(" ✓ passed!") |
|
|
|
|
| |
| |
| |
|
|
| def test_ablation_no_rollout(): |
| """K=0 produces only z0.""" |
| print("\n=== Ablation: --no_rollout (K=0) ===") |
| D, B, N_e, N_s = 256, 2, 16, 8 |
| config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=0, num_predictor_layers=2, num_heads=4, ffn_dim=512) |
| rollout = LatentRolloutModule(config) |
| output = rollout(torch.randn(B, N_e, D)) |
| assert output['trajectory'].shape[1] == 1, f"Expected 1, got {output['trajectory'].shape[1]}" |
| print(f" Trajectory: {output['trajectory'].shape} (K=0 → 1 step)") |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_no_evidence_gate(): |
| """Disabling gate passes evidence through unchanged.""" |
| print("\n=== Ablation: --no_evidence_gate ===") |
| D, B, N_e, N_s, K = 256, 2, 16, 8, 3 |
| config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512, use_evidence_gate=False) |
| rollout = LatentRolloutModule(config) |
| |
| for i, layer in enumerate(rollout.predictor_layers): |
| assert layer.evidence_gate.gate_type == "none", f"Layer {i}: expected gate_type='none', got '{layer.evidence_gate.gate_type}'" |
| output = rollout(torch.randn(B, N_e, D)) |
| assert output['trajectory'].shape == (B, K + 1, N_s, D) |
| print(f" All {len(rollout.predictor_layers)} layers have gate_type='none'") |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_k_variants(): |
| """Different rollout depths.""" |
| print("\n=== Ablation: K variants (1, 5, 7) ===") |
| D, B, N_e, N_s = 256, 2, 16, 8 |
| for K in [1, 5, 7]: |
| config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512) |
| output = LatentRolloutModule(config)(torch.randn(B, N_e, D)) |
| assert output['trajectory'].shape[1] == K + 1 |
| print(f" K={K}: trajectory len={output['trajectory'].shape[1]} ✓") |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_loss_functions(): |
| """smooth_l1, mse, cosine losses all compute.""" |
| print("\n=== Ablation: loss_fn variants ===") |
| D, K, B, N_s = 256, 3, 2, 8 |
| pred = torch.randn(B, K + 1, N_s, D) |
| target = torch.randn(B, K + 1, N_s, D) |
| task = torch.tensor(1.0) |
| for fn in ["smooth_l1", "mse", "cosine"]: |
| cfg = JEPAObjectiveConfig(jepa_loss_fn=fn, use_sigreg=False) |
| loss = JEPALoss(cfg, D)(pred, target, task) |
| print(f" {fn}: jepa={loss['jepa_loss'].item():.4f}, total={loss['total_loss'].item():.4f}") |
| assert loss['total_loss'].item() > 0 |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_sigreg_vs_vicreg(): |
| """SIGReg, VICReg, and both produce non-zero reg.""" |
| print("\n=== Ablation: SIGReg vs VICReg ===") |
| D, K, B, N_s = 256, 3, 2, 8 |
| pred = torch.randn(B, K + 1, N_s, D) |
| target = torch.randn(B, K + 1, N_s, D) |
| task = torch.tensor(1.0) |
| |
| for label, sigreg, vicreg in [("SIGReg", True, False), ("VICReg", False, True), ("Both", True, True)]: |
| cfg = JEPAObjectiveConfig(use_sigreg=sigreg, sigreg_weight=0.1, use_vicreg=vicreg, vicreg_var_weight=1.0, vicreg_cov_weight=0.04) |
| loss = JEPALoss(cfg, D)(pred, target, task) |
| print(f" {label}: reg={loss['reg_loss'].item():.4f}") |
| assert loss['reg_loss'].item() > 0, f"{label} reg should be > 0" |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_no_jepa(): |
| """no_jepa: model forward should skip JEPA entirely.""" |
| print("\n=== Ablation: --no_jepa ===") |
| D, K, B, N_s = 256, 3, 2, 8 |
| cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1) |
| loss_fn = JEPALoss(cfg, D) |
| pred = torch.randn(B, K + 1, N_s, D, requires_grad=True) |
| target = torch.randn(B, K + 1, N_s, D) |
| task = torch.tensor(1.5) |
| loss_dict = loss_fn(pred, target, task) |
| print(f" JEPA loss computes: {loss_dict['jepa_loss'].item():.4f}") |
| print(f" In no_jepa mode, model forward skips this and uses task_loss directly") |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_purist_config(): |
| """Purist branch config values.""" |
| print("\n=== Ablation: purist config ===") |
| from mr_jepa.configs.model_config import get_purist_config |
| c = get_purist_config() |
| assert c.rollout.K == 5, f"K should be 5, got {c.rollout.K}" |
| assert c.jepa.jepa_loss_fn == "cosine", f"Loss should be cosine, got {c.jepa.jepa_loss_fn}" |
| assert c.jepa.use_sigreg == True |
| assert c.jepa.use_vicreg == False |
| assert "base" in c.visual.model_name, f"Should use base model, got {c.visual.model_name}" |
| print(f" K={c.rollout.K}, loss={c.jepa.jepa_loss_fn}, SIGReg={c.jepa.use_sigreg}, backbone={c.visual.model_name}") |
| print(" ✓ passed!") |
|
|
|
|
| def test_ablation_dinov2_config(): |
| """DINOv2 ablation config values.""" |
| print("\n=== Ablation: dinov2 config ===") |
| from mr_jepa.configs.model_config import get_dinov2_ablation_config |
| c = get_dinov2_ablation_config() |
| assert c.visual.backbone_type == "dinov2" |
| assert "dinov2" in c.visual.model_name |
| assert c.visual.image_size == 518 |
| assert c.visual.patch_size == 14 |
| print(f" backbone={c.visual.model_name}, size={c.visual.image_size}, patch={c.visual.patch_size}") |
| print(" ✓ passed!") |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print("MR-JEPA Architecture Validation") |
| print("=" * 60) |
| |
| test_evidence_memory() |
| test_latent_rollout() |
| test_target_encoder_and_jepa_loss() |
| test_answer_heads() |
| test_sigreg_and_vicreg() |
| test_parameter_counting() |
| test_trajectory_metrics() |
| test_evaluation_metrics() |
| test_end_to_end_forward() |
| |
| print("\n" + "=" * 60) |
| print("Ablation Tests") |
| print("=" * 60) |
| test_ablation_no_jepa() |
| test_ablation_no_rollout() |
| test_ablation_no_evidence_gate() |
| test_ablation_k_variants() |
| test_ablation_loss_functions() |
| test_ablation_sigreg_vs_vicreg() |
| test_ablation_purist_config() |
| test_ablation_dinov2_config() |
| |
| print("\n" + "=" * 60) |
| print("ALL TESTS PASSED ✓ (9 core + 8 ablation = 17 total)") |
| print("=" * 60) |
|
|