import unittest import torch from model import OmniCoreXModel class ModelTest(unittest.TestCase): def test_forward_output_shape(self): stream_configs = {"text": 128, "image": 256, "sensor": 64} batch_size = 2 seq_len = 10 model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=2, num_heads=4) inputs = { name: torch.randn(batch_size, seq_len, dim) for name, dim in stream_configs.items() } output = model(inputs) self.assertEqual(output.shape, (batch_size, seq_len, 128)) if __name__ == "__main__": unittest.main()