sanjanatule commited on
Commit
00f975e
·
1 Parent(s): 4152812

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -107
app.py CHANGED
@@ -38,112 +38,7 @@ from utils import (
38
  get_loaders
39
  )
40
  from loss import YoloLoss
41
-
42
-
43
- # custom functions for yolo
44
- # loss function for yolov3
45
- loss_fn = YoloLoss()
46
-
47
- def model_criterion(out, y,anchors):
48
- loss = ( loss_fn(out[0], y[0], anchors[0])
49
- + loss_fn(out[1], y[1], anchors[1])
50
- + loss_fn(out[2], y[2], anchors[2])
51
- )
52
- return loss
53
-
54
-
55
- # accuracy function for yolov3
56
- def accuracy_fn(y, out, threshold,correct_class, correct_obj,correct_noobj, tot_class_preds,tot_obj, tot_noobj):
57
-
58
- for i in range(3):
59
-
60
- obj = y[i][..., 0] == 1 # in paper this is Iobj_i
61
- noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
62
-
63
- correct_class += torch.sum(
64
- torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
65
- )
66
- tot_class_preds += torch.sum(obj)
67
-
68
- obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
69
- correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
70
- tot_obj += torch.sum(obj)
71
- correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
72
- tot_noobj += torch.sum(noobj)
73
-
74
- return((correct_class/(tot_class_preds+1e-16))*100,
75
- (correct_noobj/(tot_noobj+1e-16))*100,
76
- (correct_obj/(tot_obj+1e-16))*100)
77
-
78
- # pytorch lightning
79
- class LitYolo(LightningModule):
80
- def __init__(self, num_classes=config.NUM_CLASSES, lr=1E-3,weight_decay=config.WEIGHT_DECAY,threshold=config.CONF_THRESHOLD):
81
- super().__init__()
82
-
83
- self.save_hyperparameters()
84
- self.model = YOLOv3(num_classes=self.hparams.num_classes)
85
- self.criterion = model_criterion
86
- self.accuracy_fn = accuracy_fn
87
- self.scaled_anchors = (torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
88
- self.tot_class_preds, self.correct_class = 0, 0
89
- self.tot_noobj, self.correct_noobj = 0, 0
90
- self.tot_obj, self.correct_obj = 0, 0
91
-
92
- def forward(self, x):
93
- out = self.model(x)
94
- return out
95
-
96
- def training_step(self, batch, batch_idx):
97
- x, y = batch
98
- out = self(x)
99
- loss = self.criterion(out,y,self.scaled_anchors)
100
- acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
101
- self.correct_obj,
102
- self.correct_noobj,
103
- self.tot_class_preds,
104
- self.tot_obj,
105
- self.tot_noobj)
106
-
107
- self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
108
- self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True,on_step=False, on_epoch=True)
109
- return loss
110
-
111
-
112
- def evaluate(self, batch, stage=None):
113
- x, y = batch
114
- out = self(x)
115
- loss = self.criterion(out,y,self.scaled_anchors)
116
- acc = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
117
- self.correct_obj,
118
- self.correct_noobj,
119
- self.tot_class_preds,
120
- self.tot_obj,
121
- self.tot_noobj)
122
-
123
- if stage:
124
- self.log(f"{stage}_loss", loss, prog_bar=True)
125
- self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True)
126
-
127
- def test_step(self, batch, batch_idx):
128
- self.evaluate(batch, "test")
129
-
130
- def validation_step(self, batch, batch_idx):
131
- self.evaluate(batch, "val")
132
-
133
- def configure_optimizers(self):
134
- optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
135
- scheduler = OneCycleLR(
136
- optimizer,
137
- max_lr= 1E-3,
138
- pct_start = 5/self.trainer.max_epochs,
139
- epochs=self.trainer.max_epochs,
140
- steps_per_epoch=len(train_loader),
141
- div_factor=100,verbose=True,
142
- three_phase=False
143
- )
144
- return ([optimizer],[scheduler])
145
-
146
-
147
 
148
 
149
  # gradio
@@ -176,7 +71,7 @@ with gr.Blocks() as demo:
176
  def yolo3_inference(input_img,gradcam=True,gradcam_opa=0.5): # function for yolo inference
177
 
178
  # load model
179
- yololit = LitYolo()
180
  inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
181
 
182
  # bboxes, gradcam
 
38
  get_loaders
39
  )
40
  from loss import YoloLoss
41
+ import litmodelclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  # gradio
 
71
  def yolo3_inference(input_img,gradcam=True,gradcam_opa=0.5): # function for yolo inference
72
 
73
  # load model
74
+ yololit = litmodelclass.LitYolo()
75
  inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")
76
 
77
  # bboxes, gradcam