OmniCoreX / tests /test_trainer.py
Kosasih's picture
Create tests/test_trainer.py
0ca20d6 verified
import unittest
import torch
from torch.utils.data import DataLoader, TensorDataset
from model import OmniCoreXModel
from trainer import Trainer
class TrainerTest(unittest.TestCase):
def test_basic_training_loop(self):
# Minimal random dataset
inputs = torch.randn(10, 5, 128)
labels = torch.randint(0, 128, (10, 5))
dataset = TensorDataset(inputs, labels)
loader = DataLoader(dataset, batch_size=2)
stream_configs = {"dummy": 128}
model = OmniCoreXModel(stream_configs, embed_dim=128, num_layers=1, num_heads=4)
device = torch.device("cpu")
model.to(device)
trainer = Trainer(
model=model,
train_loader=loader,
valid_loader=None,
save_dir="./",
lr=1e-3,
total_steps=10,
warmup_steps=2,
mixed_precision=False
)
trainer.fit(epochs=1)
self.assertTrue(True) # If training completes without errors
if __name__ == "__main__":
unittest.main()