Spaces:
Sleeping
Sleeping
| import torch | |
| from plaus_functs import get_center_coords, get_distance_grids, get_plaus_loss, get_bbox_map, normalize_batch | |
| from plot_functs import imshow | |
| from torchvision.transforms.functional import gaussian_blur | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import cv2 | |
| def subfigimshow(img, ax): | |
| print(f'img shape: {img.shape}') | |
| try: | |
| npimg = img.clone().detach().cpu().numpy() | |
| except: | |
| npimg = img | |
| if len(npimg.shape) == 2: | |
| # If it's a 2D array, it's likely a grayscale image | |
| ax.imshow(npimg, cmap='gray') | |
| elif len(npimg.shape) == 3: | |
| if npimg.shape[0] == 3 or npimg.shape[0] == 1: | |
| # If the first dimension is 3 or 1, it's likely in (C, H, W) format | |
| tpimg = np.transpose(npimg, (1, 2, 0)) | |
| else: | |
| # It's already in (H, W, C) format | |
| tpimg = npimg | |
| if tpimg.shape[2] == 1: | |
| # If it's a 3D array with only one channel, squeeze it | |
| ax.imshow(np.squeeze(tpimg), cmap='gray') | |
| else: | |
| ax.imshow(tpimg) | |
| else: | |
| raise ValueError(f"Unexpected image shape: {npimg.shape}") | |
| def draw_bounding_boxes(image, boxes, color=(0, 255, 0), thickness=2): | |
| # Ensure image is 3-channel RGB | |
| if len(image.shape) == 2: | |
| image = np.stack([image] * 3, axis=-1) | |
| elif len(image.shape) == 3 and image.shape[2] == 1: | |
| image = np.repeat(image, 3, axis=2) | |
| # Ensure image is uint8 and in range [0, 255] | |
| if image.dtype != np.uint8: | |
| image = (image * 255).clip(0, 255).astype(np.uint8) | |
| image_with_boxes = image.copy() | |
| for box in boxes: | |
| x_center, y_center, width, height = box | |
| x_min = int((x_center - width / 2) * image_with_boxes.shape[1]) | |
| y_min = int((y_center - height / 2) * image_with_boxes.shape[0]) | |
| x_max = int((x_center + width / 2) * image_with_boxes.shape[1]) | |
| y_max = int((y_center + height / 2) * image_with_boxes.shape[0]) | |
| cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), color, thickness) | |
| return image_with_boxes | |
| def toy_problem(pgt_coeff, focus_coeff, x_coord, y_coord, num_bb=0, alpha=200.0, scheduler=2.0, device="0", dist_coeff=0.5, dist_reg_only=True, iou_coeff=0.5, | |
| bbox_coeff=0.0, dist_x_bbox=False, iou_loss_only=False, show_dist_reg=True): | |
| # Create a Namespace object to hold params | |
| opt = argparse.Namespace() | |
| # Save all parameters as attributes of the Namespace object | |
| opt.pgt_coeff = pgt_coeff | |
| opt.focus_coeff = focus_coeff | |
| opt.x_coord = x_coord | |
| opt.y_coord = y_coord | |
| opt.num_bb = num_bb | |
| opt.alpha = alpha | |
| opt.scheduler = scheduler | |
| opt.device = device | |
| opt.dist_coeff = dist_coeff | |
| opt.dist_reg_only = dist_reg_only | |
| opt.iou_coeff = iou_coeff | |
| opt.bbox_coeff = bbox_coeff | |
| opt.dist_x_bbox = dist_x_bbox | |
| opt.iou_loss_only = iou_loss_only | |
| opt.show_dist_reg = show_dist_reg | |
| # Create a list of save dirs for output | |
| save_dirs = [] | |
| # Set CUDA device | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(int(opt.device)) | |
| #TODO - Adjust this for the number of bounding boxes | |
| targets = torch.tensor([ | |
| [0, 0, opt.x_coord, opt.y_coord, 0.05, 0.05], | |
| # [0, 1, 0.4, 0.6, 0.05, 0.07], | |
| # [1, 0, 0.25, 0.2, 0.04, 0.05], | |
| # [2, 0, 0.8, 0.76, 0.05, 0.05], | |
| # [2, 0, 0.8, 0.2, 0.05, 0.05], | |
| # [0, 0, 0.8, 0.76, 0.05, 0.05], | |
| # [1, 0, 0.8, 0.2, 0.05, 0.05], | |
| ]) | |
| unique_classes = torch.unique(targets[:,0]) | |
| # X = (gaussian_blur(torch.rand(len(unique_classes), 1, 50, 50)**2, 3)**4) | |
| attr = (gaussian_blur(torch.rand(len(unique_classes), 1, 640, 640)**2, 13)**4).requires_grad_(True) | |
| plaus_loss = get_plaus_loss(targets, attribution_map=attr, | |
| opt=opt, | |
| debug=True, | |
| only_loss=True) | |
| if opt.iou_loss_only: | |
| bbox_map = get_bbox_map(targets, attr) | |
| plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr))) | |
| plaus_loss = (1.0 - plaus_score) | |
| # Plot params (adjust as nessesary) | |
| nsamples = 10 | |
| rows = len(attr) # Number of images | |
| cols = nsamples + 2 # Define the number of columns for subplots | |
| size = 3 | |
| # Create a new figure for each i | |
| fig1 = plt.figure(figsize=(cols * size, rows * size)) | |
| plt.tight_layout() | |
| # Create the second figure for the remaining 8 attr steps | |
| fig2 = plt.figure(figsize=(cols * size, rows * size)) | |
| plt.tight_layout() | |
| # Create a figure for plausibility losses | |
| fig3, ax3 = plt.subplots(figsize=(10, 6)) | |
| plaus_losses = [] | |
| # Create a figure for plausibility scores | |
| fig4, ax4 = plt.subplots(figsize=(10, 6)) | |
| plaus_scores = [] | |
| for i in range(10): | |
| plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets.requires_grad_(True), attribution_map=attr, opt=opt, debug=True) | |
| delta_attr = torch.autograd.grad(plaus_loss, attr, create_graph=True, retain_graph=True)[0] | |
| attr = attr - (delta_attr * alpha) | |
| alpha *= opt.scheduler | |
| plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True) | |
| if opt.iou_loss_only: | |
| bbox_map = get_bbox_map(targets, attr) | |
| plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr))) | |
| plaus_loss = (1.0 - plaus_score) | |
| distance_map = bbox_map | |
| # attr = attr.clamp(0, 1) | |
| attr = normalize_batch(attr) | |
| plaus_losses.append(float(plaus_loss)) | |
| plaus_scores.append(float(plaus_score)) | |
| print(f'step: {i}, plaus_loss: {plaus_loss}, plaus_score: {plaus_score}, dist_reg: {dist_reg}, plaus_reg: {plaus_reg}') | |
| for j in range(len(attr)): | |
| # Add a subplot for each image | |
| if i == 0 and opt.show_dist_reg: | |
| ax = fig1.add_subplot(rows, cols, 1 + (j * cols)) | |
| ax.set_title(f'Distance Regularization Map {j}') | |
| img_tensor = (1 - distance_map[j]).detach().cpu() | |
| img_np = img_tensor.detach().cpu().numpy().squeeze() | |
| img_colored = plt.cm.viridis(img_np) | |
| bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now) | |
| img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords) | |
| subfigimshow(img_with_boxes, ax) | |
| ax.axis('off') | |
| else: | |
| if i == 1: | |
| # Add the first attr step to fig1 | |
| ax = fig1.add_subplot(rows, cols, 2 + (j * cols)) | |
| ax.set_title(f'Attr Step {i}' if j == 0 else '') | |
| img_tensor = attr[j].detach().cpu() | |
| img_np = img_tensor.detach().cpu().numpy().squeeze() | |
| img_colored = plt.cm.viridis(img_np) | |
| bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now) | |
| img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords) | |
| subfigimshow(img_with_boxes, ax) | |
| ax.axis('off') | |
| else: | |
| # Subsequent steps go to fig2 | |
| ax = fig2.add_subplot(rows, cols, 1 + (i - 1) + (j * cols)) | |
| ax.set_title(f'Attr Step {i}' if j == 0 else '') | |
| img_tensor = attr[j].detach().cpu() | |
| img_np = img_tensor.detach().cpu().numpy().squeeze() | |
| img_colored = plt.cm.viridis(img_np) | |
| subfigimshow(img_colored, ax) | |
| ax.axis('off') | |
| # Plot plausibility losses | |
| ax3.plot(range(nsamples), plaus_losses, marker='o', label='Plausibility Loss') | |
| ax3.set_title('Plausibility Losses Across Steps') | |
| ax3.set_xlabel('Step') | |
| ax3.set_ylabel('Plausibility Loss') | |
| ax3.grid(True) | |
| ax3.legend() | |
| # Plot plausibility scores | |
| ax4.plot(range(nsamples), plaus_scores, marker='o', label='Plausibility Scores') | |
| ax4.set_title('Plausibility Scores Across Steps') | |
| ax4.set_xlabel('Step') | |
| ax4.set_ylabel('Plausibility Score') | |
| ax4.grid(True) | |
| ax4.legend() | |
| # Save the figures | |
| fig1.savefig('figs/distance_and_first_step.png', bbox_inches='tight') | |
| plt.close(fig1) | |
| fig2.savefig('figs/remaining_attr_steps.png', bbox_inches='tight') | |
| plt.close(fig2) | |
| fig3.savefig('figs/plausibility_losses.png', bbox_inches='tight') | |
| plt.close(fig3) | |
| fig4.savefig('figs/plausibility_scores.png', bbox_inches='tight') | |
| plt.close(fig3) | |
| print('Figures saved: figs/distance_and_first_step.png, figs/remaining_attr_steps.png, and figs/plausibility_losses.png, figs/plausibility_scores.png') | |
| return 'figs/distance_and_first_step.png', 'figs/remaining_attr_steps.png', 'figs/plausibility_losses.png', 'figs/plausibility_scores.png' | |
| if __name__ == '__main__': | |
| #TODO - this does not appear to be working correctly | |
| parser = argparse.ArgumentParser() | |
| # ##################### Standard Settings ##################### | |
| parser.add_argument('--pgt_coeff', type=float, default=1.0, help='pgt_coeff') | |
| parser.add_argument('--focus_coeff', type=float, default=0.2, help='focus_coeff') | |
| parser.add_argument('--alpha', type=float, default=400.0, help='alpha') | |
| parser.add_argument('--num_bb', type=int, default=0, help='num_bb') | |
| parser.add_argument('--x_coord', type=float, default=0.2, help='x_coord') | |
| parser.add_argument('--y_coord', type=float, default=0.35, help='y_coord') | |
| ########################## Advanced ######################### | |
| parser.add_argument('--scheduler', type=float, default=2.0, help='scheduler for alpha') | |
| ############################################################# | |
| parser.add_argument('--device', type=str, default='0', help='device') | |
| parser.add_argument('--dist_coeff', type=float, default=0.5, help='dist_coeff') | |
| parser.add_argument('--dist_reg_only', type=bool, default=True, help='dist_reg_only') | |
| parser.add_argument('--iou_coeff', type=float, default=0.5, help='iou_coeff') | |
| parser.add_argument('--bbox_coeff', type=float, default=0.0, help='bbox_coeff') | |
| parser.add_argument('--dist_x_bbox', type=bool, default=False, help='dist_x_bbox') | |
| parser.add_argument('--iou_loss_only', type=bool, default=False, help='iou_loss_only') | |
| parser.add_argument('--show_dist_reg', type=bool, default=True, help='show distance regularization map in figure') | |
| opt = parser.parse_args() | |
| toy_problem(opt.pgt_coeff, opt.focus_coeff, opt.x_coord, opt.y_coord, opt.alpha, opt.num_bb, | |
| opt.scheduler, opt.device, opt.dist_coeff, opt.dist_reg_only, opt.iou_coeff, | |
| opt.bbox_coeff, opt.dist_x_bbox, opt.iou_loss_only, opt.show_dist_reg) |