import os import cv2 import numpy as np import torch from torchvision import transforms from PIL import Image, ImageDraw import torch.nn.functional as F norm = lambda t: (t - t.min()) / (t.max() - t.min()) denorm = lambda t, min_, max_: t * (max_ - min_) + min_ percentilerange = lambda t, perc: t.min() + perc * (t.max() - t.min()) midrange = lambda t: percentilerange(t, .5) downsample_mask = lambda mask, H, W: F.interpolate(mask.unsqueeze(1), size=(H, W), mode='bilinear', align_corners=False).squeeze(1) # downsampled_mask: [bsz,vecs], vecs can be H*W for example # s_feat_volume: [bsz,c,vecs] # returns [bsz,c], [bsz,c,vecs] def fg_bg_proto(sfeat_volume, downsampled_smask): B, C, vecs = sfeat_volume.shape reshaped_mask = downsampled_smask.expand(B, vecs).unsqueeze(1) # ->[B,1,vecs] masked_fg = reshaped_mask * sfeat_volume fg_proto = torch.sum(masked_fg, dim=-1) / (torch.sum(reshaped_mask, dim=-1) + 1e-8) masked_bg = (1 - reshaped_mask) * sfeat_volume bg_proto = torch.sum(masked_bg, dim=-1) / (torch.sum(1 - reshaped_mask, dim=-1) + 1e-8) assert fg_proto.shape == (B, C), ":o" return fg_proto, bg_proto # intersection = lambda pred, target: (pred * target).float().sum() # union = lambda pred, target: (pred + target).clamp(0, 1).float().sum() # # # def iou(pred, target): # binary only, input bsz,h,w # i, u = intersection(pred, target), union(pred, target) # iou = (i + 1e-8) / (u + 1e-8) # return iou.item() # # # class SimpleAvgMeter: # def __init__(self, n_classes, device=torch.device('cuda')): # self.n_lasses = n_classes # self.intersection_buf = torch.zeros(n_classes).to(device) # self.union_buf = torch.zeros(n_classes).to(device) # # def update(self, pred, target, class_id): # self.intersection_buf[class_id] += intersection(pred, target) # self.union_buf[class_id] += union(pred, target) # # def IoU(self, class_id): # return self.intersection_buf[class_id] / self.union_buf[class_id] * 100 # # def cls_mIoU(self, class_ids): # return (self.intersection_buf[class_ids] / self.union_buf[class_ids]).mean() * 100 # # def compute_mIoU(self): # noentry = self.union_buf == 0 # if noentry.sum() > 0: print("SimpleAvgMeter warning: ", noentry.sum(), "elements of", self.nclasses, # "have no empty.") # return self.cls_mIoU(~noentry) # class KMeans(): # # expects input to be in shape [bsz, -1] # def __init__(self, data, k=2, num_iterations=10): # self.k = k # self.device = data.device # self.centroids = self._init_centroids(data) # # for _ in range(num_iterations): # labels = self._assign_clusters(data) # self._update_centroids(data, labels) # # self.labels = self._assign_clusters(data) # Final cluster assignment # # def _init_centroids(self, data): # # Randomly initialize centroids # centroids = [] # min_values = data.min(dim=1, keepdim=True).values # range_values = (data.max(dim=1, keepdim=True).values - min_values) # # for _ in range(self.k): # random_values = torch.rand((data.shape[0], 1)).to(self.device) # centroids.append(min_values + random_values * range_values) # # return torch.cat(centroids, dim=1) # # def _assign_clusters(self, data): # # Calculate distances between data points and centroids # distances = torch.abs(data.unsqueeze(2) - self.centroids) # Expand data tensor to calculate distances # # Determine the closest centroid for each data point # labels = torch.argmin(distances, dim=2) # # Sort labels so that the largest mean data point has the highest label # cluster_means = [data[labels == k].mean() for k in range(self.k)] # sorted_labels = {k: rank for rank, k in enumerate(sorted(range(self.k), key=lambda k: cluster_means[k]))} # labels = torch.tensor([sorted_labels[label.item()] for label in labels.flatten()]).reshape_as(labels).to( # self.device) # # return labels # # def _update_centroids(self, data, labels): # # Calculate new centroids as the mean of the data points closest to each centroid # mask = torch.nn.functional.one_hot(labels, num_classes=self.k).to(torch.float32) # summed_data = torch.bmm(mask.transpose(1, 2), data.unsqueeze(2)) # Sum data points per centroid # self.centroids = summed_data.squeeze() / mask.sum(dim=1, keepdim=True) # Normalize to get the mean # # def compute_thresholds(self): # # Flatten the centroids along the middle dimension # flat_centroids = self.centroids.view(self.centroids.size(0), -1) # # # Sort the flattened centroids # sorted_centroids, _ = torch.sort(flat_centroids, dim=1) # # # Compute the midpoints between consecutive centroids # thresholds = (sorted_centroids[:, :-1] + sorted_centroids[:, 1:]) / 2.0 # # return thresholds # # def inference(self, data): # # Assign data points to the nearest centroid # return self._assign_clusters(data) # def iterative_triclass_thresholding(image, max_iterations=100, tolerance=25): # # Ensure image is grayscale # if len(image.shape) == 3: # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # # # Initialize iteration parameters # TBD_region = image.copy() # iteration = 0 # prev_threshold = 0 # # while iteration < max_iterations: # iteration += 1 # # # Step 1: Apply Otsu's thresholding on the TBD region # current_threshold, _ = cv2.threshold(TBD_region, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # # # Check stopping criteria # if abs(current_threshold - prev_threshold) < tolerance: # break # prev_threshold = current_threshold # # # Step 2: Calculate means for upper and lower regions # upper_region = TBD_region[TBD_region > current_threshold] # lower_region = TBD_region[TBD_region <= current_threshold] # # if len(upper_region) == 0 or len(lower_region) == 0: # break # No further division possible # # mean_upper = np.mean(upper_region) # mean_lower = np.mean(lower_region) # # # Step 3: Update temporary foreground, background, and TBD regions # TBD_region[(TBD_region > mean_upper)] = 255 # Temporary foreground F # TBD_region[(TBD_region < mean_lower)] = 0 # Temporary background B # # # Extracting the new TBD region (between mean_lower and mean_upper) # mask = (TBD_region > mean_lower) & (TBD_region < mean_upper) # TBD_region = TBD_region[mask] # Apply mask to extract region # # # Final classification after convergence or max iterations # final_foreground = (image > current_threshold).astype(np.uint8) * 255 # final_background = (image <= current_threshold).astype(np.uint8) * 255 # # return current_threshold, final_foreground def otsus(batched_tensor_image, drop_least=0.05, mode='ordinary'): bsz = batched_tensor_image.size(0) binary_tensors = [] thresholds = [] for i in range(bsz): # Convert the tensor to numpy array numpy_image = batched_tensor_image[i].cpu().numpy() # Rescale to [0, 255] and convert to uint8 type for OpenCV compatibility npmin, npmax = numpy_image.min(), numpy_image.max() numpy_image = (norm(numpy_image) * 255).astype(np.uint8) # Drop values that are in the lowest percentiles truncated_vals = numpy_image[numpy_image >= int(255 * drop_least)] # Apply Otsu's thresholding if mode == 'via_triclass': thresh_value, _ = iterative_triclass_thresholding(truncated_vals) else: thresh_value, _ = cv2.threshold(truncated_vals, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Apply the computed threshold on the original image binary_image = (numpy_image > thresh_value).astype(np.uint8) * 255 # Convert the result back to a tensor and append to the list binary_tensors.append(torch.from_numpy(binary_image).float() / 255) thresholds.append(torch.tensor(denorm(thresh_value / 255, npmin, npmax)) \ .to(batched_tensor_image.device, dtype=batched_tensor_image.dtype)) # Convert list of tensors back to a single batched tensor binary_tensor_batch = torch.stack(binary_tensors, dim=0) thresh_batch = torch.stack(thresholds, dim=0) return thresh_batch, binary_tensor_batch def iterative_otsus(probab_mask, s_mask, maxiters=5, mode='ordinary', debug=False): # verify that it works correctly when batch_size >1 it = 1 otsuthresh = 0 assert probab_mask.min() >= 0 and probab_mask.max() <= 1, 'you should pass probabilites' while True: clipped = torch.where(probab_mask < otsuthresh, 0, probab_mask) otsuthresh, newmask = otsus(clipped.detach(), drop_least=.02, mode=mode) if otsuthresh >= s_mask.mean(): return otsuthresh.to(probab_mask.device), newmask.to(probab_mask.device) if it >= maxiters: if debug: print('reached maxiter:', it, 'with thresh', otsuthresh.item(), \ 'removed', int(((clipped == 0).sum() / clipped.numel()).item() * 10000) / 100, \ '% at lower and and new min,max is', clipped[clipped > 0].min().item(), clipped.max().item()) display(pilImageRow(norm(probab_mask[0]), s_mask[0], maxwidth=300)) return s_mask.mean(), (probab_mask > s_mask.mean()).float() # otsuthresh it += 1 # def upgrade_scipy(): # os.system('!pip install - -upgrade scipy') # # # def slicRGB(q_img, n_segments=50, compactness=10., sigma=1, mask=None, debug=False): # import skimage.segmentation as skseg # # rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask, # enforce_connectivity=True) # # if debug: # plt.imshow(skseg.mark_boundaries(q_img, rgb_labels)) # plt.show() # # return rgb_labels # # # # def slicRGBP(q_img, fg_pred, n_segments=30, compactness=0.1, sigma=1, mask=None, debug=False): # import skimage.segmentation as skseg # # def concat_rgb_pred(rgbimg, pred): # h, w = rgbimg.shape[:2] # return np.concatenate((rgbimg, pred.reshape(h, w, 1)), axis=-1) # # rgbp_img = concat_rgb_pred(q_img, fg_pred) # rgbp_labels = skseg.slic(rgbp_img, n_segments=n_segments, compactness=compactness, mask=mask, sigma=sigma, # enforce_connectivity=True) # # if debug: # rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=10., sigma=sigma, mask=mask, # enforce_connectivity=True) # pred_labels = skseg.slic(fg_pred, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask, # channel_axis=None, enforce_connectivity=True) # # rows, cols = 1, 3 # fig, ax = plt.subplots(rows, cols, figsize=(10, 10), sharex=True, sharey=True) # ax[0].imshow(skseg.mark_boundaries(q_img, rgbp_labels)) # ax[1].imshow(skseg.mark_boundaries(q_img, rgb_labels)) # ax[2].imshow(skseg.mark_boundaries(q_img, pred_labels)) # plt.show() # # return rgbp_labels # # # def calc_cluster_means(label_id_map, fg_prob): # fg_pred_clustered = np.zeros_like(fg_prob) # label_ids = np.unique(label_id_map) # for lab_id in label_ids: # cluster = fg_prob[label_id_map == lab_id] # fg_pred_clustered[label_id_map == lab_id] = cluster.mean() # return fg_pred_clustered def install_pydensecrf(): os.system('pip install git+https://github.com/lucasb-eyer/pydensecrf.git') class CRF: def __init__(self, gaussian_stdxy=(3, 3), gaussian_compat=3, bilateral_stdxy=(80, 80), bilateral_compat=10, stdrgb=(13, 13, 13)): self.gaussian_stdxy = gaussian_stdxy self.gaussian_compat = gaussian_compat self.bilateral_stdxy = bilateral_stdxy self.bilateral_compat = bilateral_compat self.stdrgb = stdrgb self.iters = 5 self.debug = False def refine(self, image_tensor, fg_probs, soft_thresh=None, T=1): """ Refine segmentation using DenseCRF. Args: - image_tensor (tensor): Original image, shape [1, 3, H, W]. - fg_probs (tensor): Fg probabilities from the network, shape [1, H, W] - soft_thresh: The preferred threshold for fg_probs for segmenting into binary prediction mask - T: a temperature for softmax/sigmoid Returns: - Refined segmentation mask, shape [1, H, W]. """ try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral except ImportError as e: print("pydensecrf not found. Installing...") install_pydensecrf() # Ensure this function installs pydensecrf and handles any potential errors during installation. # After installation, retry importing. This is placed inside the except block to avoid repeating the import statements. try: import pydensecrf.densecrf as dcrf from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral except ImportError as e: print("Failed to import after installation. Please check the installation of pydensecrf.") raise # This will raise the last exception that was handled by the except block # We find the segmentation threshold that splits fg-bg if soft_thresh is None: soft_thresh, _ = otsus(fg_probs) image_tensor, fg_probs, soft_thresh = image_tensor.cpu(), fg_probs.cpu(), soft_thresh.cpu() # Then we presume at this threshold the probability should be 0.5 # probability 0 should stay 0, 1 should stay 1 # sigmoid=lambda x: 1/(1 + np.exp(-x)) fg_probs = torch.sigmoid(T * (fg_probs - soft_thresh)) probs = torch.stack([1 - fg_probs, fg_probs], dim=1) # crf expects both classes as input if self.debug: print('softthresh', soft_thresh) print('fg_probs min max', fg_probs.min(), fg_probs.max()) # C: Number of classes bsz, C, H, W = probs.shape refined_masks = [] image_numpy = np.ascontiguousarray( \ (255 * image_tensor.permute(0, 2, 3, 1)).numpy().astype(np.uint8)) probs_numpy = probs.numpy() for (image, prob) in zip(image_numpy, probs_numpy): # Unary potentials unary = np.ascontiguousarray(unary_from_softmax(prob)) d = dcrf.DenseCRF2D(W, H, C) d.setUnaryEnergy(unary) # Add pairwise potentials d.addPairwiseGaussian(sxy=self.gaussian_stdxy, compat=self.gaussian_compat) d.addPairwiseBilateral(sxy=self.bilateral_stdxy, srgb=self.stdrgb, rgbim=image, compat=self.bilateral_compat) # Perform inference Q = d.inference(self.iters) if self.debug: print('Q:', np.array(Q).shape, np.array(Q)[0].mean(), np.array(Q).mean()) result = np.reshape(Q, (2, H, W)) # np.argmax(Q, axis=0).reshape((H, W)) refined_masks.append(result) return torch.from_numpy(np.stack(refined_masks, axis=0)) # def iterrefine(self, iters, image_tensor, fg_probs, soft_thresh=None, T=1): # q1 = fg_probs # for iter in range(iters): # print(q1.shape) # q1 = self.refine(image_tensor, q1, soft_thresh=None, T=1)[:,1] # return q1 def iterrefine(self, iters, q_img, fg_probs, thresh_fn, debug=False): pred = fg_probs.unsqueeze(1).expand(1, 2, *fg_probs.shape[-2:]) for it in range(iters): thresh = thresh_fn(pred[:, 1])[0] if debug and i % 10 == 0: print('thresh', thresh) display(to_pil(pred[0, 1])) pred = self.refine(q_img, pred[:, 1], soft_thresh=thresh) return pred # # class Subplot: # def __init__(self): # self.vertical_lines = [] # self.histograms = [] # self.gaussian_curves = [] # self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] # self.title = '' # # class Element: # def __init__(self, x=None, y=None, label=''): # if x is not None: # self.x = Subplot.to_np(x) # if y is not None: # self.y = Subplot.to_np(y) # # self.label = label # # @staticmethod # def to_np(t): # return t.detach().cpu().numpy() # # def add_vertical(self, x, label=''): # self.vertical_lines.append(Subplot.Element(x=x, label=label)) # return self # # def add_histogram(self, samples, label=''): # self.histograms.append(Subplot.Element(x=samples, label=label)) # return self # # def add_gaussian(self, gaussian): # samples, mu, var = gaussian.samples, gaussian.mean, gaussian.covs # # Generate a range of x values # x_values = np.linspace(samples.min(), samples.max(), 100) # x_values = np.linspace(samples.min(), samples.max(), 100) # # # Compute Gaussian values for these x values # gaussian1_values = gaussian.gaussian_pdf(x_values, mu[0].item(), var[0].item()) # gaussian2_values = gaussian.gaussian_pdf(x_values, mu[1].item(), var[1].item()) # self.gaussian_curves.append(Subplot.Element(x_values, gaussian1_values)) # self.gaussian_curves.append(Subplot.Element(x_values, gaussian2_values)) # return self # # # class PredHistos2(): # def __init__(self, n_cols=1): # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4)) # self.n_cols = n_cols # if n_cols == 1: # self.builder = Subplot() # self.subplots = [Subplot() for x in range(n_cols)] # self.alpha = 0.5 # self.bins = 200 # # def reload(self, n_cols=1): # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4)) # # def aggr(self, ax, sub): # for hist, col in zip(sub.histograms, sub.colors): # ax.hist(hist.x, self.bins, density=True, color=col, alpha=self.alpha, label=hist.label) # for vline, col in zip(sub.vertical_lines, sub.colors): # ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--') # for gaussian, col in zip(sub.gaussian_curves, sub.colors): # ax.plot(gaussian.x, gaussian.y, gaussian.label, col) # ax.legend() # # def plot(self, name=''): # # if self.n_cols == 1: # self.aggr(plt, self.builder) # else: # for ax, sub in zip(self.axes, self.subplots): # self.aggr(ax, sub) # ax.set_title(sub.title) # # plt.legend() # plt.title(name) # plt.show() # # # from sklearn.mixture import GaussianMixture # import scipy.optimize as opt # from scipy.optimize import fsolve # # # class GMM: # def __init__(self, q_pred_coarse, name='gaussian', n_components=2): # samples = q_pred_coarse.detach().cpu().numpy() # self.samples = samples.reshape(-1, 1) # # # Fit a mixture of 2 Gaussians using EM # gmm = GaussianMixture(n_components) # gmm.fit(samples) # self.means = gmm.means_.flatten() # self.covs = gmm.covariances_.flatten() # self.weights = gmm.weights_ # self.label = name # # def intersect(self): # # Use fsolve to find the intersection # gaussian_intersect, = fsolve(difference, self.means.mean(), args=( # self.means[0].item(), self.covs[0].item(), self.means[1].item(), self.means[1].item())) # return gaussian_intersect # # # class PredHistoSNS(): # def __init__(self, n_cols=1): # import seaborn as sns # sns.set_theme(style="whitegrid") # Set the Seaborn theme. You can change the style as needed. # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4)) # self.n_cols = n_cols # if n_cols == 1: # self.axes = [self.axes] # Wrap the single axis in a list to simplify the loop logic later. # self.builder = Subplot() # This is assuming Subplot is a properly defined class. # self.subplots = [Subplot() for _ in range(n_cols)] # Use underscore for unused loop variable. # self.alpha = 0.5 # self.bins = 200 # # def reload(self, n_cols=1): # self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4)) # # def aggr(self, ax, sub): # import seaborn as sns # for hist, col in zip(sub.histograms, sub.colors): # sns.histplot(hist.x, bins=self.bins, kde=False, color=col, ax=ax, alpha=self.alpha, label=hist.label) # for vline, col in zip(sub.vertical_lines, sub.colors): # ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--') # for gaussian, col in zip(sub.gaussian_curves, sub.colors): # sns.lineplot(x=gaussian.x, y=gaussian.y, label=gaussian.label, color=col, ax=ax) # ax.legend() # # def plot(self, name=''): # # if self.n_cols == 1: # self.aggr(self.axes[0], self.builder) # else: # for ax, sub in zip(self.axes, self.subplots): # self.aggr(ax, sub) # ax.set_title(sub.title) # # plt.show() # # # def overlay_mask(image, mask, color=[255, 0, 0], alpha=0.2): # """ # Apply an overlay of a binary mask onto an image using a specified color. # # :param image: A PyTorch tensor of the image (C x H x W) with pixel values in [0, 1]. # :param mask: A PyTorch tensor of the mask (H x W) with binary values (0 or 1). # :param color: A list of 3 elements representing the RGB values of the overlay color. # :param alpha: A float representing the transparency of the overlay (0 to 1). # :return: An overlayed image tensor. # """ # # Ensure the mask is binary # mask = (mask > 0).float() # # # Create an RGB version of the mask # mask_rgb = torch.tensor(color).view(3, 1, 1) / 255.0 # Normalize the color vector # mask_rgb = mask_rgb * mask # # # Overlay the mask onto the image # overlayed_image = (1 - alpha) * image + alpha * mask_rgb # # # Ensure the resulting tensor values are between 0 and 1 # overlayed_image = torch.clamp(overlayed_image, 0, 1) # # return overlayed_image # # # import pandas as pd # to_pil = lambda t: transforms.ToPILImage()(t) if t.shape[-1] > 4 else transforms.ToPILImage()(t.permute(2, 0, 1)) # # # def pilImageRow(*imgs, maxwidth=800, bordercolor=0x000000): # imgs = [to_pil(im.float()) for im in imgs] # dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height)) # for i, im in enumerate(imgs): # loc = [x0, y0, x1, y1] = [i * im.width, 0, (i + 1) * im.width, im.height] # dst.paste(im, (x0, y0)) # ImageDraw.Draw(dst).rectangle(loc, width=2, outline=bordercolor) # factorToBig = dst.width / maxwidth # dst = dst.resize((int(dst.width / factorToBig), int(dst.height / factorToBig))) # return dst # # # def tensor_table(**kwargs): # tensor_overview = {} # for name, tensor in kwargs.items(): # if callable(tensor): # print(name, [tensor(t) for _, t in kwargs.items() if isinstance(t, torch.Tensor)]) # else: # tensor_overview[name] = { # 'min': tensor.min().item(), # 'max': tensor.max().item(), # 'shape': tensor.shape, # } # return pd.DataFrame.from_dict(tensor_overview, orient='index')