yolo1_from_scratch / utils.py
nathbns's picture
Upload 6 files
6698bc8 verified
raw
history blame
11.5 kB
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import Counter
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
if box_format == "midpoint":
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
if box_format == "corners":
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4] # (N, 1)
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
# .clamp(0) is for the case when they do not intersect
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection / (box1_area + box2_area - intersection + 1e-6)
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
"""
Does Non Max Suppression given bboxes
Parameters:
bboxes (list): list of lists containing all bboxes with each bboxes
specified as [class_pred, prob_score, x1, y1, x2, y2]
iou_threshold (float): threshold where predicted bboxes is correct
threshold (float): threshold to remove predicted bboxes (independent of IoU)
box_format (str): "midpoint" or "corners" used to specify bboxes
Returns:
list: bboxes after performing NMS given a specific IoU threshold
"""
assert type(bboxes) == list
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
def mean_average_precision(
pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
"""
Calculates mean average precision
Parameters:
pred_boxes (list): list of lists containing all bboxes with each bboxes
specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
true_boxes (list): Similar as pred_boxes except all the correct ones
iou_threshold (float): threshold where predicted bboxes is correct
box_format (str): "midpoint" or "corners" used to specify bboxes
num_classes (int): number of classes
Returns:
float: mAP value across all classes given a specific IoU threshold
"""
# list storing all AP for respective classes
average_precisions = []
# used for numerical stability later on
epsilon = 1e-6
for c in range(num_classes):
detections = []
ground_truths = []
# Go through all predictions and targets,
# and only add the ones that belong to the
# current class c
for detection in pred_boxes:
if detection[1] == c:
detections.append(detection)
for true_box in true_boxes:
if true_box[1] == c:
ground_truths.append(true_box)
# find the amount of bboxes for each training example
# Counter here finds how many ground truth bboxes we get
# for each training example, so let's say img 0 has 3,
# img 1 has 5 then we will obtain a dictionary with:
# amount_bboxes = {0:3, 1:5}
amount_bboxes = Counter([gt[0] for gt in ground_truths])
# We then go through each key, val in this dictionary
# and convert to the following (w.r.t same example):
# ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
for key, val in amount_bboxes.items():
amount_bboxes[key] = torch.zeros(val)
# sort by box probabilities which is index 2
detections.sort(key=lambda x: x[2], reverse=True)
TP = torch.zeros((len(detections)))
FP = torch.zeros((len(detections)))
total_true_bboxes = len(ground_truths)
# If none exists for this class then we can safely skip
if total_true_bboxes == 0:
continue
for detection_idx, detection in enumerate(detections):
# Only take out the ground_truths that have the same
# training idx as detection
ground_truth_img = [
bbox for bbox in ground_truths if bbox[0] == detection[0]
]
num_gts = len(ground_truth_img)
best_iou = 0
for idx, gt in enumerate(ground_truth_img):
iou = intersection_over_union(
torch.tensor(detection[3:]),
torch.tensor(gt[3:]),
box_format=box_format,
)
if iou > best_iou:
best_iou = iou
best_gt_idx = idx
if best_iou > iou_threshold:
# only detect ground truth detection once
if amount_bboxes[detection[0]][best_gt_idx] == 0:
# true positive and add this bounding box to seen
TP[detection_idx] = 1
amount_bboxes[detection[0]][best_gt_idx] = 1
else:
FP[detection_idx] = 1
# if IOU is lower then the detection is a false positive
else:
FP[detection_idx] = 1
TP_cumsum = torch.cumsum(TP, dim=0)
FP_cumsum = torch.cumsum(FP, dim=0)
recalls = TP_cumsum / (total_true_bboxes + epsilon)
precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
precisions = torch.cat((torch.tensor([1]), precisions))
recalls = torch.cat((torch.tensor([0]), recalls))
# torch.trapz for numerical integration
average_precisions.append(torch.trapz(precisions, recalls))
return sum(average_precisions) / len(average_precisions)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# box[0] is x midpoint, box[2] is width
# box[1] is y midpoint, box[3] is height
# Create a Rectangle potch
for box in boxes:
box = box[2:]
assert len(box) == 4, "Got more values than in x, y, w, h, in a box!"
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=1,
edgecolor="r",
facecolor="none",
)
# Add the patch to the Axes
ax.add_patch(rect)
plt.show()
def get_bboxes(
loader,
model,
iou_threshold,
threshold,
pred_format="cells",
box_format="midpoint",
device="cuda",
):
all_pred_boxes = []
all_true_boxes = []
# make sure model is in eval before get bboxes
model.eval()
train_idx = 0
for batch_idx, (x, labels) in enumerate(loader):
x = x.to(device)
labels = labels.to(device)
with torch.no_grad():
predictions = model(x)
batch_size = x.shape[0]
true_bboxes = cellboxes_to_boxes(labels)
bboxes = cellboxes_to_boxes(predictions)
for idx in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[idx],
iou_threshold=iou_threshold,
threshold=threshold,
box_format=box_format,
)
#if batch_idx == 0 and idx == 0:
# plot_image(x[idx].permute(1,2,0).to("cpu"), nms_boxes)
# print(nms_boxes)
for nms_box in nms_boxes:
all_pred_boxes.append([train_idx] + nms_box)
for box in true_bboxes[idx]:
# many will get converted to 0 pred
if box[1] > threshold:
all_true_boxes.append([train_idx] + box)
train_idx += 1
model.train()
return all_pred_boxes, all_true_boxes
def convert_cellboxes(predictions, S=7):
"""
Converts bounding boxes output from Yolo with
an image split size of S into entire image ratios
rather than relative to cell ratios. Tried to do this
vectorized, but this resulted in quite difficult to read
code... Use as a black box? Or implement a more intuitive,
using 2 for loops iterating range(S) and convert them one
by one, resulting in a slower but more readable implementation.
"""
predictions = predictions.to("cpu")
batch_size = predictions.shape[0]
predictions = predictions.reshape(batch_size, 7, 7, 30)
bboxes1 = predictions[..., 21:25]
bboxes2 = predictions[..., 26:30]
scores = torch.cat(
(predictions[..., 20].unsqueeze(0), predictions[..., 25].unsqueeze(0)), dim=0
)
best_box = scores.argmax(0).unsqueeze(-1)
best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2
cell_indices = torch.arange(7).repeat(batch_size, 7, 1).unsqueeze(-1)
x = 1 / S * (best_boxes[..., :1] + cell_indices)
y = 1 / S * (best_boxes[..., 1:2] + cell_indices.permute(0, 2, 1, 3))
w_y = 1 / S * best_boxes[..., 2:4]
converted_bboxes = torch.cat((x, y, w_y), dim=-1)
predicted_class = predictions[..., :20].argmax(-1).unsqueeze(-1)
best_confidence = torch.max(predictions[..., 20], predictions[..., 25]).unsqueeze(
-1
)
converted_preds = torch.cat(
(predicted_class, best_confidence, converted_bboxes), dim=-1
)
return converted_preds
def cellboxes_to_boxes(out, S=7):
converted_pred = convert_cellboxes(out).reshape(out.shape[0], S * S, -1)
converted_pred[..., 0] = converted_pred[..., 0].long()
all_bboxes = []
for ex_idx in range(out.shape[0]):
bboxes = []
for bbox_idx in range(S * S):
bboxes.append([x.item() for x in converted_pred[ex_idx, bbox_idx, :]])
all_bboxes.append(bboxes)
return all_bboxes
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])