File size: 16,430 Bytes
dba2c56
 
 
 
 
 
 
 
 
 
e158637
 
 
83e1328
 
dba2c56
 
83e1328
dba2c56
83e1328
 
dba2c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a4198
 
dba2c56
 
 
67a4198
dba2c56
 
 
67a4198
dba2c56
 
 
 
67a4198
 
dba2c56
67a4198
 
 
 
 
dba2c56
 
 
 
67a4198
 
 
 
 
 
 
 
 
dba2c56
67a4198
dba2c56
67a4198
dba2c56
67a4198
 
 
dba2c56
67a4198
dba2c56
 
67a4198
 
dba2c56
 
 
 
67a4198
 
dba2c56
 
67a4198
 
dba2c56
 
67a4198
dba2c56
67a4198
 
dba2c56
 
 
 
67a4198
dba2c56
67a4198
 
 
 
 
dba2c56
67a4198
 
 
dba2c56
 
 
 
 
67a4198
 
 
 
dba2c56
 
 
 
 
67a4198
dba2c56
 
 
 
 
67a4198
 
 
dba2c56
 
 
 
67a4198
 
 
 
 
 
 
dba2c56
 
 
67a4198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba2c56
67a4198
 
 
 
dba2c56
 
e158637
 
 
 
 
67a4198
 
 
 
 
 
 
 
 
e158637
 
 
67a4198
 
 
 
 
 
e158637
67a4198
 
 
 
 
e158637
 
 
67a4198
 
 
e158637
67a4198
 
 
 
 
e158637
 
 
67a4198
 
 
 
 
 
 
 
 
 
 
 
e158637
 
 
67a4198
 
 
 
 
 
e158637
67a4198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e158637
 
 
67a4198
 
e158637
67a4198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e158637
 
dba2c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e158637
 
 
 
 
 
 
 
 
 
67a4198
e158637
dba2c56
67a4198
dba2c56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
MR-JEPA Architecture Validation Test.

Tests the complete forward pass with synthetic data to verify:
1. All modules instantiate correctly
2. Tensor shapes are consistent throughout
3. JEPA loss computes correctly
4. Target encoder EMA updates work
5. Both MC and open-ended heads produce valid output
6. Ablation controls work (no-JEPA, no-rollout, no-evidence-gate)
7. Loss function variants (smooth_l1, mse, cosine)
8. Anti-collapse regularizations (SIGReg, VICReg)
9. Parameter counting is correct

Run from repo root:  python test_architecture.py
"""

import os
import sys
# Ensure the repo root is on the path (where mr_jepa/ package lives)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import torch
import torch.nn as nn
import numpy as np
from mr_jepa.configs.model_config import (
    MRJEPAConfig, VisualBackboneConfig, TextEncoderConfig,
    EvidenceMemoryConfig, LatentRolloutConfig, JEPAObjectiveConfig,
    AnswerHeadConfig, TrainingPhaseConfig,
)
from mr_jepa.models.evidence_memory import EvidenceMemory
from mr_jepa.models.latent_rollout import LatentRolloutModule
from mr_jepa.models.target_encoder import TargetEncoder, JEPALoss, SIGRegLoss, VICRegLoss
from mr_jepa.models.answer_heads import DiscriminativeHead, GenerativeHead


def test_evidence_memory():
    print("\n=== Test: Evidence Memory ===")
    config = EvidenceMemoryConfig(hidden_dim=256, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4, dropout=0.1)
    visual_dim, text_dim, B, N_v, N_t = 512, 384, 4, 49, 32
    model = EvidenceMemory(config, visual_dim=visual_dim, text_dim=text_dim)
    visual_tokens = torch.randn(B, N_v, visual_dim)
    text_tokens = torch.randn(B, N_t, text_dim)
    text_mask = torch.ones(B, N_t); text_mask[:, -5:] = 0
    output = model(visual_tokens, text_tokens, text_mask)
    evidence = output['evidence_tokens']
    assert evidence.shape == (B, config.num_evidence_tokens, config.hidden_dim)
    print(f"  Evidence shape: {evidence.shape}"); print("  ✓ passed!")


def test_latent_rollout():
    print("\n=== Test: Latent Rollout ===")
    config = LatentRolloutConfig(hidden_dim=256, num_state_tokens=8, K=3, num_predictor_layers=2, num_heads=4, ffn_dim=512, dropout=0.1, use_evidence_gate=True, gate_type="sigmoid", use_step_embedding=True)
    B, N_e = 4, 16
    model = LatentRolloutModule(config)
    output = model(torch.randn(B, N_e, config.hidden_dim))
    assert output['trajectory'].shape == (B, config.K + 1, config.num_state_tokens, config.hidden_dim)
    assert output['z_final'].shape == (B, config.num_state_tokens, config.hidden_dim)
    assert output['z_projected'].shape == output['trajectory'].shape
    print(f"  Trajectory: {output['trajectory'].shape}"); print("  ✓ passed!")


def test_target_encoder_and_jepa_loss():
    print("\n=== Test: Target Encoder + JEPA Loss ===")
    D, N_e, N_s, K, B = 256, 16, 8, 3, 4
    visual_dim, text_dim = 512, 384
    ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4)
    ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
    j_cfg = JEPAObjectiveConfig(ema_momentum_base=0.996, ema_momentum_end=1.0, use_sigreg=True, sigreg_weight=0.1)
    evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim)
    rollout = LatentRolloutModule(ro_cfg)
    target_enc = TargetEncoder(evidence_mem, rollout, j_cfg)
    orig = list(target_enc.target_rollout.parameters())[0].clone()
    with torch.no_grad():
        for p in rollout.parameters(): p.add_(torch.randn_like(p) * 0.1)
    target_enc.update_ema(evidence_mem, rollout, step=100, total_steps=1000)
    assert not torch.allclose(orig, list(target_enc.target_rollout.parameters())[0]), "EMA did not update!"
    print(f"  EMA momentum: {target_enc._current_momentum:.6f}")
    target_output = target_enc(torch.randn(B, 49, visual_dim), torch.randn(B, 32, text_dim), torch.ones(B, 32))
    assert target_output['target_trajectory'].shape == (B, K + 1, N_s, D)
    jepa_loss_fn = JEPALoss(j_cfg, D)
    pred_traj = torch.randn(B, K + 1, N_s, D, requires_grad=True)
    loss_dict = jepa_loss_fn(pred_traj, target_output['target_trajectory'], torch.tensor(1.5))
    loss_dict['total_loss'].backward()
    assert pred_traj.grad is not None, "No gradients!"
    print(f"  Total loss: {loss_dict['total_loss'].item():.4f}, grad norm: {pred_traj.grad.norm().item():.4f}")
    print("  ✓ passed!")


def test_answer_heads():
    print("\n=== Test: Answer Heads ===")
    D, text_dim, B, N_s, max_opts, vocab_size = 256, 384, 4, 8, 4, 1000
    head_config = AnswerHeadConfig(disc_hidden_dim=256, disc_num_layers=2, max_num_options=max_opts, gen_hidden_dim=256, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=32)
    disc_head = DiscriminativeHead(head_config, hidden_dim=D, text_dim=text_dim)
    z_final = torch.randn(B, N_s, D)
    option_mask = torch.tensor([[True,True,True,True],[True,True,True,False],[True,True,False,False],[True,True,True,True]])
    disc_output = disc_head(z_final, torch.randn(B, max_opts, text_dim), option_mask)
    assert disc_output['logits'][2, 2] == float('-inf'), "Masked option should be -inf!"
    gen_head = GenerativeHead(head_config, hidden_dim=D, vocab_size=vocab_size)
    gen_output = gen_head(z_final, torch.randint(0, vocab_size, (B, 16)))
    generated = gen_head.generate(z_final, start_token_id=1, max_length=10)
    print(f"  Disc logits: {disc_output['logits'].shape}, Gen loss: {gen_output['loss'].item():.4f}, Generated: {generated.shape}")
    print("  ✓ passed!")


def test_sigreg_and_vicreg():
    print("\n=== Test: SIGReg + VICReg ===")
    D, B, N = 256, 32, 8
    sigreg = SIGRegLoss(D, num_projections=64)
    z_rand = torch.randn(B, N, D)
    z_coll = torch.ones(B, N, D)
    loss_rand = sigreg(z_rand)
    loss_coll = sigreg(z_coll)
    assert loss_coll > loss_rand, "SIGReg should penalize collapsed representations more!"
    vicreg = VICRegLoss(var_weight=1.0, cov_weight=0.04)
    loss_vic = vicreg(z_rand)
    print(f"  SIGReg random={loss_rand.item():.4f}, collapsed={loss_coll.item():.4f}; VICReg={loss_vic.item():.4f}")
    print("  ✓ passed!")


def test_parameter_counting():
    print("\n=== Test: Parameter Counting ===")
    D = 256
    ev = EvidenceMemory(EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=16, num_cross_attn_layers=2, num_heads=4), visual_dim=512, text_dim=384)
    ro = LatentRolloutModule(LatentRolloutConfig(hidden_dim=D, num_state_tokens=8, K=3, num_predictor_layers=3, num_heads=4, ffn_dim=512))
    print(f"  Evidence: {sum(p.numel() for p in ev.parameters()):,}, Rollout: {sum(p.numel() for p in ro.parameters()):,}")
    print("  ✓ passed!")


def test_trajectory_metrics():
    print("\n=== Test: Trajectory Metrics ===")
    from mr_jepa.utils.visualization import compute_trajectory_metrics, visualize_trajectory
    B, K, N_s, D = 4, 3, 8, 256
    trajectory = torch.randn(B, K + 1, N_s, D)
    for k in range(1, K + 1):
        trajectory[:, k] = trajectory[:, k-1] + torch.randn(B, N_s, D) * (0.5 ** k)
    metrics = compute_trajectory_metrics(trajectory)
    viz = visualize_trajectory(trajectory[0], method='pca')
    assert metrics['convergence_rate'] < 1.0
    print(f"  Convergence rate: {metrics['convergence_rate']:.4f}")
    print("  ✓ passed!")


def test_evaluation_metrics():
    print("\n=== Test: Evaluation Metrics ===")
    from mr_jepa.evaluation.metrics import compute_accuracy, compute_anls, compute_vqa_accuracy, compute_relaxed_accuracy
    assert compute_accuracy([0,1,2,0], [0,1,1,0])['accuracy'] == 75.0
    compute_anls(["hello world", "test"], [["hello world"], ["testing"]])
    compute_vqa_accuracy(["cat"], [["cat"]*10])
    compute_relaxed_accuracy(["100","hello"], ["100","hello"], types=["human_test","human_test"])
    print("  All metrics compute correctly")
    print("  ✓ passed!")


def test_end_to_end_forward():
    print("\n=== Test: End-to-End Forward Pass ===")
    D, B, N_v, N_t, N_e, N_s, K = 256, 2, 49, 32, 16, 8, 3
    max_opts, vocab_size, visual_dim, text_dim = 4, 100, 512, 384
    ev_cfg = EvidenceMemoryConfig(hidden_dim=D, num_evidence_tokens=N_e, num_cross_attn_layers=2, num_heads=4)
    ro_cfg = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
    j_cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
    h_cfg = AnswerHeadConfig(disc_hidden_dim=D, gen_hidden_dim=D, gen_num_layers=2, gen_num_heads=4, gen_vocab_size=vocab_size, gen_max_answer_length=16)
    evidence_mem = EvidenceMemory(ev_cfg, visual_dim, text_dim)
    rollout = LatentRolloutModule(ro_cfg)
    target_enc = TargetEncoder(evidence_mem, rollout, j_cfg)
    disc_head = DiscriminativeHead(h_cfg, D, text_dim)
    gen_head = GenerativeHead(h_cfg, D, vocab_size)
    jepa_loss_fn = JEPALoss(j_cfg, D)
    vis = torch.randn(B, N_v, visual_dim); txt = torch.randn(B, N_t, text_dim); mask = torch.ones(B, N_t)
    evidence = evidence_mem(vis, txt, mask)['evidence_tokens']
    rollout_out = rollout(evidence)
    target_out = target_enc(vis, txt, mask)
    disc_out = disc_head(rollout_out['z_final'], torch.randn(B, max_opts, text_dim), torch.ones(B, max_opts, dtype=torch.bool))
    task_loss = nn.functional.cross_entropy(disc_out['logits'], torch.tensor([1, 3]))
    gen_out = gen_head(rollout_out['z_final'], torch.randint(0, vocab_size, (B, 16)), evidence)
    loss_dict = jepa_loss_fn(rollout_out['z_projected'], target_out['target_trajectory'], task_loss, gen_out['loss'])
    loss_dict['total_loss'].backward()
    target_enc.update_ema(evidence_mem, rollout, step=1, total_steps=100)
    ev_grads = sum(1 for p in evidence_mem.parameters() if p.grad is not None)
    ro_grads = sum(1 for p in rollout.parameters() if p.grad is not None)
    print(f"  Total loss: {loss_dict['total_loss'].item():.4f}, EV grads: {ev_grads}, RO grads: {ro_grads}")
    print("  ✓ passed!")


# ──────────────────────────────────────────────────────────
# ABLATION TESTS
# ──────────────────────────────────────────────────────────

def test_ablation_no_rollout():
    """K=0 produces only z0."""
    print("\n=== Ablation: --no_rollout (K=0) ===")
    D, B, N_e, N_s = 256, 2, 16, 8
    config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=0, num_predictor_layers=2, num_heads=4, ffn_dim=512)
    rollout = LatentRolloutModule(config)
    output = rollout(torch.randn(B, N_e, D))
    assert output['trajectory'].shape[1] == 1, f"Expected 1, got {output['trajectory'].shape[1]}"
    print(f"  Trajectory: {output['trajectory'].shape} (K=0 → 1 step)")
    print("  ✓ passed!")


def test_ablation_no_evidence_gate():
    """Disabling gate passes evidence through unchanged."""
    print("\n=== Ablation: --no_evidence_gate ===")
    D, B, N_e, N_s, K = 256, 2, 16, 8, 3
    config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512, use_evidence_gate=False)
    rollout = LatentRolloutModule(config)
    # Verify gate_type is "none" for all layers (identity pass-through)
    for i, layer in enumerate(rollout.predictor_layers):
        assert layer.evidence_gate.gate_type == "none", f"Layer {i}: expected gate_type='none', got '{layer.evidence_gate.gate_type}'"
    output = rollout(torch.randn(B, N_e, D))
    assert output['trajectory'].shape == (B, K + 1, N_s, D)
    print(f"  All {len(rollout.predictor_layers)} layers have gate_type='none'")
    print("  ✓ passed!")


def test_ablation_k_variants():
    """Different rollout depths."""
    print("\n=== Ablation: K variants (1, 5, 7) ===")
    D, B, N_e, N_s = 256, 2, 16, 8
    for K in [1, 5, 7]:
        config = LatentRolloutConfig(hidden_dim=D, num_state_tokens=N_s, K=K, num_predictor_layers=2, num_heads=4, ffn_dim=512)
        output = LatentRolloutModule(config)(torch.randn(B, N_e, D))
        assert output['trajectory'].shape[1] == K + 1
        print(f"  K={K}: trajectory len={output['trajectory'].shape[1]} ✓")
    print("  ✓ passed!")


def test_ablation_loss_functions():
    """smooth_l1, mse, cosine losses all compute."""
    print("\n=== Ablation: loss_fn variants ===")
    D, K, B, N_s = 256, 3, 2, 8
    pred = torch.randn(B, K + 1, N_s, D)
    target = torch.randn(B, K + 1, N_s, D)
    task = torch.tensor(1.0)
    for fn in ["smooth_l1", "mse", "cosine"]:
        cfg = JEPAObjectiveConfig(jepa_loss_fn=fn, use_sigreg=False)
        loss = JEPALoss(cfg, D)(pred, target, task)
        print(f"  {fn}: jepa={loss['jepa_loss'].item():.4f}, total={loss['total_loss'].item():.4f}")
        assert loss['total_loss'].item() > 0
    print("  ✓ passed!")


def test_ablation_sigreg_vs_vicreg():
    """SIGReg, VICReg, and both produce non-zero reg."""
    print("\n=== Ablation: SIGReg vs VICReg ===")
    D, K, B, N_s = 256, 3, 2, 8
    pred = torch.randn(B, K + 1, N_s, D)
    target = torch.randn(B, K + 1, N_s, D)
    task = torch.tensor(1.0)
    
    for label, sigreg, vicreg in [("SIGReg", True, False), ("VICReg", False, True), ("Both", True, True)]:
        cfg = JEPAObjectiveConfig(use_sigreg=sigreg, sigreg_weight=0.1, use_vicreg=vicreg, vicreg_var_weight=1.0, vicreg_cov_weight=0.04)
        loss = JEPALoss(cfg, D)(pred, target, task)
        print(f"  {label}: reg={loss['reg_loss'].item():.4f}")
        assert loss['reg_loss'].item() > 0, f"{label} reg should be > 0"
    print("  ✓ passed!")


def test_ablation_no_jepa():
    """no_jepa: model forward should skip JEPA entirely."""
    print("\n=== Ablation: --no_jepa ===")
    D, K, B, N_s = 256, 3, 2, 8
    cfg = JEPAObjectiveConfig(use_sigreg=True, sigreg_weight=0.1)
    loss_fn = JEPALoss(cfg, D)
    pred = torch.randn(B, K + 1, N_s, D, requires_grad=True)
    target = torch.randn(B, K + 1, N_s, D)
    task = torch.tensor(1.5)
    loss_dict = loss_fn(pred, target, task)
    print(f"  JEPA loss computes: {loss_dict['jepa_loss'].item():.4f}")
    print(f"  In no_jepa mode, model forward skips this and uses task_loss directly")
    print("  ✓ passed!")


def test_ablation_purist_config():
    """Purist branch config values."""
    print("\n=== Ablation: purist config ===")
    from mr_jepa.configs.model_config import get_purist_config
    c = get_purist_config()
    assert c.rollout.K == 5, f"K should be 5, got {c.rollout.K}"
    assert c.jepa.jepa_loss_fn == "cosine", f"Loss should be cosine, got {c.jepa.jepa_loss_fn}"
    assert c.jepa.use_sigreg == True
    assert c.jepa.use_vicreg == False
    assert "base" in c.visual.model_name, f"Should use base model, got {c.visual.model_name}"
    print(f"  K={c.rollout.K}, loss={c.jepa.jepa_loss_fn}, SIGReg={c.jepa.use_sigreg}, backbone={c.visual.model_name}")
    print("  ✓ passed!")


def test_ablation_dinov2_config():
    """DINOv2 ablation config values."""
    print("\n=== Ablation: dinov2 config ===")
    from mr_jepa.configs.model_config import get_dinov2_ablation_config
    c = get_dinov2_ablation_config()
    assert c.visual.backbone_type == "dinov2"
    assert "dinov2" in c.visual.model_name
    assert c.visual.image_size == 518
    assert c.visual.patch_size == 14
    print(f"  backbone={c.visual.model_name}, size={c.visual.image_size}, patch={c.visual.patch_size}")
    print("  ✓ passed!")


if __name__ == "__main__":
    print("=" * 60)
    print("MR-JEPA Architecture Validation")
    print("=" * 60)
    
    test_evidence_memory()
    test_latent_rollout()
    test_target_encoder_and_jepa_loss()
    test_answer_heads()
    test_sigreg_and_vicreg()
    test_parameter_counting()
    test_trajectory_metrics()
    test_evaluation_metrics()
    test_end_to_end_forward()
    
    print("\n" + "=" * 60)
    print("Ablation Tests")
    print("=" * 60)
    test_ablation_no_jepa()
    test_ablation_no_rollout()
    test_ablation_no_evidence_gate()
    test_ablation_k_variants()
    test_ablation_loss_functions()
    test_ablation_sigreg_vs_vicreg()
    test_ablation_purist_config()
    test_ablation_dinov2_config()
    
    print("\n" + "=" * 60)
    print("ALL TESTS PASSED ✓ (9 core + 8 ablation = 17 total)")
    print("=" * 60)