import torch from plaus_functs import get_center_coords, get_distance_grids, get_plaus_loss, get_bbox_map, normalize_batch from plot_functs import imshow from torchvision.transforms.functional import gaussian_blur import argparse import matplotlib.pyplot as plt import numpy as np import os import cv2 def subfigimshow(img, ax): print(f'img shape: {img.shape}') try: npimg = img.clone().detach().cpu().numpy() except: npimg = img if len(npimg.shape) == 2: # If it's a 2D array, it's likely a grayscale image ax.imshow(npimg, cmap='gray') elif len(npimg.shape) == 3: if npimg.shape[0] == 3 or npimg.shape[0] == 1: # If the first dimension is 3 or 1, it's likely in (C, H, W) format tpimg = np.transpose(npimg, (1, 2, 0)) else: # It's already in (H, W, C) format tpimg = npimg if tpimg.shape[2] == 1: # If it's a 3D array with only one channel, squeeze it ax.imshow(np.squeeze(tpimg), cmap='gray') else: ax.imshow(tpimg) else: raise ValueError(f"Unexpected image shape: {npimg.shape}") def draw_bounding_boxes(image, boxes, color=(0, 255, 0), thickness=2): # Ensure image is 3-channel RGB if len(image.shape) == 2: image = np.stack([image] * 3, axis=-1) elif len(image.shape) == 3 and image.shape[2] == 1: image = np.repeat(image, 3, axis=2) # Ensure image is uint8 and in range [0, 255] if image.dtype != np.uint8: image = (image * 255).clip(0, 255).astype(np.uint8) image_with_boxes = image.copy() for box in boxes: x_center, y_center, width, height = box x_min = int((x_center - width / 2) * image_with_boxes.shape[1]) y_min = int((y_center - height / 2) * image_with_boxes.shape[0]) x_max = int((x_center + width / 2) * image_with_boxes.shape[1]) y_max = int((y_center + height / 2) * image_with_boxes.shape[0]) cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), color, thickness) return image_with_boxes def toy_problem(pgt_coeff, focus_coeff, x_coord, y_coord, num_bb=0, alpha=200.0, scheduler=2.0, device="0", dist_coeff=0.5, dist_reg_only=True, iou_coeff=0.5, bbox_coeff=0.0, dist_x_bbox=False, iou_loss_only=False, show_dist_reg=True): # Create a Namespace object to hold params opt = argparse.Namespace() # Save all parameters as attributes of the Namespace object opt.pgt_coeff = pgt_coeff opt.focus_coeff = focus_coeff opt.x_coord = x_coord opt.y_coord = y_coord opt.num_bb = num_bb opt.alpha = alpha opt.scheduler = scheduler opt.device = device opt.dist_coeff = dist_coeff opt.dist_reg_only = dist_reg_only opt.iou_coeff = iou_coeff opt.bbox_coeff = bbox_coeff opt.dist_x_bbox = dist_x_bbox opt.iou_loss_only = iou_loss_only opt.show_dist_reg = show_dist_reg # Create a list of save dirs for output save_dirs = [] # Set CUDA device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(int(opt.device)) #TODO - Adjust this for the number of bounding boxes targets = torch.tensor([ [0, 0, opt.x_coord, opt.y_coord, 0.05, 0.05], # [0, 1, 0.4, 0.6, 0.05, 0.07], # [1, 0, 0.25, 0.2, 0.04, 0.05], # [2, 0, 0.8, 0.76, 0.05, 0.05], # [2, 0, 0.8, 0.2, 0.05, 0.05], # [0, 0, 0.8, 0.76, 0.05, 0.05], # [1, 0, 0.8, 0.2, 0.05, 0.05], ]) unique_classes = torch.unique(targets[:,0]) # X = (gaussian_blur(torch.rand(len(unique_classes), 1, 50, 50)**2, 3)**4) attr = (gaussian_blur(torch.rand(len(unique_classes), 1, 640, 640)**2, 13)**4).requires_grad_(True) plaus_loss = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True, only_loss=True) if opt.iou_loss_only: bbox_map = get_bbox_map(targets, attr) plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr))) plaus_loss = (1.0 - plaus_score) # Plot params (adjust as nessesary) nsamples = 10 rows = len(attr) # Number of images cols = nsamples + 2 # Define the number of columns for subplots size = 3 # Create a new figure for each i fig1 = plt.figure(figsize=(cols * size, rows * size)) plt.tight_layout() # Create the second figure for the remaining 8 attr steps fig2 = plt.figure(figsize=(cols * size, rows * size)) plt.tight_layout() # Create a figure for plausibility losses fig3, ax3 = plt.subplots(figsize=(10, 6)) plaus_losses = [] # Create a figure for plausibility scores fig4, ax4 = plt.subplots(figsize=(10, 6)) plaus_scores = [] for i in range(10): plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets.requires_grad_(True), attribution_map=attr, opt=opt, debug=True) delta_attr = torch.autograd.grad(plaus_loss, attr, create_graph=True, retain_graph=True)[0] attr = attr - (delta_attr * alpha) alpha *= opt.scheduler plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True) if opt.iou_loss_only: bbox_map = get_bbox_map(targets, attr) plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr))) plaus_loss = (1.0 - plaus_score) distance_map = bbox_map # attr = attr.clamp(0, 1) attr = normalize_batch(attr) plaus_losses.append(float(plaus_loss)) plaus_scores.append(float(plaus_score)) print(f'step: {i}, plaus_loss: {plaus_loss}, plaus_score: {plaus_score}, dist_reg: {dist_reg}, plaus_reg: {plaus_reg}') for j in range(len(attr)): # Add a subplot for each image if i == 0 and opt.show_dist_reg: ax = fig1.add_subplot(rows, cols, 1 + (j * cols)) ax.set_title(f'Distance Regularization Map {j}') img_tensor = (1 - distance_map[j]).detach().cpu() img_np = img_tensor.detach().cpu().numpy().squeeze() img_colored = plt.cm.viridis(img_np) bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now) img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords) subfigimshow(img_with_boxes, ax) ax.axis('off') else: if i == 1: # Add the first attr step to fig1 ax = fig1.add_subplot(rows, cols, 2 + (j * cols)) ax.set_title(f'Attr Step {i}' if j == 0 else '') img_tensor = attr[j].detach().cpu() img_np = img_tensor.detach().cpu().numpy().squeeze() img_colored = plt.cm.viridis(img_np) bbox_coords = targets[:, 2:6].detach().cpu().numpy() # This gives us [x_coord, y_coord, width, height] (all bb for now) img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords) subfigimshow(img_with_boxes, ax) ax.axis('off') else: # Subsequent steps go to fig2 ax = fig2.add_subplot(rows, cols, 1 + (i - 1) + (j * cols)) ax.set_title(f'Attr Step {i}' if j == 0 else '') img_tensor = attr[j].detach().cpu() img_np = img_tensor.detach().cpu().numpy().squeeze() img_colored = plt.cm.viridis(img_np) subfigimshow(img_colored, ax) ax.axis('off') # Plot plausibility losses ax3.plot(range(nsamples), plaus_losses, marker='o', label='Plausibility Loss') ax3.set_title('Plausibility Losses Across Steps') ax3.set_xlabel('Step') ax3.set_ylabel('Plausibility Loss') ax3.grid(True) ax3.legend() # Plot plausibility scores ax4.plot(range(nsamples), plaus_scores, marker='o', label='Plausibility Scores') ax4.set_title('Plausibility Scores Across Steps') ax4.set_xlabel('Step') ax4.set_ylabel('Plausibility Score') ax4.grid(True) ax4.legend() # Save the figures fig1.savefig('figs/distance_and_first_step.png', bbox_inches='tight') plt.close(fig1) fig2.savefig('figs/remaining_attr_steps.png', bbox_inches='tight') plt.close(fig2) fig3.savefig('figs/plausibility_losses.png', bbox_inches='tight') plt.close(fig3) fig4.savefig('figs/plausibility_scores.png', bbox_inches='tight') plt.close(fig3) print('Figures saved: figs/distance_and_first_step.png, figs/remaining_attr_steps.png, and figs/plausibility_losses.png, figs/plausibility_scores.png') return 'figs/distance_and_first_step.png', 'figs/remaining_attr_steps.png', 'figs/plausibility_losses.png', 'figs/plausibility_scores.png' if __name__ == '__main__': #TODO - this does not appear to be working correctly parser = argparse.ArgumentParser() # ##################### Standard Settings ##################### parser.add_argument('--pgt_coeff', type=float, default=1.0, help='pgt_coeff') parser.add_argument('--focus_coeff', type=float, default=0.2, help='focus_coeff') parser.add_argument('--alpha', type=float, default=400.0, help='alpha') parser.add_argument('--num_bb', type=int, default=0, help='num_bb') parser.add_argument('--x_coord', type=float, default=0.2, help='x_coord') parser.add_argument('--y_coord', type=float, default=0.35, help='y_coord') ########################## Advanced ######################### parser.add_argument('--scheduler', type=float, default=2.0, help='scheduler for alpha') ############################################################# parser.add_argument('--device', type=str, default='0', help='device') parser.add_argument('--dist_coeff', type=float, default=0.5, help='dist_coeff') parser.add_argument('--dist_reg_only', type=bool, default=True, help='dist_reg_only') parser.add_argument('--iou_coeff', type=float, default=0.5, help='iou_coeff') parser.add_argument('--bbox_coeff', type=float, default=0.0, help='bbox_coeff') parser.add_argument('--dist_x_bbox', type=bool, default=False, help='dist_x_bbox') parser.add_argument('--iou_loss_only', type=bool, default=False, help='iou_loss_only') parser.add_argument('--show_dist_reg', type=bool, default=True, help='show distance regularization map in figure') opt = parser.parse_args() toy_problem(opt.pgt_coeff, opt.focus_coeff, opt.x_coord, opt.y_coord, opt.alpha, opt.num_bb, opt.scheduler, opt.device, opt.dist_coeff, opt.dist_reg_only, opt.iou_coeff, opt.bbox_coeff, opt.dist_x_bbox, opt.iou_loss_only, opt.show_dist_reg)