Spaces:
Sleeping
Sleeping
| """Util functions | |
| Extended from original PANet code | |
| TODO: move part of dataset configurations to data_utils | |
| """ | |
| import random | |
| import torch | |
| import numpy as np | |
| import operator | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import kneed | |
| import urllib | |
| from tqdm.auto import tqdm | |
| from sklearn.decomposition import PCA | |
| import torchvision.transforms.functional as F | |
| def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): | |
| num_labels, labels, stats, centroids = cca_output | |
| # Create an output image with random colors for each component | |
| output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) | |
| for label in range(1, num_labels): # Start from 1 to skip the background | |
| mask = labels == label | |
| output_image[mask] = np.random.randint(0, 255, size=3) | |
| # Plotting the original and the colored components image | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') | |
| plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') | |
| if confidences is not None: | |
| # Plot the axes color chart with the confidences, use the same colors as the connected components | |
| plt.subplot(122) | |
| scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') | |
| plt.colorbar(scatter) | |
| plt.savefig(title) | |
| plt.close() | |
| def reverse_tensor(tensor, original_h, original_w, degrees): | |
| """ | |
| tensor: tensor of shape (B, C, H, W) to be rotated | |
| original_h: int - original height of the tensor (after it was rotated) | |
| original_w: int - original width of the tensor (after it was rotated) | |
| degrees: int or float - angle in degrees couterclockwise | |
| """ | |
| _, _, h, w = tensor.shape # this is the shape that we want to return to | |
| if tensor.shape[-2:] != (original_h, original_w): | |
| tensor = F.resize(tensor, (original_h, original_w), interpolation=F.InterpolationMode.BILINEAR, antialias=True) | |
| # print("interpolating") | |
| rotated_tensor = F.rotate(tensor, degrees, expand=False) | |
| # remove the black padding | |
| h_remove = abs(h - original_h) // 2 | |
| w_remove = abs(w - original_w) // 2 | |
| if h_remove > 0 and w_remove > 0: | |
| rotated_tensor = rotated_tensor[:, :, h_remove:-h_remove, w_remove:-w_remove] | |
| return rotated_tensor | |
| def need_softmax(tensor, dim=1): | |
| return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0)) | |
| def rotate_tensor_no_crop(image_tensor, degrees): | |
| """ | |
| image_tensor: tensor of shape (B, C, H, W) | |
| degrees: int or float - angle in degrees couterclockwise | |
| returns: tensor of shape (B, C, H, W) rotated by degrees, | |
| """ | |
| if degrees == 0: | |
| return image_tensor, image_tensor.shape[-2:] | |
| b, c, h, w = image_tensor.shape | |
| rotated_tensor = F.rotate(image_tensor, degrees, expand=True) | |
| interpolation_mode = F.InterpolationMode.BILINEAR | |
| if c == 1: | |
| interpolation_mode = F.InterpolationMode.NEAREST | |
| resized_tensor = F.resize(rotated_tensor, (h, w), interpolation=interpolation_mode, antialias=True) | |
| return resized_tensor, rotated_tensor.shape[-2:] | |
| def plot_dinov2_fts(img_fts, title="debug/img_fts.png"): | |
| """ | |
| Using PCA to reduce img_fts to 2D and plot it | |
| Args: | |
| img_fts: (B, C, H, W) | |
| """ | |
| if isinstance(img_fts, torch.Tensor): | |
| img_fts = img_fts.cpu().detach().numpy() | |
| B, C, H, W = img_fts.shape | |
| img_fts_reshaped = img_fts.transpose(0, 2, 3, 1).reshape(-1, C) | |
| # Apply PCA to reduce dimensionality from C to 1 | |
| pca = PCA(n_components=1) | |
| img_fts_pca = pca.fit_transform(img_fts_reshaped) | |
| # Reshape back to (B, 1, H, W) | |
| img_fts_reduced = img_fts_pca.reshape(B, H, W, 1).transpose(0, 3, 1, 2) | |
| # Plot the B images | |
| if B == 1: | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| ax.imshow(img_fts_reduced[0, 0]) | |
| else: | |
| fig, axes = plt.subplots(1, B, figsize=(B*5, 5)) | |
| for i, ax in enumerate(axes.flat): | |
| ax.imshow(img_fts_reduced[i, 0]) | |
| # ax.axis('off') | |
| plt.tight_layout() | |
| plt.savefig(title) | |
| plt.close(fig) | |
| def move_to_device(dict_obj, device='cuda'): | |
| for key in dict_obj: | |
| value = dict_obj[key] | |
| if isinstance(value, torch.Tensor): | |
| dict_obj[key] = value.to(device) | |
| elif isinstance(value, list): | |
| for i, item in enumerate(value): | |
| if isinstance(item, torch.Tensor): | |
| dict_obj[key][i] = item.to(device) | |
| def validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, query_images, _config, q_part=0): | |
| model.eval() | |
| sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][q_part]]] # way(1) x shot x [B(1) x C x H x W] | |
| sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][q_part]]] | |
| sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][q_part]]] | |
| with torch.no_grad(): | |
| query_pred_logits, _, _, assign_mats, _, _ = model( sup_img_part , sup_fgm_part, sup_bgm_part, query_images, isval = True, val_wsize = _config["val_wsize"] ) | |
| query_pred = np.array(query_pred_logits.argmax(dim=1)[0].cpu().detach()) | |
| if _config['do_cca']: | |
| query_pred = cca(query_pred, query_pred_logits) | |
| if _config["debug"]: | |
| # plot the support images, support fg mask, query image, query pred before cca and query pred after cca | |
| fig, ax = plt.subplots(3, 2, figsize=(15, 10)) | |
| ax[0,0].imshow(support_images[0][q_part][0,0].cpu().numpy(), cmap='gray') | |
| ax[0,1].imshow(support_fg_mask[0][q_part][0].cpu().numpy(), cmap='gray') | |
| ax[1,0].imshow(query_images[0][0][0].cpu().numpy(), cmap='gray') | |
| ax[1,1].imshow(query_pred_logits.argmax(dim=1)[0].cpu().detach().numpy(), cmap='gray') | |
| ax[2,0].imshow(query_pred, cmap='gray') | |
| ax[2,1].imshow(query_pred_logits.argmax(dim=1)[0].cpu().detach().numpy(), cmap='gray') | |
| # remove all ticks | |
| for axi in ax.flat: | |
| axi.set_xticks([]) | |
| axi.set_yticks([]) | |
| fig.savefig("debug/cca_before_after.png") | |
| plt.close(fig) | |
| model.train() | |
| return query_pred, query_pred_logits | |
| def validation_on_scans(model, curr_lb, support_images, support_fg_mask, support_bg_mask, testloader, te_parent, te_dataset, _config, sup_img_indx=1, save_pred_buffer=None): | |
| if save_pred_buffer is None: | |
| save_pred_buffer = {} | |
| lb_buffer = {} | |
| conf_buffer = {} | |
| # sup_img_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_images[0][sup_img_indx]]] # way(1) x shot x [B(1) x C x H x W] | |
| # sup_fgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_fg_mask[0][sup_img_indx]]] | |
| # sup_bgm_part = [[shot_tensor.unsqueeze(0) for shot_tensor in support_bg_mask[0][sup_img_indx]]] | |
| for scan_idx, sample_batched in enumerate(testloader): | |
| print(f"Processing scan: {scan_idx + 1} / {len(testloader)}") | |
| _scan_id = sample_batched["scan_id"][0] | |
| if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query | |
| print(f"Skipping support scan: {_scan_id}") # TODO delete | |
| continue | |
| outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] | |
| outsize = (_config['input_size'][0], _config['input_size'][1], outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z | |
| _pred = np.zeros( outsize ) | |
| _pred.fill(np.nan) | |
| conf_buffer[_scan_id] = [] | |
| query_images = sample_batched['image'].cuda() | |
| z_min = sample_batched['z_min'][0] | |
| z_max = sample_batched['z_max'][0] | |
| # create an index list that starts with s_idx goes down to 0, then concat the indices from s_idx + 1 to the end | |
| # this is to make sure that the most similiar slice is the first one to be processed | |
| indices = list(range(len(query_images[0]))) | |
| qpart = sup_img_indx | |
| for idx, i in enumerate(tqdm(indices)): | |
| if _config["use_3_slices"]: | |
| # change the query to 3 slices (-1, 0, 1) | |
| if i == 0: | |
| prev_q = torch.zeros_like(query_images[0, i]).unsqueeze(0) | |
| else: | |
| prev_q = query_images[0, i - 1].unsqueeze(0) | |
| if i == len(query_images[0]) - 1: | |
| next_q = torch.zeros_like(query_images[0, i]).unsqueeze(0) | |
| else: | |
| next_q = query_images[0, i + 1].unsqueeze(0) | |
| query = torch.cat([prev_q, query_images[0, i].unsqueeze(0), next_q], dim=1) | |
| else: | |
| query = query_images[0, i].unsqueeze(0) | |
| query_pred, query_pred_logits = validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, [query], _config, q_part=qpart) | |
| query_conf = get_confidence_from_logits(query_pred_logits, query_pred) | |
| conf_buffer[_scan_id].append(query_conf) | |
| _pred[..., i] = query_pred.copy() | |
| if _config['dataset'] != 'C0': | |
| lb_buffer[_scan_id] = _pred.transpose(2,0,1) | |
| else: | |
| lb_buffer[_scan_id] = _pred | |
| save_pred_buffer[str(curr_lb)] = lb_buffer | |
| return save_pred_buffer, conf_buffer | |
| def validation(model, curr_lb, testloader, te_parent, te_dataset, _config, support_images, support_fg_mask, support_bg_mask, mar_val_metric_node=None, save_pred_buffer=None, do_validation=False, get_confidence=False): | |
| model.eval() | |
| with torch.no_grad(): | |
| curr_scan_count = -1 # counting for current scan | |
| _lb_buffer = {} # indexed by scan | |
| _conf_buffer = {} # indexed by scan | |
| _has_label_buffer = {} # indexed by scan | |
| last_qpart = 0 # used as indicator for adding result to buffer | |
| for idx, sample_batched in enumerate(tqdm(testloader)): | |
| _scan_id = sample_batched["scan_id"][0] # we assume batch size for query is 1 | |
| if _scan_id in te_parent.potential_support_sid: # skip the support scan, don't include that to query | |
| continue | |
| if sample_batched["is_start"]: | |
| ii = 0 | |
| curr_scan_count += 1 | |
| if do_validation: | |
| if curr_scan_count > 0: | |
| break | |
| print(f"Processing scan {curr_scan_count + 1} / {len(te_dataset.dataset.pid_curr_load)}") | |
| _scan_id = sample_batched["scan_id"][0] | |
| outsize = te_dataset.dataset.info_by_scan[_scan_id]["array_size"] | |
| outsize = (te_dataset.dataset.image_size, te_dataset.dataset.image_size, outsize[0]) # original image read by itk: Z, H, W, in prediction we use H, W, Z | |
| _pred = np.zeros( outsize ) | |
| _pred.fill(np.nan) | |
| _conf_buffer[_scan_id] = [] | |
| _has_label_buffer[_scan_id] = [] | |
| q_part = sample_batched["part_assign"] # the chunck of query, for assignment with support | |
| query_images = [sample_batched['image'].cuda()] | |
| query_labels = torch.cat([ sample_batched['label'].cuda()], dim=0) | |
| # if not 1 in query_labels: | |
| # continue | |
| # [way, [part, [shot x C x H x W]]] -> | |
| query_pred, query_pred_logits = validation_single_slice(model, support_images, support_fg_mask, support_bg_mask, query_images, _config, q_part=q_part) | |
| _pred[..., ii] = query_pred.copy() | |
| if 1 in query_labels: | |
| _has_label_buffer[_scan_id].append(True) | |
| else: | |
| _has_label_buffer[_scan_id].append(False) | |
| if get_confidence: | |
| # calc condfidence from logits and log it in the _conf_buffer | |
| query_conf = get_confidence_from_logits(query_pred_logits, query_pred) | |
| _conf_buffer[_scan_id].append(query_conf) | |
| if mar_val_metric_node is not None and ((sample_batched["z_id"] - sample_batched["z_max"] <= _config['z_margin']) and (sample_batched["z_id"] - sample_batched["z_min"] >= -1 * _config['z_margin'])): | |
| mar_val_metric_node.record(query_pred, np.array(query_labels[0].cpu()), labels=[curr_lb], n_scan=curr_scan_count) | |
| else: | |
| pass | |
| ii += 1 | |
| # now check data format | |
| if sample_batched["is_end"]: | |
| if _config['dataset'] != 'C0': | |
| _lb_buffer[_scan_id] = _pred.transpose(2,0,1) # H, W, Z -> to Z H W | |
| else: | |
| _lb_buffer[_scan_id] = _pred | |
| save_pred_buffer[str(curr_lb)] = _lb_buffer | |
| model.train() | |
| return save_pred_buffer, _conf_buffer, _has_label_buffer | |
| def load_config_from_url(url: str) -> str: | |
| with urllib.request.urlopen(url) as f: | |
| return f.read().decode() | |
| def save_pred_gt_fig(query_images, query_pred, query_labels, support_images=None, support_labels=None, path="debug/gt_vs_pred.png"): | |
| fig = plt.figure(figsize=(10, 5 if support_images is None else 10)) | |
| ax1 = fig.add_subplot(2 if support_images is not None else 1, 2, 1) | |
| ax1.imshow(query_images[0][0, 1].cpu().numpy()) | |
| ax1.imshow(query_labels[0].cpu().numpy(), alpha=0.5) | |
| ax1.set_title("Ground Truth") | |
| ax2 = fig.add_subplot(2 if support_images is not None else 1, 2, 2) | |
| ax2.imshow(query_images[0][0, 1].cpu().numpy()) | |
| ax2.imshow(query_pred, alpha=0.5) | |
| ax2.set_title("Prediction") | |
| if support_images is not None: | |
| ax3 = fig.add_subplot(2, 2, 3) | |
| ax3.imshow(support_images[0][0, 1].cpu().numpy()) | |
| ax3.imshow(support_labels[0].cpu().numpy(), alpha=0.5) | |
| ax3.set_title("Support") | |
| plt.savefig(path) | |
| plt.close('all') | |
| def plot_heatmap_of_probs(probs, image, path=None): | |
| # normalize image values to be between 0 and 1, assume image doesnt have a specific range | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| rgb_image = np.repeat(image[:, :, np.newaxis], 3, axis=2) | |
| # Create a 3D figure | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111) | |
| ax.imshow(rgb_image) | |
| ax.imshow(probs, alpha=0.5) | |
| if path is not None: | |
| fig.savefig(path) | |
| else: | |
| plt.show() | |
| plt.close(fig) | |
| def plot_3d_bar_probabilities(probabilities, labels, image, path=None): | |
| # Create a 3D figure | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111, projection='3d') | |
| # Create a meshgrid of the x and y coordinates | |
| x, y = np.meshgrid(np.arange(probabilities.shape[1]), np.arange(probabilities.shape[0])) | |
| # Flatten the probabilities and labels data and convert them to 1D arrays | |
| z = probabilities.flatten() | |
| c = np.where(labels.flatten() == 1, 'g', 'r') | |
| # normaliize image values to be between 0 and 1, assume image doesnt have a specific range | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| rgb_image = np.repeat(image[:, :, np.newaxis], 3, axis=2) | |
| # ax.imshow(rgb_image, extent=[0, probabilities.shape[1], 0, probabilities.shape[0]], alpha=0.5) | |
| # Create the 3D bar plot | |
| ax.plot_surface(x, y, np.zeros_like(x), facecolors=rgb_image) | |
| ax.bar3d(x.ravel(), y.ravel(), np.zeros_like(z), 1, 1, z, color=c, alpha=0.3) | |
| # Set the axis labels | |
| ax.set_xlabel('X') | |
| ax.set_ylabel('Y') | |
| ax.set_zlabel('Probability') | |
| # Show the plot | |
| if path is not None: | |
| fig.savefig(path) | |
| else: | |
| plt.show() | |
| plt.close(fig) | |
| # def plot_3d_bar_probabilities(probabilities, labels, path=None): | |
| # # Create a 3D figure | |
| # fig = plt.figure() | |
| # ax = fig.add_subplot(111, projection='3d') | |
| # # Create a meshgrid of the x and y coordinates | |
| # x, y = np.meshgrid(np.arange(probabilities.shape[1]), np.arange(probabilities.shape[0])) | |
| # # Flatten the probabilities and labels data and convert them to 1D arrays | |
| # z = probabilities.flatten() | |
| # c = np.where(labels.flatten() == 1, 'g', 'r') | |
| # # Create the 3D bar plot | |
| # ax.bar3d(x.ravel(), y.ravel(), np.zeros_like(z), 1, 1, z, color=c) | |
| # # Set the axis labels | |
| # ax.set_xlabel('X') | |
| # ax.set_ylabel('Y') | |
| # ax.set_zlabel('Probability') | |
| # # Show the plot | |
| # if path is not None: | |
| # fig.savefig(path) | |
| # else: | |
| # plt.show() | |
| # plt.close(fig) | |
| # def sliding_window_confidence_segmentation(query_pred_conf:np.array, window_size=3, threshold=0.5): | |
| # """ | |
| # query_pred_conf: np.array, shape (B, H, W) | |
| # """ | |
| # # slice window across the query_pred_conf, if the window has a mean confidence > 0.5, the center pixel is 1, otherwise 0 | |
| # pred = np.zeros_like(query_pred_conf) | |
| # # slice the window | |
| # for i in range(query_pred_conf.shape[-1] - window_size + 1): | |
| # for j in range(query_pred_conf.shape[-2] - window_size + 1): | |
| # window = query_pred_conf[:, i:i+window_size, j:j+window_size] | |
| # if np.mean(window) > threshold: | |
| # pred[:, i+window_size//2, j+window_size//2] = 1 | |
| # return pred | |
| def sliding_window_confidence_segmentation(query_pred_conf: np.array, window_size=3, threshold=0.5): | |
| """ | |
| query_pred_conf: np.array, shape (B, H, W) | |
| """ | |
| B, H, W = query_pred_conf.shape | |
| pad = window_size // 2 | |
| padded_conf = np.pad(query_pred_conf, ((0, 0), (pad, pad), (pad, pad)), mode='constant') | |
| # Calculate the mean in sliding windows | |
| window_view = np.lib.stride_tricks.sliding_window_view(padded_conf, (B, window_size, window_size)) | |
| mean_values = np.mean(window_view, axis=(-1, -2)) | |
| pred = (mean_values > threshold).astype(int) | |
| return pred[..., 0] | |
| def get_confidence_from_logits(query_pred_logits: torch.Tensor): | |
| query_probs = query_pred_logits.softmax(1)[:,1].flatten(1) | |
| query_pred = query_probs.clone() | |
| query_pred[query_probs < 0.5] = 0 | |
| query_pred[query_probs >= 0.5] = 1 | |
| return ((query_probs * query_pred).sum() / (query_pred.sum() + 1e-6)).item() | |
| def choose_threshold_kneedle(p): | |
| ''' | |
| p - probabilities of prediction | |
| ''' | |
| # use kneed to choose the threshold | |
| # create pdf from x | |
| n_bins = min(100, len(p)) | |
| hist, bin_edges = np.histogram(p, bins=n_bins) | |
| pdf = hist / hist.sum() | |
| cdf = np.cumsum(pdf) | |
| x = np.linspace(0, 1, n_bins) | |
| y = cdf | |
| # plot x, y in a fig and save the fig | |
| plt.figure() | |
| plt.plot(x, y) | |
| plt.savefig(f'debug/cdf.png') | |
| plt.figure() | |
| plt.plot(x, pdf) | |
| plt.savefig(f'debug/pdf.png') | |
| plt.close('all') | |
| kneedle = kneed.KneeLocator(x, y, curve='convex', direction='increasing') | |
| # get the value at the knee from the bin_edges | |
| threshold = bin_edges[int(kneedle.knee * n_bins)] | |
| return threshold | |
| def plot_cca_output(cca_output): | |
| for j in range(cca_output[0]): | |
| if j == 0: | |
| continue | |
| plt.figure() | |
| plt.imshow(cca_output[1] == j) | |
| plt.savefig(f'debug/cca_{j}.png') | |
| plt.close('all') | |
| def get_connected_components(query_pred_original, query_pred_logits, return_conf=False): | |
| """ | |
| get all connected components | |
| """ | |
| cca_output = cv2.connectedComponentsWithStats(query_pred_original.astype(np.uint8), connectivity=8) # TODO try 8 | |
| # plot_cca_output(cca_output) | |
| if return_conf: | |
| # calc confidence for each connected component | |
| cca_conf = {} # conf by id | |
| query_probs = query_pred_logits.softmax(1)[:,1].cpu().detach().numpy() | |
| for j in range(cca_output[0]): | |
| if j == 0: | |
| cca_conf[0] = 0 # background | |
| continue | |
| cca_conf[j] = ((query_probs.flatten() * (cca_output[1] == j).flatten()).sum() / ((query_pred_original.flatten().sum() + 1e-6))) # take into account the area of the connected component | |
| return cca_output, cca_conf | |
| return cca_output, None | |
| def cca(query_pred_original, query_pred_logits, return_conf=False, return_cc=False): | |
| ''' | |
| Performs connected component analysis on the query_pred and returns the most confident connected component | |
| ''' | |
| # cca_output = cv2.connectedComponentsWithStats(query_pred_original.astype(np.uint8), connectivity=8) # TODO try 8 | |
| # # calc confidence for each connected component | |
| # cca_conf = [] | |
| # for j in range(cca_output[0]): | |
| # if j == 0: | |
| # cca_conf.append(0) # background | |
| # continue | |
| # cca_conf.append((query_pred_logits.softmax(1)[:,1].flatten(1).cpu().detach().numpy() * (cca_output[1] == j).flatten()).sum() / ((cca_output[1] == j).flatten().sum() + 1e-6) * ((cca_output[1] == j).flatten().sum() / (query_pred_original.flatten().sum() + 1e-6))) # take into account the area of the connected component | |
| cca_output, cca_conf = get_connected_components(query_pred_original, query_pred_logits, return_conf=True) | |
| # find the most confident connected component, find max conf and its key | |
| max_conf = cca_conf[0] | |
| for k,v in cca_conf.items(): | |
| if v > max_conf: | |
| max_conf = v | |
| max_key = k | |
| if max_conf == 0: | |
| # no connected component found, use zeros | |
| query_pred = np.zeros_like(query_pred_original) | |
| else: | |
| # zero out all other connected components | |
| new_cca_output = list(cca_output) | |
| new_cca_output[0] = 2 # bg + fg | |
| new_cca_output[1] = np.where(cca_output[1] != max_key, 0, 1) # binarize the max_key | |
| new_cca_output[2] = cca_output[2][[0, max_key]] | |
| new_cca_output[3] = cca_output[3][[0, max_key]] | |
| cca_output = tuple(new_cca_output) | |
| query_pred = (cca_output[1] == 1).astype(np.uint8) | |
| # convert to binary mask | |
| query_pred = (query_pred > 0).astype(np.uint8) | |
| if return_cc: | |
| return cca_output | |
| query_pred_original = query_pred_original * query_pred | |
| if return_conf: | |
| return query_pred_original, max_conf | |
| return query_pred_original | |
| def set_seed(seed): | |
| """ | |
| Set the random seed | |
| """ | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| CLASS_LABELS = { | |
| 'SABS': { | |
| 'pa_all': set( [1,2,3,6] ), | |
| 0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing | |
| 1: set( [2,3] ), # lower_abdomen | |
| }, | |
| 'C0': { | |
| 'pa_all': set(range(1, 4)), | |
| 0: set([2,3]), | |
| 1: set([1,3]), | |
| 2: set([1,2]), | |
| }, | |
| 'CHAOST2': { | |
| 'pa_all': set(range(1, 5)), | |
| 0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes | |
| 1: set([2, 3]), # lower_abdomen | |
| }, | |
| } | |
| def get_bbox(fg_mask, inst_mask): | |
| """ | |
| Get the ground truth bounding boxes | |
| """ | |
| fg_bbox = torch.zeros_like(fg_mask, device=fg_mask.device) | |
| bg_bbox = torch.ones_like(fg_mask, device=fg_mask.device) | |
| inst_mask[fg_mask == 0] = 0 | |
| area = torch.bincount(inst_mask.view(-1)) | |
| cls_id = area[1:].argmax() + 1 | |
| cls_ids = np.unique(inst_mask)[1:] | |
| mask_idx = np.where(inst_mask[0] == cls_id) | |
| y_min = mask_idx[0].min() | |
| y_max = mask_idx[0].max() | |
| x_min = mask_idx[1].min() | |
| x_max = mask_idx[1].max() | |
| fg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 1 | |
| for i in cls_ids: | |
| mask_idx = np.where(inst_mask[0] == i) | |
| y_min = max(mask_idx[0].min(), 0) | |
| y_max = min(mask_idx[0].max(), fg_mask.shape[1] - 1) | |
| x_min = max(mask_idx[1].min(), 0) | |
| x_max = min(mask_idx[1].max(), fg_mask.shape[2] - 1) | |
| bg_bbox[0, y_min:y_max+1, x_min:x_max+1] = 0 | |
| return fg_bbox, bg_bbox | |
| def t2n(img_t): | |
| """ | |
| torch to numpy regardless of whether tensor is on gpu or memory | |
| """ | |
| if img_t.is_cuda: | |
| return img_t.data.cpu().numpy() | |
| else: | |
| return img_t.data.numpy() | |
| def to01(x_np): | |
| """ | |
| normalize a numpy to 0-1 for visualize | |
| """ | |
| return (x_np - x_np.min()) / (x_np.max() - x_np.min() + 1e-5) | |
| def compose_wt_simple(is_wce, data_name): | |
| """ | |
| Weights for cross-entropy loss | |
| """ | |
| # if is_wce: | |
| # if data_name in ['SABS', 'SABS_Superpix', 'SABS_448', 'SABS_Superpix_448', 'SABS_672', 'SABS_Superpix_672','C0', 'C0_Superpix', 'CHAOST2', 'CHAOST2_Superpix', 'CHAOST2_672', 'CHAOST2_Superpix_672', 'LITS17', 'LITS17_Superpix']: | |
| # return torch.FloatTensor([0.05, 1.0]).cuda() | |
| # else: | |
| # raise NotImplementedError | |
| # else: | |
| # return torch.FloatTensor([1.0, 1.0]).cuda() | |
| return torch.FloatTensor([0.05, 1.0]).cuda() | |
| class CircularList(list): | |
| """ | |
| Helper for spliting training and validation scans | |
| Originally: https://stackoverflow.com/questions/8951020/pythonic-circular-list/8951224 | |
| """ | |
| def __getitem__(self, x): | |
| if isinstance(x, slice): | |
| return [self[x] for x in self._rangeify(x)] | |
| index = operator.index(x) | |
| try: | |
| return super().__getitem__(index % len(self)) | |
| except ZeroDivisionError: | |
| raise IndexError('list index out of range') | |
| def _rangeify(self, slice): | |
| start, stop, step = slice.start, slice.stop, slice.step | |
| if start is None: | |
| start = 0 | |
| if stop is None: | |
| stop = len(self) | |
| if step is None: | |
| step = 1 | |
| return range(start, stop, step) | |