Spaces:
Build error
Build error
| import numpy as np | |
| import torchvision.transforms.functional as F | |
| from torchvision import transforms | |
| from typing import DefaultDict | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import torch | |
| import logging | |
| from torchvision.utils import draw_bounding_boxes | |
| matplotlib.style.use('ggplot') | |
| logging.getLogger('matplotlib').setLevel(logging.CRITICAL) | |
| logging.getLogger('PIL').setLevel(logging.CRITICAL) | |
| def save_plot(train_loss_list, label, output_dir): | |
| """ | |
| Function to save the loss plot to disk. | |
| """ | |
| # Loss plots. | |
| plt.figure(figsize=(10, 7)) | |
| plt.plot( | |
| train_loss_list, linestyle='-', | |
| label=label | |
| ) | |
| plt.xlabel('Epochs') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.savefig(f"{output_dir}/{label}.png") | |
| def save_train_loss_plot(train_loss_dict: DefaultDict, output_dir): | |
| """ | |
| Function to save the loss plots to disk. | |
| """ | |
| for key in train_loss_dict.keys(): | |
| save_plot(train_loss_dict[key], key, output_dir) | |
| def show(imgs): | |
| if not isinstance(imgs, list): | |
| imgs = [imgs] | |
| fig, axs = plt.subplots(nrows=len(imgs), ncols=1, | |
| figsize=(45, 21), squeeze=False) | |
| for i, img in enumerate(imgs): | |
| img = img.detach() | |
| img = F.to_pil_image(img) | |
| img = np.asarray(img) | |
| axs[i, 0].imshow(img) | |
| axs[i, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) | |
| plt.show() | |
| def plot_img_tensor(img_tensor): | |
| transforms.ToPILImage()(img_tensor).show() | |
| def show_img(data_loader, model, device, th=0.7): | |
| for imgs, target in data_loader: | |
| with torch.no_grad(): | |
| prediction = model([imgs[0].to(device)])[0] | |
| plot_img_tensor(add_bbox(imgs[0], prediction, th)) | |
| plot_img_tensor(add_bbox(imgs[0], target[0]['boxes'])) | |
| break | |
| def add_bbox(img, output, th=None): | |
| img_canvas = img.clone() | |
| img_canvas = torch.clip(img*255, 0, 255) | |
| img_canvas = img_canvas.type(torch.uint8) | |
| if th == None: | |
| img_with_bbbox = draw_bounding_boxes( | |
| img_canvas, boxes=output, width=4) | |
| else: | |
| mask = (output["scores"] > th) & (output["labels"] == 1) | |
| scores_list = [score for score in ( | |
| output["scores"][mask]).tolist()] | |
| labels_list = [str(label) for label in ( | |
| output["labels"][mask]).tolist()] | |
| labels = ["person" for label in labels_list if label == "1"] | |
| assert len(labels) == len(scores_list) == len(labels_list) | |
| for i in range(0, len(labels)): | |
| labels[i] = f"{labels[i]}:{scores_list[i]:.3f}" | |
| img_with_bbbox = draw_bounding_boxes( | |
| img_canvas, boxes=output["boxes"][mask], labels=labels, width=4) | |
| return img_with_bbbox | |