tchauffi's picture
Update mnist_model/train.py
116c743
from typing import Any
import torch
import lightning as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
from lightning.pytorch.callbacks.model_summary import ModelSummary
from lightning.pytorch import seed_everything
from tqdm import tqdm
from mnist_model.model import ResNet
from mnist_model.database import get_dataset
seed_everything(42, workers=True)
class MNISTModel(pl.LightningModule):
def __init__(self, lr=1e-3):
super().__init__()
self.model = ResNet()
self.save_hyperparameters()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = torch.nn.functional.cross_entropy(pred, y)
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = torch.nn.functional.cross_entropy(pred, y)
self.log('val_loss', loss, prog_bar=True)
accuracy = torch.sum(torch.argmax(pred, dim=1) == y).item() / len(y)
self.log('val_accuracy', accuracy, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = torch.nn.functional.cross_entropy(pred, y)
self.log('test_loss', loss)
return loss
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
return self(batch)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
def train(split=0.2, batch_size=64):
model = MNISTModel()
train_dataset, _ = get_dataset()
train_size = int((1 - split) * len(train_dataset))
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, len(train_dataset) - train_size])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)
trainer = pl.Trainer(max_epochs=10, callbacks=[EarlyStopping(monitor='train_loss'), LearningRateFinder(1e-6, 1e-2), ModelSummary()], deterministic=True, num_sanity_val_steps=2)
trainer.fit(model, train_loader, val_loader)
return model, trainer
def test(model, trainer, batch_size=64):
_, test_dataset = get_dataset()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
trainer.test(model, test_loader)
def save(model, trainer):
trainer.save_checkpoint("model.pt")
if __name__ == '__main__':
model, trainer = train()
test(model, trainer)
save(model, trainer)