Spaces:
Runtime error
Runtime error
| from numpy import argsort, array as np_array | |
| from PIL import Image as PImage, ImageDraw as PImageDraw | |
| from sklearn.metrics.pairwise import euclidean_distances, cosine_distances | |
| from torch import no_grad | |
| from transformers import __version__ as t_version | |
| def draw_results(img, objs): | |
| img = img.convert("RGB").copy() | |
| draw = PImageDraw.Draw(img) | |
| for o in objs: | |
| draw.rectangle(tuple(o["box"].values()), | |
| outline=(10, 220, 10), | |
| width=(min(img.size) // 128)) | |
| img.thumbnail((512, 512)) | |
| return img | |
| def embed_word(word, processor, model, device): | |
| txt_t = processor(text=[word], padding="max_length", max_length=64, return_tensors="pt").to(device) | |
| with no_grad(): | |
| if t_version.startswith("5."): | |
| txt_embedding = model.get_text_features(**txt_t)["pooler_output"][0] | |
| else: | |
| txt_embedding = model.get_text_features(**txt_t)[0] | |
| return txt_embedding.cpu().numpy() | |
| def embed_image(image, processor, model, device): | |
| img_t = processor(images=[image], return_tensors="pt", padding=True).to(device) | |
| with no_grad(): | |
| if t_version.startswith("5."): | |
| img_embedding = model.get_image_features(**img_t)["pooler_output"][0] | |
| else: | |
| img_embedding = model.get_image_features(**img_t)[0] | |
| return img_embedding.cpu().numpy() | |
| def idxs_by_dist(img_embeddings, txt_embedding, cos=True): | |
| if cos: | |
| dists = cosine_distances([txt_embedding], img_embeddings) | |
| else: | |
| dists = euclidean_distances([txt_embedding], img_embeddings) | |
| return argsort(dists[0]) | |
| def make_image(imgs, order): | |
| iw = sum(i.width for i in imgs) | |
| ih = min(i.height for i in imgs) | |
| oimg = PImage.new("RGB", (iw, ih)) | |
| cw = 0 | |
| for idx in order: | |
| mw = imgs[idx].width | |
| oimg.paste(imgs[idx], (cw, 0, cw+mw, ih)) | |
| cw += mw | |
| return oimg | |
| def idxs_along_axes(img_embeddings, txt_embeddings): | |
| word_dists = euclidean_distances(txt_embeddings, img_embeddings) | |
| emb_dists = word_dists[0] / word_dists[1] | |
| return argsort(emb_dists) | |