Spaces:
Runtime error
Runtime error
File size: 4,569 Bytes
26a054b fc8170b 26a054b ab7bcf7 26a054b ab7bcf7 26a054b 63c8324 26a054b 15a8385 26a054b 292dd74 26a054b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# imports
import sys
import urllib
import io
import numpy as np
from PIL import Image
#import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.ndimage import binary_closing, binary_opening
import torch
import torchvision.transforms as transforms
def load_image_from_url(url):
if url.startswith('http'):
url = urllib.request.urlopen(url)
return Image.open(url).convert("RGB")
def make_transform(smaller_edge_size=448):
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
interpolation_mode = transforms.InterpolationMode.BICUBIC
return transforms.Compose([
transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
#fixed_images
def prepare_image_dino(image, smaller_edge_size=448, square=True, patch_size=14):
if square:
transform = make_transform((int(smaller_edge_size),int(smaller_edge_size)))
else:
transform = make_transform(int(smaller_edge_size))
image_tensor = transform(image)
# Crop image to dimensions that are a multiple of the patch size
height, width = image_tensor.shape[1:] # C x H x W
cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
image_tensor = image_tensor[:, :cropped_height, :cropped_width]
grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
return image_tensor, grid_size
def extract_dino_features(img_path_list, dino_model, smaller_edge_size=448):
no_images = len(img_path_list)
for i in range(no_images):
#stack_image_batch = torch.zeros(no_images, , )
img = load_image_from_url(img_path_list[i])#img_path_list[i]#
image_tensor, grid_size = prepare_image_dino(image= img,
smaller_edge_size= smaller_edge_size,
square = True)
if i ==0:
stack_image_batch = torch.zeros(no_images, *image_tensor.shape )
stack_image_batch[i] = image_tensor
with torch.inference_mode():
stackedtokens = dino_model.get_intermediate_layers(stack_image_batch)[0]
stackedtokens_ = stackedtokens.detach()#.numpy()#.squeeze()
return stackedtokens_, stack_image_batch, grid_size
def get_projections_and_standardarray(stackedtokens):
N_features = stackedtokens.shape[-1]
pca_t = PCA(n_components=1)
projected_features = pca_t.fit_transform(stackedtokens.reshape([-1,N_features]))
standard_array = pca_t.components_.squeeze()
projections = stackedtokens @ standard_array
return projections, standard_array
def min_max(projected_features):
b = torch.tensor(projected_features)
b_min = b.min(dim=0, keepdim=True).values
b_max = b.max(dim=0, keepdim=True).values
normalized_b = (b - b_min) / (b_max - b_min)
return normalized_b
def get_masks(projections, grid_size, background_threshold = 0.0,
apply_opening = True, apply_closing = True):
no_images = projections.shape[0]
masks = np.zeros([no_images, grid_size[0]*grid_size[1]])
for i in range(no_images):
mask_i = projections[i,:]> background_threshold
mask_i = mask_i.reshape(*grid_size)
if apply_opening:
mask_i = binary_opening(mask_i)
if apply_closing:
mask_i = binary_closing(mask_i)
masks[i,:] = mask_i.flatten()
return masks
def make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0.0,
apply_opening = True, apply_closing = True):
projection = tokens @ standard_array
mask = projection > background_threshold
mask = mask.reshape(*grid_size)
if apply_opening:
mask = binary_opening(mask)
if apply_closing:
mask = binary_closing(mask)
return mask.flatten()
def render_patch_pca3(tokens, standard_array,grid_size, smaller_edge_size = 448,background_threshold = 0.05,
apply_opening = False, apply_closing = False):
mask = make_foreground_mask(tokens, standard_array, grid_size,
background_threshold = 0,apply_opening=True, apply_closing=True)
pca = PCA(n_components=3)
pca.fit(tokens[mask])
projected_tokens = pca.transform(tokens)
normalized_t = min_max(projected_tokens)
array = (normalized_t * 255).byte().numpy()
array[~mask] = 0
array = array.reshape(*grid_size, 3)
return Image.fromarray(array)
|