OmniCoreX / tests /test_model.py
Kosasih's picture
Rename test/test_model.py to tests/test_model.py
a88f4f9 verified
raw
history blame contribute delete
624 Bytes
import unittest
import torch
from model import OmniCoreXModel
class ModelTest(unittest.TestCase):
def test_forward_output_shape(self):
stream_configs = {"text": 128, "image": 256, "sensor": 64}
batch_size = 2
seq_len = 10
model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=2, num_heads=4)
inputs = {
name: torch.randn(batch_size, seq_len, dim)
for name, dim in stream_configs.items()
}
output = model(inputs)
self.assertEqual(output.shape, (batch_size, seq_len, 128))
if __name__ == "__main__":
unittest.main()