File size: 11,283 Bytes
32938bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import torch
from plaus_functs import get_center_coords, get_distance_grids, get_plaus_loss, get_bbox_map, normalize_batch
from plot_functs import imshow
from torchvision.transforms.functional import gaussian_blur
import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2

def subfigimshow(img, ax):
    print(f'img shape: {img.shape}')
    try:
        npimg = img.clone().detach().cpu().numpy()
    except:
        npimg = img
    if len(npimg.shape) == 2:
        # If it's a 2D array, it's likely a grayscale image
        ax.imshow(npimg, cmap='gray')
    elif len(npimg.shape) == 3:
        if npimg.shape[0] == 3 or npimg.shape[0] == 1:
            # If the first dimension is 3 or 1, it's likely in (C, H, W) format
            tpimg = np.transpose(npimg, (1, 2, 0))
        else:
            # It's already in (H, W, C) format
            tpimg = npimg
        
        if tpimg.shape[2] == 1:
            # If it's a 3D array with only one channel, squeeze it
            ax.imshow(np.squeeze(tpimg), cmap='gray')
        else:
            ax.imshow(tpimg)
    else:
        raise ValueError(f"Unexpected image shape: {npimg.shape}")

def draw_bounding_boxes(image, boxes, color=(0, 255, 0), thickness=2):
    # Ensure image is 3-channel RGB
    if len(image.shape) == 2:
        image = np.stack([image] * 3, axis=-1)
    elif len(image.shape) == 3 and image.shape[2] == 1:
        image = np.repeat(image, 3, axis=2)
    
    # Ensure image is uint8 and in range [0, 255]
    if image.dtype != np.uint8:
        image = (image * 255).clip(0, 255).astype(np.uint8)
    
    image_with_boxes = image.copy()
    for box in boxes:
        x_center, y_center, width, height = box
        x_min = int((x_center - width / 2) * image_with_boxes.shape[1])
        y_min = int((y_center - height / 2) * image_with_boxes.shape[0])
        x_max = int((x_center + width / 2) * image_with_boxes.shape[1])
        y_max = int((y_center + height / 2) * image_with_boxes.shape[0])
        cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), color, thickness)
    
    return image_with_boxes

def toy_problem(pgt_coeff, focus_coeff, x_coord, y_coord, num_bb=0, alpha=200.0, scheduler=2.0, device="0", dist_coeff=0.5, dist_reg_only=True, iou_coeff=0.5, 
                bbox_coeff=0.0, dist_x_bbox=False, iou_loss_only=False, show_dist_reg=True):
    
    # Create a Namespace object to hold params
    opt = argparse.Namespace()
    # Save all parameters as attributes of the Namespace object
    opt.pgt_coeff = pgt_coeff
    opt.focus_coeff = focus_coeff
    opt.x_coord = x_coord
    opt.y_coord = y_coord
    opt.num_bb = num_bb
    opt.alpha = alpha
    opt.scheduler = scheduler
    opt.device = device
    opt.dist_coeff = dist_coeff
    opt.dist_reg_only = dist_reg_only
    opt.iou_coeff = iou_coeff
    opt.bbox_coeff = bbox_coeff
    opt.dist_x_bbox = dist_x_bbox
    opt.iou_loss_only = iou_loss_only
    opt.show_dist_reg = show_dist_reg

    # Create a list of save dirs for output
    save_dirs = []

    # Set CUDA device
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
    os.environ["CUDA_VISIBLE_DEVICES"] = str(int(opt.device))
    
    #TODO - Adjust this for the number of bounding boxes
    targets = torch.tensor([
                            [0, 0, opt.x_coord, opt.y_coord, 0.05, 0.05],
    #                        [0, 1, 0.4, 0.6, 0.05, 0.07],
    #                        [1, 0, 0.25, 0.2, 0.04, 0.05],
                        #    [2, 0, 0.8, 0.76, 0.05, 0.05],
                        #    [2, 0, 0.8, 0.2, 0.05, 0.05],
                        #    [0, 0, 0.8, 0.76, 0.05, 0.05],
                        #    [1, 0, 0.8, 0.2, 0.05, 0.05],
    ])
    
    unique_classes = torch.unique(targets[:,0])
    # X = (gaussian_blur(torch.rand(len(unique_classes), 1, 50, 50)**2, 3)**4)
    attr = (gaussian_blur(torch.rand(len(unique_classes), 1, 640, 640)**2, 13)**4).requires_grad_(True)
    plaus_loss = get_plaus_loss(targets, attribution_map=attr, 
                                                        opt=opt, 
                                                        debug=True,
                                                        only_loss=True)
    if opt.iou_loss_only:
        bbox_map = get_bbox_map(targets, attr)
        plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
        plaus_loss = (1.0 - plaus_score)

    # Plot params (adjust as nessesary)
    nsamples = 10
    rows = len(attr)  # Number of images
    cols = nsamples + 2  # Define the number of columns for subplots
    size = 3

    # Create a new figure for each i
    fig1 = plt.figure(figsize=(cols * size, rows * size))
    plt.tight_layout()

    # Create the second figure for the remaining 8 attr steps
    fig2 = plt.figure(figsize=(cols * size, rows * size))
    plt.tight_layout()

    # Create a figure for plausibility losses
    fig3, ax3 = plt.subplots(figsize=(10, 6))
    plaus_losses = []

    # Create a figure for plausibility scores
    fig4, ax4 = plt.subplots(figsize=(10, 6))
    plaus_scores = []
    
    for i in range(10):
        plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets.requires_grad_(True), attribution_map=attr, opt=opt, debug=True)
        
        delta_attr = torch.autograd.grad(plaus_loss, attr, create_graph=True, retain_graph=True)[0] 
        attr = attr - (delta_attr * alpha) 
        alpha *= opt.scheduler

        plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map = get_plaus_loss(targets, attribution_map=attr, opt=opt, debug=True)
        if opt.iou_loss_only:
            bbox_map = get_bbox_map(targets, attr)
            plaus_score = ((torch.sum((attr * bbox_map))) / (torch.sum(attr)))
            plaus_loss = (1.0 - plaus_score)
            distance_map = bbox_map

        # attr = attr.clamp(0, 1) 
        attr = normalize_batch(attr)
        plaus_losses.append(float(plaus_loss))
        plaus_scores.append(float(plaus_score))
        print(f'step: {i}, plaus_loss: {plaus_loss}, plaus_score: {plaus_score}, dist_reg: {dist_reg}, plaus_reg: {plaus_reg}')

        for j in range(len(attr)):

            # Add a subplot for each image 
            if i == 0 and opt.show_dist_reg: 
                ax = fig1.add_subplot(rows, cols, 1 + (j * cols)) 
                ax.set_title(f'Distance Regularization Map {j}') 
                img_tensor = (1 - distance_map[j]).detach().cpu()
                img_np = img_tensor.detach().cpu().numpy().squeeze()
                img_colored = plt.cm.viridis(img_np)
                bbox_coords = targets[:, 2:6].detach().cpu().numpy()  # This gives us [x_coord, y_coord, width, height] (all bb for now)
                img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
                subfigimshow(img_with_boxes, ax) 
                ax.axis('off')
             
            else:  
                if i == 1:
                    # Add the first attr step to fig1
                    ax = fig1.add_subplot(rows, cols, 2 + (j * cols))
                    ax.set_title(f'Attr Step {i}' if j == 0 else '')
                    img_tensor = attr[j].detach().cpu()
                    img_np = img_tensor.detach().cpu().numpy().squeeze()
                    img_colored = plt.cm.viridis(img_np)
                    bbox_coords = targets[:, 2:6].detach().cpu().numpy()  # This gives us [x_coord, y_coord, width, height] (all bb for now)
                    img_with_boxes = draw_bounding_boxes(img_colored, bbox_coords)
                    subfigimshow(img_with_boxes, ax)
                    ax.axis('off')
                else:
                    # Subsequent steps go to fig2
                    ax = fig2.add_subplot(rows, cols, 1 + (i - 1) + (j * cols))
                    ax.set_title(f'Attr Step {i}' if j == 0 else '')
                    img_tensor = attr[j].detach().cpu()
                    img_np = img_tensor.detach().cpu().numpy().squeeze()
                    img_colored = plt.cm.viridis(img_np)
                    subfigimshow(img_colored, ax)
                    ax.axis('off')
    
    # Plot plausibility losses
    ax3.plot(range(nsamples), plaus_losses, marker='o', label='Plausibility Loss')
    ax3.set_title('Plausibility Losses Across Steps')
    ax3.set_xlabel('Step')
    ax3.set_ylabel('Plausibility Loss')
    ax3.grid(True)
    ax3.legend()

    # Plot plausibility scores
    ax4.plot(range(nsamples), plaus_scores, marker='o', label='Plausibility Scores')
    ax4.set_title('Plausibility Scores Across Steps')
    ax4.set_xlabel('Step')
    ax4.set_ylabel('Plausibility Score')
    ax4.grid(True)
    ax4.legend()

    # Save the figures
    fig1.savefig('figs/distance_and_first_step.png', bbox_inches='tight')
    plt.close(fig1)

    fig2.savefig('figs/remaining_attr_steps.png', bbox_inches='tight')
    plt.close(fig2)

    fig3.savefig('figs/plausibility_losses.png', bbox_inches='tight')
    plt.close(fig3)

    fig4.savefig('figs/plausibility_scores.png', bbox_inches='tight')
    plt.close(fig3)


    print('Figures saved: figs/distance_and_first_step.png, figs/remaining_attr_steps.png, and figs/plausibility_losses.png, figs/plausibility_scores.png')
    return 'figs/distance_and_first_step.png', 'figs/remaining_attr_steps.png', 'figs/plausibility_losses.png', 'figs/plausibility_scores.png'

if __name__ == '__main__':

    #TODO - this does not appear to be working correctly
    parser = argparse.ArgumentParser()
    # ##################### Standard Settings #####################
    parser.add_argument('--pgt_coeff', type=float, default=1.0, help='pgt_coeff')
    parser.add_argument('--focus_coeff', type=float, default=0.2, help='focus_coeff')
    parser.add_argument('--alpha', type=float, default=400.0, help='alpha')
    parser.add_argument('--num_bb', type=int, default=0, help='num_bb')
    parser.add_argument('--x_coord', type=float, default=0.2, help='x_coord')
    parser.add_argument('--y_coord', type=float, default=0.35, help='y_coord')
    ########################## Advanced #########################
    parser.add_argument('--scheduler', type=float, default=2.0, help='scheduler for alpha')
    #############################################################
    parser.add_argument('--device', type=str, default='0', help='device')
    parser.add_argument('--dist_coeff', type=float, default=0.5, help='dist_coeff')
    parser.add_argument('--dist_reg_only', type=bool, default=True, help='dist_reg_only')
    parser.add_argument('--iou_coeff', type=float, default=0.5, help='iou_coeff')
    parser.add_argument('--bbox_coeff', type=float, default=0.0, help='bbox_coeff')
    parser.add_argument('--dist_x_bbox', type=bool, default=False, help='dist_x_bbox')
    parser.add_argument('--iou_loss_only', type=bool, default=False, help='iou_loss_only')
    parser.add_argument('--show_dist_reg', type=bool, default=True, help='show distance regularization map in figure')
    opt = parser.parse_args()

    toy_problem(opt.pgt_coeff, opt.focus_coeff, opt.x_coord, opt.y_coord, opt.alpha, opt.num_bb, 
                opt.scheduler, opt.device, opt.dist_coeff, opt.dist_reg_only, opt.iou_coeff, 
                opt.bbox_coeff, opt.dist_x_bbox, opt.iou_loss_only, opt.show_dist_reg)