Amould's picture
Update codes.py
292dd74
# imports
import sys
import urllib
import io
import numpy as np
from PIL import Image
#import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.ndimage import binary_closing, binary_opening
import torch
import torchvision.transforms as transforms
def load_image_from_url(url):
if url.startswith('http'):
url = urllib.request.urlopen(url)
return Image.open(url).convert("RGB")
def make_transform(smaller_edge_size=448):
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
interpolation_mode = transforms.InterpolationMode.BICUBIC
return transforms.Compose([
transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
#fixed_images
def prepare_image_dino(image, smaller_edge_size=448, square=True, patch_size=14):
if square:
transform = make_transform((int(smaller_edge_size),int(smaller_edge_size)))
else:
transform = make_transform(int(smaller_edge_size))
image_tensor = transform(image)
# Crop image to dimensions that are a multiple of the patch size
height, width = image_tensor.shape[1:] # C x H x W
cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
image_tensor = image_tensor[:, :cropped_height, :cropped_width]
grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
return image_tensor, grid_size
def extract_dino_features(img_path_list, dino_model, smaller_edge_size=448):
no_images = len(img_path_list)
for i in range(no_images):
#stack_image_batch = torch.zeros(no_images, , )
img = load_image_from_url(img_path_list[i])#img_path_list[i]#
image_tensor, grid_size = prepare_image_dino(image= img,
smaller_edge_size= smaller_edge_size,
square = True)
if i ==0:
stack_image_batch = torch.zeros(no_images, *image_tensor.shape )
stack_image_batch[i] = image_tensor
with torch.inference_mode():
stackedtokens = dino_model.get_intermediate_layers(stack_image_batch)[0]
stackedtokens_ = stackedtokens.detach()#.numpy()#.squeeze()
return stackedtokens_, stack_image_batch, grid_size
def get_projections_and_standardarray(stackedtokens):
N_features = stackedtokens.shape[-1]
pca_t = PCA(n_components=1)
projected_features = pca_t.fit_transform(stackedtokens.reshape([-1,N_features]))
standard_array = pca_t.components_.squeeze()
projections = stackedtokens @ standard_array
return projections, standard_array
def min_max(projected_features):
b = torch.tensor(projected_features)
b_min = b.min(dim=0, keepdim=True).values
b_max = b.max(dim=0, keepdim=True).values
normalized_b = (b - b_min) / (b_max - b_min)
return normalized_b
def get_masks(projections, grid_size, background_threshold = 0.0,
apply_opening = True, apply_closing = True):
no_images = projections.shape[0]
masks = np.zeros([no_images, grid_size[0]*grid_size[1]])
for i in range(no_images):
mask_i = projections[i,:]> background_threshold
mask_i = mask_i.reshape(*grid_size)
if apply_opening:
mask_i = binary_opening(mask_i)
if apply_closing:
mask_i = binary_closing(mask_i)
masks[i,:] = mask_i.flatten()
return masks
def make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0.0,
apply_opening = True, apply_closing = True):
projection = tokens @ standard_array
mask = projection > background_threshold
mask = mask.reshape(*grid_size)
if apply_opening:
mask = binary_opening(mask)
if apply_closing:
mask = binary_closing(mask)
return mask.flatten()
def render_patch_pca3(tokens, standard_array,grid_size, smaller_edge_size = 448,background_threshold = 0.05,
apply_opening = False, apply_closing = False):
mask = make_foreground_mask(tokens, standard_array, grid_size,
background_threshold = 0,apply_opening=True, apply_closing=True)
pca = PCA(n_components=3)
pca.fit(tokens[mask])
projected_tokens = pca.transform(tokens)
normalized_t = min_max(projected_tokens)
array = (normalized_t * 255).byte().numpy()
array[~mask] = 0
array = array.reshape(*grid_size, 3)
return Image.fromarray(array)