File size: 2,424 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Image comprehension tests β€” verify the image pipeline on CPU.

Tests: ImageSequencer forward, DINOv2 feature extraction,
      patch_proj β†’ unfold β†’ projection β†’ norm pipeline.
Runs on CPU β€” large model download first time.
"""
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
import torch
from arbitor.kernel.ternary_scale import TScaleType

device = "cpu"
FAILED = 0

def check(name, condition, detail=""):
    global FAILED
    if condition:
        print(f"  βœ“ {name}")
    else:
        print(f"  βœ— {name} β€” {detail}")
        FAILED += 1


print("\n=== Image Comprehension ===\n")
print("Loading ImageSequencer (downloads DINOv2-small on first run)...")

from arbitor import ARBModel, HIDDEN_DIM

# 1. ImageSequencer forward with synthetic image
model = ARBModel(enable_image=True, enable_audio=False,
                 enable_vq=False, enable_graph=False,
                 enable_memory_modules=False, enable_moe=False)
model.eval()
img = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    seq_out = model.image_sequencer(img)
check("ImageSequencer forward", seq_out is not None, "got None")
check("Output shape", seq_out.shape[-1] == HIDDEN_DIM,
      f"last dim={seq_out.shape[-1]}")
check("No NaN in image features", not torch.isnan(seq_out).any())

# 2. Image features are finite and reasonable
check("Image features finite", torch.isfinite(seq_out).all())
check("Image features have variance", seq_out.std().item() > 0.001,
      f"std={seq_out.std().item()}")

# 3. Full model with image input
with torch.no_grad():
    logits, losses, _, _ = model(x=None, images=img,
                                  targets=torch.randint(0, 256, (1, 100)))
check("Model with image forward", logits is not None)
if losses is not None:
    check("Image loss is finite", torch.isfinite(losses.total))

# 4. Modality gate
del model  # free memory
from arbitor import ARBModel as ARBModel2
model2 = ARBModel2(enable_image=True, enable_audio=False,
                   enable_vq=True, enable_graph=True,
                   enable_memory_modules=False, enable_moe=True)
model2.eval()
with torch.no_grad():
    seq_out2 = model2.image_sequencer(img)
check("Image with VQ pipeline", seq_out2 is not None, "got None")

print(f"\n{'='*50}")
if FAILED == 0:
    print("βœ“ All image comprehension tests passed!")
else:
    print(f"βœ— {FAILED} test(s) failed")
sys.exit(FAILED)