Spaces:
Build error
Build error
| from typing import Any | |
| import torch | |
| import lightning as pl | |
| from lightning.pytorch.callbacks.early_stopping import EarlyStopping | |
| from lightning.pytorch.callbacks.lr_finder import LearningRateFinder | |
| from lightning.pytorch.callbacks.model_summary import ModelSummary | |
| from lightning.pytorch import seed_everything | |
| from tqdm import tqdm | |
| from mnist_model.model import ResNet | |
| from mnist_model.database import get_dataset | |
| seed_everything(42, workers=True) | |
| class MNISTModel(pl.LightningModule): | |
| def __init__(self, lr=1e-3): | |
| super().__init__() | |
| self.model = ResNet() | |
| self.save_hyperparameters() | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| pred = self(x) | |
| loss = torch.nn.functional.cross_entropy(pred, y) | |
| self.log('train_loss', loss, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| pred = self(x) | |
| loss = torch.nn.functional.cross_entropy(pred, y) | |
| self.log('val_loss', loss, prog_bar=True) | |
| accuracy = torch.sum(torch.argmax(pred, dim=1) == y).item() / len(y) | |
| self.log('val_accuracy', accuracy, prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| pred = self(x) | |
| loss = torch.nn.functional.cross_entropy(pred, y) | |
| self.log('test_loss', loss) | |
| return loss | |
| def predict_step(self, batch, batch_idx=None, dataloader_idx=None): | |
| return self(batch) | |
| def configure_optimizers(self): | |
| return torch.optim.Adam(self.parameters(), lr=self.hparams.lr) | |
| def train(split=0.2, batch_size=64): | |
| model = MNISTModel() | |
| train_dataset, _ = get_dataset() | |
| train_size = int((1 - split) * len(train_dataset)) | |
| train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, len(train_dataset) - train_size]) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True) | |
| val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True) | |
| trainer = pl.Trainer(max_epochs=10, callbacks=[EarlyStopping(monitor='train_loss'), LearningRateFinder(1e-6, 1e-2), ModelSummary()], deterministic=True, num_sanity_val_steps=2) | |
| trainer.fit(model, train_loader, val_loader) | |
| return model, trainer | |
| def test(model, trainer, batch_size=64): | |
| _, test_dataset = get_dataset() | |
| test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
| trainer.test(model, test_loader) | |
| def save(model, trainer): | |
| trainer.save_checkpoint("model.pt") | |
| if __name__ == '__main__': | |
| model, trainer = train() | |
| test(model, trainer) | |
| save(model, trainer) | |