Spaces:
Sleeping
Sleeping
| from typing import Any,List,Tuple,Dict | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision.utils import make_grid | |
| from torch.optim import Optimizer,Adam,SGD | |
| from lightning import LightningModule | |
| from torchmetrics import Accuracy,F1Score,AUROC,ConfusionMatrix | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| torch.set_default_device( device= device ) | |
| from .mnist_model import Net | |
| __all__: List[str] = ["LitMNISTModel"] | |
| class LitMNISTModel(LightningModule): | |
| def __init__( | |
| self, | |
| learning_rate:float = 3e-4, | |
| num_classes:int = 10, | |
| dropout_rate:float=0.01, | |
| bias:bool=False, | |
| momentum:float =.9, | |
| *args: Any, | |
| **kwargs: Any | |
| ) -> None: | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.learning_rate:float = learning_rate | |
| self.num_class:int = num_classes | |
| self.momentum:float = momentum | |
| # metric | |
| ## Accuracy | |
| self.train_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
| self.val_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
| self.test_accuracy = Accuracy(task="multiclass", num_classes=num_classes) | |
| ## F1 Score | |
| self.train_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
| self.val_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
| self.test_f1 = F1Score(task="multiclass", num_classes=num_classes) | |
| ## Model | |
| self.model = Net(config={'dropout_rate':dropout_rate, 'bias':bias}) | |
| def forward(self, x) -> Any: | |
| return self.model(x) | |
| def training_step(self, batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor: | |
| x,y = batch | |
| logits = self(x) | |
| loss = F.nll_loss(logits,y) | |
| preds = torch.argmax(logits,dim=1) | |
| acc = self.train_accuracy(preds,y) | |
| f1 = self.train_f1(preds,y) | |
| self.log("train/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
| self.log("train/acc",acc,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) | |
| self.log("train/train_f1",f1,prog_bar=True,on_epoch=False,on_step=True,logger=self.trainer.logger) | |
| if batch_idx==0: | |
| grid = make_grid(x) | |
| self.logger.experiment.add_image("train_imgs",grid,self.current_epoch) | |
| return { | |
| 'loss':loss, | |
| 'logits':logits, | |
| 'preds':preds | |
| } | |
| def validation_step(self,batch,batch_idx, *args: Any, **kwargs: Any) -> torch.Tensor : | |
| x,y = batch | |
| logits = self(x) | |
| loss = F.nll_loss(logits,y) | |
| preds = torch.argmax(logits,dim=1) | |
| acc = self.val_accuracy(preds,y) | |
| f1 = self.val_f1(preds,y) | |
| self.log("val/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
| self.log("val/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
| self.log("val/val_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) | |
| if batch_idx==0: | |
| grid = make_grid(x) | |
| self.logger.experiment.add_image("val_imgs",grid,self.current_epoch) | |
| return { | |
| 'loss':loss, | |
| 'logits':logits, | |
| 'preds':preds | |
| } | |
| def predict_step(self,x:torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: | |
| with torch.no_grad(): | |
| logits = self(x) | |
| probs,indices = torch.max( F.softmax(logits,dim=1), dim=1) | |
| return { | |
| 'prob':probs, | |
| 'predict':indices | |
| } | |
| def test_step(self,batch): | |
| x,y = batch | |
| logits = self(x) | |
| loss = F.nll_loss(logits,y) | |
| preds = torch.argmax(logits,dim=1) | |
| acc = self.test_accuracy(preds,y) | |
| f1 = self.test_f1(preds,y) | |
| self.log("test/loss",loss,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
| self.log("test/acc",acc,prog_bar=True,on_epoch=True,on_step=True,logger=self.trainer.logger) | |
| self.log("test/test_f1",f1,prog_bar=True,on_epoch=True,on_step=False,logger=self.trainer.logger) | |
| return { | |
| 'loss':loss, | |
| 'logits':logits, | |
| 'preds':preds | |
| } | |
| def configure_optimizers(self): | |
| # optimizer = SGD(self.parameters(),lr=self.learning_rate,momentum=self.momentum) | |
| # Reduce LR ON Plateau | |
| # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,factor=.1,patience=2,verbose=True) | |
| # return { | |
| # "optimizer": optimizer, | |
| # "lr_scheduler": scheduler, | |
| # "monitor": 'val/loss', | |
| # 'interval':'step', | |
| # "frequency": 15 | |
| # } | |
| optimizer = Adam(self.parameters(),lr=1e3) | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer, | |
| max_lr=1e2*self.learning_rate, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| pct_start=.3, | |
| cycle_momentum=True, | |
| div_factor =100, | |
| final_div_factor = 1e10, | |
| verbose = False, | |
| three_phase=True | |
| ) | |
| return ([optimizer],[scheduler]) |