MR-JEPA / test_architecture.py
JorgeAV's picture
fix: test_architecture.py — use os.path.dirname(__file__) instead of hardcoded /app for sys.path
83e1328 verified
"""
MR-JEPA Architecture Validation Test.
Tests the complete forward pass with synthetic data to verify:
1. All modules instantiate correctly
2. Tensor shapes are consistent throughout
3. JEPA loss computes correctly
4. Target encoder EMA updates work
5. Both MC and open-ended heads produce valid output
6. Ablation controls work (no-JEPA, no-rollout, no-evidence-gate)
7. Loss function variants (smooth_l1, mse, cosine)
8. Anti-collapse regularizations (SIGReg, VICReg)
9. Parameter counting is correct
Run from repo root: python test_architecture.py
"""
import os
import sys
# Ensure the repo root is on the path (where mr_jepa/ package lives)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
import numpy as np
from mr_jepa.configs.model_config import (
MRJEPAConfig, VisualBackboneConfig, TextEncoderConfig,
EvidenceMemoryConfig, LatentRolloutConfig, JEPAObjectiveConfig,
AnswerHeadConfig, TrainingPhaseConfig,
)
from mr_jepa.models.evidence_memory import EvidenceMemory
from mr_jepa.models.latent_rollout import LatentRolloutModule
from mr_jepa.models.target_encoder import TargetEncoder, JEPALoss, SIGRegLoss, VICRegLoss
from mr_jepa.models.answer_heads import DiscriminativeHead, GenerativeHead
def test_evidence_memory():
print("\n=== Test: Evidence Memory ===")
config = EvidenceMemoryConfig(hidden_dim=256, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4, dropout=0.1)
visual_dim, text_dim, B, N_v, N_t = 512, 384, 4, 49, 32
model = EvidenceMemory(config, visual_dim=visual_dim, text_dim=text_dim)
visual_tokens = torch.randn(B, N_v, visual_dim)
text_tokens = torch.randn(B, N_t, text_dim)
text_mask = torch.ones(B, N_t); text_mask[:, -5:] = 0
output = model(visual_tokens, text_tokens, text_mask)
evidence = output['evidence_tokens']
assert evidence.shape == (B, config.num_evidence_tokens, config.hidden_dim)
print(f" Evidence shape: {evidence.shape}"); print(" ✓ passed!")
def test_latent_rollout():
print("\n=== Test: Latent Rollout ===")
config = LatentRolloutConfig(hidden_dim=256, num_state_tokens=8, K=3, num_predictor_layers=2, num_heads=4, ffn_dim=512, dropout=0.1, use_evidence_gate=True, gate_type="sigmoid", use_step_embedding=True)
B, N_e = 4, 16
model = LatentRolloutModule(config)
output = model(torch.randn(B, N_e, config.hidden_dim))
assert output['trajectory'].shape == (B, config.K + 1, config.num_state_tokens, config.hidden_dim)
assert output['z_final'].shape == (B, config.num_state_tokens, config.hidden_dim)
assert output['z_projected'].shape == output['trajectory'].shape
print(f" Trajectory: {output['trajectory'].shape}"); print(" ✓ passed!")
def test_target_encoder_and_jepa_loss():
print("\n=== Test: Target Encoder + JEPA Loss ===")
D, N_e, N_s, K, B = 256, 16, 8, 3, 4
visual_dim, text_dim = 512, 384
ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4)
ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
j_cfg = JEPAObjectiveConfig(ema_momentum_base=0.996, ema_momentum_end=1.0, use_sigreg=True, sigreg_weight=0.1)
evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim)
rollout = LatentRolloutModule(ro_cfg)
target_enc = TargetEncoder(evidence_mem, rollout, j_cfg)
orig = list(target_enc.target_rollout.parameters())[0].clone()
with torch.no_grad():
for p in rollout.parameters(): p.add_(torch.randn_like(p) * 0.1)
target_enc.update_ema(evidence_mem, rollout, step=100, total_steps=1000)
assert not torch.allclose(orig, list(target_enc.target_rollout.parameters())[0]), "EMA did not update!"
print(f" EMA momentum: {target_enc._current_momentum:.6f}")
target_output = target_enc(torch.randn(B, 49, visual_dim), torch.randn(B, 32, text_dim), torch.ones(B, 32))
assert target_output['target_trajectory'].shape == (B, K + 1, N_s, D)
jepa_loss_fn = JEPALoss(j_cfg, D)
pred_traj = torch.randn(B, K + 1, N_s, D, requires_grad=True)
loss_dict = jepa_loss_fn(pred_traj, target_output['target_trajectory'], torch.tensor(1.5))
loss_dict['total_loss'].backward()
assert pred_traj.grad is not None, "No gradients!"
print(f" Total loss: {loss_dict['total_loss'].item():.4f}, grad norm: {pred_traj.grad.norm().item():.4f}")
print(" ✓ passed!")
def test_answer_heads():
print("\n=== Test: Answer Heads ===")
D, text_dim, B, N_s, max_opts, vocab_size = 256, 384, 4, 8, 4, 1000
head_config = AnswerHeadConfig(disc_hidden_dim=256, disc_num_layers=2, max_num_options=max_opts, gen_hidden_dim=256, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=32)
disc_head = DiscriminativeHead(head_config, hidden_dim=D, text_dim=text_dim)
z_final = torch.randn(B, N_s, D)
option_mask = torch.tensor([[True,True,True,True],[True,True,True,False],[True,True,False,False],[True,True,True,True]])
disc_output = disc_head(z_final, torch.randn(B, max_opts, text_dim), option_mask)
assert disc_output['logits'][2, 2] == float('-inf'), "Masked option should be -inf!"
gen_head = GenerativeHead(head_config, hidden_dim=D, vocab_size=vocab_size)
gen_output = gen_head(z_final, torch.randint(0, vocab_size, (B, 16)))
generated = gen_head.generate(z_final, start_token_id=1, max_length=10)
print(f" Disc logits: {disc_output['logits'].shape}, Gen loss: {gen_output['loss'].item():.4f}, Generated: {generated.shape}")
print(" ✓ passed!")
def test_sigreg_and_vicreg():
print("\n=== Test: SIGReg + VICReg ===")
D, B, N = 256, 32, 8
sigreg = SIGRegLoss(D, num_projections=64)
z_rand = torch.randn(B, N, D)
z_coll = torch.ones(B, N, D)
loss_rand = sigreg(z_rand)
loss_coll = sigreg(z_coll)
assert loss_coll > loss_rand, "SIGReg should penalize collapsed representations more!"
vicreg = VICRegLoss(var_weight=1.0, cov_weight=0.04)
loss_vic = vicreg(z_rand)
print(f" SIGReg random={loss_rand.item():.4f}, collapsed={loss_coll.item():.4f}; VICReg={loss_vic.item():.4f}")
print(" ✓ passed!")
def test_parameter_counting():
print("\n=== Test: Parameter Counting ===")
D = 256
ev = EvidenceMemory(EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4), visual_dim=512, text_dim=384)
ro = LatentRolloutModule(LatentRolloutConfig(hidden_dim=D, num_state_tokens=8, K=3, num_predictor_layers=3, num_heads=4, ffn_dim=512))
print(f" Evidence: {sum(p.numel() for p in ev.parameters()):,}, Rollout: {sum(p.numel() for p in ro.parameters()):,}")
print(" ✓ passed!")
def test_trajectory_metrics():
print("\n=== Test: Trajectory Metrics ===")
from mr_jepa.utils.visualization import compute_trajectory_metrics, visualize_trajectory
B, K, N_s, D = 4, 3, 8, 256
trajectory = torch.randn(B, K + 1, N_s, D)
for k in range(1, K + 1):
trajectory[:, k] = trajectory[:, k-1] + torch.randn(B, N_s, D) * (0.5 ** k)
metrics = compute_trajectory_metrics(trajectory)
viz = visualize_trajectory(trajectory[0], method='pca')
assert metrics['convergence_rate'] < 1.0
print(f" Convergence rate: {metrics['convergence_rate']:.4f}")
print(" ✓ passed!")
def test_evaluation_metrics():
print("\n=== Test: Evaluation Metrics ===")
from mr_jepa.evaluation.metrics import compute_accuracy, compute_anls, compute_vqa_accuracy, compute_relaxed_accuracy
assert compute_accuracy([0,1,2,0], [0,1,1,0])['accuracy'] == 75.0
compute_anls(["hello world", "test"], [["hello world"], ["testing"]])
compute_vqa_accuracy(["cat"], [["cat"]*10])
compute_relaxed_accuracy(["100","hello"], ["100","hello"], types=["human_test","human_test"])
print(" All metrics compute correctly")
print(" ✓ passed!")
def test_end_to_end_forward():
print("\n=== Test: End-to-End Forward Pass ===")
D, B, N_v, N_t, N_e, N_s, K = 256, 2, 49, 32, 16, 8, 3
max_opts, vocab_size, visual_dim, text_dim = 4, 100, 512, 384
ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4)
ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
j_cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
h_cfg = AnswerHeadConfig(disc_hidden_dim=D, gen_hidden_dim=D, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=16)
evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim)
rollout = LatentRolloutModule(ro_cfg)
target_enc = TargetEncoder(evidence_mem, rollout, j_cfg)
disc_head = DiscriminativeHead(h_cfg, D, text_dim)
gen_head = GenerativeHead(h_cfg, D, vocab_size)
jepa_loss_fn = JEPALoss(j_cfg, D)
vis = torch.randn(B, N_v, visual_dim); txt = torch.randn(B, N_t, text_dim); mask = torch.ones(B, N_t)
evidence = evidence_mem(vis, txt, mask)['evidence_tokens']
rollout_out = rollout(evidence)
target_out = target_enc(vis, txt, mask)
disc_out = disc_head(rollout_out['z_final'], torch.randn(B, max_opts, text_dim), torch.ones(B, max_opts, dtype=torch.bool))
task_loss = nn.functional.cross_entropy(disc_out['logits'], torch.tensor([1, 3]))
gen_out = gen_head(rollout_out['z_final'], torch.randint(0, vocab_size, (B, 16)), evidence)
loss_dict = jepa_loss_fn(rollout_out['z_projected'], target_out['target_trajectory'], task_loss, gen_out['loss'])
loss_dict['total_loss'].backward()
target_enc.update_ema(evidence_mem, rollout, step=1, total_steps=100)
ev_grads = sum(1 for p in evidence_mem.parameters() if p.grad is not None)
ro_grads = sum(1 for p in rollout.parameters() if p.grad is not None)
print(f" Total loss: {loss_dict['total_loss'].item():.4f}, EV grads: {ev_grads}, RO grads: {ro_grads}")
print(" ✓ passed!")
# ──────────────────────────────────────────────────────────
# ABLATION TESTS
# ──────────────────────────────────────────────────────────
def test_ablation_no_rollout():
"""K=0 produces only z0."""
print("\n=== Ablation: --no_rollout (K=0) ===")
D, B, N_e, N_s = 256, 2, 16, 8
config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=0, num_predictor_layers=2, num_heads=4, ffn_dim=512)
rollout = LatentRolloutModule(config)
output = rollout(torch.randn(B, N_e, D))
assert output['trajectory'].shape[1] == 1, f"Expected 1, got {output['trajectory'].shape[1]}"
print(f" Trajectory: {output['trajectory'].shape} (K=0 → 1 step)")
print(" ✓ passed!")
def test_ablation_no_evidence_gate():
"""Disabling gate passes evidence through unchanged."""
print("\n=== Ablation: --no_evidence_gate ===")
D, B, N_e, N_s, K = 256, 2, 16, 8, 3
config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512, use_evidence_gate=False)
rollout = LatentRolloutModule(config)
# Verify gate_type is "none" for all layers (identity pass-through)
for i, layer in enumerate(rollout.predictor_layers):
assert layer.evidence_gate.gate_type == "none", f"Layer {i}: expected gate_type='none', got '{layer.evidence_gate.gate_type}'"
output = rollout(torch.randn(B, N_e, D))
assert output['trajectory'].shape == (B, K + 1, N_s, D)
print(f" All {len(rollout.predictor_layers)} layers have gate_type='none'")
print(" ✓ passed!")
def test_ablation_k_variants():
"""Different rollout depths."""
print("\n=== Ablation: K variants (1, 5, 7) ===")
D, B, N_e, N_s = 256, 2, 16, 8
for K in [1, 5, 7]:
config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
output = LatentRolloutModule(config)(torch.randn(B, N_e, D))
assert output['trajectory'].shape[1] == K + 1
print(f" K={K}: trajectory len={output['trajectory'].shape[1]} ✓")
print(" ✓ passed!")
def test_ablation_loss_functions():
"""smooth_l1, mse, cosine losses all compute."""
print("\n=== Ablation: loss_fn variants ===")
D, K, B, N_s = 256, 3, 2, 8
pred = torch.randn(B, K + 1, N_s, D)
target = torch.randn(B, K + 1, N_s, D)
task = torch.tensor(1.0)
for fn in ["smooth_l1", "mse", "cosine"]:
cfg = JEPAObjectiveConfig(jepa_loss_fn=fn, use_sigreg=False)
loss = JEPALoss(cfg, D)(pred, target, task)
print(f" {fn}: jepa={loss['jepa_loss'].item():.4f}, total={loss['total_loss'].item():.4f}")
assert loss['total_loss'].item() > 0
print(" ✓ passed!")
def test_ablation_sigreg_vs_vicreg():
"""SIGReg, VICReg, and both produce non-zero reg."""
print("\n=== Ablation: SIGReg vs VICReg ===")
D, K, B, N_s = 256, 3, 2, 8
pred = torch.randn(B, K + 1, N_s, D)
target = torch.randn(B, K + 1, N_s, D)
task = torch.tensor(1.0)
for label, sigreg, vicreg in [("SIGReg", True, False), ("VICReg", False, True), ("Both", True, True)]:
cfg = JEPAObjectiveConfig(use_sigreg=sigreg, sigreg_weight=0.1, use_vicreg=vicreg, vicreg_var_weight=1.0, vicreg_cov_weight=0.04)
loss = JEPALoss(cfg, D)(pred, target, task)
print(f" {label}: reg={loss['reg_loss'].item():.4f}")
assert loss['reg_loss'].item() > 0, f"{label} reg should be > 0"
print(" ✓ passed!")
def test_ablation_no_jepa():
"""no_jepa: model forward should skip JEPA entirely."""
print("\n=== Ablation: --no_jepa ===")
D, K, B, N_s = 256, 3, 2, 8
cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
loss_fn = JEPALoss(cfg, D)
pred = torch.randn(B, K + 1, N_s, D, requires_grad=True)
target = torch.randn(B, K + 1, N_s, D)
task = torch.tensor(1.5)
loss_dict = loss_fn(pred, target, task)
print(f" JEPA loss computes: {loss_dict['jepa_loss'].item():.4f}")
print(f" In no_jepa mode, model forward skips this and uses task_loss directly")
print(" ✓ passed!")
def test_ablation_purist_config():
"""Purist branch config values."""
print("\n=== Ablation: purist config ===")
from mr_jepa.configs.model_config import get_purist_config
c = get_purist_config()
assert c.rollout.K == 5, f"K should be 5, got {c.rollout.K}"
assert c.jepa.jepa_loss_fn == "cosine", f"Loss should be cosine, got {c.jepa.jepa_loss_fn}"
assert c.jepa.use_sigreg == True
assert c.jepa.use_vicreg == False
assert "base" in c.visual.model_name, f"Should use base model, got {c.visual.model_name}"
print(f" K={c.rollout.K}, loss={c.jepa.jepa_loss_fn}, SIGReg={c.jepa.use_sigreg}, backbone={c.visual.model_name}")
print(" ✓ passed!")
def test_ablation_dinov2_config():
"""DINOv2 ablation config values."""
print("\n=== Ablation: dinov2 config ===")
from mr_jepa.configs.model_config import get_dinov2_ablation_config
c = get_dinov2_ablation_config()
assert c.visual.backbone_type == "dinov2"
assert "dinov2" in c.visual.model_name
assert c.visual.image_size == 518
assert c.visual.patch_size == 14
print(f" backbone={c.visual.model_name}, size={c.visual.image_size}, patch={c.visual.patch_size}")
print(" ✓ passed!")
if __name__ == "__main__":
print("=" * 60)
print("MR-JEPA Architecture Validation")
print("=" * 60)
test_evidence_memory()
test_latent_rollout()
test_target_encoder_and_jepa_loss()
test_answer_heads()
test_sigreg_and_vicreg()
test_parameter_counting()
test_trajectory_metrics()
test_evaluation_metrics()
test_end_to_end_forward()
print("\n" + "=" * 60)
print("Ablation Tests")
print("=" * 60)
test_ablation_no_jepa()
test_ablation_no_rollout()
test_ablation_no_evidence_gate()
test_ablation_k_variants()
test_ablation_loss_functions()
test_ablation_sigreg_vs_vicreg()
test_ablation_purist_config()
test_ablation_dinov2_config()
print("\n" + "=" * 60)
print("ALL TESTS PASSED ✓ (9 core + 8 ablation = 17 total)")
print("=" * 60)