File size: 1,051 Bytes
0ca20d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import unittest
import torch
from torch.utils.data import DataLoader, TensorDataset
from model import OmniCoreXModel
from trainer import Trainer

class TrainerTest(unittest.TestCase):
    def test_basic_training_loop(self):
        # Minimal random dataset
        inputs = torch.randn(10, 5, 128)
        labels = torch.randint(0, 128, (10, 5))
        dataset = TensorDataset(inputs, labels)
        loader = DataLoader(dataset, batch_size=2)

        stream_configs = {"dummy": 128}
        model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=1, num_heads=4)
        device = torch.device("cpu")
        model.to(device)

        trainer = Trainer(
            model=model,
            train_loader=loader,
            valid_loader=None,
            save_dir="./",
            lr=1e-3,
            total_steps=10,
            warmup_steps=2,
            mixed_precision=False
        )
        trainer.fit(epochs=1)

        self.assertTrue(True)  # If training completes without errors

if __name__ == "__main__":
    unittest.main()