File size: 624 Bytes
80497bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import unittest
import torch
from model import OmniCoreXModel

class ModelTest(unittest.TestCase):
    def test_forward_output_shape(self):
        stream_configs = {"text": 128, "image": 256, "sensor": 64}
        batch_size = 2
        seq_len = 10
        model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=2, num_heads=4)

        inputs = {
            name: torch.randn(batch_size, seq_len, dim)
            for name, dim in stream_configs.items()
        }
        output = model(inputs)
        self.assertEqual(output.shape, (batch_size, seq_len, 128))

if __name__ == "__main__":
    unittest.main()