sanjanatule commited on
Commit
fe0548e
·
1 Parent(s): 23cccbd

Upload loss.py

Browse files
Files changed (1) hide show
  1. loss.py +80 -0
loss.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
3
+ the difference from what I can tell is I use CrossEntropy for the classes
4
+ instead of BinaryCrossEntropy.
5
+ """
6
+ import random
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from utils import intersection_over_union
11
+
12
+
13
+ class YoloLoss(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.mse = nn.MSELoss()
17
+ self.bce = nn.BCEWithLogitsLoss()
18
+ self.entropy = nn.CrossEntropyLoss()
19
+ self.sigmoid = nn.Sigmoid()
20
+
21
+ # Constants signifying how much to pay for each respective part of the loss
22
+ self.lambda_class = 1
23
+ self.lambda_noobj = 10
24
+ self.lambda_obj = 1
25
+ self.lambda_box = 10
26
+
27
+ def forward(self, predictions, target, anchors):
28
+ # Check where obj and noobj (we ignore if target == -1)
29
+ obj = target[..., 0] == 1 # in paper this is Iobj_i
30
+ noobj = target[..., 0] == 0 # in paper this is Inoobj_i
31
+
32
+ # ======================= #
33
+ # FOR NO OBJECT LOSS #
34
+ # ======================= #
35
+
36
+ no_object_loss = self.bce(
37
+ (predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
38
+ )
39
+
40
+ # ==================== #
41
+ # FOR OBJECT LOSS #
42
+ # ==================== #
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ anchors = anchors.reshape(1, 3, 1, 1, 2).to(device)
46
+ box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
47
+ ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
48
+ object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
49
+
50
+ # ======================== #
51
+ # FOR BOX COORDINATES #
52
+ # ======================== #
53
+
54
+ predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
55
+ target[..., 3:5] = torch.log(
56
+ (1e-16 + target[..., 3:5] / anchors)
57
+ ) # width, height coordinates
58
+ box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
59
+
60
+ # ================== #
61
+ # FOR CLASS LOSS #
62
+ # ================== #
63
+
64
+ class_loss = self.entropy(
65
+ (predictions[..., 5:][obj]), (target[..., 5][obj].long()),
66
+ )
67
+
68
+ #print("__________________________________")
69
+ #print(self.lambda_box * box_loss)
70
+ #print(self.lambda_obj * object_loss)
71
+ #print(self.lambda_noobj * no_object_loss)
72
+ #print(self.lambda_class * class_loss)
73
+ #print("\n")
74
+
75
+ return (
76
+ self.lambda_box * box_loss
77
+ + self.lambda_obj * object_loss
78
+ + self.lambda_noobj * no_object_loss
79
+ + self.lambda_class * class_loss
80
+ )