| | import os |
| | os.system("pip install git+https://github.com/zhanghang1989/PyTorch-Encoding/") |
| | os.system("pip install git+https://github.com/openai/CLIP.git") |
| | os.mkdir("checkpoints") |
| | os.system("wget https://huggingface.co/isl-org/lang-seg/resolve/main/demo_e200.ckpt -O /home/user/app/checkpoints/demo_e200.ckpt") |
| | import torch |
| | import argparse |
| | import numpy as np |
| | from tqdm import tqdm |
| | from collections import OrderedDict |
| | import torch.nn.functional as F |
| | from torch.utils import data |
| | import torchvision.transforms as transform |
| | from torch.nn.parallel.scatter_gather import gather |
| | from additional_utils.models import LSeg_MultiEvalModule |
| | from modules.lseg_module import LSegModule |
| | import cv2 |
| | import math |
| | import types |
| | import functools |
| | import torchvision.transforms as torch_transforms |
| | import copy |
| | import itertools |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | import clip |
| | from encoding.models.sseg import BaseNet |
| | import matplotlib as mpl |
| | import matplotlib.colors as mplc |
| | import matplotlib.figure as mplfigure |
| | import matplotlib.patches as mpatches |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg |
| | from data import get_dataset |
| | import torchvision.transforms as transforms |
| |
|
| | import gradio as gr |
| |
|
| | model_name = "lseg_demo" |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | def get_new_pallete(num_cls): |
| | n = num_cls |
| | pallete = [0]*(n*3) |
| | for j in range(0,n): |
| | lab = j |
| | pallete[j*3+0] = 0 |
| | pallete[j*3+1] = 0 |
| | pallete[j*3+2] = 0 |
| | i = 0 |
| | while (lab > 0): |
| | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) |
| | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) |
| | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) |
| | i = i + 1 |
| | lab >>= 3 |
| | return pallete |
| |
|
| | def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None): |
| | """Get image color pallete for visualizing masks""" |
| | |
| | out_img = Image.fromarray(npimg.squeeze().astype('uint8')) |
| | out_img.putpalette(new_palette) |
| |
|
| | if out_label_flag: |
| | assert labels is not None |
| | u_index = np.unique(npimg) |
| | patches = [] |
| | for i, index in enumerate(u_index): |
| | label = labels[index] |
| | cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0] |
| | red_patch = mpatches.Patch(color=cur_color, label=label) |
| | patches.append(red_patch) |
| | return out_img, patches |
| |
|
| | def load_model(): |
| | class Options: |
| | def __init__(self): |
| | parser = argparse.ArgumentParser(description="PyTorch Segmentation") |
| | |
| | parser.add_argument( |
| | "--model", type=str, default="encnet", help="model name (default: encnet)" |
| | ) |
| | parser.add_argument( |
| | "--backbone", |
| | type=str, |
| | default="clip_vitl16_384", |
| | help="backbone name (default: resnet50)", |
| | ) |
| | parser.add_argument( |
| | "--dataset", |
| | type=str, |
| | default="ade20k", |
| | help="dataset name (default: pascal12)", |
| | ) |
| | parser.add_argument( |
| | "--workers", type=int, default=1, metavar="N", help="dataloader threads" |
| | ) |
| | parser.add_argument( |
| | "--base-size", type=int, default=256, help="base image size" |
| | ) |
| | parser.add_argument( |
| | "--crop-size", type=int, default=128, help="crop image size" |
| | ) |
| | parser.add_argument( |
| | "--train-split", |
| | type=str, |
| | default="train", |
| | help="dataset train split (default: train)", |
| | ) |
| | parser.add_argument( |
| | "--aux", action="store_true", default=False, help="Auxilary Loss" |
| | ) |
| | parser.add_argument( |
| | "--se-loss", |
| | action="store_true", |
| | default=False, |
| | help="Semantic Encoding Loss SE-loss", |
| | ) |
| | parser.add_argument( |
| | "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" |
| | ) |
| | parser.add_argument( |
| | "--batch-size", |
| | type=int, |
| | default=16, |
| | metavar="N", |
| | help="input batch size for \ |
| | training (default: auto)", |
| | ) |
| | parser.add_argument( |
| | "--test-batch-size", |
| | type=int, |
| | default=16, |
| | metavar="N", |
| | help="input batch size for \ |
| | testing (default: same as batch size)", |
| | ) |
| | |
| | parser.add_argument( |
| | "--no-cuda", |
| | action="store_true", |
| | default=True, |
| | help="disables CUDA training", |
| | ) |
| | parser.add_argument( |
| | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" |
| | ) |
| | |
| | parser.add_argument( |
| | "--weights", type=str, default='', help="checkpoint to test" |
| | ) |
| | |
| | parser.add_argument( |
| | "--eval", action="store_true", default=False, help="evaluating mIoU" |
| | ) |
| | parser.add_argument( |
| | "--export", |
| | type=str, |
| | default=None, |
| | help="put the path to resuming file if needed", |
| | ) |
| | parser.add_argument( |
| | "--acc-bn", |
| | action="store_true", |
| | default=False, |
| | help="Re-accumulate BN statistics", |
| | ) |
| | parser.add_argument( |
| | "--test-val", |
| | action="store_true", |
| | default=False, |
| | help="generate masks on val set", |
| | ) |
| | parser.add_argument( |
| | "--no-val", |
| | action="store_true", |
| | default=False, |
| | help="skip validation during training", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--module", |
| | default='lseg', |
| | help="select model definition", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--data-path", type=str, default='../datasets/', help="path to test image folder" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--no-scaleinv", |
| | dest="scale_inv", |
| | default=True, |
| | action="store_false", |
| | help="turn off scaleinv layers", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--widehead", default=False, action="store_true", help="wider output head" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--widehead_hr", |
| | default=False, |
| | action="store_true", |
| | help="wider output head", |
| | ) |
| | parser.add_argument( |
| | "--ignore_index", |
| | type=int, |
| | default=-1, |
| | help="numeric value of ignore label in gt", |
| | ) |
| | |
| | parser.add_argument( |
| | "--label_src", |
| | type=str, |
| | default="default", |
| | help="how to get the labels", |
| | ) |
| | |
| | parser.add_argument( |
| | "--arch_option", |
| | type=int, |
| | default=0, |
| | help="which kind of architecture to be used", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--block_depth", |
| | type=int, |
| | default=0, |
| | help="how many blocks should be used", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--activation", |
| | choices=['lrelu', 'tanh'], |
| | default="lrelu", |
| | help="use which activation to activate the block", |
| | ) |
| |
|
| | self.parser = parser |
| |
|
| | def parse(self): |
| | args = self.parser.parse_args(args=[]) |
| | args.cuda = not args.no_cuda and torch.cuda.is_available() |
| | print(args) |
| | return args |
| |
|
| | args = Options().parse() |
| |
|
| | torch.manual_seed(args.seed) |
| | args.test_batch_size = 1 |
| | alpha=0.5 |
| | |
| | args.scale_inv = False |
| | args.widehead = True |
| | args.dataset = 'ade20k' |
| | args.backbone = 'clip_vitl16_384' |
| | args.weights = 'checkpoints/demo_e200.ckpt' |
| | args.ignore_index = 255 |
| |
|
| | module = LSegModule.load_from_checkpoint( |
| | checkpoint_path=args.weights, |
| | data_path=args.data_path, |
| | dataset=args.dataset, |
| | backbone=args.backbone, |
| | aux=args.aux, |
| | num_features=256, |
| | aux_weight=0, |
| | se_loss=False, |
| | se_weight=0, |
| | base_lr=0, |
| | batch_size=1, |
| | max_epochs=0, |
| | ignore_index=args.ignore_index, |
| | dropout=0.0, |
| | scale_inv=args.scale_inv, |
| | augment=False, |
| | no_batchnorm=False, |
| | widehead=args.widehead, |
| | widehead_hr=args.widehead_hr, |
| | map_location="cpu", |
| | arch_option=0, |
| | block_depth=0, |
| | activation='lrelu', |
| | ) |
| |
|
| | input_transform = module.val_transform |
| |
|
| | |
| | loader_kwargs = ( |
| | {"num_workers": args.workers, "pin_memory": True} if args.cuda else {} |
| | ) |
| |
|
| | |
| | if isinstance(module.net, BaseNet): |
| | model = module.net |
| | else: |
| | model = module |
| | |
| | model = model.eval() |
| | model = model.cpu() |
| | scales = ( |
| | [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] |
| | if args.dataset == "citys" |
| | else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] |
| | ) |
| |
|
| | model.mean = [0.5, 0.5, 0.5] |
| | model.std = [0.5, 0.5, 0.5] |
| | evaluator = LSeg_MultiEvalModule( |
| | model, scales=scales, flip=True |
| | ).cpu() |
| | evaluator.eval() |
| | |
| | transform = transforms.Compose( |
| | [ |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| | transforms.Resize([360,480]), |
| | ] |
| | ) |
| |
|
| | return evaluator, transform |
| |
|
| | """ |
| | # LSeg Demo |
| | """ |
| | lseg_model, lseg_transform = load_model() |
| |
|
| | def inference(image,text): |
| | input_labels = text |
| |
|
| | pimage = lseg_transform(np.array(image)).unsqueeze(0) |
| |
|
| | labels = [] |
| | for label in input_labels.split(","): |
| | labels.append(label.strip()) |
| | |
| | with torch.no_grad(): |
| | outputs = lseg_model.parallel_forward(pimage, labels) |
| | |
| | predicts = [ |
| | torch.max(output, 1)[1].cpu().numpy() |
| | for output in outputs |
| | ] |
| | |
| | image = pimage[0].permute(1,2,0) |
| | image = image * 0.5 + 0.5 |
| | image = Image.fromarray(np.uint8(255*image)).convert("RGBA") |
| | |
| | pred = predicts[0] |
| | new_palette = get_new_pallete(len(labels)) |
| | mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels) |
| | seg = mask.convert("RGBA") |
| |
|
| | fig = plt.figure() |
| | plt.subplot(121) |
| | plt.axis('off') |
| |
|
| | plt.subplot(122) |
| | plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5}) |
| | plt.axis('off') |
| | |
| | plt.tight_layout() |
| | plt.savefig('out.png', bbox_inches='tight') |
| | return 'out.png' |
| |
|
| | title = "Lang-Seg" |
| | description = "Gradio demo for Language-driven Semantic Segmentation. To use it, simply add your image and text or click one of the examples to load them. Read more at the links below." |
| |
|
| | article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03546' target='_blank'>Language-driven Semantic Segmentation</a> | <a href='https://github.com/isl-org/lang-seg' target='_blank'>Github Repo</a></p>" |
| |
|
| | gr.Interface(inference,[gr.inputs.Image(type="pil"),"text"],gr.outputs.Image(type="file"),title=title,description=description,article=article).launch(enable_queue=True) |