Kosasih commited on
Commit
0ca20d6
·
verified ·
1 Parent(s): 998f95a

Create tests/test_trainer.py

Browse files
Files changed (1) hide show
  1. tests/test_trainer.py +35 -0
tests/test_trainer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ from model import OmniCoreXModel
5
+ from trainer import Trainer
6
+
7
+ class TrainerTest(unittest.TestCase):
8
+ def test_basic_training_loop(self):
9
+ # Minimal random dataset
10
+ inputs = torch.randn(10, 5, 128)
11
+ labels = torch.randint(0, 128, (10, 5))
12
+ dataset = TensorDataset(inputs, labels)
13
+ loader = DataLoader(dataset, batch_size=2)
14
+
15
+ stream_configs = {"dummy": 128}
16
+ model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=1, num_heads=4)
17
+ device = torch.device("cpu")
18
+ model.to(device)
19
+
20
+ trainer = Trainer(
21
+ model=model,
22
+ train_loader=loader,
23
+ valid_loader=None,
24
+ save_dir="./",
25
+ lr=1e-3,
26
+ total_steps=10,
27
+ warmup_steps=2,
28
+ mixed_precision=False
29
+ )
30
+ trainer.fit(epochs=1)
31
+
32
+ self.assertTrue(True) # If training completes without errors
33
+
34
+ if __name__ == "__main__":
35
+ unittest.main()