Spaces:
Runtime error
Runtime error
| # imports | |
| import sys | |
| import urllib | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| #import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| from scipy.ndimage import binary_closing, binary_opening | |
| import torch | |
| import torchvision.transforms as transforms | |
| def load_image_from_url(url): | |
| if url.startswith('http'): | |
| url = urllib.request.urlopen(url) | |
| return Image.open(url).convert("RGB") | |
| def make_transform(smaller_edge_size=448): | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| interpolation_mode = transforms.InterpolationMode.BICUBIC | |
| return transforms.Compose([ | |
| transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
| ]) | |
| #fixed_images | |
| def prepare_image_dino(image, smaller_edge_size=448, square=True, patch_size=14): | |
| if square: | |
| transform = make_transform((int(smaller_edge_size),int(smaller_edge_size))) | |
| else: | |
| transform = make_transform(int(smaller_edge_size)) | |
| image_tensor = transform(image) | |
| # Crop image to dimensions that are a multiple of the patch size | |
| height, width = image_tensor.shape[1:] # C x H x W | |
| cropped_width, cropped_height = width - width % patch_size, height - height % patch_size | |
| image_tensor = image_tensor[:, :cropped_height, :cropped_width] | |
| grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check) | |
| return image_tensor, grid_size | |
| def extract_dino_features(img_path_list, dino_model, smaller_edge_size=448): | |
| no_images = len(img_path_list) | |
| for i in range(no_images): | |
| #stack_image_batch = torch.zeros(no_images, , ) | |
| img = load_image_from_url(img_path_list[i])#img_path_list[i]# | |
| image_tensor, grid_size = prepare_image_dino(image= img, | |
| smaller_edge_size= smaller_edge_size, | |
| square = True) | |
| if i ==0: | |
| stack_image_batch = torch.zeros(no_images, *image_tensor.shape ) | |
| stack_image_batch[i] = image_tensor | |
| with torch.inference_mode(): | |
| stackedtokens = dino_model.get_intermediate_layers(stack_image_batch)[0] | |
| stackedtokens_ = stackedtokens.detach()#.numpy()#.squeeze() | |
| return stackedtokens_, stack_image_batch, grid_size | |
| def get_projections_and_standardarray(stackedtokens): | |
| N_features = stackedtokens.shape[-1] | |
| pca_t = PCA(n_components=1) | |
| projected_features = pca_t.fit_transform(stackedtokens.reshape([-1,N_features])) | |
| standard_array = pca_t.components_.squeeze() | |
| projections = stackedtokens @ standard_array | |
| return projections, standard_array | |
| def min_max(projected_features): | |
| b = torch.tensor(projected_features) | |
| b_min = b.min(dim=0, keepdim=True).values | |
| b_max = b.max(dim=0, keepdim=True).values | |
| normalized_b = (b - b_min) / (b_max - b_min) | |
| return normalized_b | |
| def get_masks(projections, grid_size, background_threshold = 0.0, | |
| apply_opening = True, apply_closing = True): | |
| no_images = projections.shape[0] | |
| masks = np.zeros([no_images, grid_size[0]*grid_size[1]]) | |
| for i in range(no_images): | |
| mask_i = projections[i,:]> background_threshold | |
| mask_i = mask_i.reshape(*grid_size) | |
| if apply_opening: | |
| mask_i = binary_opening(mask_i) | |
| if apply_closing: | |
| mask_i = binary_closing(mask_i) | |
| masks[i,:] = mask_i.flatten() | |
| return masks | |
| def make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0.0, | |
| apply_opening = True, apply_closing = True): | |
| projection = tokens @ standard_array | |
| mask = projection > background_threshold | |
| mask = mask.reshape(*grid_size) | |
| if apply_opening: | |
| mask = binary_opening(mask) | |
| if apply_closing: | |
| mask = binary_closing(mask) | |
| return mask.flatten() | |
| def render_patch_pca3(tokens, standard_array,grid_size, smaller_edge_size = 448,background_threshold = 0.05, | |
| apply_opening = False, apply_closing = False): | |
| mask = make_foreground_mask(tokens, standard_array, grid_size, | |
| background_threshold = 0,apply_opening=True, apply_closing=True) | |
| pca = PCA(n_components=3) | |
| pca.fit(tokens[mask]) | |
| projected_tokens = pca.transform(tokens) | |
| normalized_t = min_max(projected_tokens) | |
| array = (normalized_t * 255).byte().numpy() | |
| array[~mask] = 0 | |
| array = array.reshape(*grid_size, 3) | |
| return Image.fromarray(array) | |