| | |
| | """ |
| | Debug script to test grid mismatches in Quillan model |
| | """ |
| |
|
| | import torch |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(__file__)) |
| |
|
| | from __init__ import QuillanSOTA, Config |
| |
|
| | def test_grid_matches(): |
| | """Test the model forward pass to identify grid mismatches""" |
| | print("🧪 Testing Quillan model grid matches...") |
| |
|
| | |
| | config = Config() |
| | model = QuillanSOTA(config) |
| | model.eval() |
| |
|
| | |
| | batch_size = 1 |
| | text = torch.randint(0, config.vocab_size, (batch_size, 128)) |
| | img = torch.randn(batch_size, 3, 256, 256) |
| | aud = torch.randn(batch_size, 1, 2048) |
| | vid = torch.randn(batch_size, 3, 8, 32, 32) |
| |
|
| | print(f"Input shapes:") |
| | print(f" Text: {text.shape}") |
| | print(f" Image: {img.shape}") |
| | print(f" Audio: {aud.shape}") |
| | print(f" Video: {vid.shape}") |
| |
|
| | try: |
| | with torch.no_grad(): |
| | outputs = model(text, img, aud, vid) |
| |
|
| | print("✅ Forward pass successful!") |
| | print(f"Output shapes:") |
| | for key, value in outputs.items(): |
| | if isinstance(value, torch.Tensor): |
| | print(f" {key}: {value.shape}") |
| | else: |
| | print(f" {key}: {type(value)}") |
| |
|
| | except Exception as e: |
| | print(f"❌ Grid mismatch error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| |
|
| | return True |
| |
|
| | if __name__ == "__main__": |
| | test_grid_matches() |
| |
|