ABCDFSS / utils /segutils.py
heyoujue's picture
add submission code
322161a
import os
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image, ImageDraw
import torch.nn.functional as F
norm = lambda t: (t - t.min()) / (t.max() - t.min())
denorm = lambda t, min_, max_: t * (max_ - min_) + min_
percentilerange = lambda t, perc: t.min() + perc * (t.max() - t.min())
midrange = lambda t: percentilerange(t, .5)
downsample_mask = lambda mask, H, W: F.interpolate(mask.unsqueeze(1), size=(H, W), mode='bilinear',
align_corners=False).squeeze(1)
# downsampled_mask: [bsz,vecs], vecs can be H*W for example
# s_feat_volume: [bsz,c,vecs]
# returns [bsz,c], [bsz,c,vecs]
def fg_bg_proto(sfeat_volume, downsampled_smask):
B, C, vecs = sfeat_volume.shape
reshaped_mask = downsampled_smask.expand(B, vecs).unsqueeze(1) # ->[B,1,vecs]
masked_fg = reshaped_mask * sfeat_volume
fg_proto = torch.sum(masked_fg, dim=-1) / (torch.sum(reshaped_mask, dim=-1) + 1e-8)
masked_bg = (1 - reshaped_mask) * sfeat_volume
bg_proto = torch.sum(masked_bg, dim=-1) / (torch.sum(1 - reshaped_mask, dim=-1) + 1e-8)
assert fg_proto.shape == (B, C), ":o"
return fg_proto, bg_proto
# intersection = lambda pred, target: (pred * target).float().sum()
# union = lambda pred, target: (pred + target).clamp(0, 1).float().sum()
#
#
# def iou(pred, target): # binary only, input bsz,h,w
# i, u = intersection(pred, target), union(pred, target)
# iou = (i + 1e-8) / (u + 1e-8)
# return iou.item()
#
#
# class SimpleAvgMeter:
# def __init__(self, n_classes, device=torch.device('cuda')):
# self.n_lasses = n_classes
# self.intersection_buf = torch.zeros(n_classes).to(device)
# self.union_buf = torch.zeros(n_classes).to(device)
#
# def update(self, pred, target, class_id):
# self.intersection_buf[class_id] += intersection(pred, target)
# self.union_buf[class_id] += union(pred, target)
#
# def IoU(self, class_id):
# return self.intersection_buf[class_id] / self.union_buf[class_id] * 100
#
# def cls_mIoU(self, class_ids):
# return (self.intersection_buf[class_ids] / self.union_buf[class_ids]).mean() * 100
#
# def compute_mIoU(self):
# noentry = self.union_buf == 0
# if noentry.sum() > 0: print("SimpleAvgMeter warning: ", noentry.sum(), "elements of", self.nclasses,
# "have no empty.")
# return self.cls_mIoU(~noentry)
# class KMeans():
# # expects input to be in shape [bsz, -1]
# def __init__(self, data, k=2, num_iterations=10):
# self.k = k
# self.device = data.device
# self.centroids = self._init_centroids(data)
#
# for _ in range(num_iterations):
# labels = self._assign_clusters(data)
# self._update_centroids(data, labels)
#
# self.labels = self._assign_clusters(data) # Final cluster assignment
#
# def _init_centroids(self, data):
# # Randomly initialize centroids
# centroids = []
# min_values = data.min(dim=1, keepdim=True).values
# range_values = (data.max(dim=1, keepdim=True).values - min_values)
#
# for _ in range(self.k):
# random_values = torch.rand((data.shape[0], 1)).to(self.device)
# centroids.append(min_values + random_values * range_values)
#
# return torch.cat(centroids, dim=1)
#
# def _assign_clusters(self, data):
# # Calculate distances between data points and centroids
# distances = torch.abs(data.unsqueeze(2) - self.centroids) # Expand data tensor to calculate distances
# # Determine the closest centroid for each data point
# labels = torch.argmin(distances, dim=2)
# # Sort labels so that the largest mean data point has the highest label
# cluster_means = [data[labels == k].mean() for k in range(self.k)]
# sorted_labels = {k: rank for rank, k in enumerate(sorted(range(self.k), key=lambda k: cluster_means[k]))}
# labels = torch.tensor([sorted_labels[label.item()] for label in labels.flatten()]).reshape_as(labels).to(
# self.device)
#
# return labels
#
# def _update_centroids(self, data, labels):
# # Calculate new centroids as the mean of the data points closest to each centroid
# mask = torch.nn.functional.one_hot(labels, num_classes=self.k).to(torch.float32)
# summed_data = torch.bmm(mask.transpose(1, 2), data.unsqueeze(2)) # Sum data points per centroid
# self.centroids = summed_data.squeeze() / mask.sum(dim=1, keepdim=True) # Normalize to get the mean
#
# def compute_thresholds(self):
# # Flatten the centroids along the middle dimension
# flat_centroids = self.centroids.view(self.centroids.size(0), -1)
#
# # Sort the flattened centroids
# sorted_centroids, _ = torch.sort(flat_centroids, dim=1)
#
# # Compute the midpoints between consecutive centroids
# thresholds = (sorted_centroids[:, :-1] + sorted_centroids[:, 1:]) / 2.0
#
# return thresholds
#
# def inference(self, data):
# # Assign data points to the nearest centroid
# return self._assign_clusters(data)
# def iterative_triclass_thresholding(image, max_iterations=100, tolerance=25):
# # Ensure image is grayscale
# if len(image.shape) == 3:
# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#
# # Initialize iteration parameters
# TBD_region = image.copy()
# iteration = 0
# prev_threshold = 0
#
# while iteration < max_iterations:
# iteration += 1
#
# # Step 1: Apply Otsu's thresholding on the TBD region
# current_threshold, _ = cv2.threshold(TBD_region, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
#
# # Check stopping criteria
# if abs(current_threshold - prev_threshold) < tolerance:
# break
# prev_threshold = current_threshold
#
# # Step 2: Calculate means for upper and lower regions
# upper_region = TBD_region[TBD_region > current_threshold]
# lower_region = TBD_region[TBD_region <= current_threshold]
#
# if len(upper_region) == 0 or len(lower_region) == 0:
# break # No further division possible
#
# mean_upper = np.mean(upper_region)
# mean_lower = np.mean(lower_region)
#
# # Step 3: Update temporary foreground, background, and TBD regions
# TBD_region[(TBD_region > mean_upper)] = 255 # Temporary foreground F
# TBD_region[(TBD_region < mean_lower)] = 0 # Temporary background B
#
# # Extracting the new TBD region (between mean_lower and mean_upper)
# mask = (TBD_region > mean_lower) & (TBD_region < mean_upper)
# TBD_region = TBD_region[mask] # Apply mask to extract region
#
# # Final classification after convergence or max iterations
# final_foreground = (image > current_threshold).astype(np.uint8) * 255
# final_background = (image <= current_threshold).astype(np.uint8) * 255
#
# return current_threshold, final_foreground
def otsus(batched_tensor_image, drop_least=0.05, mode='ordinary'):
bsz = batched_tensor_image.size(0)
binary_tensors = []
thresholds = []
for i in range(bsz):
# Convert the tensor to numpy array
numpy_image = batched_tensor_image[i].cpu().numpy()
# Rescale to [0, 255] and convert to uint8 type for OpenCV compatibility
npmin, npmax = numpy_image.min(), numpy_image.max()
numpy_image = (norm(numpy_image) * 255).astype(np.uint8)
# Drop values that are in the lowest percentiles
truncated_vals = numpy_image[numpy_image >= int(255 * drop_least)]
# Apply Otsu's thresholding
if mode == 'via_triclass':
thresh_value, _ = iterative_triclass_thresholding(truncated_vals)
else:
thresh_value, _ = cv2.threshold(truncated_vals, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply the computed threshold on the original image
binary_image = (numpy_image > thresh_value).astype(np.uint8) * 255
# Convert the result back to a tensor and append to the list
binary_tensors.append(torch.from_numpy(binary_image).float() / 255)
thresholds.append(torch.tensor(denorm(thresh_value / 255, npmin, npmax)) \
.to(batched_tensor_image.device, dtype=batched_tensor_image.dtype))
# Convert list of tensors back to a single batched tensor
binary_tensor_batch = torch.stack(binary_tensors, dim=0)
thresh_batch = torch.stack(thresholds, dim=0)
return thresh_batch, binary_tensor_batch
def iterative_otsus(probab_mask, s_mask, maxiters=5, mode='ordinary',
debug=False): # verify that it works correctly when batch_size >1
it = 1
otsuthresh = 0
assert probab_mask.min() >= 0 and probab_mask.max() <= 1, 'you should pass probabilites'
while True:
clipped = torch.where(probab_mask < otsuthresh, 0, probab_mask)
otsuthresh, newmask = otsus(clipped.detach(), drop_least=.02, mode=mode)
if otsuthresh >= s_mask.mean():
return otsuthresh.to(probab_mask.device), newmask.to(probab_mask.device)
if it >= maxiters:
if debug:
print('reached maxiter:', it, 'with thresh', otsuthresh.item(), \
'removed', int(((clipped == 0).sum() / clipped.numel()).item() * 10000) / 100, \
'% at lower and and new min,max is', clipped[clipped > 0].min().item(), clipped.max().item())
display(pilImageRow(norm(probab_mask[0]), s_mask[0], maxwidth=300))
return s_mask.mean(), (probab_mask > s_mask.mean()).float() # otsuthresh
it += 1
# def upgrade_scipy():
# os.system('!pip install - -upgrade scipy')
#
#
# def slicRGB(q_img, n_segments=50, compactness=10., sigma=1, mask=None, debug=False):
# import skimage.segmentation as skseg
#
# rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
# enforce_connectivity=True)
#
# if debug:
# plt.imshow(skseg.mark_boundaries(q_img, rgb_labels))
# plt.show()
#
# return rgb_labels
#
#
#
# def slicRGBP(q_img, fg_pred, n_segments=30, compactness=0.1, sigma=1, mask=None, debug=False):
# import skimage.segmentation as skseg
#
# def concat_rgb_pred(rgbimg, pred):
# h, w = rgbimg.shape[:2]
# return np.concatenate((rgbimg, pred.reshape(h, w, 1)), axis=-1)
#
# rgbp_img = concat_rgb_pred(q_img, fg_pred)
# rgbp_labels = skseg.slic(rgbp_img, n_segments=n_segments, compactness=compactness, mask=mask, sigma=sigma,
# enforce_connectivity=True)
#
# if debug:
# rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=10., sigma=sigma, mask=mask,
# enforce_connectivity=True)
# pred_labels = skseg.slic(fg_pred, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
# channel_axis=None, enforce_connectivity=True)
#
# rows, cols = 1, 3
# fig, ax = plt.subplots(rows, cols, figsize=(10, 10), sharex=True, sharey=True)
# ax[0].imshow(skseg.mark_boundaries(q_img, rgbp_labels))
# ax[1].imshow(skseg.mark_boundaries(q_img, rgb_labels))
# ax[2].imshow(skseg.mark_boundaries(q_img, pred_labels))
# plt.show()
#
# return rgbp_labels
#
#
# def calc_cluster_means(label_id_map, fg_prob):
# fg_pred_clustered = np.zeros_like(fg_prob)
# label_ids = np.unique(label_id_map)
# for lab_id in label_ids:
# cluster = fg_prob[label_id_map == lab_id]
# fg_pred_clustered[label_id_map == lab_id] = cluster.mean()
# return fg_pred_clustered
def install_pydensecrf():
os.system('pip install git+https://github.com/lucasb-eyer/pydensecrf.git')
class CRF:
def __init__(self, gaussian_stdxy=(3, 3), gaussian_compat=3,
bilateral_stdxy=(80, 80), bilateral_compat=10, stdrgb=(13, 13, 13)):
self.gaussian_stdxy = gaussian_stdxy
self.gaussian_compat = gaussian_compat
self.bilateral_stdxy = bilateral_stdxy
self.bilateral_compat = bilateral_compat
self.stdrgb = stdrgb
self.iters = 5
self.debug = False
def refine(self, image_tensor, fg_probs, soft_thresh=None, T=1):
"""
Refine segmentation using DenseCRF.
Args:
- image_tensor (tensor): Original image, shape [1, 3, H, W].
- fg_probs (tensor): Fg probabilities from the network, shape [1, H, W]
- soft_thresh: The preferred threshold for fg_probs for segmenting into binary prediction mask
- T: a temperature for softmax/sigmoid
Returns:
- Refined segmentation mask, shape [1, H, W].
"""
try:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
except ImportError as e:
print("pydensecrf not found. Installing...")
install_pydensecrf() # Ensure this function installs pydensecrf and handles any potential errors during installation.
# After installation, retry importing. This is placed inside the except block to avoid repeating the import statements.
try:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
except ImportError as e:
print("Failed to import after installation. Please check the installation of pydensecrf.")
raise # This will raise the last exception that was handled by the except block
# We find the segmentation threshold that splits fg-bg
if soft_thresh is None:
soft_thresh, _ = otsus(fg_probs)
image_tensor, fg_probs, soft_thresh = image_tensor.cpu(), fg_probs.cpu(), soft_thresh.cpu()
# Then we presume at this threshold the probability should be 0.5
# probability 0 should stay 0, 1 should stay 1
# sigmoid=lambda x: 1/(1 + np.exp(-x))
fg_probs = torch.sigmoid(T * (fg_probs - soft_thresh))
probs = torch.stack([1 - fg_probs, fg_probs], dim=1) # crf expects both classes as input
if self.debug:
print('softthresh', soft_thresh)
print('fg_probs min max', fg_probs.min(), fg_probs.max())
# C: Number of classes
bsz, C, H, W = probs.shape
refined_masks = []
image_numpy = np.ascontiguousarray( \
(255 * image_tensor.permute(0, 2, 3, 1)).numpy().astype(np.uint8))
probs_numpy = probs.numpy()
for (image, prob) in zip(image_numpy, probs_numpy):
# Unary potentials
unary = np.ascontiguousarray(unary_from_softmax(prob))
d = dcrf.DenseCRF2D(W, H, C)
d.setUnaryEnergy(unary)
# Add pairwise potentials
d.addPairwiseGaussian(sxy=self.gaussian_stdxy, compat=self.gaussian_compat)
d.addPairwiseBilateral(sxy=self.bilateral_stdxy, srgb=self.stdrgb,
rgbim=image, compat=self.bilateral_compat)
# Perform inference
Q = d.inference(self.iters)
if self.debug:
print('Q:', np.array(Q).shape, np.array(Q)[0].mean(), np.array(Q).mean())
result = np.reshape(Q, (2, H, W)) # np.argmax(Q, axis=0).reshape((H, W))
refined_masks.append(result)
return torch.from_numpy(np.stack(refined_masks, axis=0))
# def iterrefine(self, iters, image_tensor, fg_probs, soft_thresh=None, T=1):
# q1 = fg_probs
# for iter in range(iters):
# print(q1.shape)
# q1 = self.refine(image_tensor, q1, soft_thresh=None, T=1)[:,1]
# return q1
def iterrefine(self, iters, q_img, fg_probs, thresh_fn, debug=False):
pred = fg_probs.unsqueeze(1).expand(1, 2, *fg_probs.shape[-2:])
for it in range(iters):
thresh = thresh_fn(pred[:, 1])[0]
if debug and i % 10 == 0:
print('thresh', thresh)
display(to_pil(pred[0, 1]))
pred = self.refine(q_img, pred[:, 1], soft_thresh=thresh)
return pred
#
# class Subplot:
# def __init__(self):
# self.vertical_lines = []
# self.histograms = []
# self.gaussian_curves = []
# self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
# self.title = ''
#
# class Element:
# def __init__(self, x=None, y=None, label=''):
# if x is not None:
# self.x = Subplot.to_np(x)
# if y is not None:
# self.y = Subplot.to_np(y)
#
# self.label = label
#
# @staticmethod
# def to_np(t):
# return t.detach().cpu().numpy()
#
# def add_vertical(self, x, label=''):
# self.vertical_lines.append(Subplot.Element(x=x, label=label))
# return self
#
# def add_histogram(self, samples, label=''):
# self.histograms.append(Subplot.Element(x=samples, label=label))
# return self
#
# def add_gaussian(self, gaussian):
# samples, mu, var = gaussian.samples, gaussian.mean, gaussian.covs
# # Generate a range of x values
# x_values = np.linspace(samples.min(), samples.max(), 100)
# x_values = np.linspace(samples.min(), samples.max(), 100)
#
# # Compute Gaussian values for these x values
# gaussian1_values = gaussian.gaussian_pdf(x_values, mu[0].item(), var[0].item())
# gaussian2_values = gaussian.gaussian_pdf(x_values, mu[1].item(), var[1].item())
# self.gaussian_curves.append(Subplot.Element(x_values, gaussian1_values))
# self.gaussian_curves.append(Subplot.Element(x_values, gaussian2_values))
# return self
#
#
# class PredHistos2():
# def __init__(self, n_cols=1):
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
# self.n_cols = n_cols
# if n_cols == 1:
# self.builder = Subplot()
# self.subplots = [Subplot() for x in range(n_cols)]
# self.alpha = 0.5
# self.bins = 200
#
# def reload(self, n_cols=1):
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#
# def aggr(self, ax, sub):
# for hist, col in zip(sub.histograms, sub.colors):
# ax.hist(hist.x, self.bins, density=True, color=col, alpha=self.alpha, label=hist.label)
# for vline, col in zip(sub.vertical_lines, sub.colors):
# ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
# for gaussian, col in zip(sub.gaussian_curves, sub.colors):
# ax.plot(gaussian.x, gaussian.y, gaussian.label, col)
# ax.legend()
#
# def plot(self, name=''):
#
# if self.n_cols == 1:
# self.aggr(plt, self.builder)
# else:
# for ax, sub in zip(self.axes, self.subplots):
# self.aggr(ax, sub)
# ax.set_title(sub.title)
#
# plt.legend()
# plt.title(name)
# plt.show()
#
#
# from sklearn.mixture import GaussianMixture
# import scipy.optimize as opt
# from scipy.optimize import fsolve
#
#
# class GMM:
# def __init__(self, q_pred_coarse, name='gaussian', n_components=2):
# samples = q_pred_coarse.detach().cpu().numpy()
# self.samples = samples.reshape(-1, 1)
#
# # Fit a mixture of 2 Gaussians using EM
# gmm = GaussianMixture(n_components)
# gmm.fit(samples)
# self.means = gmm.means_.flatten()
# self.covs = gmm.covariances_.flatten()
# self.weights = gmm.weights_
# self.label = name
#
# def intersect(self):
# # Use fsolve to find the intersection
# gaussian_intersect, = fsolve(difference, self.means.mean(), args=(
# self.means[0].item(), self.covs[0].item(), self.means[1].item(), self.means[1].item()))
# return gaussian_intersect
#
#
# class PredHistoSNS():
# def __init__(self, n_cols=1):
# import seaborn as sns
# sns.set_theme(style="whitegrid") # Set the Seaborn theme. You can change the style as needed.
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
# self.n_cols = n_cols
# if n_cols == 1:
# self.axes = [self.axes] # Wrap the single axis in a list to simplify the loop logic later.
# self.builder = Subplot() # This is assuming Subplot is a properly defined class.
# self.subplots = [Subplot() for _ in range(n_cols)] # Use underscore for unused loop variable.
# self.alpha = 0.5
# self.bins = 200
#
# def reload(self, n_cols=1):
# self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#
# def aggr(self, ax, sub):
# import seaborn as sns
# for hist, col in zip(sub.histograms, sub.colors):
# sns.histplot(hist.x, bins=self.bins, kde=False, color=col, ax=ax, alpha=self.alpha, label=hist.label)
# for vline, col in zip(sub.vertical_lines, sub.colors):
# ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
# for gaussian, col in zip(sub.gaussian_curves, sub.colors):
# sns.lineplot(x=gaussian.x, y=gaussian.y, label=gaussian.label, color=col, ax=ax)
# ax.legend()
#
# def plot(self, name=''):
#
# if self.n_cols == 1:
# self.aggr(self.axes[0], self.builder)
# else:
# for ax, sub in zip(self.axes, self.subplots):
# self.aggr(ax, sub)
# ax.set_title(sub.title)
#
# plt.show()
#
#
# def overlay_mask(image, mask, color=[255, 0, 0], alpha=0.2):
# """
# Apply an overlay of a binary mask onto an image using a specified color.
#
# :param image: A PyTorch tensor of the image (C x H x W) with pixel values in [0, 1].
# :param mask: A PyTorch tensor of the mask (H x W) with binary values (0 or 1).
# :param color: A list of 3 elements representing the RGB values of the overlay color.
# :param alpha: A float representing the transparency of the overlay (0 to 1).
# :return: An overlayed image tensor.
# """
# # Ensure the mask is binary
# mask = (mask > 0).float()
#
# # Create an RGB version of the mask
# mask_rgb = torch.tensor(color).view(3, 1, 1) / 255.0 # Normalize the color vector
# mask_rgb = mask_rgb * mask
#
# # Overlay the mask onto the image
# overlayed_image = (1 - alpha) * image + alpha * mask_rgb
#
# # Ensure the resulting tensor values are between 0 and 1
# overlayed_image = torch.clamp(overlayed_image, 0, 1)
#
# return overlayed_image
#
#
# import pandas as pd
# to_pil = lambda t: transforms.ToPILImage()(t) if t.shape[-1] > 4 else transforms.ToPILImage()(t.permute(2, 0, 1))
#
#
# def pilImageRow(*imgs, maxwidth=800, bordercolor=0x000000):
# imgs = [to_pil(im.float()) for im in imgs]
# dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height))
# for i, im in enumerate(imgs):
# loc = [x0, y0, x1, y1] = [i * im.width, 0, (i + 1) * im.width, im.height]
# dst.paste(im, (x0, y0))
# ImageDraw.Draw(dst).rectangle(loc, width=2, outline=bordercolor)
# factorToBig = dst.width / maxwidth
# dst = dst.resize((int(dst.width / factorToBig), int(dst.height / factorToBig)))
# return dst
#
#
# def tensor_table(**kwargs):
# tensor_overview = {}
# for name, tensor in kwargs.items():
# if callable(tensor):
# print(name, [tensor(t) for _, t in kwargs.items() if isinstance(t, torch.Tensor)])
# else:
# tensor_overview[name] = {
# 'min': tensor.min().item(),
# 'max': tensor.max().item(),
# 'shape': tensor.shape,
# }
# return pd.DataFrame.from_dict(tensor_overview, orient='index')