Spaces:
Sleeping
Sleeping
| from torchvision import datasets, transforms | |
| import albumentations as Al | |
| from albumentations.pytorch import ToTensorV2 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from pytorch_lightning import LightningModule, Trainer, seed_everything | |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
| from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
| from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger | |
| from tqdm import tqdm | |
| import torch | |
| import torch.optim as optim | |
| import matplotlib | |
| import cv2 | |
| # my files | |
| import utils | |
| import config | |
| from model import YOLOv3 | |
| from utils import ( | |
| mean_average_precision, | |
| cells_to_bboxes, | |
| get_evaluation_bboxes, | |
| save_checkpoint, | |
| load_checkpoint, | |
| check_class_accuracy, | |
| plot_couple_examples, | |
| accuracy_fn, | |
| get_loaders | |
| ) | |
| from loss import YoloLoss | |
| # custom functions for yolo | |
| # loss function for yolov3 | |
| loss_fn = YoloLoss() | |
| def model_criterion(out, y,anchors): | |
| loss = ( loss_fn(out[0], y[0], anchors[0]) | |
| + loss_fn(out[1], y[1], anchors[1]) | |
| + loss_fn(out[2], y[2], anchors[2]) | |
| ) | |
| return loss | |
| # accuracy function for yolov3 | |
| def accuracy_fn(y, out, threshold,correct_class, correct_obj,correct_noobj, tot_class_preds,tot_obj, tot_noobj): | |
| for i in range(3): | |
| obj = y[i][..., 0] == 1 # in paper this is Iobj_i | |
| noobj = y[i][..., 0] == 0 # in paper this is Iobj_i | |
| correct_class += torch.sum( | |
| torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] | |
| ) | |
| tot_class_preds += torch.sum(obj) | |
| obj_preds = torch.sigmoid(out[i][..., 0]) > threshold | |
| correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) | |
| tot_obj += torch.sum(obj) | |
| correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) | |
| tot_noobj += torch.sum(noobj) | |
| return((correct_class/(tot_class_preds+1e-16))*100, | |
| (correct_noobj/(tot_noobj+1e-16))*100, | |
| (correct_obj/(tot_obj+1e-16))*100) | |
| # pytorch lightning | |
| class LitYolo(LightningModule): | |
| def __init__(self, num_classes=config.NUM_CLASSES, lr=1E-3,weight_decay=config.WEIGHT_DECAY,threshold=config.CONF_THRESHOLD): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.model = YOLOv3(num_classes=self.hparams.num_classes) | |
| self.criterion = model_criterion | |
| self.accuracy_fn = accuracy_fn | |
| self.scaled_anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)) | |
| self.tot_class_preds, self.correct_class = 0, 0 | |
| self.tot_noobj, self.correct_noobj = 0, 0 | |
| self.tot_obj, self.correct_obj = 0, 0 | |
| def forward(self, x): | |
| out = self.model(x) | |
| return out | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| out = self(x) | |
| loss = self.criterion(out,y,self.scaled_anchors) | |
| acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class, | |
| self.correct_obj, | |
| self.correct_noobj, | |
| self.tot_class_preds, | |
| self.tot_obj, | |
| self.tot_noobj) | |
| self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True) | |
| self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True,on_step=False, on_epoch=True) | |
| return loss | |
| def evaluate(self, batch, stage=None): | |
| x, y = batch | |
| out = self(x) | |
| loss = self.criterion(out,y,self.scaled_anchors) | |
| acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class, | |
| self.correct_obj, | |
| self.correct_noobj, | |
| self.tot_class_preds, | |
| self.tot_obj, | |
| self.tot_noobj) | |
| if stage: | |
| self.log(f"{stage}_loss", loss, prog_bar=True) | |
| self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True) | |
| def test_step(self, batch, batch_idx): | |
| self.evaluate(batch, "test") | |
| def validation_step(self, batch, batch_idx): | |
| self.evaluate(batch, "val") | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) | |
| scheduler = OneCycleLR( | |
| optimizer, | |
| max_lr= 1E-3, | |
| pct_start = 5/self.trainer.max_epochs, | |
| epochs=self.trainer.max_epochs, | |
| steps_per_epoch=len(train_loader), | |
| div_factor=100,verbose=True, | |
| three_phase=False | |
| ) | |
| return ([optimizer],[scheduler]) |