File size: 546 Bytes
a5be142
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import lightning

def test_model_trainer_fit(multimodal_model, sample_train_val_datamodule):
    """Test end-to-end training."""
    # Get a sample batch for testing
    batch = next(iter(sample_train_val_datamodule.train_dataloader()))

    # Run a forward pass to verify the model works with the data
    y = multimodal_model(batch)

    # Train the model for one batch
    trainer = lightning.pytorch.trainer.trainer.Trainer(fast_dev_run=True, accelerator="cpu")
    trainer.fit(model=multimodal_model, datamodule=sample_train_val_datamodule)