File size: 624 Bytes
80497bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import unittest
import torch
from model import OmniCoreXModel
class ModelTest(unittest.TestCase):
def test_forward_output_shape(self):
stream_configs = {"text": 128, "image": 256, "sensor": 64}
batch_size = 2
seq_len = 10
model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=2, num_heads=4)
inputs = {
name: torch.randn(batch_size, seq_len, dim)
for name, dim in stream_configs.items()
}
output = model(inputs)
self.assertEqual(output.shape, (batch_size, seq_len, 128))
if __name__ == "__main__":
unittest.main()
|