File size: 1,981 Bytes
8358707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from numpy import argsort, array as np_array
from PIL import Image as PImage, ImageDraw as PImageDraw
from sklearn.metrics.pairwise import euclidean_distances, cosine_distances
from torch import no_grad
from transformers import __version__ as t_version

def draw_results(img, objs):
  img = img.convert("RGB").copy()
  draw = PImageDraw.Draw(img)
  for o in objs:
    draw.rectangle(tuple(o["box"].values()),
                   outline=(10, 220, 10),
                   width=(min(img.size) // 128))
  img.thumbnail((512, 512))
  return img

def embed_word(word, processor, model, device):
  txt_t = processor(text=[word], padding="max_length", max_length=64, return_tensors="pt").to(device)
  with no_grad():
    if t_version.startswith("5."):
      txt_embedding = model.get_text_features(**txt_t)["pooler_output"][0]
    else:
      txt_embedding = model.get_text_features(**txt_t)[0]
  return txt_embedding.cpu().numpy()

def embed_image(image, processor, model, device):
  img_t = processor(images=[image], return_tensors="pt", padding=True).to(device)
  with no_grad():
    if t_version.startswith("5."):
      img_embedding = model.get_image_features(**img_t)["pooler_output"][0]
    else:
      img_embedding = model.get_image_features(**img_t)[0]
  return img_embedding.cpu().numpy()

def idxs_by_dist(img_embeddings, txt_embedding, cos=True):
  if cos:
    dists = cosine_distances([txt_embedding], img_embeddings)
  else:
    dists = euclidean_distances([txt_embedding], img_embeddings)
  return argsort(dists[0])

def make_image(imgs, order):
  iw = sum(i.width for i in imgs)
  ih = min(i.height for i in imgs)
  oimg = PImage.new("RGB", (iw, ih))
  cw = 0
  for idx in order:
    mw = imgs[idx].width
    oimg.paste(imgs[idx], (cw, 0, cw+mw, ih))
    cw += mw
  return oimg

def idxs_along_axes(img_embeddings, txt_embeddings):
  word_dists = euclidean_distances(txt_embeddings, img_embeddings)
  emb_dists = word_dists[0] / word_dists[1]
  return argsort(emb_dists)