yolov3 / utils.py
reputation's picture
Update utils.py
d39299a verified
import config
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import os
import random
import torch
from collections import Counter
from torch.utils.data import DataLoader
from tqdm import tqdm
def iou_width_height(boxes1, boxes2):
intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
boxes1[..., 1], boxes2[..., 1]
)
union = (
boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
)
return intersection / union
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
if box_format == "midpoint":
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
if box_format == "corners":
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4]
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection / (box1_area + box2_area - intersection + 1e-6)
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
assert type(bboxes) == list
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
def mean_average_precision(
pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
average_precisions = []
epsilon = 1e-6
for c in range(num_classes):
detections = []
ground_truths = []
for detection in pred_boxes:
if detection[1] == c:
detections.append(detection)
for true_box in true_boxes:
if true_box[1] == c:
ground_truths.append(true_box)
amount_bboxes = Counter([gt[0] for gt in ground_truths])
for key, val in amount_bboxes.items():
amount_bboxes[key] = torch.zeros(val)
detections.sort(key=lambda x: x[2], reverse=True)
TP = torch.zeros((len(detections)))
FP = torch.zeros((len(detections)))
total_true_bboxes = len(ground_truths)
if total_true_bboxes == 0:
continue
for detection_idx, detection in enumerate(detections):
ground_truth_img = [
bbox for bbox in ground_truths if bbox[0] == detection[0]
]
num_gts = len(ground_truth_img)
best_iou = 0
for idx, gt in enumerate(ground_truth_img):
iou = intersection_over_union(
torch.tensor(detection[3:]),
torch.tensor(gt[3:]),
box_format=box_format,
)
if iou > best_iou:
best_iou = iou
best_gt_idx = idx
if best_iou > iou_threshold:
if amount_bboxes[detection[0]][best_gt_idx] == 0:
TP[detection_idx] = 1
amount_bboxes[detection[0]][best_gt_idx] = 1
else:
FP[detection_idx] = 1
else:
FP[detection_idx] = 1
TP_cumsum = torch.cumsum(TP, dim=0)
FP_cumsum = torch.cumsum(FP, dim=0)
recalls = TP_cumsum / (total_true_bboxes + epsilon)
precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
precisions = torch.cat((torch.tensor([1]), precisions))
recalls = torch.cat((torch.tensor([0]), recalls))
# torch.trapz for numerical integration
average_precisions.append(torch.trapz(precisions, recalls))
return sum(average_precisions) / len(average_precisions)
def plot_image(image, boxes):
cmap = plt.get_cmap("tab20b")
class_labels = config.COCO_LABELS if config.DATASET == 'COCO' else config.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
fig, ax = plt.subplots(1)
ax.imshow(im)
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
plt.show()
def get_evaluation_bboxes(
loader,
model,
iou_threshold,
anchors,
threshold,
box_format="midpoint",
device="cuda",
):
model.eval()
train_idx = 0
all_pred_boxes = []
all_true_boxes = []
for batch_idx, (x, labels) in enumerate(tqdm(loader)):
x = x.to(device)
with torch.no_grad():
predictions = model(x)
batch_size = x.shape[0]
bboxes = [[] for _ in range(batch_size)]
for i in range(3):
S = predictions[i].shape[2]
anchor = torch.tensor([*anchors[i]]).to(device) * S
boxes_scale_i = cells_to_bboxes(
predictions[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
true_bboxes = cells_to_bboxes(
labels[2], anchor, S=S, is_preds=False
)
for idx in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[idx],
iou_threshold=iou_threshold,
threshold=threshold,
box_format=box_format,
)
for nms_box in nms_boxes:
all_pred_boxes.append([train_idx] + nms_box)
for box in true_bboxes[idx]:
if box[1] > threshold:
all_true_boxes.append([train_idx] + box)
train_idx += 1
model.train()
return all_pred_boxes, all_true_boxes
def cells_to_bboxes(predictions, anchors, S, is_preds=True):
BATCH_SIZE = predictions.shape[0]
num_anchors = len(anchors)
box_predictions = predictions[..., 1:5]
if is_preds:
anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
scores = torch.sigmoid(predictions[..., 0:1])
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
else:
scores = predictions[..., 0:1]
best_class = predictions[..., 5:6]
cell_indices = (
torch.arange(S)
.repeat(predictions.shape[0], 3, S, 1)
.unsqueeze(-1)
.to(predictions.device)
)
x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
w_h = 1 / S * box_predictions[..., 2:4]
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
return converted_bboxes.tolist()
def check_class_accuracy(model, loader, threshold):
model.eval()
tot_class_preds, correct_class = 0, 0
tot_noobj, correct_noobj = 0, 0
tot_obj, correct_obj = 0, 0
for idx, (x, y) in enumerate(tqdm(loader)):
x = x.to(config.DEVICE)
with torch.no_grad():
out = model(x)
for i in range(3):
y[i] = y[i].to(config.DEVICE)
obj = y[i][..., 0] == 1
noobj = y[i][..., 0] == 0
correct_class += torch.sum(
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
)
tot_class_preds += torch.sum(obj)
obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
tot_obj += torch.sum(obj)
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
tot_noobj += torch.sum(noobj)
print(f"Class accuracy is: {(correct_class / (tot_class_preds + 1e-16)) * 100:2f}%")
print(f"No obj accuracy is: {(correct_noobj / (tot_noobj + 1e-16)) * 100:2f}%")
print(f"Obj accuracy is: {(correct_obj / (tot_obj + 1e-16)) * 100:2f}%")
model.train()
def get_mean_std(loader):
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
for data, _ in tqdm(loader):
channels_sum += torch.mean(data, dim=[0, 2, 3])
channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
num_batches += 1
mean = channels_sum / num_batches
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
return mean, std
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def get_loaders(train_csv_path, test_csv_path):
from dataset import YOLODataset
IMAGE_SIZE = config.IMAGE_SIZE
train_dataset = YOLODataset(
train_csv_path,
transform=config.train_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
)
test_dataset = YOLODataset(
test_csv_path,
transform=config.test_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=True,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False,
drop_last=False,
)
train_eval_dataset = YOLODataset(
train_csv_path,
transform=config.test_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
anchors=config.ANCHORS,
)
train_eval_loader = DataLoader(
dataset=train_eval_dataset,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False,
drop_last=False,
)
return train_loader, test_loader, train_eval_loader
def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
model.eval()
x, y = next(iter(loader))
x = x.to(config.DEVICE)
with torch.no_grad():
out = model(x)
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
model.train()
for i in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
)
plot_image(x[i].permute(1, 2, 0).detach().cpu(), nms_boxes)
def seed_everything(seed=42):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False