# imports import sys import urllib import io import numpy as np from PIL import Image #import matplotlib.pyplot as plt from sklearn.decomposition import PCA from scipy.ndimage import binary_closing, binary_opening import torch import torchvision.transforms as transforms def load_image_from_url(url): if url.startswith('http'): url = urllib.request.urlopen(url) return Image.open(url).convert("RGB") def make_transform(smaller_edge_size=448): IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) interpolation_mode = transforms.InterpolationMode.BICUBIC return transforms.Compose([ transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) #fixed_images def prepare_image_dino(image, smaller_edge_size=448, square=True, patch_size=14): if square: transform = make_transform((int(smaller_edge_size),int(smaller_edge_size))) else: transform = make_transform(int(smaller_edge_size)) image_tensor = transform(image) # Crop image to dimensions that are a multiple of the patch size height, width = image_tensor.shape[1:] # C x H x W cropped_width, cropped_height = width - width % patch_size, height - height % patch_size image_tensor = image_tensor[:, :cropped_height, :cropped_width] grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check) return image_tensor, grid_size def extract_dino_features(img_path_list, dino_model, smaller_edge_size=448): no_images = len(img_path_list) for i in range(no_images): #stack_image_batch = torch.zeros(no_images, , ) img = load_image_from_url(img_path_list[i])#img_path_list[i]# image_tensor, grid_size = prepare_image_dino(image= img, smaller_edge_size= smaller_edge_size, square = True) if i ==0: stack_image_batch = torch.zeros(no_images, *image_tensor.shape ) stack_image_batch[i] = image_tensor with torch.inference_mode(): stackedtokens = dino_model.get_intermediate_layers(stack_image_batch)[0] stackedtokens_ = stackedtokens.detach()#.numpy()#.squeeze() return stackedtokens_, stack_image_batch, grid_size def get_projections_and_standardarray(stackedtokens): N_features = stackedtokens.shape[-1] pca_t = PCA(n_components=1) projected_features = pca_t.fit_transform(stackedtokens.reshape([-1,N_features])) standard_array = pca_t.components_.squeeze() projections = stackedtokens @ standard_array return projections, standard_array def min_max(projected_features): b = torch.tensor(projected_features) b_min = b.min(dim=0, keepdim=True).values b_max = b.max(dim=0, keepdim=True).values normalized_b = (b - b_min) / (b_max - b_min) return normalized_b def get_masks(projections, grid_size, background_threshold = 0.0, apply_opening = True, apply_closing = True): no_images = projections.shape[0] masks = np.zeros([no_images, grid_size[0]*grid_size[1]]) for i in range(no_images): mask_i = projections[i,:]> background_threshold mask_i = mask_i.reshape(*grid_size) if apply_opening: mask_i = binary_opening(mask_i) if apply_closing: mask_i = binary_closing(mask_i) masks[i,:] = mask_i.flatten() return masks def make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0.0, apply_opening = True, apply_closing = True): projection = tokens @ standard_array mask = projection > background_threshold mask = mask.reshape(*grid_size) if apply_opening: mask = binary_opening(mask) if apply_closing: mask = binary_closing(mask) return mask.flatten() def render_patch_pca3(tokens, standard_array,grid_size, smaller_edge_size = 448,background_threshold = 0.05, apply_opening = False, apply_closing = False): mask = make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0,apply_opening=True, apply_closing=True) pca = PCA(n_components=3) pca.fit(tokens[mask]) projected_tokens = pca.transform(tokens) normalized_t = min_max(projected_tokens) array = (normalized_t * 255).byte().numpy() array[~mask] = 0 array = array.reshape(*grid_size, 3) return Image.fromarray(array)