File size: 1,616 Bytes
1c70d34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
"""
Debug script to test grid mismatches in Quillan model
"""

import torch
import os
import sys

# Add the model directory to path
sys.path.insert(0, os.path.dirname(__file__))

from __init__ import QuillanSOTA, Config

def test_grid_matches():
    """Test the model forward pass to identify grid mismatches"""
    print("🧪 Testing Quillan model grid matches...")

    # Configuration
    config = Config()
    model = QuillanSOTA(config)
    model.eval()

    # Create test inputs matching our data loader shapes
    batch_size = 1
    text = torch.randint(0, config.vocab_size, (batch_size, 128))  # 128 sequence length
    img = torch.randn(batch_size, 3, 256, 256)  # 256x256 images
    aud = torch.randn(batch_size, 1, 2048)  # 2048 audio length
    vid = torch.randn(batch_size, 3, 8, 32, 32)  # 8 frames, 32x32 video

    print(f"Input shapes:")
    print(f"  Text: {text.shape}")
    print(f"  Image: {img.shape}")
    print(f"  Audio: {aud.shape}")
    print(f"  Video: {vid.shape}")

    try:
        with torch.no_grad():
            outputs = model(text, img, aud, vid)

        print("✅ Forward pass successful!")
        print(f"Output shapes:")
        for key, value in outputs.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: {value.shape}")
            else:
                print(f"  {key}: {type(value)}")

    except Exception as e:
        print(f"❌ Grid mismatch error: {e}")
        import traceback
        traceback.print_exc()
        return False

    return True

if __name__ == "__main__":
    test_grid_matches()