| from models import IntuitionKillingMachine |
| from transforms import undo_box_transforms_batch, ToTensor, Normalize, SquarePad, Resize, NormalizeBoxCoords |
| from torchvision.transforms import Compose |
| from encoders import get_tokenizer |
| from PIL import Image, ImageDraw |
| from zipfile import ZipFile |
| from copy import copy |
| import gradio as gr |
| import pandas as pd |
| import torch |
|
|
| def parse_model_args(model_path): |
| _, _, dataset, max_length, input_size, backbone, num_heads, num_layers, num_conv, _, _, mu, mask_pooling = model_path.split('_')[:13] |
| return { |
| 'dataset': dataset, |
| 'max_length': int(max_length), |
| 'input_size': int(input_size), |
| 'backbone': backbone, |
| 'num_heads': int(num_heads), |
| 'num_layers': int(num_layers), |
| 'num_conv': int(num_conv), |
| 'mu': float(mu), |
| 'mask_pooling': bool(mask_pooling == '1') |
| } |
|
|
|
|
| class Prober: |
| def __init__(self, |
| df_path=None, |
| dataset_path=None, |
| model_checkpoint=None): |
| params = parse_model_args(model_checkpoint) |
| mean = [0.485, 0.456, 0.406] |
| sdev = [0.229, 0.224, 0.225] |
| self.tokenizer = get_tokenizer() |
| self.df = pd.read_json(df_path)[['sample_idx', 'bbox', 'file_path', 'sent']] |
| self.df.loc[:, "image_id"] = self.df.file_path.apply(lambda x: int(x.split('/')[-1][:-4])) |
| self.df.file_path = self.df.file_path.apply(lambda x: x.replace('refer/data/images/', '')) |
| self.model = IntuitionKillingMachine( |
| backbone=params['backbone'], |
| pretrained=True, |
| num_heads=params['num_heads'], |
| num_layers=params['num_layers'], |
| num_conv=params['num_conv'], |
| segmentation_head=bool(params['mu'] > 0.0), |
| mask_pooling=params['mask_pooling'] |
| ) |
| self.load_model(model_checkpoint) |
| self.transform = Compose([ |
| ToTensor(), |
| Normalize(mean, sdev), |
| SquarePad(), |
| Resize(size=(params['input_size'], params['input_size'])), |
| NormalizeBoxCoords(), |
| ]) |
| self.max_length = 30 |
| self.zipfile = ZipFile(dataset_path, 'r') |
|
|
| def load_model(self, model_checkpoint): |
| checkpoint = torch.load( |
| model_checkpoint, map_location=lambda storage, loc: storage |
| ) |
|
|
| |
| state_dict = { |
| k[len('model.'):]: v |
| for k, v in checkpoint['state_dict'].items() |
| } |
|
|
| missing, _ = self.model.load_state_dict(state_dict, strict=False) |
|
|
| |
| assert [k for k in missing if 'segm' not in k] == [] |
|
|
| self.model = self.model.eval() |
|
|
| def preview_image(self, idx): |
| img_path, target, = self.df.loc[idx][['file_path','bbox']].values |
| img = Image.open(self.zipfile.open(img_path)).convert('RGB') |
| return img |
|
|
| @torch.no_grad() |
| def probe(self, idx, re, search_by_sample_id: bool= True): |
| if search_by_sample_id: |
| img_path, target, = self.df.loc[idx][['file_path','bbox']].values |
| else: |
| img_path, target = self.df[self.df.image_id == idx][['file_path','bbox']].values[0] |
| img = Image.open(self.zipfile.open(img_path)).convert('RGB') |
| if re != "": |
| W0, H0 = img.size |
| sample = { |
| 'image': img, |
| 'image_size': (H0, W0), |
| 'bbox': torch.tensor([copy(target)]), |
| 'bbox_raw': torch.tensor([copy(target)]), |
| 'mask': torch.ones((1, H0, W0), dtype=torch.float32), |
| 'mask_bbox': None, |
| } |
| sample = self.transform(sample) |
| tok = self.tokenizer(re, |
| max_length=30, |
| return_tensors='pt', |
| truncation=True) |
| inn = {'image': torch.stack([sample['image']]), |
| 'mask': torch.stack([sample['mask']]), |
| 'tok': tok} |
| output = undo_box_transforms_batch(self.model(inn)[0], |
| [sample['tr_param']]).numpy().tolist()[0] |
| img1 = ImageDraw.Draw(img) |
| |
| img1.rectangle(output, outline ="#00FF0000", width=3) |
| return img |
| else: |
| return img |
| |
| prober = Prober( |
| df_path = 'data/val-sim_metric.json', |
| dataset_path = "data/saiapr_tc-12.zip", |
| model_checkpoint= "cache/20211220_191132_refclef_32_512_resnet50_8_6_8_0.1_0.1_0.1_0_0.0001_0.0_12_4_90_1_0_0_0/best.ckpt" |
| ) |
|
|
| demo = gr.Interface(fn=prober.probe, inputs=["number", "text"], outputs="image") |
|
|
| demo.queue(concurrency_count=10) |
| demo.launch() |