| | import gradio as gr |
| | import torch |
| | import copy |
| | import time |
| | import requests |
| | import io |
| | import numpy as np |
| | import re |
| | from einops import rearrange |
| |
|
| | import ipdb |
| |
|
| | from PIL import Image |
| |
|
| | from vilt.config import ex |
| | from vilt.modules import ViLTransformerSS |
| |
|
| | from vilt.modules.objectives import cost_matrix_cosine, ipot |
| | from vilt.transforms import pixelbert_transform |
| | from vilt.datamodules.datamodule_base import get_pretrained_tokenizer |
| |
|
| |
|
| | @ex.automain |
| | def main(_config): |
| | _config = copy.deepcopy(_config) |
| |
|
| | loss_names = { |
| | "itm": 1, |
| | "mlm": 0.5, |
| | "mpp": 0, |
| | "vqa": 0, |
| | "imgcls": 0, |
| | "nlvr2": 0, |
| | "irtr": 1, |
| | "arc": 0, |
| | } |
| | tokenizer = get_pretrained_tokenizer(_config["tokenizer"]) |
| |
|
| | _config.update( |
| | { |
| | "loss_names": loss_names, |
| | } |
| | ) |
| |
|
| | model = ViLTransformerSS(_config) |
| | model.setup("test") |
| | model.eval() |
| |
|
| | device = "cuda:0" if _config["num_gpus"] > 0 else "cpu" |
| | model.to(device) |
| | lst_imgs = [f"C:\\Users\\alimh\\PycharmProjects\\ViLT\\assets\\database\\{i}.jpg" for i in range(1,10)] |
| |
|
| |
|
| | def infer( mp_text, hidx =0 ): |
| | def get_image(path): |
| | image = Image.open(path).convert("RGB") |
| | img = pixelbert_transform(size=384)(image) |
| | return img.unsqueeze(0).to(device) |
| |
|
| | imgs = [get_image(pth) for pth in lst_imgs] |
| |
|
| | batch = [] |
| | for img in imgs: |
| | batch.append({"text": [mp_text], "image": [img]}) |
| |
|
| | for dic in batch: |
| | encoded = tokenizer(dic["text"]) |
| |
|
| | dic["text_ids"] = torch.tensor(encoded["input_ids"]).to(device) |
| | dic["text_labels"] = torch.tensor(encoded["input_ids"]).to(device) |
| | dic["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device) |
| |
|
| | scores = [] |
| | with torch.no_grad(): |
| |
|
| | for dic in batch: |
| | s = time.time() |
| | infer = model(dic) |
| |
|
| | e = time.time() |
| | print("time ", round(e - s, 2)) |
| |
|
| | score = model.rank_output(infer["cls_feats"]) |
| | scores.append(score.item()) |
| | print(scores) |
| | img_idx =np.argmax(scores) |
| | print(np.argmax(scores) + 1 ) |
| | selected_image = Image.open(lst_imgs[img_idx]).convert("RGB") |
| | selected_image = np.asarray(selected_image) |
| | print(selected_image.shape) |
| | selected_token ="" |
| | if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]): |
| | image = Image.open(lst_imgs[img_idx]).convert("RGB") |
| | selected_batch = batch[img_idx] |
| | with torch.no_grad(): |
| | infer = model(selected_batch) |
| | txt_emb, img_emb = infer["text_feats"], infer["image_feats"] |
| | txt_mask, img_mask = ( |
| | infer["text_masks"].bool(), |
| | infer["image_masks"].bool(), |
| | ) |
| | for i, _len in enumerate(txt_mask.sum(dim=1)): |
| | txt_mask[i, _len - 1] = False |
| | txt_mask[:, 0] = False |
| | img_mask[:, 0] = False |
| | txt_pad, img_pad = ~txt_mask, ~img_mask |
| |
|
| | cost = cost_matrix_cosine(txt_emb.float(), img_emb.float()) |
| | joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2) |
| | cost.masked_fill_(joint_pad, 0) |
| |
|
| | txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to( |
| | dtype=cost.dtype |
| | ) |
| | img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to( |
| | dtype=cost.dtype |
| | ) |
| | T = ipot( |
| | cost.detach(), |
| | txt_len, |
| | txt_pad, |
| | img_len, |
| | img_pad, |
| | joint_pad, |
| | 0.1, |
| | 1000, |
| | 1, |
| | ) |
| |
|
| | plan = T[0] |
| | plan_single = plan * len(txt_emb) |
| | cost_ = plan_single.t() |
| |
|
| | cost_ = cost_[hidx][1:].cpu() |
| |
|
| | patch_index, (H, W) = infer["patch_index"] |
| | heatmap = torch.zeros(H, W) |
| | for i, pidx in enumerate(patch_index[0]): |
| | h, w = pidx[0].item(), pidx[1].item() |
| | heatmap[h, w] = cost_[i] |
| |
|
| | heatmap = (heatmap - heatmap.mean()) / heatmap.std() |
| | heatmap = np.clip(heatmap, 1.0, 3.0) |
| | heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) |
| |
|
| | _w, _h = image.size |
| | overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize( |
| | (_w, _h), resample=Image.NEAREST |
| | ) |
| | image_rgba = image.copy() |
| | image_rgba.putalpha(overlay) |
| | selected_image = image_rgba |
| |
|
| | selected_token = tokenizer.convert_ids_to_tokens( |
| | encoded["input_ids"][0][hidx] |
| | ) |
| |
|
| |
|
| | return [selected_image,hidx] |
| |
|
| | imgs = [Image.open(pth).convert("RGB") for pth in lst_imgs] |
| | inputs = [ |
| |
|
| | gr.inputs.Textbox(label="Caption with [MASK] tokens to be filled.", lines=5), |
| | gr.inputs.Slider( |
| | minimum=0, |
| | maximum=38, |
| | step=1, |
| | label="Index of token for heatmap visualization (ignored if zero)", |
| | ), |
| | ] |
| | outputs = [ |
| | gr.outputs.Image(label="Image"), |
| |
|
| |
|
| | gr.outputs.Textbox(label="matching index "), |
| | ] |
| |
|
| |
|
| | interface = gr.Interface( |
| | fn=infer, |
| | inputs=inputs, |
| | outputs=outputs, |
| | server_name="localhost", |
| | server_port=8888, |
| |
|
| | ) |
| |
|
| | interface.launch(debug=True,share=False) |