sanjanatule commited on
Commit
4152812
·
1 Parent(s): ee9ceac

Create litmodelclass.py

Browse files
Files changed (1) hide show
  1. litmodelclass.py +138 -0
litmodelclass.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import datasets, transforms
2
+ import albumentations as Al
3
+ from albumentations.pytorch import ToTensorV2
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ from torch.optim.lr_scheduler import OneCycleLR
9
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
10
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
11
+ from pytorch_lightning.callbacks.progress import TQDMProgressBar
12
+ from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger
13
+ from tqdm import tqdm
14
+ import torch
15
+ import torch.optim as optim
16
+ import matplotlib
17
+ import cv2
18
+
19
+ # my files
20
+ import utils
21
+ import config
22
+ from model import YOLOv3
23
+ from utils import (
24
+ mean_average_precision,
25
+ cells_to_bboxes,
26
+ get_evaluation_bboxes,
27
+ save_checkpoint,
28
+ load_checkpoint,
29
+ check_class_accuracy,
30
+ plot_couple_examples,
31
+ accuracy_fn,
32
+ get_loaders
33
+ )
34
+ from loss import YoloLoss
35
+
36
+
37
+ # custom functions for yolo
38
+ # loss function for yolov3
39
+ loss_fn = YoloLoss()
40
+
41
+ def model_criterion(out, y,anchors):
42
+ loss = ( loss_fn(out[0], y[0], anchors[0])
43
+ + loss_fn(out[1], y[1], anchors[1])
44
+ + loss_fn(out[2], y[2], anchors[2])
45
+ )
46
+ return loss
47
+
48
+
49
+ # accuracy function for yolov3
50
+ def accuracy_fn(y, out, threshold,correct_class, correct_obj,correct_noobj, tot_class_preds,tot_obj, tot_noobj):
51
+
52
+ for i in range(3):
53
+
54
+ obj = y[i][..., 0] == 1 # in paper this is Iobj_i
55
+ noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
56
+
57
+ correct_class += torch.sum(
58
+ torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
59
+ )
60
+ tot_class_preds += torch.sum(obj)
61
+
62
+ obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
63
+ correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
64
+ tot_obj += torch.sum(obj)
65
+ correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
66
+ tot_noobj += torch.sum(noobj)
67
+
68
+ return((correct_class/(tot_class_preds+1e-16))*100,
69
+ (correct_noobj/(tot_noobj+1e-16))*100,
70
+ (correct_obj/(tot_obj+1e-16))*100)
71
+
72
+ # pytorch lightning
73
+ class LitYolo(LightningModule):
74
+ def __init__(self, num_classes=config.NUM_CLASSES, lr=1E-3,weight_decay=config.WEIGHT_DECAY,threshold=config.CONF_THRESHOLD):
75
+ super().__init__()
76
+
77
+ self.save_hyperparameters()
78
+ self.model = YOLOv3(num_classes=self.hparams.num_classes)
79
+ self.criterion = model_criterion
80
+ self.accuracy_fn = accuracy_fn
81
+ self.scaled_anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
82
+ self.tot_class_preds, self.correct_class = 0, 0
83
+ self.tot_noobj, self.correct_noobj = 0, 0
84
+ self.tot_obj, self.correct_obj = 0, 0
85
+
86
+ def forward(self, x):
87
+ out = self.model(x)
88
+ return out
89
+
90
+ def training_step(self, batch, batch_idx):
91
+ x, y = batch
92
+ out = self(x)
93
+ loss = self.criterion(out,y,self.scaled_anchors)
94
+ acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
95
+ self.correct_obj,
96
+ self.correct_noobj,
97
+ self.tot_class_preds,
98
+ self.tot_obj,
99
+ self.tot_noobj)
100
+
101
+ self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
102
+ self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True,on_step=False, on_epoch=True)
103
+ return loss
104
+
105
+
106
+ def evaluate(self, batch, stage=None):
107
+ x, y = batch
108
+ out = self(x)
109
+ loss = self.criterion(out,y,self.scaled_anchors)
110
+ acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
111
+ self.correct_obj,
112
+ self.correct_noobj,
113
+ self.tot_class_preds,
114
+ self.tot_obj,
115
+ self.tot_noobj)
116
+
117
+ if stage:
118
+ self.log(f"{stage}_loss", loss, prog_bar=True)
119
+ self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True)
120
+
121
+ def test_step(self, batch, batch_idx):
122
+ self.evaluate(batch, "test")
123
+
124
+ def validation_step(self, batch, batch_idx):
125
+ self.evaluate(batch, "val")
126
+
127
+ def configure_optimizers(self):
128
+ optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
129
+ scheduler = OneCycleLR(
130
+ optimizer,
131
+ max_lr= 1E-3,
132
+ pct_start = 5/self.trainer.max_epochs,
133
+ epochs=self.trainer.max_epochs,
134
+ steps_per_epoch=len(train_loader),
135
+ div_factor=100,verbose=True,
136
+ three_phase=False
137
+ )
138
+ return ([optimizer],[scheduler])