| import lightning | |
| def test_model_trainer_fit(multimodal_model, sample_train_val_datamodule): | |
| """Test end-to-end training.""" | |
| # Get a sample batch for testing | |
| batch = next(iter(sample_train_val_datamodule.train_dataloader())) | |
| # Run a forward pass to verify the model works with the data | |
| y = multimodal_model(batch) | |
| # Train the model for one batch | |
| trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True, accelerator="cpu") | |
| trainer.fit(model=multimodal_model, datamodule=sample_train_val_datamodule) | |