Spaces:
Sleeping
Sleeping
Commit
·
2f51281
1
Parent(s):
70a1c01
Added files
Browse files- README.md +14 -5
- YOLOv3.pth +3 -0
- app.py +195 -0
- batch_sampler.py +47 -0
- config.py +103 -0
- dataset.py +215 -0
- dataset_org.py +127 -0
- examples/1.jpg +0 -0
- examples/2.jpg +0 -0
- loss.py +79 -0
- model.py +218 -0
- requirements.txt +9 -0
- train.py +180 -0
- utils.py +588 -0
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.40.1
|
| 8 |
app_file: app.py
|
|
@@ -10,4 +10,13 @@ pinned: false
|
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Object Detection With Yolov3
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.40.1
|
| 8 |
app_file: app.py
|
|
|
|
| 10 |
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# YOLOv3 Object Detection App
|
| 14 |
+
Welcome to the YOLOv3 Object Detection App! This repository showcases an interactive application that combines the power of YOLOv3, a state-of-the-art object detection model, with the elegance of Gradio.
|
| 15 |
+
|
| 16 |
+
## What is YOLOv3?
|
| 17 |
+
YOLO (You Only Look Once) is an advanced object detection algorithm that stands out for its real-time performance. YOLOv3, the third iteration of YOLO, further refines its predecessor's accuracy and speed by leveraging a series of convolutional layers to predict bounding boxes and class probabilities.
|
| 18 |
+
|
| 19 |
+
## How does the App Work?
|
| 20 |
+
The YOLOv3 Object Detection App allows you to experience the magic of YOLOv3 firsthand. Simply upload an image, and watch as the app processes it through the YOLOv3 model to identify and highlight objects of interest. The app then presents you with a visual output, displaying the image with bounding boxes around the detected objects.
|
| 21 |
+
|
| 22 |
+
Link to Github: https://github.com/selvaraj-sembulingam/ERA-V1/tree/main/Assignments/S13
|
YOLOv3.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14ad7d1fda29ed91ed955e38615e8b9c66a42e376ad115a6f5e1140f7aece657
|
| 3 |
+
size 250325919
|
app.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import datasets, transforms
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from train import YOLOv3Lightning
|
| 8 |
+
from utils import non_max_suppression, plot_image, cells_to_bboxes
|
| 9 |
+
from dataset import YOLODataset
|
| 10 |
+
import config
|
| 11 |
+
import albumentations as A
|
| 12 |
+
from albumentations.pytorch import ToTensorV2
|
| 13 |
+
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import matplotlib.patches as patches
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Load the model
|
| 19 |
+
model = YOLOv3Lightning(config)
|
| 20 |
+
model.load_state_dict(torch.load('YOLOv3.pth', map_location=torch.device('cpu')), strict=False)
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
# Anchor
|
| 24 |
+
scaled_anchors = (
|
| 25 |
+
torch.tensor(config.ANCHORS)
|
| 26 |
+
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
| 27 |
+
).to("cpu")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
test_transforms = A.Compose(
|
| 31 |
+
[
|
| 32 |
+
A.LongestMaxSize(max_size=416),
|
| 33 |
+
A.PadIfNeeded(
|
| 34 |
+
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
|
| 35 |
+
),
|
| 36 |
+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
|
| 37 |
+
ToTensorV2(),
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GradCAM:
|
| 43 |
+
def __init__(self, model, target_layer):
|
| 44 |
+
self.model = model
|
| 45 |
+
self.target_layer = target_layer
|
| 46 |
+
self.gradients = None
|
| 47 |
+
|
| 48 |
+
self.model.eval()
|
| 49 |
+
self._register_hooks()
|
| 50 |
+
|
| 51 |
+
def _register_hooks(self):
|
| 52 |
+
def forward_hook(module, input, output):
|
| 53 |
+
self.feature_map = output
|
| 54 |
+
|
| 55 |
+
def backward_hook(module, grad_input, grad_output):
|
| 56 |
+
self.gradients = grad_output[0]
|
| 57 |
+
|
| 58 |
+
target_module = self.model
|
| 59 |
+
for name in self.target_layer.split("."):
|
| 60 |
+
target_module = target_module._modules[name]
|
| 61 |
+
|
| 62 |
+
target_module.register_forward_hook(forward_hook)
|
| 63 |
+
target_module.register_backward_hook(backward_hook)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _get_gradients_and_features(self, image):
|
| 67 |
+
self.model.zero_grad()
|
| 68 |
+
outputs = self.model(image)
|
| 69 |
+
|
| 70 |
+
gradients_list = []
|
| 71 |
+
for output in outputs:
|
| 72 |
+
self.gradients = None # Reset gradients
|
| 73 |
+
output.backward(gradient=output, retain_graph=True)
|
| 74 |
+
gradients_list.append(self.gradients)
|
| 75 |
+
|
| 76 |
+
return gradients_list, self.feature_map
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def generate_heatmap(self, image):
|
| 80 |
+
gradients_list, feature_map = self._get_gradients_and_features(image)
|
| 81 |
+
|
| 82 |
+
for gradients, fmap in zip(gradients_list, feature_map):
|
| 83 |
+
if gradients is not None:
|
| 84 |
+
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
|
| 85 |
+
for i in range(len(pooled_gradients)):
|
| 86 |
+
fmap[:, i, :, :] *= pooled_gradients[i]
|
| 87 |
+
|
| 88 |
+
heatmap = torch.mean(feature_map, dim=1).squeeze().detach().numpy()
|
| 89 |
+
heatmap = np.maximum(heatmap, 0)
|
| 90 |
+
heatmap /= np.max(heatmap)
|
| 91 |
+
|
| 92 |
+
return heatmap
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def plot_image(image, boxes):
|
| 96 |
+
"""Plots predicted bounding boxes on the image"""
|
| 97 |
+
cmap = plt.get_cmap("tab20b")
|
| 98 |
+
class_labels = config.PASCAL_CLASSES
|
| 99 |
+
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
|
| 100 |
+
im = np.array(image)
|
| 101 |
+
height, width, _ = im.shape
|
| 102 |
+
|
| 103 |
+
# Create figure and axes
|
| 104 |
+
fig, ax = plt.subplots(1)
|
| 105 |
+
# Display the image
|
| 106 |
+
ax.imshow(im)
|
| 107 |
+
|
| 108 |
+
# box[0] is x midpoint, box[2] is width
|
| 109 |
+
# box[1] is y midpoint, box[3] is height
|
| 110 |
+
|
| 111 |
+
# Create a Rectangle patch
|
| 112 |
+
for box in boxes:
|
| 113 |
+
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
|
| 114 |
+
class_pred = box[0]
|
| 115 |
+
box = box[2:]
|
| 116 |
+
upper_left_x = box[0] - box[2] / 2
|
| 117 |
+
upper_left_y = box[1] - box[3] / 2
|
| 118 |
+
rect = patches.Rectangle(
|
| 119 |
+
(upper_left_x * width, upper_left_y * height),
|
| 120 |
+
box[2] * width,
|
| 121 |
+
box[3] * height,
|
| 122 |
+
linewidth=2,
|
| 123 |
+
edgecolor=colors[int(class_pred)],
|
| 124 |
+
facecolor="none",
|
| 125 |
+
)
|
| 126 |
+
# Add the patch to the Axes
|
| 127 |
+
ax.add_patch(rect)
|
| 128 |
+
plt.text(
|
| 129 |
+
upper_left_x * width,
|
| 130 |
+
upper_left_y * height,
|
| 131 |
+
s=class_labels[int(class_pred)],
|
| 132 |
+
color="white",
|
| 133 |
+
verticalalignment="top",
|
| 134 |
+
bbox={"color": colors[int(class_pred)], "pad": 0},
|
| 135 |
+
)
|
| 136 |
+
ax.axis('off')
|
| 137 |
+
plt.savefig('inference.png', bbox_inches='tight', pad_inches=0)
|
| 138 |
+
|
| 139 |
+
# Inference function
|
| 140 |
+
def inference(inp_image):
|
| 141 |
+
org_image = inp_image
|
| 142 |
+
transform = test_transforms
|
| 143 |
+
x = transform(image=inp_image)["image"].unsqueeze(0)
|
| 144 |
+
out = model(x)
|
| 145 |
+
|
| 146 |
+
bboxes = [[] for _ in range(x.shape[0])]
|
| 147 |
+
|
| 148 |
+
for i in range(3):
|
| 149 |
+
batch_size, A, S, _, _ = out[i].shape
|
| 150 |
+
anchor = scaled_anchors[i]
|
| 151 |
+
boxes_scale_i = cells_to_bboxes(
|
| 152 |
+
out[i], anchor, S=S, is_preds=True
|
| 153 |
+
)
|
| 154 |
+
for idx, (box) in enumerate(boxes_scale_i):
|
| 155 |
+
bboxes[idx] += box
|
| 156 |
+
|
| 157 |
+
nms_boxes = non_max_suppression(
|
| 158 |
+
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
|
| 159 |
+
)
|
| 160 |
+
plot_image(cv2.resize(org_image,(416,416)), nms_boxes)
|
| 161 |
+
plotted_img = 'inference.png'
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# GradCAM
|
| 165 |
+
grad_cam = GradCAM(model, target_layer="model.layers.27.layers.0")
|
| 166 |
+
image = cv2.cvtColor(org_image, cv2.COLOR_BGR2RGB)
|
| 167 |
+
image = cv2.resize(image, (416, 416))
|
| 168 |
+
image = image.transpose(2, 0, 1)
|
| 169 |
+
image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
|
| 170 |
+
heatmap = grad_cam.generate_heatmap(image)
|
| 171 |
+
heatmap = cv2.resize(heatmap, (image.shape[3], image.shape[2]))
|
| 172 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
|
| 173 |
+
overlay = heatmap * 0.4 + (image.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) * 0.6
|
| 174 |
+
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
|
| 175 |
+
overlay_bgr = cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)
|
| 176 |
+
output_path = "gradcam.png"
|
| 177 |
+
plt.imshow(overlay_bgr)
|
| 178 |
+
plt.axis('off')
|
| 179 |
+
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
|
| 180 |
+
plt.close()
|
| 181 |
+
gradcam_img = 'gradcam.png'
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
return plotted_img, gradcam_img
|
| 185 |
+
|
| 186 |
+
inputs = gr.inputs.Image(label="Original Image")
|
| 187 |
+
outputs = gr.outputs.Image(type="pil",label="Output Image")
|
| 188 |
+
gradcam = gr.outputs.Image(type="pil",label="GradCAM Image")
|
| 189 |
+
title = "YOLOv3 trained on PASCAL VOC"
|
| 190 |
+
description = """YOLOv3 Gradio demo for object detection
|
| 191 |
+
- Classes supported = aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor
|
| 192 |
+
"""
|
| 193 |
+
examples = [['examples/1.jpg'], ['examples/2.jpg']]
|
| 194 |
+
gr.Interface(inference, inputs, [outputs,gradcam], title=title, examples=examples, description=description, theme='abidlabs/dracula_revamped').launch(
|
| 195 |
+
debug=True)
|
batch_sampler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Sampler,RandomSampler,SequentialSampler
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class BatchSampler(object):
|
| 5 |
+
def __init__(self, sampler, batch_size, drop_last,multiscale_step=None,img_sizes = None):
|
| 6 |
+
if not isinstance(sampler, Sampler):
|
| 7 |
+
raise ValueError("sampler should be an instance of "
|
| 8 |
+
"torch.utils.data.Sampler, but got sampler={}"
|
| 9 |
+
.format(sampler))
|
| 10 |
+
if not isinstance(drop_last, bool):
|
| 11 |
+
raise ValueError("drop_last should be a boolean value, but got "
|
| 12 |
+
"drop_last={}".format(drop_last))
|
| 13 |
+
self.sampler = sampler
|
| 14 |
+
self.batch_size = batch_size
|
| 15 |
+
self.drop_last = drop_last
|
| 16 |
+
if multiscale_step is not None and multiscale_step < 1 :
|
| 17 |
+
raise ValueError("multiscale_step should be > 0, but got "
|
| 18 |
+
"multiscale_step={}".format(multiscale_step))
|
| 19 |
+
if multiscale_step is not None and img_sizes is None:
|
| 20 |
+
raise ValueError("img_sizes must a list, but got img_sizes={} ".format(img_sizes))
|
| 21 |
+
|
| 22 |
+
self.multiscale_step = multiscale_step
|
| 23 |
+
self.img_sizes = img_sizes
|
| 24 |
+
|
| 25 |
+
def __iter__(self):
|
| 26 |
+
num_batch = 0
|
| 27 |
+
batch = []
|
| 28 |
+
size = 416
|
| 29 |
+
for idx in self.sampler:
|
| 30 |
+
batch.append([idx,size])
|
| 31 |
+
if len(batch) == self.batch_size:
|
| 32 |
+
# print("Batch size reached:", batch)
|
| 33 |
+
yield batch
|
| 34 |
+
num_batch+=1
|
| 35 |
+
batch = []
|
| 36 |
+
if self.multiscale_step and num_batch % self.multiscale_step == 0 :
|
| 37 |
+
size = np.random.choice(self.img_sizes)
|
| 38 |
+
# print("Changing image size:", size)
|
| 39 |
+
if len(batch) > 0 and not self.drop_last:
|
| 40 |
+
# print("Last batch:", batch)
|
| 41 |
+
yield batch
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
if self.drop_last:
|
| 45 |
+
return len(self.sampler) // self.batch_size
|
| 46 |
+
else:
|
| 47 |
+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
config.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import albumentations as A
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from albumentations.pytorch import ToTensorV2
|
| 7 |
+
from utils import seed_everything
|
| 8 |
+
|
| 9 |
+
DATASET = '/storage/PASCAL_VOC'
|
| 10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
seed_everything() # If you want deterministic behavior
|
| 12 |
+
NUM_WORKERS = os.cpu_count()-1
|
| 13 |
+
BATCH_SIZE = 16
|
| 14 |
+
IMAGE_SIZE = 416
|
| 15 |
+
NUM_CLASSES = 20
|
| 16 |
+
LEARNING_RATE = 1e-5
|
| 17 |
+
WEIGHT_DECAY = 1e-4
|
| 18 |
+
NUM_EPOCHS = 40
|
| 19 |
+
CONF_THRESHOLD = 0.05
|
| 20 |
+
MAP_IOU_THRESH = 0.5
|
| 21 |
+
NMS_IOU_THRESH = 0.45
|
| 22 |
+
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
|
| 23 |
+
PIN_MEMORY = True
|
| 24 |
+
LOAD_MODEL = False
|
| 25 |
+
SAVE_MODEL = True
|
| 26 |
+
CHECKPOINT_FILE = "checkpoint.pth.tar"
|
| 27 |
+
IMG_DIR = DATASET + "/images/"
|
| 28 |
+
LABEL_DIR = DATASET + "/labels/"
|
| 29 |
+
|
| 30 |
+
ANCHORS = [
|
| 31 |
+
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
|
| 32 |
+
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
|
| 33 |
+
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
|
| 34 |
+
] # Note these have been rescaled to be between [0, 1]
|
| 35 |
+
|
| 36 |
+
means = [0.485, 0.456, 0.406]
|
| 37 |
+
|
| 38 |
+
scale = 1.1
|
| 39 |
+
train_transforms = A.Compose(
|
| 40 |
+
[
|
| 41 |
+
A.LongestMaxSize(max_size=int(IMAGE_SIZE * scale)),
|
| 42 |
+
A.PadIfNeeded(
|
| 43 |
+
min_height=int(IMAGE_SIZE * scale),
|
| 44 |
+
min_width=int(IMAGE_SIZE * scale),
|
| 45 |
+
border_mode=cv2.BORDER_CONSTANT,
|
| 46 |
+
),
|
| 47 |
+
A.Rotate(limit = 10, interpolation=1, border_mode=4),
|
| 48 |
+
A.RandomCrop(width=IMAGE_SIZE, height=IMAGE_SIZE),
|
| 49 |
+
A.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
|
| 50 |
+
A.OneOf(
|
| 51 |
+
[
|
| 52 |
+
A.ShiftScaleRotate(
|
| 53 |
+
rotate_limit=20, p=0.5, border_mode=cv2.BORDER_CONSTANT
|
| 54 |
+
),
|
| 55 |
+
# A.Affine(shear=15, p=0.5, mode="constant"),
|
| 56 |
+
],
|
| 57 |
+
p=1.0,
|
| 58 |
+
),
|
| 59 |
+
A.HorizontalFlip(p=0.5),
|
| 60 |
+
A.Blur(p=0.1),
|
| 61 |
+
A.CLAHE(p=0.1),
|
| 62 |
+
A.Posterize(p=0.1),
|
| 63 |
+
A.ToGray(p=0.1),
|
| 64 |
+
A.ChannelShuffle(p=0.05),
|
| 65 |
+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
|
| 66 |
+
ToTensorV2(),
|
| 67 |
+
],
|
| 68 |
+
bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[],),
|
| 69 |
+
)
|
| 70 |
+
test_transforms = A.Compose(
|
| 71 |
+
[
|
| 72 |
+
A.LongestMaxSize(max_size=IMAGE_SIZE),
|
| 73 |
+
A.PadIfNeeded(
|
| 74 |
+
min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
|
| 75 |
+
),
|
| 76 |
+
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
|
| 77 |
+
ToTensorV2(),
|
| 78 |
+
],
|
| 79 |
+
bbox_params=A.BboxParams(format="yolo", min_visibility=0.4, label_fields=[]),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
PASCAL_CLASSES = [
|
| 83 |
+
"aeroplane",
|
| 84 |
+
"bicycle",
|
| 85 |
+
"bird",
|
| 86 |
+
"boat",
|
| 87 |
+
"bottle",
|
| 88 |
+
"bus",
|
| 89 |
+
"car",
|
| 90 |
+
"cat",
|
| 91 |
+
"chair",
|
| 92 |
+
"cow",
|
| 93 |
+
"diningtable",
|
| 94 |
+
"dog",
|
| 95 |
+
"horse",
|
| 96 |
+
"motorbike",
|
| 97 |
+
"person",
|
| 98 |
+
"pottedplant",
|
| 99 |
+
"sheep",
|
| 100 |
+
"sofa",
|
| 101 |
+
"train",
|
| 102 |
+
"tvmonitor"
|
| 103 |
+
]
|
dataset.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import config
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from utils import xywhn2xyxy, xyxy2xywhn
|
| 11 |
+
import random
|
| 12 |
+
import torchvision.transforms as transforms
|
| 13 |
+
from batch_sampler import BatchSampler,RandomSampler,SequentialSampler
|
| 14 |
+
|
| 15 |
+
from PIL import Image, ImageFile
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
from utils import (
|
| 18 |
+
cells_to_bboxes,
|
| 19 |
+
iou_width_height as iou,
|
| 20 |
+
non_max_suppression as nms,
|
| 21 |
+
plot_image
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 25 |
+
|
| 26 |
+
class YOLODataset(Dataset):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
csv_file,
|
| 30 |
+
img_dir,
|
| 31 |
+
label_dir,
|
| 32 |
+
anchors,
|
| 33 |
+
image_size=416,
|
| 34 |
+
S=[13, 26, 52],
|
| 35 |
+
C=20,
|
| 36 |
+
transform=None,
|
| 37 |
+
):
|
| 38 |
+
self.annotations = pd.read_csv(csv_file)
|
| 39 |
+
self.img_dir = img_dir
|
| 40 |
+
self.label_dir = label_dir
|
| 41 |
+
self.image_size = image_size
|
| 42 |
+
self.mosaic_border = [image_size // 2, image_size // 2]
|
| 43 |
+
self.transform = transform
|
| 44 |
+
self.S = S
|
| 45 |
+
self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales
|
| 46 |
+
self.num_anchors = self.anchors.shape[0]
|
| 47 |
+
self.num_anchors_per_scale = self.num_anchors // 3
|
| 48 |
+
self.C = C
|
| 49 |
+
self.ignore_iou_thresh = 0.5
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.annotations)
|
| 53 |
+
|
| 54 |
+
def load_mosaic(self, index):
|
| 55 |
+
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
|
| 56 |
+
labels4 = []
|
| 57 |
+
s = self.image_size
|
| 58 |
+
yc, xc = (int(random.uniform(x, 2 * s - x)) for x in self.mosaic_border) # mosaic center x, y
|
| 59 |
+
indices = [index] + random.choices(range(len(self)), k=3) # 3 additional image indices
|
| 60 |
+
random.shuffle(indices)
|
| 61 |
+
for i, index in enumerate(indices):
|
| 62 |
+
# Load image
|
| 63 |
+
label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
|
| 64 |
+
bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
|
| 65 |
+
img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
|
| 66 |
+
img = np.array(Image.open(img_path).convert("RGB"))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
h, w = img.shape[0], img.shape[1]
|
| 70 |
+
labels = np.array(bboxes)
|
| 71 |
+
|
| 72 |
+
# place img in img4
|
| 73 |
+
if i == 0: # top left
|
| 74 |
+
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
| 75 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
| 76 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
| 77 |
+
elif i == 1: # top right
|
| 78 |
+
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
| 79 |
+
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
| 80 |
+
elif i == 2: # bottom left
|
| 81 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
| 82 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
| 83 |
+
elif i == 3: # bottom right
|
| 84 |
+
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
| 85 |
+
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
| 86 |
+
|
| 87 |
+
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
| 88 |
+
padw = x1a - x1b
|
| 89 |
+
padh = y1a - y1b
|
| 90 |
+
|
| 91 |
+
# Labels
|
| 92 |
+
if labels.size:
|
| 93 |
+
labels[:, :-1] = xywhn2xyxy(labels[:, :-1], w, h, padw, padh) # normalized xywh to pixel xyxy format
|
| 94 |
+
labels4.append(labels)
|
| 95 |
+
|
| 96 |
+
# Concat/clip labels
|
| 97 |
+
labels4 = np.concatenate(labels4, 0)
|
| 98 |
+
for x in (labels4[:, :-1],):
|
| 99 |
+
np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
|
| 100 |
+
# img4, labels4 = replicate(img4, labels4) # replicate
|
| 101 |
+
labels4[:, :-1] = xyxy2xywhn(labels4[:, :-1], 2 * s, 2 * s)
|
| 102 |
+
labels4[:, :-1] = np.clip(labels4[:, :-1], 0, 1)
|
| 103 |
+
labels4 = labels4[labels4[:, 2] > 0]
|
| 104 |
+
labels4 = labels4[labels4[:, 3] > 0]
|
| 105 |
+
return img4, labels4
|
| 106 |
+
|
| 107 |
+
def resize(self, img, size):
|
| 108 |
+
# Image resizing for Multi-resolution training
|
| 109 |
+
transform = transforms.Resize((size, size))
|
| 110 |
+
img = transform(img)
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
def __getitem__(self, index):
|
| 114 |
+
sizee = None
|
| 115 |
+
if isinstance(index, list):
|
| 116 |
+
sizee = index[1]
|
| 117 |
+
index = index[0]
|
| 118 |
+
|
| 119 |
+
# apply mosaic 50% of the times
|
| 120 |
+
if random.random() >= 0.5:
|
| 121 |
+
image, bboxes = self.load_mosaic(index)
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
|
| 125 |
+
bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
|
| 126 |
+
img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
|
| 127 |
+
image = np.array(Image.open(img_path).convert("RGB"))
|
| 128 |
+
|
| 129 |
+
if self.transform:
|
| 130 |
+
augmentations = self.transform(image=image, bboxes=bboxes)
|
| 131 |
+
image = augmentations["image"]
|
| 132 |
+
bboxes = augmentations["bboxes"]
|
| 133 |
+
|
| 134 |
+
if sizee:
|
| 135 |
+
image = self.resize(image, sizee)
|
| 136 |
+
|
| 137 |
+
# Below assumes 3 scale predictions (as paper) and same num of anchors per scale
|
| 138 |
+
targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
|
| 139 |
+
for box in bboxes:
|
| 140 |
+
iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
|
| 141 |
+
anchor_indices = iou_anchors.argsort(descending=True, dim=0)
|
| 142 |
+
x, y, width, height, class_label = box
|
| 143 |
+
has_anchor = [False] * 3 # each scale should have one anchor
|
| 144 |
+
for anchor_idx in anchor_indices:
|
| 145 |
+
scale_idx = anchor_idx // self.num_anchors_per_scale
|
| 146 |
+
anchor_on_scale = anchor_idx % self.num_anchors_per_scale
|
| 147 |
+
S = self.S[scale_idx]
|
| 148 |
+
i, j = int(S * y), int(S * x) # which cell
|
| 149 |
+
anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
|
| 150 |
+
if not anchor_taken and not has_anchor[scale_idx]:
|
| 151 |
+
targets[scale_idx][anchor_on_scale, i, j, 0] = 1
|
| 152 |
+
x_cell, y_cell = S * x - j, S * y - i # both between [0,1]
|
| 153 |
+
width_cell, height_cell = (
|
| 154 |
+
width * S,
|
| 155 |
+
height * S,
|
| 156 |
+
) # can be greater than 1 since it's relative to cell
|
| 157 |
+
box_coordinates = torch.tensor(
|
| 158 |
+
[x_cell, y_cell, width_cell, height_cell]
|
| 159 |
+
)
|
| 160 |
+
targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
|
| 161 |
+
targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
|
| 162 |
+
has_anchor[scale_idx] = True
|
| 163 |
+
|
| 164 |
+
elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
|
| 165 |
+
targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction
|
| 166 |
+
|
| 167 |
+
return image, tuple(targets)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def test():
|
| 171 |
+
anchors = config.ANCHORS
|
| 172 |
+
|
| 173 |
+
transform = config.test_transforms
|
| 174 |
+
|
| 175 |
+
dataset = YOLODataset(
|
| 176 |
+
config.DATASET + "/train.csv",
|
| 177 |
+
config.DATASET + "/images/",
|
| 178 |
+
config.DATASET + "/labels/",
|
| 179 |
+
S=[13, 26, 52],
|
| 180 |
+
anchors=anchors,
|
| 181 |
+
transform=transform,
|
| 182 |
+
)
|
| 183 |
+
S = [13, 26, 52]
|
| 184 |
+
scaled_anchors = torch.tensor(anchors) / (
|
| 185 |
+
1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
| 186 |
+
)
|
| 187 |
+
# loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
|
| 188 |
+
|
| 189 |
+
loader = DataLoader(dataset=dataset,
|
| 190 |
+
batch_sampler= BatchSampler(SequentialSampler(dataset),
|
| 191 |
+
batch_size=1,
|
| 192 |
+
drop_last=True,
|
| 193 |
+
multiscale_step=1,
|
| 194 |
+
img_sizes=list(range(320, 608 + 1, 32))
|
| 195 |
+
),
|
| 196 |
+
# num_workers=4
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
for x, y in loader:
|
| 200 |
+
boxes = []
|
| 201 |
+
|
| 202 |
+
for i in range(y[0].shape[1]):
|
| 203 |
+
anchor = scaled_anchors[i]
|
| 204 |
+
print(anchor.shape)
|
| 205 |
+
print(y[i].shape)
|
| 206 |
+
boxes += cells_to_bboxes(
|
| 207 |
+
y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
|
| 208 |
+
)[0]
|
| 209 |
+
boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
|
| 210 |
+
print(boxes)
|
| 211 |
+
plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
test()
|
dataset_org.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Creates a Pytorch dataset to load the Pascal VOC & MS COCO datasets
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import config
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from PIL import Image, ImageFile
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from utils import (
|
| 14 |
+
cells_to_bboxes,
|
| 15 |
+
iou_width_height as iou,
|
| 16 |
+
non_max_suppression as nms,
|
| 17 |
+
plot_image
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 21 |
+
|
| 22 |
+
class YOLODataset(Dataset):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
csv_file,
|
| 26 |
+
img_dir,
|
| 27 |
+
label_dir,
|
| 28 |
+
anchors,
|
| 29 |
+
image_size=416,
|
| 30 |
+
S=[13, 26, 52],
|
| 31 |
+
C=20,
|
| 32 |
+
transform=None,
|
| 33 |
+
):
|
| 34 |
+
self.annotations = pd.read_csv(csv_file)
|
| 35 |
+
self.img_dir = img_dir
|
| 36 |
+
self.label_dir = label_dir
|
| 37 |
+
self.image_size = image_size
|
| 38 |
+
self.transform = transform
|
| 39 |
+
self.S = S
|
| 40 |
+
self.anchors = torch.tensor(anchors[0] + anchors[1] + anchors[2]) # for all 3 scales
|
| 41 |
+
self.num_anchors = self.anchors.shape[0]
|
| 42 |
+
self.num_anchors_per_scale = self.num_anchors // 3
|
| 43 |
+
self.C = C
|
| 44 |
+
self.ignore_iou_thresh = 0.5
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return len(self.annotations)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, index):
|
| 50 |
+
label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
|
| 51 |
+
bboxes = np.roll(np.loadtxt(fname=label_path, delimiter=" ", ndmin=2), 4, axis=1).tolist()
|
| 52 |
+
img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
|
| 53 |
+
image = np.array(Image.open(img_path).convert("RGB"))
|
| 54 |
+
|
| 55 |
+
if self.transform:
|
| 56 |
+
augmentations = self.transform(image=image, bboxes=bboxes)
|
| 57 |
+
image = augmentations["image"]
|
| 58 |
+
bboxes = augmentations["bboxes"]
|
| 59 |
+
|
| 60 |
+
# Below assumes 3 scale predictions (as paper) and same num of anchors per scale
|
| 61 |
+
targets = [torch.zeros((self.num_anchors // 3, S, S, 6)) for S in self.S]
|
| 62 |
+
for box in bboxes:
|
| 63 |
+
iou_anchors = iou(torch.tensor(box[2:4]), self.anchors)
|
| 64 |
+
anchor_indices = iou_anchors.argsort(descending=True, dim=0)
|
| 65 |
+
x, y, width, height, class_label = box
|
| 66 |
+
has_anchor = [False] * 3 # each scale should have one anchor
|
| 67 |
+
for anchor_idx in anchor_indices:
|
| 68 |
+
scale_idx = anchor_idx // self.num_anchors_per_scale
|
| 69 |
+
anchor_on_scale = anchor_idx % self.num_anchors_per_scale
|
| 70 |
+
S = self.S[scale_idx]
|
| 71 |
+
i, j = int(S * y), int(S * x) # which cell
|
| 72 |
+
anchor_taken = targets[scale_idx][anchor_on_scale, i, j, 0]
|
| 73 |
+
if not anchor_taken and not has_anchor[scale_idx]:
|
| 74 |
+
targets[scale_idx][anchor_on_scale, i, j, 0] = 1
|
| 75 |
+
x_cell, y_cell = S * x - j, S * y - i # both between [0,1]
|
| 76 |
+
width_cell, height_cell = (
|
| 77 |
+
width * S,
|
| 78 |
+
height * S,
|
| 79 |
+
) # can be greater than 1 since it's relative to cell
|
| 80 |
+
box_coordinates = torch.tensor(
|
| 81 |
+
[x_cell, y_cell, width_cell, height_cell]
|
| 82 |
+
)
|
| 83 |
+
targets[scale_idx][anchor_on_scale, i, j, 1:5] = box_coordinates
|
| 84 |
+
targets[scale_idx][anchor_on_scale, i, j, 5] = int(class_label)
|
| 85 |
+
has_anchor[scale_idx] = True
|
| 86 |
+
|
| 87 |
+
elif not anchor_taken and iou_anchors[anchor_idx] > self.ignore_iou_thresh:
|
| 88 |
+
targets[scale_idx][anchor_on_scale, i, j, 0] = -1 # ignore prediction
|
| 89 |
+
|
| 90 |
+
return image, tuple(targets)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test():
|
| 94 |
+
anchors = config.ANCHORS
|
| 95 |
+
|
| 96 |
+
transform = config.test_transforms
|
| 97 |
+
|
| 98 |
+
dataset = YOLODataset(
|
| 99 |
+
"COCO/train.csv",
|
| 100 |
+
"COCO/images/images/",
|
| 101 |
+
"COCO/labels/labels_new/",
|
| 102 |
+
S=[13, 26, 52],
|
| 103 |
+
anchors=anchors,
|
| 104 |
+
transform=transform,
|
| 105 |
+
)
|
| 106 |
+
S = [13, 26, 52]
|
| 107 |
+
scaled_anchors = torch.tensor(anchors) / (
|
| 108 |
+
1 / torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
| 109 |
+
)
|
| 110 |
+
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)
|
| 111 |
+
for x, y in loader:
|
| 112 |
+
boxes = []
|
| 113 |
+
|
| 114 |
+
for i in range(y[0].shape[1]):
|
| 115 |
+
anchor = scaled_anchors[i]
|
| 116 |
+
print(anchor.shape)
|
| 117 |
+
print(y[i].shape)
|
| 118 |
+
boxes += cells_to_bboxes(
|
| 119 |
+
y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
|
| 120 |
+
)[0]
|
| 121 |
+
boxes = nms(boxes, iou_threshold=1, threshold=0.7, box_format="midpoint")
|
| 122 |
+
print(boxes)
|
| 123 |
+
plot_image(x[0].permute(1, 2, 0).to("cpu"), boxes)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
test()
|
examples/1.jpg
ADDED
|
examples/2.jpg
ADDED
|
loss.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of Yolo Loss Function similar to the one in Yolov3 paper,
|
| 3 |
+
the difference from what I can tell is I use CrossEntropy for the classes
|
| 4 |
+
instead of BinaryCrossEntropy.
|
| 5 |
+
"""
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from utils import intersection_over_union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class YoloLoss(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.mse = nn.MSELoss()
|
| 17 |
+
self.bce = nn.BCEWithLogitsLoss()
|
| 18 |
+
self.entropy = nn.CrossEntropyLoss()
|
| 19 |
+
self.sigmoid = nn.Sigmoid()
|
| 20 |
+
|
| 21 |
+
# Constants signifying how much to pay for each respective part of the loss
|
| 22 |
+
self.lambda_class = 1
|
| 23 |
+
self.lambda_noobj = 10
|
| 24 |
+
self.lambda_obj = 1
|
| 25 |
+
self.lambda_box = 10
|
| 26 |
+
|
| 27 |
+
def forward(self, predictions, target, anchors):
|
| 28 |
+
# Check where obj and noobj (we ignore if target == -1)
|
| 29 |
+
obj = target[..., 0] == 1 # in paper this is Iobj_i
|
| 30 |
+
noobj = target[..., 0] == 0 # in paper this is Inoobj_i
|
| 31 |
+
|
| 32 |
+
# ======================= #
|
| 33 |
+
# FOR NO OBJECT LOSS #
|
| 34 |
+
# ======================= #
|
| 35 |
+
|
| 36 |
+
no_object_loss = self.bce(
|
| 37 |
+
(predictions[..., 0:1][noobj]), (target[..., 0:1][noobj]),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# ==================== #
|
| 41 |
+
# FOR OBJECT LOSS #
|
| 42 |
+
# ==================== #
|
| 43 |
+
|
| 44 |
+
anchors = anchors.reshape(1, 3, 1, 1, 2)
|
| 45 |
+
box_preds = torch.cat([self.sigmoid(predictions[..., 1:3]), torch.exp(predictions[..., 3:5]) * anchors], dim=-1)
|
| 46 |
+
ious = intersection_over_union(box_preds[obj], target[..., 1:5][obj]).detach()
|
| 47 |
+
object_loss = self.mse(self.sigmoid(predictions[..., 0:1][obj]), ious * target[..., 0:1][obj])
|
| 48 |
+
|
| 49 |
+
# ======================== #
|
| 50 |
+
# FOR BOX COORDINATES #
|
| 51 |
+
# ======================== #
|
| 52 |
+
|
| 53 |
+
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3]) # x,y coordinates
|
| 54 |
+
target[..., 3:5] = torch.log(
|
| 55 |
+
(1e-16 + target[..., 3:5] / anchors)
|
| 56 |
+
) # width, height coordinates
|
| 57 |
+
box_loss = self.mse(predictions[..., 1:5][obj], target[..., 1:5][obj])
|
| 58 |
+
|
| 59 |
+
# ================== #
|
| 60 |
+
# FOR CLASS LOSS #
|
| 61 |
+
# ================== #
|
| 62 |
+
|
| 63 |
+
class_loss = self.entropy(
|
| 64 |
+
(predictions[..., 5:][obj]), (target[..., 5][obj].long()),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
#print("__________________________________")
|
| 68 |
+
#print(self.lambda_box * box_loss)
|
| 69 |
+
#print(self.lambda_obj * object_loss)
|
| 70 |
+
#print(self.lambda_noobj * no_object_loss)
|
| 71 |
+
#print(self.lambda_class * class_loss)
|
| 72 |
+
#print("\n")
|
| 73 |
+
|
| 74 |
+
return (
|
| 75 |
+
self.lambda_box * box_loss
|
| 76 |
+
+ self.lambda_obj * object_loss
|
| 77 |
+
+ self.lambda_noobj * no_object_loss
|
| 78 |
+
+ self.lambda_class * class_loss
|
| 79 |
+
)
|
model.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Implementation of YOLOv3 architecture
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Information about architecture config:
|
| 10 |
+
Tuple is structured by (filters, kernel_size, stride)
|
| 11 |
+
Every conv is a same convolution.
|
| 12 |
+
List is structured by "B" indicating a residual block followed by the number of repeats
|
| 13 |
+
"S" is for scale prediction block and computing the yolo loss
|
| 14 |
+
"U" is for upsampling the feature map and concatenating with a previous layer
|
| 15 |
+
"""
|
| 16 |
+
config = [
|
| 17 |
+
(32, 3, 1),
|
| 18 |
+
(64, 3, 2),
|
| 19 |
+
["B", 1],
|
| 20 |
+
(128, 3, 2),
|
| 21 |
+
["B", 2],
|
| 22 |
+
(256, 3, 2),
|
| 23 |
+
["B", 8],
|
| 24 |
+
(512, 3, 2),
|
| 25 |
+
["B", 8],
|
| 26 |
+
(1024, 3, 2),
|
| 27 |
+
["B", 4], # To this point is Darknet-53
|
| 28 |
+
(512, 1, 1),
|
| 29 |
+
(1024, 3, 1),
|
| 30 |
+
"S1",
|
| 31 |
+
(256, 1, 1),
|
| 32 |
+
"U",
|
| 33 |
+
(256, 1, 1),
|
| 34 |
+
(512, 3, 1),
|
| 35 |
+
"S2",
|
| 36 |
+
(128, 1, 1),
|
| 37 |
+
"U",
|
| 38 |
+
(128, 1, 1),
|
| 39 |
+
(256, 3, 1),
|
| 40 |
+
"S3",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
S=[13,26,52]
|
| 44 |
+
|
| 45 |
+
class CNNBlock(nn.Module):
|
| 46 |
+
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
|
| 49 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 50 |
+
self.leaky = nn.LeakyReLU(0.1)
|
| 51 |
+
self.use_bn_act = bn_act
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
if self.use_bn_act:
|
| 55 |
+
return self.leaky(self.bn(self.conv(x)))
|
| 56 |
+
else:
|
| 57 |
+
return self.conv(x)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ResidualBlock(nn.Module):
|
| 61 |
+
def __init__(self, channels, use_residual=True, num_repeats=1):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.layers = nn.ModuleList()
|
| 64 |
+
for repeat in range(num_repeats):
|
| 65 |
+
self.layers += [
|
| 66 |
+
nn.Sequential(
|
| 67 |
+
CNNBlock(channels, channels // 2, kernel_size=1),
|
| 68 |
+
CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
|
| 69 |
+
)
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
self.use_residual = use_residual
|
| 73 |
+
self.num_repeats = num_repeats
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
for layer in self.layers:
|
| 77 |
+
if self.use_residual:
|
| 78 |
+
x = x + layer(x)
|
| 79 |
+
else:
|
| 80 |
+
x = layer(x)
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class SPPBlock(nn.Module):
|
| 86 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
| 87 |
+
super().__init__()
|
| 88 |
+
c_ = c1 // 2 # Intermediate channels
|
| 89 |
+
self.cv1 = nn.Conv2d(c1, c_, kernel_size=1, stride=1)
|
| 90 |
+
self.pool_layers = nn.ModuleList([
|
| 91 |
+
nn.MaxPool2d(kernel_size=size, stride=1, padding=size // 2) for size in k
|
| 92 |
+
])
|
| 93 |
+
self.cv2 = nn.Conv2d(c_ * (len(k) + 1), c2, kernel_size=1, stride=1)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = self.cv1(x)
|
| 97 |
+
pool_outputs = [layer(x) for layer in self.pool_layers]
|
| 98 |
+
pool_outputs = [x] + pool_outputs
|
| 99 |
+
x = torch.cat(pool_outputs, dim=1)
|
| 100 |
+
x = self.cv2(x)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ScalePrediction(nn.Module):
|
| 105 |
+
def __init__(self, in_channels, num_classes, im_shape):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.im_shape = im_shape
|
| 108 |
+
self.pred = nn.Sequential(
|
| 109 |
+
SPPBlock(in_channels,in_channels),
|
| 110 |
+
nn.AdaptiveMaxPool2d(self.im_shape),
|
| 111 |
+
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
|
| 112 |
+
CNNBlock(
|
| 113 |
+
2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
|
| 114 |
+
),
|
| 115 |
+
|
| 116 |
+
)
|
| 117 |
+
self.num_classes = num_classes
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
|
| 121 |
+
x = self.pred(x)
|
| 122 |
+
return (
|
| 123 |
+
x
|
| 124 |
+
.reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
|
| 125 |
+
.permute(0, 1, 3, 4, 2)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class YOLOv3(nn.Module):
|
| 130 |
+
def __init__(self, in_channels=3, num_classes=80):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.num_classes = num_classes
|
| 133 |
+
self.in_channels = in_channels
|
| 134 |
+
self.layers = self._create_conv_layers()
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
outputs = [] # for each scale
|
| 138 |
+
route_connections = []
|
| 139 |
+
for layer in self.layers:
|
| 140 |
+
if isinstance(layer, ScalePrediction):
|
| 141 |
+
outputs.append(layer(x))
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
x = layer(x)
|
| 145 |
+
|
| 146 |
+
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
|
| 147 |
+
route_connections.append(x)
|
| 148 |
+
|
| 149 |
+
elif isinstance(layer, nn.Upsample):
|
| 150 |
+
x = torch.cat([x, route_connections[-1]], dim=1)
|
| 151 |
+
route_connections.pop()
|
| 152 |
+
|
| 153 |
+
return outputs
|
| 154 |
+
|
| 155 |
+
def _create_conv_layers(self):
|
| 156 |
+
layers = nn.ModuleList()
|
| 157 |
+
in_channels = self.in_channels
|
| 158 |
+
|
| 159 |
+
for module in config:
|
| 160 |
+
if isinstance(module, tuple):
|
| 161 |
+
out_channels, kernel_size, stride = module
|
| 162 |
+
layers.append(
|
| 163 |
+
CNNBlock(
|
| 164 |
+
in_channels,
|
| 165 |
+
out_channels,
|
| 166 |
+
kernel_size=kernel_size,
|
| 167 |
+
stride=stride,
|
| 168 |
+
padding=1 if kernel_size == 3 else 0,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
in_channels = out_channels
|
| 172 |
+
|
| 173 |
+
elif isinstance(module, list):
|
| 174 |
+
num_repeats = module[1]
|
| 175 |
+
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
|
| 176 |
+
|
| 177 |
+
elif isinstance(module, str):
|
| 178 |
+
if module == "S1":
|
| 179 |
+
layers += [
|
| 180 |
+
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
|
| 181 |
+
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
|
| 182 |
+
ScalePrediction(in_channels // 2, num_classes=self.num_classes, im_shape=S[0]),
|
| 183 |
+
]
|
| 184 |
+
in_channels = in_channels // 2
|
| 185 |
+
|
| 186 |
+
if module == "S2":
|
| 187 |
+
layers += [
|
| 188 |
+
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
|
| 189 |
+
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
|
| 190 |
+
ScalePrediction(in_channels // 2, num_classes=self.num_classes, im_shape=S[1]),
|
| 191 |
+
]
|
| 192 |
+
in_channels = in_channels // 2
|
| 193 |
+
|
| 194 |
+
if module == "S3":
|
| 195 |
+
layers += [
|
| 196 |
+
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
|
| 197 |
+
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
|
| 198 |
+
ScalePrediction(in_channels // 2, num_classes=self.num_classes, im_shape=S[2]),
|
| 199 |
+
]
|
| 200 |
+
in_channels = in_channels // 2
|
| 201 |
+
|
| 202 |
+
elif module == "U":
|
| 203 |
+
layers.append(nn.Upsample(scale_factor=2),)
|
| 204 |
+
in_channels = in_channels * 3
|
| 205 |
+
|
| 206 |
+
return layers
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
num_classes = 20
|
| 211 |
+
IMAGE_SIZE = 416
|
| 212 |
+
model = YOLOv3(num_classes=num_classes)
|
| 213 |
+
x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
|
| 214 |
+
out = model(x)
|
| 215 |
+
assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
|
| 216 |
+
assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
|
| 217 |
+
assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
|
| 218 |
+
print("Success!")
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
grad-cam==1.4.8
|
| 2 |
+
gradio==3.39.0
|
| 3 |
+
gradio_client==0.3.0
|
| 4 |
+
numpy==1.22.4
|
| 5 |
+
torchvision
|
| 6 |
+
Pillow==9.4.0
|
| 7 |
+
torch
|
| 8 |
+
pytorch-lightning==2.0.6
|
| 9 |
+
albumentations
|
train.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main file for training Yolo model on Pascal VOC and COCO dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import config
|
| 6 |
+
import torch
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
|
| 9 |
+
from model import YOLOv3
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from utils import (
|
| 12 |
+
mean_average_precision,
|
| 13 |
+
cells_to_bboxes,
|
| 14 |
+
get_evaluation_bboxes,
|
| 15 |
+
save_checkpoint,
|
| 16 |
+
load_checkpoint,
|
| 17 |
+
check_class_accuracy,
|
| 18 |
+
get_loaders,
|
| 19 |
+
plot_couple_examples
|
| 20 |
+
)
|
| 21 |
+
from loss import YoloLoss
|
| 22 |
+
import warnings
|
| 23 |
+
warnings.filterwarnings("ignore")
|
| 24 |
+
|
| 25 |
+
import pytorch_lightning as pl
|
| 26 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 27 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class YOLOv3Lightning(pl.LightningModule):
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.config = config
|
| 34 |
+
self.model = YOLOv3(num_classes=config.NUM_CLASSES)
|
| 35 |
+
self.loss_fn = YoloLoss()
|
| 36 |
+
self.scaled_anchors = (
|
| 37 |
+
torch.tensor(config.ANCHORS)
|
| 38 |
+
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
|
| 39 |
+
).to(config.DEVICE)
|
| 40 |
+
self.train_step_outputs = []
|
| 41 |
+
self.validation_step_outputs = []
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return self.model(x)
|
| 45 |
+
|
| 46 |
+
def get_loss(self, batch):
|
| 47 |
+
x, y = batch
|
| 48 |
+
y0, y1, y2 = (
|
| 49 |
+
y[0],
|
| 50 |
+
y[1],
|
| 51 |
+
y[2],
|
| 52 |
+
)
|
| 53 |
+
out = self(x)
|
| 54 |
+
loss = (
|
| 55 |
+
self.loss_fn(out[0], y0, self.scaled_anchors[0])
|
| 56 |
+
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
|
| 57 |
+
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
|
| 58 |
+
)
|
| 59 |
+
return loss
|
| 60 |
+
|
| 61 |
+
def training_step(self, batch, batch_idx):
|
| 62 |
+
loss = self.get_loss(batch)
|
| 63 |
+
self.log("train/loss", loss, on_epoch=True, prog_bar=True, logger=True) # Logging the training loss for visualization
|
| 64 |
+
self.train_step_outputs.append(loss)
|
| 65 |
+
return loss
|
| 66 |
+
|
| 67 |
+
def on_train_epoch_end(self):
|
| 68 |
+
print(f"\nCurrently epoch {self.current_epoch}")
|
| 69 |
+
train_epoch_average = torch.stack(self.train_step_outputs).mean()
|
| 70 |
+
self.train_step_outputs.clear()
|
| 71 |
+
print(f"Train loss {train_epoch_average}")
|
| 72 |
+
print("On Train loader:")
|
| 73 |
+
class_accuracy, no_obj_accuracy, obj_accuracy = check_class_accuracy(self.model, self.train_loader, threshold=config.CONF_THRESHOLD)
|
| 74 |
+
self.log("train/class_accuracy", class_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 75 |
+
self.log("train/no_obj_accuracy", no_obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 76 |
+
self.log("train/obj_accuracy", obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 77 |
+
|
| 78 |
+
val_epoch_average = torch.stack(self.validation_step_outputs).mean()
|
| 79 |
+
self.validation_step_outputs.clear()
|
| 80 |
+
print(f"Validation loss {val_epoch_average}")
|
| 81 |
+
print("On Train Eval loader:")
|
| 82 |
+
class_accuracy, no_obj_accuracy, obj_accuracy = check_class_accuracy(self.model, self.train_eval_loader, threshold=config.CONF_THRESHOLD)
|
| 83 |
+
self.log("val/class_accuracy", class_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 84 |
+
self.log("val/no_obj_accuracy", no_obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 85 |
+
self.log("val/obj_accuracy", obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 86 |
+
|
| 87 |
+
if (self.current_epoch>0) and ((self.current_epoch+1) % 10 == 0):
|
| 88 |
+
plot_couple_examples(self.model, self.test_loader, 0.6, 0.5, self.scaled_anchors)
|
| 89 |
+
|
| 90 |
+
if (self.current_epoch>0) and (self.current_epoch+1 == 40):
|
| 91 |
+
print("On Test loader:")
|
| 92 |
+
test_class_accuracy, test_no_obj_accuracy, test_obj_accuracy = check_class_accuracy(self.model, self.test_loader, threshold=config.CONF_THRESHOLD)
|
| 93 |
+
self.log("test/class_accuracy", test_class_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 94 |
+
self.log("test/no_obj_accuracy", test_no_obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 95 |
+
self.log("test/obj_accuracy", test_obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
| 96 |
+
pred_boxes, true_boxes = get_evaluation_bboxes(
|
| 97 |
+
self.test_loader,
|
| 98 |
+
self.model,
|
| 99 |
+
iou_threshold=config.NMS_IOU_THRESH,
|
| 100 |
+
anchors=config.ANCHORS,
|
| 101 |
+
threshold=config.CONF_THRESHOLD,
|
| 102 |
+
)
|
| 103 |
+
mapval = mean_average_precision(
|
| 104 |
+
pred_boxes,
|
| 105 |
+
true_boxes,
|
| 106 |
+
iou_threshold=config.MAP_IOU_THRESH,
|
| 107 |
+
box_format="midpoint",
|
| 108 |
+
num_classes=config.NUM_CLASSES,
|
| 109 |
+
)
|
| 110 |
+
print(f"MAP: {mapval.item()}")
|
| 111 |
+
|
| 112 |
+
self.log("MAP", mapval.item(), on_epoch=True, prog_bar=True, logger=True)
|
| 113 |
+
|
| 114 |
+
def validation_step(self, batch, batch_idx):
|
| 115 |
+
loss = self.get_loss(batch)
|
| 116 |
+
self.log("val/loss", loss, on_epoch=True, prog_bar=True, logger=True)
|
| 117 |
+
self.validation_step_outputs.append(loss)
|
| 118 |
+
return loss
|
| 119 |
+
|
| 120 |
+
def configure_optimizers(self):
|
| 121 |
+
optimizer = optim.Adam(
|
| 122 |
+
self.parameters(),
|
| 123 |
+
lr=self.config.LEARNING_RATE,
|
| 124 |
+
weight_decay=self.config.WEIGHT_DECAY,
|
| 125 |
+
)
|
| 126 |
+
self.trainer.fit_loop.setup_data()
|
| 127 |
+
dataloader = self.trainer.train_dataloader
|
| 128 |
+
|
| 129 |
+
EPOCHS = config.NUM_EPOCHS
|
| 130 |
+
lr_scheduler = OneCycleLR(
|
| 131 |
+
optimizer,
|
| 132 |
+
max_lr=1.0E-03,
|
| 133 |
+
steps_per_epoch=len(dataloader),
|
| 134 |
+
epochs=EPOCHS,
|
| 135 |
+
pct_start=5/EPOCHS,
|
| 136 |
+
div_factor=100,
|
| 137 |
+
three_phase=False,
|
| 138 |
+
final_div_factor=100,
|
| 139 |
+
anneal_strategy='linear'
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
scheduler = {"scheduler": lr_scheduler, "interval" : "step"}
|
| 143 |
+
|
| 144 |
+
return [optimizer], [scheduler]
|
| 145 |
+
|
| 146 |
+
def setup(self, stage=None):
|
| 147 |
+
self.train_loader, self.test_loader, self.train_eval_loader = get_loaders(
|
| 148 |
+
train_csv_path=self.config.DATASET + "/train.csv",
|
| 149 |
+
test_csv_path=self.config.DATASET + "/test.csv",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def train_dataloader(self):
|
| 153 |
+
return self.train_loader
|
| 154 |
+
|
| 155 |
+
def val_dataloader(self):
|
| 156 |
+
return self.train_eval_loader
|
| 157 |
+
|
| 158 |
+
def test_dataloader(self):
|
| 159 |
+
return self.test_loader
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
|
| 164 |
+
model = YOLOv3Lightning(config)
|
| 165 |
+
|
| 166 |
+
checkpoint = ModelCheckpoint(filename='last_epoch', save_last=True)
|
| 167 |
+
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch")
|
| 168 |
+
trainer = pl.Trainer(
|
| 169 |
+
max_epochs=config.NUM_EPOCHS,
|
| 170 |
+
deterministic=False,
|
| 171 |
+
logger=True,
|
| 172 |
+
callbacks=[checkpoint, lr_rate_monitor],
|
| 173 |
+
enable_model_summary=False,
|
| 174 |
+
log_every_n_steps=1,
|
| 175 |
+
precision=16
|
| 176 |
+
)
|
| 177 |
+
print("Training Started by Selvaraj Sembulingam")
|
| 178 |
+
trainer.fit(model)
|
| 179 |
+
print("Training Completed by Selvaraj Sembulingam")
|
| 180 |
+
torch.save(model.state_dict(), 'YOLOv3.pth')
|
utils.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import config
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import matplotlib.patches as patches
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from batch_sampler import BatchSampler,RandomSampler,SequentialSampler
|
| 9 |
+
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def iou_width_height(boxes1, boxes2):
|
| 16 |
+
"""
|
| 17 |
+
Parameters:
|
| 18 |
+
boxes1 (tensor): width and height of the first bounding boxes
|
| 19 |
+
boxes2 (tensor): width and height of the second bounding boxes
|
| 20 |
+
Returns:
|
| 21 |
+
tensor: Intersection over union of the corresponding boxes
|
| 22 |
+
"""
|
| 23 |
+
intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
|
| 24 |
+
boxes1[..., 1], boxes2[..., 1]
|
| 25 |
+
)
|
| 26 |
+
union = (
|
| 27 |
+
boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
|
| 28 |
+
)
|
| 29 |
+
return intersection / union
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
|
| 33 |
+
"""
|
| 34 |
+
Video explanation of this function:
|
| 35 |
+
https://youtu.be/XXYG5ZWtjj0
|
| 36 |
+
|
| 37 |
+
This function calculates intersection over union (iou) given pred boxes
|
| 38 |
+
and target boxes.
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
|
| 42 |
+
boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
|
| 43 |
+
box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
tensor: Intersection over union for all examples
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if box_format == "midpoint":
|
| 50 |
+
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
|
| 51 |
+
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
|
| 52 |
+
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
|
| 53 |
+
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
|
| 54 |
+
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
|
| 55 |
+
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
|
| 56 |
+
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
|
| 57 |
+
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
|
| 58 |
+
|
| 59 |
+
if box_format == "corners":
|
| 60 |
+
box1_x1 = boxes_preds[..., 0:1]
|
| 61 |
+
box1_y1 = boxes_preds[..., 1:2]
|
| 62 |
+
box1_x2 = boxes_preds[..., 2:3]
|
| 63 |
+
box1_y2 = boxes_preds[..., 3:4]
|
| 64 |
+
box2_x1 = boxes_labels[..., 0:1]
|
| 65 |
+
box2_y1 = boxes_labels[..., 1:2]
|
| 66 |
+
box2_x2 = boxes_labels[..., 2:3]
|
| 67 |
+
box2_y2 = boxes_labels[..., 3:4]
|
| 68 |
+
|
| 69 |
+
x1 = torch.max(box1_x1, box2_x1)
|
| 70 |
+
y1 = torch.max(box1_y1, box2_y1)
|
| 71 |
+
x2 = torch.min(box1_x2, box2_x2)
|
| 72 |
+
y2 = torch.min(box1_y2, box2_y2)
|
| 73 |
+
|
| 74 |
+
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
| 75 |
+
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
|
| 76 |
+
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
|
| 77 |
+
|
| 78 |
+
return intersection / (box1_area + box2_area - intersection + 1e-6)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
|
| 82 |
+
"""
|
| 83 |
+
Video explanation of this function:
|
| 84 |
+
https://youtu.be/YDkjWEN8jNA
|
| 85 |
+
|
| 86 |
+
Does Non Max Suppression given bboxes
|
| 87 |
+
|
| 88 |
+
Parameters:
|
| 89 |
+
bboxes (list): list of lists containing all bboxes with each bboxes
|
| 90 |
+
specified as [class_pred, prob_score, x1, y1, x2, y2]
|
| 91 |
+
iou_threshold (float): threshold where predicted bboxes is correct
|
| 92 |
+
threshold (float): threshold to remove predicted bboxes (independent of IoU)
|
| 93 |
+
box_format (str): "midpoint" or "corners" used to specify bboxes
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
list: bboxes after performing NMS given a specific IoU threshold
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
assert type(bboxes) == list
|
| 100 |
+
|
| 101 |
+
bboxes = [box for box in bboxes if box[1] > threshold]
|
| 102 |
+
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
|
| 103 |
+
bboxes_after_nms = []
|
| 104 |
+
|
| 105 |
+
while bboxes:
|
| 106 |
+
chosen_box = bboxes.pop(0)
|
| 107 |
+
|
| 108 |
+
bboxes = [
|
| 109 |
+
box
|
| 110 |
+
for box in bboxes
|
| 111 |
+
if box[0] != chosen_box[0]
|
| 112 |
+
or intersection_over_union(
|
| 113 |
+
torch.tensor(chosen_box[2:]),
|
| 114 |
+
torch.tensor(box[2:]),
|
| 115 |
+
box_format=box_format,
|
| 116 |
+
)
|
| 117 |
+
< iou_threshold
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
bboxes_after_nms.append(chosen_box)
|
| 121 |
+
|
| 122 |
+
return bboxes_after_nms
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def mean_average_precision(
|
| 126 |
+
pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Video explanation of this function:
|
| 130 |
+
https://youtu.be/FppOzcDvaDI
|
| 131 |
+
|
| 132 |
+
This function calculates mean average precision (mAP)
|
| 133 |
+
|
| 134 |
+
Parameters:
|
| 135 |
+
pred_boxes (list): list of lists containing all bboxes with each bboxes
|
| 136 |
+
specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
|
| 137 |
+
true_boxes (list): Similar as pred_boxes except all the correct ones
|
| 138 |
+
iou_threshold (float): threshold where predicted bboxes is correct
|
| 139 |
+
box_format (str): "midpoint" or "corners" used to specify bboxes
|
| 140 |
+
num_classes (int): number of classes
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
float: mAP value across all classes given a specific IoU threshold
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
# list storing all AP for respective classes
|
| 147 |
+
average_precisions = []
|
| 148 |
+
|
| 149 |
+
# used for numerical stability later on
|
| 150 |
+
epsilon = 1e-6
|
| 151 |
+
|
| 152 |
+
for c in range(num_classes):
|
| 153 |
+
detections = []
|
| 154 |
+
ground_truths = []
|
| 155 |
+
|
| 156 |
+
# Go through all predictions and targets,
|
| 157 |
+
# and only add the ones that belong to the
|
| 158 |
+
# current class c
|
| 159 |
+
for detection in pred_boxes:
|
| 160 |
+
if detection[1] == c:
|
| 161 |
+
detections.append(detection)
|
| 162 |
+
|
| 163 |
+
for true_box in true_boxes:
|
| 164 |
+
if true_box[1] == c:
|
| 165 |
+
ground_truths.append(true_box)
|
| 166 |
+
|
| 167 |
+
# find the amount of bboxes for each training example
|
| 168 |
+
# Counter here finds how many ground truth bboxes we get
|
| 169 |
+
# for each training example, so let's say img 0 has 3,
|
| 170 |
+
# img 1 has 5 then we will obtain a dictionary with:
|
| 171 |
+
# amount_bboxes = {0:3, 1:5}
|
| 172 |
+
amount_bboxes = Counter([gt[0] for gt in ground_truths])
|
| 173 |
+
|
| 174 |
+
# We then go through each key, val in this dictionary
|
| 175 |
+
# and convert to the following (w.r.t same example):
|
| 176 |
+
# ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
|
| 177 |
+
for key, val in amount_bboxes.items():
|
| 178 |
+
amount_bboxes[key] = torch.zeros(val)
|
| 179 |
+
|
| 180 |
+
# sort by box probabilities which is index 2
|
| 181 |
+
detections.sort(key=lambda x: x[2], reverse=True)
|
| 182 |
+
TP = torch.zeros((len(detections)))
|
| 183 |
+
FP = torch.zeros((len(detections)))
|
| 184 |
+
total_true_bboxes = len(ground_truths)
|
| 185 |
+
|
| 186 |
+
# If none exists for this class then we can safely skip
|
| 187 |
+
if total_true_bboxes == 0:
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
for detection_idx, detection in enumerate(detections):
|
| 191 |
+
# Only take out the ground_truths that have the same
|
| 192 |
+
# training idx as detection
|
| 193 |
+
ground_truth_img = [
|
| 194 |
+
bbox for bbox in ground_truths if bbox[0] == detection[0]
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
num_gts = len(ground_truth_img)
|
| 198 |
+
best_iou = 0
|
| 199 |
+
|
| 200 |
+
for idx, gt in enumerate(ground_truth_img):
|
| 201 |
+
iou = intersection_over_union(
|
| 202 |
+
torch.tensor(detection[3:]),
|
| 203 |
+
torch.tensor(gt[3:]),
|
| 204 |
+
box_format=box_format,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if iou > best_iou:
|
| 208 |
+
best_iou = iou
|
| 209 |
+
best_gt_idx = idx
|
| 210 |
+
|
| 211 |
+
if best_iou > iou_threshold:
|
| 212 |
+
# only detect ground truth detection once
|
| 213 |
+
if amount_bboxes[detection[0]][best_gt_idx] == 0:
|
| 214 |
+
# true positive and add this bounding box to seen
|
| 215 |
+
TP[detection_idx] = 1
|
| 216 |
+
amount_bboxes[detection[0]][best_gt_idx] = 1
|
| 217 |
+
else:
|
| 218 |
+
FP[detection_idx] = 1
|
| 219 |
+
|
| 220 |
+
# if IOU is lower then the detection is a false positive
|
| 221 |
+
else:
|
| 222 |
+
FP[detection_idx] = 1
|
| 223 |
+
|
| 224 |
+
TP_cumsum = torch.cumsum(TP, dim=0)
|
| 225 |
+
FP_cumsum = torch.cumsum(FP, dim=0)
|
| 226 |
+
recalls = TP_cumsum / (total_true_bboxes + epsilon)
|
| 227 |
+
precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
|
| 228 |
+
precisions = torch.cat((torch.tensor([1]), precisions))
|
| 229 |
+
recalls = torch.cat((torch.tensor([0]), recalls))
|
| 230 |
+
# torch.trapz for numerical integration
|
| 231 |
+
average_precisions.append(torch.trapz(precisions, recalls))
|
| 232 |
+
|
| 233 |
+
return sum(average_precisions) / len(average_precisions)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def plot_image(image, boxes):
|
| 237 |
+
"""Plots predicted bounding boxes on the image"""
|
| 238 |
+
cmap = plt.get_cmap("tab20b")
|
| 239 |
+
class_labels = config.COCO_LABELS if config.DATASET=='COCO' else config.PASCAL_CLASSES
|
| 240 |
+
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
|
| 241 |
+
im = np.array(image)
|
| 242 |
+
height, width, _ = im.shape
|
| 243 |
+
|
| 244 |
+
# Create figure and axes
|
| 245 |
+
fig, ax = plt.subplots(1)
|
| 246 |
+
# Display the image
|
| 247 |
+
ax.imshow(im)
|
| 248 |
+
|
| 249 |
+
# box[0] is x midpoint, box[2] is width
|
| 250 |
+
# box[1] is y midpoint, box[3] is height
|
| 251 |
+
|
| 252 |
+
# Create a Rectangle patch
|
| 253 |
+
for box in boxes:
|
| 254 |
+
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
|
| 255 |
+
class_pred = box[0]
|
| 256 |
+
box = box[2:]
|
| 257 |
+
upper_left_x = box[0] - box[2] / 2
|
| 258 |
+
upper_left_y = box[1] - box[3] / 2
|
| 259 |
+
rect = patches.Rectangle(
|
| 260 |
+
(upper_left_x * width, upper_left_y * height),
|
| 261 |
+
box[2] * width,
|
| 262 |
+
box[3] * height,
|
| 263 |
+
linewidth=2,
|
| 264 |
+
edgecolor=colors[int(class_pred)],
|
| 265 |
+
facecolor="none",
|
| 266 |
+
)
|
| 267 |
+
# Add the patch to the Axes
|
| 268 |
+
ax.add_patch(rect)
|
| 269 |
+
plt.text(
|
| 270 |
+
upper_left_x * width,
|
| 271 |
+
upper_left_y * height,
|
| 272 |
+
s=class_labels[int(class_pred)],
|
| 273 |
+
color="white",
|
| 274 |
+
verticalalignment="top",
|
| 275 |
+
bbox={"color": colors[int(class_pred)], "pad": 0},
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
plt.show()
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_evaluation_bboxes(
|
| 282 |
+
loader,
|
| 283 |
+
model,
|
| 284 |
+
iou_threshold,
|
| 285 |
+
anchors,
|
| 286 |
+
threshold,
|
| 287 |
+
box_format="midpoint",
|
| 288 |
+
device="cuda",
|
| 289 |
+
):
|
| 290 |
+
# make sure model is in eval before get bboxes
|
| 291 |
+
model.eval()
|
| 292 |
+
train_idx = 0
|
| 293 |
+
all_pred_boxes = []
|
| 294 |
+
all_true_boxes = []
|
| 295 |
+
for batch_idx, (x, labels) in enumerate(loader):
|
| 296 |
+
x = x.to(device)
|
| 297 |
+
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
predictions = model(x)
|
| 300 |
+
|
| 301 |
+
batch_size = x.shape[0]
|
| 302 |
+
bboxes = [[] for _ in range(batch_size)]
|
| 303 |
+
for i in range(3):
|
| 304 |
+
S = predictions[i].shape[2]
|
| 305 |
+
anchor = torch.tensor([*anchors[i]]).to(device) * S
|
| 306 |
+
boxes_scale_i = cells_to_bboxes(
|
| 307 |
+
predictions[i], anchor, S=S, is_preds=True
|
| 308 |
+
)
|
| 309 |
+
for idx, (box) in enumerate(boxes_scale_i):
|
| 310 |
+
bboxes[idx] += box
|
| 311 |
+
|
| 312 |
+
# we just want one bbox for each label, not one for each scale
|
| 313 |
+
true_bboxes = cells_to_bboxes(
|
| 314 |
+
labels[2], anchor, S=S, is_preds=False
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
for idx in range(batch_size):
|
| 318 |
+
nms_boxes = non_max_suppression(
|
| 319 |
+
bboxes[idx],
|
| 320 |
+
iou_threshold=iou_threshold,
|
| 321 |
+
threshold=threshold,
|
| 322 |
+
box_format=box_format,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
for nms_box in nms_boxes:
|
| 326 |
+
all_pred_boxes.append([train_idx] + nms_box)
|
| 327 |
+
|
| 328 |
+
for box in true_bboxes[idx]:
|
| 329 |
+
if box[1] > threshold:
|
| 330 |
+
all_true_boxes.append([train_idx] + box)
|
| 331 |
+
|
| 332 |
+
train_idx += 1
|
| 333 |
+
|
| 334 |
+
model.train()
|
| 335 |
+
return all_pred_boxes, all_true_boxes
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def cells_to_bboxes(predictions, anchors, S, is_preds=True):
|
| 339 |
+
"""
|
| 340 |
+
Scales the predictions coming from the model to
|
| 341 |
+
be relative to the entire image such that they for example later
|
| 342 |
+
can be plotted or.
|
| 343 |
+
INPUT:
|
| 344 |
+
predictions: tensor of size (N, 3, S, S, num_classes+5)
|
| 345 |
+
anchors: the anchors used for the predictions
|
| 346 |
+
S: the number of cells the image is divided in on the width (and height)
|
| 347 |
+
is_preds: whether the input is predictions or the true bounding boxes
|
| 348 |
+
OUTPUT:
|
| 349 |
+
converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
|
| 350 |
+
object score, bounding box coordinates
|
| 351 |
+
"""
|
| 352 |
+
BATCH_SIZE = predictions.shape[0]
|
| 353 |
+
num_anchors = len(anchors)
|
| 354 |
+
box_predictions = predictions[..., 1:5]
|
| 355 |
+
if is_preds:
|
| 356 |
+
anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
|
| 357 |
+
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
|
| 358 |
+
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
|
| 359 |
+
scores = torch.sigmoid(predictions[..., 0:1])
|
| 360 |
+
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
|
| 361 |
+
else:
|
| 362 |
+
scores = predictions[..., 0:1]
|
| 363 |
+
best_class = predictions[..., 5:6]
|
| 364 |
+
|
| 365 |
+
cell_indices = (
|
| 366 |
+
torch.arange(S)
|
| 367 |
+
.repeat(predictions.shape[0], 3, S, 1)
|
| 368 |
+
.unsqueeze(-1)
|
| 369 |
+
.to(predictions.device)
|
| 370 |
+
)
|
| 371 |
+
x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
|
| 372 |
+
y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
|
| 373 |
+
w_h = 1 / S * box_predictions[..., 2:4]
|
| 374 |
+
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
|
| 375 |
+
return converted_bboxes.tolist()
|
| 376 |
+
|
| 377 |
+
def check_class_accuracy(model, loader, threshold):
|
| 378 |
+
model.eval()
|
| 379 |
+
tot_class_preds, correct_class = 0, 0
|
| 380 |
+
tot_noobj, correct_noobj = 0, 0
|
| 381 |
+
tot_obj, correct_obj = 0, 0
|
| 382 |
+
|
| 383 |
+
for idx, (x, y) in enumerate(loader):
|
| 384 |
+
x = x.to(config.DEVICE)
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
out = model(x)
|
| 387 |
+
|
| 388 |
+
for i in range(3):
|
| 389 |
+
y[i] = y[i].to(config.DEVICE)
|
| 390 |
+
obj = y[i][..., 0] == 1 # in paper this is Iobj_i
|
| 391 |
+
noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
|
| 392 |
+
|
| 393 |
+
correct_class += torch.sum(
|
| 394 |
+
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
|
| 395 |
+
)
|
| 396 |
+
tot_class_preds += torch.sum(obj)
|
| 397 |
+
|
| 398 |
+
obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
|
| 399 |
+
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
|
| 400 |
+
tot_obj += torch.sum(obj)
|
| 401 |
+
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
|
| 402 |
+
tot_noobj += torch.sum(noobj)
|
| 403 |
+
|
| 404 |
+
print(f"Class accuracy is: {(correct_class/(tot_class_preds+1e-16))*100:2f}%")
|
| 405 |
+
print(f"No obj accuracy is: {(correct_noobj/(tot_noobj+1e-16))*100:2f}%")
|
| 406 |
+
print(f"Obj accuracy is: {(correct_obj/(tot_obj+1e-16))*100:2f}%")
|
| 407 |
+
model.train()
|
| 408 |
+
|
| 409 |
+
return (correct_class/(tot_class_preds+1e-16))*100, (correct_noobj/(tot_noobj+1e-16))*100, (correct_obj/(tot_obj+1e-16))*100
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def get_mean_std(loader):
|
| 413 |
+
# var[X] = E[X**2] - E[X]**2
|
| 414 |
+
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
|
| 415 |
+
|
| 416 |
+
for data, _ in loader:
|
| 417 |
+
channels_sum += torch.mean(data, dim=[0, 2, 3])
|
| 418 |
+
channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
|
| 419 |
+
num_batches += 1
|
| 420 |
+
|
| 421 |
+
mean = channels_sum / num_batches
|
| 422 |
+
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
|
| 423 |
+
|
| 424 |
+
return mean, std
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
|
| 428 |
+
print("=> Saving checkpoint")
|
| 429 |
+
checkpoint = {
|
| 430 |
+
"state_dict": model.state_dict(),
|
| 431 |
+
"optimizer": optimizer.state_dict(),
|
| 432 |
+
}
|
| 433 |
+
torch.save(checkpoint, filename)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def load_checkpoint(checkpoint_file, model, optimizer, lr):
|
| 437 |
+
print("=> Loading checkpoint")
|
| 438 |
+
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
|
| 439 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 440 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 441 |
+
|
| 442 |
+
# If we don't do this then it will just have learning rate of old checkpoint
|
| 443 |
+
# and it will lead to many hours of debugging \:
|
| 444 |
+
for param_group in optimizer.param_groups:
|
| 445 |
+
param_group["lr"] = lr
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def get_loaders(train_csv_path, test_csv_path):
|
| 449 |
+
from dataset import YOLODataset
|
| 450 |
+
|
| 451 |
+
IMAGE_SIZE = config.IMAGE_SIZE
|
| 452 |
+
train_dataset = YOLODataset(
|
| 453 |
+
train_csv_path,
|
| 454 |
+
transform=config.train_transforms,
|
| 455 |
+
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
|
| 456 |
+
img_dir=config.IMG_DIR,
|
| 457 |
+
label_dir=config.LABEL_DIR,
|
| 458 |
+
anchors=config.ANCHORS,
|
| 459 |
+
)
|
| 460 |
+
test_dataset = YOLODataset(
|
| 461 |
+
test_csv_path,
|
| 462 |
+
transform=config.test_transforms,
|
| 463 |
+
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
|
| 464 |
+
img_dir=config.IMG_DIR,
|
| 465 |
+
label_dir=config.LABEL_DIR,
|
| 466 |
+
anchors=config.ANCHORS,
|
| 467 |
+
)
|
| 468 |
+
train_loader = DataLoader(
|
| 469 |
+
dataset=train_dataset,
|
| 470 |
+
batch_sampler= BatchSampler(RandomSampler(train_dataset),
|
| 471 |
+
batch_size=config.BATCH_SIZE,
|
| 472 |
+
drop_last=False,
|
| 473 |
+
multiscale_step=1,
|
| 474 |
+
img_sizes=list(range(320, 608 + 1, 32))
|
| 475 |
+
),
|
| 476 |
+
num_workers=config.NUM_WORKERS,
|
| 477 |
+
pin_memory=config.PIN_MEMORY,
|
| 478 |
+
)
|
| 479 |
+
test_loader = DataLoader(
|
| 480 |
+
dataset=test_dataset,
|
| 481 |
+
batch_size=config.BATCH_SIZE,
|
| 482 |
+
num_workers=config.NUM_WORKERS,
|
| 483 |
+
pin_memory=config.PIN_MEMORY,
|
| 484 |
+
shuffle=False,
|
| 485 |
+
drop_last=False,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
train_eval_dataset = YOLODataset(
|
| 489 |
+
train_csv_path,
|
| 490 |
+
transform=config.test_transforms,
|
| 491 |
+
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
|
| 492 |
+
img_dir=config.IMG_DIR,
|
| 493 |
+
label_dir=config.LABEL_DIR,
|
| 494 |
+
anchors=config.ANCHORS,
|
| 495 |
+
)
|
| 496 |
+
train_eval_loader = DataLoader(
|
| 497 |
+
dataset=train_eval_dataset,
|
| 498 |
+
batch_size=config.BATCH_SIZE,
|
| 499 |
+
num_workers=config.NUM_WORKERS,
|
| 500 |
+
pin_memory=config.PIN_MEMORY,
|
| 501 |
+
shuffle=False,
|
| 502 |
+
drop_last=False,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
return train_loader, test_loader, train_eval_loader
|
| 506 |
+
|
| 507 |
+
def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
|
| 508 |
+
model.eval()
|
| 509 |
+
x, y = next(iter(loader))
|
| 510 |
+
x = x.to("cuda")
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
out = model(x)
|
| 513 |
+
bboxes = [[] for _ in range(x.shape[0])]
|
| 514 |
+
for i in range(3):
|
| 515 |
+
batch_size, A, S, _, _ = out[i].shape
|
| 516 |
+
anchor = anchors[i]
|
| 517 |
+
boxes_scale_i = cells_to_bboxes(
|
| 518 |
+
out[i], anchor, S=S, is_preds=True
|
| 519 |
+
)
|
| 520 |
+
for idx, (box) in enumerate(boxes_scale_i):
|
| 521 |
+
bboxes[idx] += box
|
| 522 |
+
|
| 523 |
+
model.train()
|
| 524 |
+
|
| 525 |
+
for i in range(batch_size//4):
|
| 526 |
+
nms_boxes = non_max_suppression(
|
| 527 |
+
bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
|
| 528 |
+
)
|
| 529 |
+
plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def seed_everything(seed=42):
|
| 534 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 535 |
+
random.seed(seed)
|
| 536 |
+
np.random.seed(seed)
|
| 537 |
+
torch.manual_seed(seed)
|
| 538 |
+
torch.cuda.manual_seed(seed)
|
| 539 |
+
torch.cuda.manual_seed_all(seed)
|
| 540 |
+
torch.backends.cudnn.deterministic = True
|
| 541 |
+
torch.backends.cudnn.benchmark = False
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def clip_coords(boxes, img_shape):
|
| 545 |
+
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
| 546 |
+
boxes[:, 0].clamp_(0, img_shape[1]) # x1
|
| 547 |
+
boxes[:, 1].clamp_(0, img_shape[0]) # y1
|
| 548 |
+
boxes[:, 2].clamp_(0, img_shape[1]) # x2
|
| 549 |
+
boxes[:, 3].clamp_(0, img_shape[0]) # y2
|
| 550 |
+
|
| 551 |
+
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
| 552 |
+
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 553 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 554 |
+
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
| 555 |
+
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
| 556 |
+
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
| 557 |
+
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
|
| 558 |
+
return y
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
| 562 |
+
# Convert normalized segments into pixel segments, shape (n,2)
|
| 563 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 564 |
+
y[..., 0] = w * x[..., 0] + padw # top left x
|
| 565 |
+
y[..., 1] = h * x[..., 1] + padh # top left y
|
| 566 |
+
return y
|
| 567 |
+
|
| 568 |
+
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
| 569 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
| 570 |
+
if clip:
|
| 571 |
+
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
| 572 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 573 |
+
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
| 574 |
+
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
| 575 |
+
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
| 576 |
+
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
|
| 577 |
+
return y
|
| 578 |
+
|
| 579 |
+
def clip_boxes(boxes, shape):
|
| 580 |
+
# Clip boxes (xyxy) to image shape (height, width)
|
| 581 |
+
if isinstance(boxes, torch.Tensor): # faster individually
|
| 582 |
+
boxes[..., 0].clamp_(0, shape[1]) # x1
|
| 583 |
+
boxes[..., 1].clamp_(0, shape[0]) # y1
|
| 584 |
+
boxes[..., 2].clamp_(0, shape[1]) # x2
|
| 585 |
+
boxes[..., 3].clamp_(0, shape[0]) # y2
|
| 586 |
+
else: # np.array (faster grouped)
|
| 587 |
+
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
| 588 |
+
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|