Spaces:
Sleeping
Sleeping
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from models.grid_proto_fewshot import FewShotSeg | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
| from models.SamWrapper import SamWrapper | |
| from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits | |
| from util.lora import inject_trainable_lora | |
| from models.segment_anything.utils.transforms import ResizeLongestSide | |
| import cv2 | |
| import time | |
| from abc import ABC, abstractmethod | |
| CONF_MODE="conf" | |
| CENTROID_MODE="centroid" | |
| BOTH_MODE="both" | |
| POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE) | |
| TYPE_ALPNET="alpnet" | |
| TYPE_SAM="sam" | |
| def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"): | |
| num_labels, labels, stats, centroids = cca_output | |
| # Create an output image with random colors for each component | |
| output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8) | |
| for label in range(1, num_labels): # Start from 1 to skip the background | |
| mask = labels == label | |
| output_image[mask] = np.random.randint(0, 255, size=3) | |
| # Plotting the original and the colored components image | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(121), plt.imshow(original_image), plt.title('Original Image') | |
| plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components') | |
| if confidences is not None: | |
| # Plot the axes color chart with the confidences, use the same colors as the connected components | |
| plt.subplot(122) | |
| scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet') | |
| plt.colorbar(scatter) | |
| plt.savefig(title) | |
| plt.close() | |
| class SegmentationInput(ABC): | |
| def set_query_images(self, query_images): | |
| pass | |
| def to(self, device): | |
| pass | |
| class SegmentationOutput(ABC): | |
| def get_prediction(self): | |
| pass | |
| class ALPNetInput(SegmentationInput): # for alpnet | |
| def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None): | |
| self.supp_imgs = [support_images] | |
| self.fore_mask = [support_labels] | |
| self.back_mask = [[1 - sup_labels for sup_labels in support_labels]] | |
| self.qry_imgs = [query_images] | |
| self.isval = isval | |
| self.val_wsize = val_wsize | |
| self.show_viz = show_viz | |
| self.supp_fts = supp_fts | |
| def set_query_images(self, query_images): | |
| self.qry_imgs = [query_images] | |
| def to(self, device): | |
| self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]] | |
| self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]] | |
| self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]] | |
| self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs] | |
| if self.supp_fts is not None: | |
| self.supp_fts = self.supp_fts.to(device) | |
| class ALPNetOutput(SegmentationOutput): | |
| def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts): | |
| self.pred = pred | |
| self.align_loss = align_loss | |
| self.sim_maps = sim_maps | |
| self.assign_maps = assign_maps | |
| self.proto_grid = proto_grid | |
| self.supp_fts = supp_fts | |
| self.qry_fts = qry_fts | |
| def get_prediction(self): | |
| return self.pred | |
| class SAMWrapperInput(SegmentationInput): | |
| def __init__(self, image, image_labels): | |
| self.image = image | |
| self.image_labels = image_labels | |
| def set_query_images(self, query_images): | |
| B, C, H, W = query_images.shape | |
| if isinstance(query_images, torch.Tensor): | |
| query_images = query_images.cpu().detach().numpy() | |
| assert B == 1, "batch size must be 1" | |
| query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255 | |
| query_images = query_images.astype(np.uint8) | |
| self.image = np.transpose(query_images[0], (1, 2, 0)) | |
| def to(self, device): | |
| pass | |
| class InputFactory(ABC): | |
| def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None): | |
| if input_type == TYPE_ALPNET: | |
| return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts) | |
| elif input_type == TYPE_SAM: | |
| qimg = np.array(query_image.detach().cpu()) | |
| B,C,H,W = qimg.shape | |
| assert B == 1, "batch size must be 1" | |
| gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W) | |
| assert np.unique(gts).shape[0] <= 2, "support labels must be binary" | |
| gts[gts > 0] = 1 | |
| qimg = qimg.reshape(H,W,C) | |
| qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255 | |
| qimg = qimg.astype(np.uint8) | |
| return SAMWrapperInput(qimg, gts) | |
| else: | |
| raise ValueError(f"input_type not supported") | |
| class ModelWrapper(ABC): | |
| def __init__(self, model): | |
| self.model = model | |
| def __call__(self, input_data: SegmentationInput)->SegmentationOutput: | |
| pass | |
| def state_dict(self): | |
| return self.model.state_dict() | |
| def load_state_dict(self, state_dict): | |
| self.model.load_state_dict(state_dict) | |
| def eval(self): | |
| self.model.eval() | |
| def train(self): | |
| self.model.train() | |
| def parameters(self): | |
| pass | |
| class ALPNetWrapper(ModelWrapper): | |
| def __init__(self, model: FewShotSeg): | |
| super().__init__(model) | |
| def __call__(self, input_data: ALPNetInput): | |
| output = self.model(**input_data.__dict__) | |
| output = ALPNetOutput(*output) | |
| return output.pred | |
| def parameters(self): | |
| return self.model.encoder.parameters() | |
| def train(self): | |
| self.model.encoder.train() | |
| class SamWrapperWrapper(ModelWrapper): | |
| def __init__(self, model:SamWrapper): | |
| super().__init__(model) | |
| def __call__(self, input_data: SAMWrapperInput): | |
| pred = self.model(**input_data.__dict__) | |
| # make pred look like logits | |
| pred = torch.tensor(pred).float()[None, None, ...] | |
| pred = torch.cat([1-pred, pred], dim=1) | |
| return pred | |
| def to(self, device): | |
| self.model.sam.to(device) | |
| class ProtoSAM(nn.Module): | |
| def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ): | |
| super().__init__() | |
| if isinstance(image_size, int): | |
| image_size = (image_size, image_size) | |
| self.image_size = image_size | |
| self.coarse_segmentation_model = coarse_segmentation_model | |
| self.get_sam(sam_pretrained_path, use_sam_trans) | |
| self.num_points_for_sam = num_points_for_sam | |
| self.use_points = use_points | |
| self.use_bbox = use_bbox # if False then uses points | |
| self.use_mask = use_mask | |
| self.use_neg_points = use_neg_points | |
| assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask" | |
| self.use_cca = use_cca | |
| self.point_mode = point_mode | |
| if self.point_mode not in POINT_MODES: | |
| raise ValueError(f"point mode must be one of {POINT_MODES}") | |
| self.debug=debug | |
| self.coarse_pred_only = coarse_pred_only | |
| def get_sam(self, checkpoint_path, use_sam_trans): | |
| model_type="vit_b" # TODO make generic? | |
| if 'vit_h' in checkpoint_path: | |
| model_type = "vit_h" | |
| self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval() | |
| self.predictor = SamPredictor(self.sam) | |
| self.sam.requires_grad_(False) | |
| if use_sam_trans: | |
| # sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1]) | |
| sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size) | |
| sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1) | |
| sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1) | |
| else: | |
| sam_trans = None | |
| self.sam_trans = sam_trans | |
| def get_bbox(self, pred): | |
| ''' | |
| pred tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
| returns a list of 2d points representing the bbox | |
| ''' | |
| if isinstance(pred, np.ndarray): | |
| pred = torch.tensor(pred) | |
| # get the indices of the foreground points | |
| indices = torch.nonzero(pred) | |
| # get the min and max of the indices | |
| min_x = indices[:, 1].min() | |
| max_x = indices[:, 1].max() | |
| min_y = indices[:, 0].min() | |
| max_y = indices[:, 0].max() | |
| # get the bbox | |
| bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] | |
| return bbox | |
| def get_bbox_per_cc(self, conn_components): | |
| """ | |
| conn_components: output of cca function | |
| return list of bboxes per connected component, each bbox is a list of 2d points | |
| """ | |
| bboxes = [] | |
| for i in range(1, conn_components[0]): | |
| # get the indices of the foreground points | |
| indices = torch.nonzero(torch.tensor(conn_components[1] == i)) | |
| # get the min and max of the indices | |
| min_x = indices[:, 1].min() | |
| max_x = indices[:, 1].max() | |
| min_y = indices[:, 0].min() | |
| max_y = indices[:, 0].max() | |
| # get the bbox | |
| # bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]] | |
| # bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]] | |
| # bbox should be in a XYXY format | |
| bbox = [min_x, min_y, max_x, max_y] | |
| bboxes.append(bbox) | |
| bboxes = np.array(bboxes) | |
| return bboxes | |
| def get_most_conf_points(self, output_p_fg, pred, k): | |
| ''' | |
| get the k most confident points from pred | |
| output_p: 3d tensor of shape (H, W) | |
| pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
| ''' | |
| # Create a mask where pred is 1 | |
| mask = pred.bool() | |
| # Apply the mask to output_p_fg | |
| masked_output_p_fg = output_p_fg[mask] | |
| if masked_output_p_fg.numel() == 0: | |
| return None, None | |
| # Get the top k probabilities and their indices | |
| confidences, indices = torch.topk(masked_output_p_fg, k) | |
| # Get the locations of the top k points in xy format | |
| locations = torch.nonzero(mask)[indices] | |
| # convert locations to xy format | |
| locations = locations[:, [1, 0]] | |
| # convert locations to list of lists | |
| # points = [loc.tolist() for loc in locations] | |
| return locations.numpy(), [float(conf.item()) for conf in confidences] | |
| def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None): | |
| ''' | |
| points: np array of shape (N, 2) where each row is a point in xy format | |
| pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background | |
| image: 2d tensor of shape (H,W) representing the image | |
| bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format | |
| ''' | |
| warnings.filterwarnings('ignore', category=UserWarning) | |
| if isinstance(pred, torch.Tensor): | |
| pred = pred.cpu().detach().numpy() | |
| if len(image.shape) == 3 and image.shape[0] == 3: | |
| image = image.permute(1, 2, 0) | |
| if title is None: | |
| title="debug/most_conf_points.png" | |
| fig = plt.figure() | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| plt.imshow(image) | |
| plt.imshow(pred, alpha=0.5) | |
| for i, point in enumerate(points): | |
| plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red') | |
| if confidences is not None: | |
| plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red') | |
| # assume points is a list of lists | |
| if bboxes is not None: | |
| for bbox in bboxes: | |
| if bbox is None: | |
| continue | |
| bbox = np.array(bbox) | |
| # plt.scatter(bbox[:, 1], bbox[:, 0], c='red') | |
| # plot a line connecting the points | |
| box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]]) | |
| box = np.vstack([box, box[0]]) | |
| plt.plot(box[:, 0], box[:, 1], c='green') | |
| plt.colorbar() | |
| fig.savefig(title) | |
| plt.close(fig) | |
| def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None): | |
| if len(image.shape) == 3: | |
| image = image.permute(1, 2, 0) | |
| image = (image - image.min()) / (image.max() - image.min()) | |
| for i, (mask, score) in enumerate(zip(masks, scores)): | |
| plt.figure(figsize=(10,10)) | |
| plt.imshow(image) | |
| show_mask(mask, plt.gca()) | |
| if input_point is not None: | |
| show_points(input_point, input_label, plt.gca()) | |
| if input_box is not None: | |
| show_box(input_box, plt.gca()) | |
| plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) | |
| # plt.axis('off') | |
| plt.savefig(f'debug/sam_mask_{i+1}.png') | |
| plt.close() | |
| if i > 5: | |
| break | |
| def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1): | |
| """ | |
| args: | |
| conn_components: output of cca function | |
| output_p: 3d tensor of shape (1, 2, H, W) | |
| get_neg_points: bool, if True then return the negative points | |
| l: int, number of negative points to get | |
| """ | |
| sam_input_points = [] | |
| sam_neg_points = [] | |
| fg_p = output_p[0, 1].detach().cpu() | |
| if get_neg_points: | |
| # get global negative points | |
| bg_p = output_p[0, 0].detach().cpu() | |
| bg_p[bg_p < 0.95] = 0 | |
| bg_pred = torch.where(bg_p > 0, 1, 0) | |
| glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1) | |
| if self.debug: | |
| # plot the bg_p as a heatmap | |
| plt.figure() | |
| plt.imshow(bg_p) | |
| plt.colorbar() | |
| plt.savefig('debug/bg_p_heatmap.png') | |
| plt.close() | |
| for i, cc_id in enumerate(np.unique(conn_components[1])): | |
| # get self.num_points_for_sam most confident points from pred | |
| if cc_id == 0: | |
| continue # skip background | |
| pred = torch.tensor(conn_components[1] == cc_id).float() | |
| if self.point_mode == CONF_MODE: | |
| points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2) | |
| elif self.point_mode == CENTROID_MODE: | |
| points = conn_components[3][cc_id][None, :] # (1, 2) | |
| confidences = [1 for _ in range(len(points))] | |
| elif self.point_mode == BOTH_MODE: | |
| points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) | |
| point = conn_components[3][cc_id][None, :] | |
| points = np.vstack([points, point]) # (N+1, 2) | |
| confidences.append(1) | |
| else: | |
| raise NotImplementedError(f"point mode {self.point_mode} not implemented") | |
| sam_input_points.append(np.array(points)) | |
| if get_neg_points: | |
| pred_uint8 = (pred.numpy() * 255).astype(np.uint8) | |
| # Dilate the mask to expand it | |
| kernel_size = 3 # Size of the dilation kernel, adjust accordingly | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| dilation_iterations = 10 # Number of times dilation is applied, adjust as needed | |
| dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations) | |
| # Subtract the original mask from the dilated mask | |
| # This will give a boundary that is only outside the original mask | |
| outside_boundary = dilated_mask - pred_uint8 | |
| # Convert back to torch tensor and normalize | |
| boundary = torch.tensor(outside_boundary).float() / 255 | |
| try: | |
| bg_p = output_p[0, 0].detach().cpu() | |
| neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l) | |
| except RuntimeError as e: | |
| # make each point (None, None) | |
| neg_points = None | |
| # append global negative points to the negative points | |
| if neg_points is not None and glob_neg_points is not None: | |
| neg_points = np.vstack([neg_points, glob_neg_points]) | |
| else: | |
| neg_points = glob_neg_points if neg_points is None else neg_points | |
| if self.debug and neg_points is not None: | |
| # draw an image with 2 subplots, one is the pred and the other is the boundary | |
| plt.figure() | |
| plt.subplot(121) | |
| plt.imshow(pred) | |
| plt.imshow(boundary, alpha=0.5) | |
| # plot the neg points | |
| plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') | |
| plt.subplot(122) | |
| plt.imshow(pred) | |
| plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red') | |
| plt.savefig('debug/pred_and_boundary.png') | |
| plt.close() | |
| sam_neg_points.append(neg_points) | |
| else: | |
| # create a list of None same shape as points | |
| sam_neg_points = [None for _ in range(len(sam_input_points))] | |
| sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))]) | |
| sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2) | |
| # if get_neg_points: | |
| sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None | |
| if sam_neg_input_points is not None: | |
| sam_neg_input_points = sam_neg_points | |
| sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) ) | |
| else: | |
| sam_neg_input_points = None | |
| sam_neg_input_labels = None | |
| return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels | |
| def get_sam_input_mask(self, conn_components): | |
| sam_input_masks = [] | |
| sam_input_mask_lables = [] | |
| for i, cc_id in enumerate(np.unique(conn_components[1])): | |
| # get self.num_points_for_sam most confident points from pred | |
| if cc_id == 0: | |
| continue | |
| pred = torch.tensor(conn_components[1] == cc_id).float() | |
| sam_input_masks.append(pred) | |
| sam_input_mask_lables.append(cc_id) | |
| sam_input_masks = np.stack(sam_input_masks) | |
| sam_input_mask_lables = np.array(sam_input_mask_lables) | |
| return sam_input_masks, sam_input_mask_lables | |
| def predict_w_masks(self, sam_input_masks, qry_img, original_size): | |
| masks = [] | |
| scores = [] | |
| for in_mask in sam_input_masks: | |
| in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| in_mask[in_mask == 1] = 10 | |
| in_mask[in_mask == 0] = -8 | |
| assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 | |
| self.predictor.set_image(qry_img) | |
| mask, score, _ = self.predictor.predict( | |
| mask_input=in_mask[None, ...].astype(np.uint8), | |
| multimask_output=True) | |
| # get max index from score | |
| if self.debug: | |
| # plot each channel of mask | |
| fig, ax = plt.subplots(1, 4, figsize=(15, 5)) | |
| for i in range(mask.shape[0]): | |
| ax[i].imshow(qry_img) | |
| ax[i].imshow(mask[i], alpha=0.5) | |
| ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18) | |
| # ax[i].axis('off') | |
| ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST)) | |
| fig.savefig(f'debug/sam_mask_from_mask_prompts.png') | |
| plt.close(fig) | |
| max_index = score.argmax() | |
| masks.append(mask[max_index]) | |
| scores.append(score[max_index]) | |
| return masks, scores | |
| def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False): | |
| masks, scores = [], [] | |
| self.predictor.set_image(qry_img) | |
| # if sam_input_points is None: | |
| # sam_input_points = [None for _ in range(len(bboxes))] | |
| for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points): | |
| assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8 | |
| points = point | |
| point_labels = np.array([1] * len(point)) if point is not None else None | |
| if self.use_neg_points: | |
| neg_points = [npoint for npoint in neg_point if None not in npoint] | |
| points = np.vstack([point, *neg_points]) | |
| point_labels = np.array([1] * len(point) + [0] * len(neg_points)) | |
| if self.debug: | |
| self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points | |
| mask, score, _ = self.predictor.predict( | |
| point_coords=points, | |
| point_labels=point_labels, | |
| # box=bbox_xyxy[None, :] if bbox_xyxy is not None else None, | |
| box = bbox_xyxy if bbox_xyxy is not None else None, | |
| # mask_input=sam_mask_input, | |
| return_logits=return_logits, | |
| multimask_output=False if self.use_cca else True | |
| ) | |
| # best_pred_idx = np.argmax(score) | |
| best_pred_idx = 0 | |
| masks.append(mask[best_pred_idx]) | |
| scores.append(score[best_pred_idx]) | |
| if self.debug: | |
| # pass | |
| self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None) | |
| return masks, scores | |
| def forward(self, query_image, coarse_model_input, degrees_rotate=0): | |
| """ | |
| query_image: 3d tensor of shape (1, 3, H, W) | |
| images should be normalized with mean and std but not to [0, 1]? | |
| """ | |
| original_size = query_image.shape[-2] | |
| # rotate query_image by degrees_rotate | |
| start_time = time.time() | |
| rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate) | |
| # print(f"rotating query image took {time.time() - start_time} seconds") | |
| start_time = time.time() | |
| coarse_model_input.set_query_images(rotated_img) | |
| output_logits_rot = self.coarse_segmentation_model(coarse_model_input) | |
| # print(f"ALPNet took {time.time() - start_time} seconds") | |
| if degrees_rotate != 0: | |
| start_time = time.time() | |
| output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate) | |
| # print(f"reversing rotated output_logits took {time.time() - start_time} seconds") | |
| else: | |
| output_logits = output_logits_rot | |
| # check if softmax is needed | |
| output_p = output_logits.softmax(dim=1) | |
| # output_p = output_logits | |
| pred = output_logits.argmax(dim=1)[0] | |
| if self.debug: | |
| _pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu()) | |
| plt.subplot(132) | |
| plt.imshow(query_image[0,0].detach().cpu()) | |
| plt.imshow(_pred, alpha=0.5) | |
| plt.subplot(131) | |
| # plot heatmap of prob of being fg | |
| plt.imshow(output_p[0, 1].detach().cpu()) | |
| # plot rotated query image and rotated pred | |
| output_p_rot = output_logits_rot.softmax(dim=1) | |
| _pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu()) | |
| _pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0] | |
| plt.subplot(133) | |
| plt.imshow(rotated_img[0, 0].detach().cpu()) | |
| plt.imshow(_pred_rot, alpha=0.5) | |
| plt.savefig('debug/coarse_pred.png') | |
| plt.close() | |
| if self.coarse_pred_only: | |
| output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits | |
| pred = output_logits.argmax(dim=1)[0] | |
| conf = get_confidence_from_logits(output_logits) | |
| if self.use_cca: | |
| _pred = np.array(pred.detach().cpu()) | |
| _pred, conf = cca(_pred, output_logits, return_conf=True) | |
| pred = torch.from_numpy(_pred) | |
| if self.training: | |
| return output_logits, [conf] | |
| # Ensure pred is a float tensor for consistent visualization | |
| return pred.float(), [conf] | |
| if query_image.shape[-2:] != self.image_size: | |
| query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear') | |
| output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear') | |
| # if need_softmax(output_logits): | |
| # output_logits = output_logits.softmax(dim=1) | |
| # output_p = output_logits | |
| output_p = output_logits.softmax(dim=1) | |
| pred = output_p.argmax(dim=1)[0] | |
| _pred = np.array(output_p.argmax(dim=1)[0].detach().cpu()) | |
| start_time = time.time() | |
| if self.use_cca: | |
| conn_components = cca(_pred, output_logits, return_cc=True) | |
| conf=None | |
| else: | |
| conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True) | |
| if self.debug: | |
| plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf) | |
| # print(f"connected components took {time.time() - start_time} seconds") | |
| if _pred.max() == 0: | |
| return output_p.argmax(dim=1)[0], [0] | |
| # get bbox from pred | |
| if self.use_bbox: | |
| start_time = time.time() | |
| try: | |
| bboxes = self.get_bbox_per_cc(conn_components) | |
| except: | |
| bboxes = [None] * conn_components[0] | |
| else: | |
| bboxes = [None] * conn_components[0] | |
| # print(f"getting bboxes took {time.time() - start_time} seconds") | |
| start_time = time.time() | |
| if self.use_points: | |
| sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1) | |
| else: | |
| sam_input_points = [None] * conn_components[0] | |
| sam_input_point_labels = [None] * conn_components[0] | |
| sam_neg_input_points = [None] * conn_components[0] | |
| sam_neg_input_labels = [None] * conn_components[0] | |
| # print(f"getting sam input points took {time.time() - start_time} seconds") | |
| if self.use_mask: | |
| sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components) | |
| else: | |
| sam_input_masks = None | |
| sam_input_mask_labels = None | |
| if self.debug and sam_input_points is not None: | |
| title = f'debug/most_conf_points.png' | |
| if self.use_cca: | |
| title = f'debug/most_conf_points_cca.png' | |
| # convert points to a list where each item is a list of 2 elements in xy format | |
| self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points | |
| # self.sam_trans = None | |
| if self.sam_trans is None: | |
| query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() | |
| else: | |
| query_image = self.sam_trans.apply_image_torch(query_image[0]) | |
| query_image = self.sam_trans.preprocess(query_image) | |
| query_image = query_image.permute(1, 2, 0).detach().cpu().numpy() | |
| # mask = self.sam_trans.preprocess(mask) | |
| query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8) | |
| if self.use_mask: | |
| masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size) | |
| start_time = time.time() | |
| if self.use_points or self.use_bbox: | |
| masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False) | |
| # print(f"predicting w points/bbox took {time.time() - start_time} seconds") | |
| pred = sum(masks) | |
| if not self.training: | |
| pred = pred > 0 | |
| pred = torch.tensor(pred).float().to(output_p.device) | |
| # pred = torch.tensor(masks[0]).float().cuda() | |
| # resize pred to the size of the input | |
| pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0] | |
| return pred, scores | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_points(coords, labels, ax, marker_size=375): | |
| pos_points = coords[labels==1] | |
| neg_points = coords[labels==0] | |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| def show_box(box, ax): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
| def need_softmax(tensor, dim=1): | |
| return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0)) | |