from argparse import ArgumentParser, Namespace from typing import Dict, List, Tuple import codecs import yaml import numpy as np import cv2 from PIL import Image import torch import torch.nn.functional as F from torchvision.transforms.functional import to_tensor, normalize, resize import gradio as gr from utils import get_network, colourise_mask import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # state_dict: dict = torch.hub.load_state_dict_from_url( # "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt", # map_location=device # "cuda" if torch.cuda.is_available() else "cpu" # )["model"] parser = ArgumentParser("NamedMask demo") parser.add_argument( "--config", type=str, default="voc_val_n500_cp2_ex.yaml" ) args: Namespace = parser.parse_args() base_args = yaml.safe_load(open(f"{args.config}", 'r')) base_args.pop("dataset_name") args: dict = vars(args) args.update(base_args) args: Namespace = Namespace(**args) model = get_network().to(device) # model.load_state_dict(state_dict) model.eval() size: int = 384 max_size: int = 512 mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) std: Tuple[float, float, float] = (0.229, 0.224, 0.225) @torch.no_grad() def main(image: Image): pil_image: Image.Image = resize(image, size=size, max_size=max_size) image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W # logits: b (=1) x n_categories x H x W, torch.float32 logits: torch.Tensor = model(image[None].to(device)) # pred: H x W pred: torch.Tensor = logits.squeeze(dim=0).argmax(dim=0).cpu().numpy() coloured_pred: np.ndarray = colourise_mask(mask=pred) super_imposed_img = cv2.addWeighted(coloured_pred, 0.5, np.array(pil_image), 0.5, 0) # resize prediction to original resolution # note: upsampling by 4 and cutting the padded region allows for a better result # H, W = image.shape[-2:] # # # iterate over batch dimension # pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 # # pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) # # attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) # super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) return super_imposed_img demo = gr.Interface( fn=main, inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"), outputs=gr.outputs.Image(type="numpy", label="prediction"), # "image", examples=[f"images/{fname}.jpg" for fname in [ "2007_002260", "2008_002536", "2008_003499", "2008_007814", "2010_001079", "2010_005063" ]], examples_per_page=10, description=codecs.open("description.html", 'r', "utf-8").read(), title="NamedMask: Distilling Segmenters from Complementary Foundation Models", allow_flagging="never", analytics_enabled=False ) demo.launch( # share=True )