import unittest import torch from torch.utils.data import DataLoader, TensorDataset from model import OmniCoreXModel from trainer import Trainer class TrainerTest(unittest.TestCase): def test_basic_training_loop(self): # Minimal random dataset inputs = torch.randn(10, 5, 128) labels = torch.randint(0, 128, (10, 5)) dataset = TensorDataset(inputs, labels) loader = DataLoader(dataset, batch_size=2) stream_configs = {"dummy": 128} model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=1, num_heads=4) device = torch.device("cpu") model.to(device) trainer = Trainer( model=model, train_loader=loader, valid_loader=None, save_dir="./", lr=1e-3, total_steps=10, warmup_steps=2, mixed_precision=False ) trainer.fit(epochs=1) self.assertTrue(True) # If training completes without errors if __name__ == "__main__": unittest.main()