| import os | |
| import torch | |
| from tests.utils.mnist import MnistModel, MnistModelConfig | |
| from trainer import Trainer, TrainerArgs | |
| is_cuda = torch.cuda.is_available() | |
| def test_train_mnist(): | |
| model = MnistModel() | |
| trainer = Trainer( | |
| TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None | |
| ) | |
| trainer.fit() | |
| loss1 = trainer.keep_avg_train["avg_loss"] | |
| trainer.fit() | |
| loss2 = trainer.keep_avg_train["avg_loss"] | |
| assert loss1 > loss2 | |