File size: 4,569 Bytes
26a054b
 
 
 
 
 
fc8170b
26a054b
 
 
 
 
 
 
 
 
 
ab7bcf7
26a054b
 
 
 
 
 
 
 
 
 
ab7bcf7
26a054b
 
 
 
 
 
 
 
 
 
 
 
 
63c8324
26a054b
 
 
15a8385
26a054b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292dd74
26a054b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# imports
import sys
import urllib
import io
import numpy as np
from PIL import Image
#import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.ndimage import binary_closing, binary_opening
import torch
import torchvision.transforms as transforms

def load_image_from_url(url):
    if url.startswith('http'):
      url = urllib.request.urlopen(url)
    return Image.open(url).convert("RGB")

def make_transform(smaller_edge_size=448):
    IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
    IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
    interpolation_mode = transforms.InterpolationMode.BICUBIC
    return transforms.Compose([
        transforms.Resize(size=smaller_edge_size, interpolation=interpolation_mode, antialias=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
    ])

#fixed_images
def prepare_image_dino(image, smaller_edge_size=448, square=True, patch_size=14):
  if square:
    transform = make_transform((int(smaller_edge_size),int(smaller_edge_size)))
  else:
    transform = make_transform(int(smaller_edge_size))
  image_tensor = transform(image)
  # Crop image to dimensions that are a multiple of the patch size
  height, width = image_tensor.shape[1:] # C x H x W
  cropped_width, cropped_height = width - width % patch_size, height - height % patch_size
  image_tensor = image_tensor[:, :cropped_height, :cropped_width]
  grid_size = (cropped_height // patch_size, cropped_width // patch_size) # h x w (TODO: check)
  return image_tensor, grid_size


def extract_dino_features(img_path_list, dino_model, smaller_edge_size=448):
  no_images = len(img_path_list)
  for i in range(no_images):
    #stack_image_batch = torch.zeros(no_images, , )
    img = load_image_from_url(img_path_list[i])#img_path_list[i]#
    image_tensor, grid_size = prepare_image_dino(image= img,
                                                smaller_edge_size= smaller_edge_size,
                                                square = True)
    if i ==0:
      stack_image_batch = torch.zeros(no_images, *image_tensor.shape )
    stack_image_batch[i] = image_tensor
  with torch.inference_mode():
      stackedtokens = dino_model.get_intermediate_layers(stack_image_batch)[0]
      stackedtokens_ = stackedtokens.detach()#.numpy()#.squeeze()
  return stackedtokens_, stack_image_batch, grid_size

def get_projections_and_standardarray(stackedtokens):
  N_features = stackedtokens.shape[-1]
  pca_t = PCA(n_components=1)
  projected_features = pca_t.fit_transform(stackedtokens.reshape([-1,N_features]))
  standard_array = pca_t.components_.squeeze()
  projections = stackedtokens @ standard_array
  return projections, standard_array

def min_max(projected_features):
  b = torch.tensor(projected_features)
  b_min = b.min(dim=0, keepdim=True).values
  b_max = b.max(dim=0, keepdim=True).values
  normalized_b = (b - b_min) / (b_max - b_min)
  return normalized_b

def get_masks(projections, grid_size, background_threshold = 0.0,
                         apply_opening = True, apply_closing = True):
    no_images = projections.shape[0]
    masks = np.zeros([no_images, grid_size[0]*grid_size[1]]) 
    for i in range(no_images):
      mask_i =  projections[i,:]> background_threshold
      mask_i = mask_i.reshape(*grid_size)
      if apply_opening:
          mask_i = binary_opening(mask_i)
      if apply_closing:
          mask_i = binary_closing(mask_i)
      masks[i,:] = mask_i.flatten()
    return masks


def make_foreground_mask(tokens, standard_array, grid_size, background_threshold = 0.0,
                         apply_opening = True, apply_closing = True):
    projection = tokens @ standard_array
    mask = projection > background_threshold
    mask = mask.reshape(*grid_size)
    if apply_opening:
        mask = binary_opening(mask)
    if apply_closing:
        mask = binary_closing(mask)
    return mask.flatten()

def render_patch_pca3(tokens, standard_array,grid_size, smaller_edge_size = 448,background_threshold = 0.05,
                      apply_opening = False, apply_closing = False):
    mask = make_foreground_mask(tokens, standard_array, grid_size, 
                                background_threshold = 0,apply_opening=True, apply_closing=True)
    pca = PCA(n_components=3)
    pca.fit(tokens[mask])
    projected_tokens = pca.transform(tokens)
    normalized_t = min_max(projected_tokens)
    array = (normalized_t * 255).byte().numpy()
    array[~mask] = 0
    array = array.reshape(*grid_size, 3)
    return Image.fromarray(array)