| from collections import namedtuple |
| import altair as alt |
| import math |
| import pandas as pd |
| import streamlit as st |
| st.set_page_config(layout="wide") |
|
|
| from PIL import Image |
|
|
| import os |
| import torch |
|
|
| import os |
| import argparse |
| import numpy as np |
| from tqdm import tqdm |
| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils import data |
| import torchvision.transforms as transform |
| from torch.nn.parallel.scatter_gather import gather |
|
|
| from additional_utils.models import LSeg_MultiEvalModule |
| from modules.lseg_module import LSegModule |
|
|
| import cv2 |
| import math |
| import types |
| import functools |
| import torchvision.transforms as torch_transforms |
| import copy |
| import itertools |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import clip |
| from encoding.models.sseg import BaseNet |
| import matplotlib as mpl |
| import matplotlib.colors as mplc |
| import matplotlib.figure as mplfigure |
| import matplotlib.patches as mpatches |
| from matplotlib.backends.backend_agg import FigureCanvasAgg |
| from data import get_dataset |
| import torchvision.transforms as transforms |
|
|
|
|
| def get_new_pallete(num_cls): |
| n = num_cls |
| pallete = [0]*(n*3) |
| for j in range(0,n): |
| lab = j |
| pallete[j*3+0] = 0 |
| pallete[j*3+1] = 0 |
| pallete[j*3+2] = 0 |
| i = 0 |
| while (lab > 0): |
| pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) |
| pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) |
| pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) |
| i = i + 1 |
| lab >>= 3 |
| return pallete |
|
|
| def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None): |
| """Get image color pallete for visualizing masks""" |
| |
| out_img = Image.fromarray(npimg.squeeze().astype('uint8')) |
| out_img.putpalette(new_palette) |
|
|
| if out_label_flag: |
| assert labels is not None |
| u_index = np.unique(npimg) |
| patches = [] |
| for i, index in enumerate(u_index): |
| label = labels[index] |
| cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0] |
| red_patch = mpatches.Patch(color=cur_color, label=label) |
| patches.append(red_patch) |
| return out_img, patches |
|
|
| @st.cache(allow_output_mutation=True) |
| def load_model(): |
| class Options: |
| def __init__(self): |
| parser = argparse.ArgumentParser(description="PyTorch Segmentation") |
| |
| parser.add_argument( |
| "--model", type=str, default="encnet", help="model name (default: encnet)" |
| ) |
| parser.add_argument( |
| "--backbone", |
| type=str, |
| default="clip_vitl16_384", |
| help="backbone name (default: resnet50)", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="ade20k", |
| help="dataset name (default: pascal12)", |
| ) |
| parser.add_argument( |
| "--workers", type=int, default=16, metavar="N", help="dataloader threads" |
| ) |
| parser.add_argument( |
| "--base-size", type=int, default=520, help="base image size" |
| ) |
| parser.add_argument( |
| "--crop-size", type=int, default=480, help="crop image size" |
| ) |
| parser.add_argument( |
| "--train-split", |
| type=str, |
| default="train", |
| help="dataset train split (default: train)", |
| ) |
| parser.add_argument( |
| "--aux", action="store_true", default=False, help="Auxilary Loss" |
| ) |
| parser.add_argument( |
| "--se-loss", |
| action="store_true", |
| default=False, |
| help="Semantic Encoding Loss SE-loss", |
| ) |
| parser.add_argument( |
| "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=16, |
| metavar="N", |
| help="input batch size for \ |
| training (default: auto)", |
| ) |
| parser.add_argument( |
| "--test-batch-size", |
| type=int, |
| default=16, |
| metavar="N", |
| help="input batch size for \ |
| testing (default: same as batch size)", |
| ) |
| |
| parser.add_argument( |
| "--no-cuda", |
| action="store_true", |
| default=False, |
| help="disables CUDA training", |
| ) |
| parser.add_argument( |
| "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" |
| ) |
| |
| parser.add_argument( |
| "--weights", type=str, default='', help="checkpoint to test" |
| ) |
| |
| parser.add_argument( |
| "--eval", action="store_true", default=False, help="evaluating mIoU" |
| ) |
| parser.add_argument( |
| "--export", |
| type=str, |
| default=None, |
| help="put the path to resuming file if needed", |
| ) |
| parser.add_argument( |
| "--acc-bn", |
| action="store_true", |
| default=False, |
| help="Re-accumulate BN statistics", |
| ) |
| parser.add_argument( |
| "--test-val", |
| action="store_true", |
| default=False, |
| help="generate masks on val set", |
| ) |
| parser.add_argument( |
| "--no-val", |
| action="store_true", |
| default=False, |
| help="skip validation during training", |
| ) |
|
|
| parser.add_argument( |
| "--module", |
| default='lseg', |
| help="select model definition", |
| ) |
|
|
| |
| parser.add_argument( |
| "--data-path", type=str, default='../datasets/', help="path to test image folder" |
| ) |
|
|
| parser.add_argument( |
| "--no-scaleinv", |
| dest="scale_inv", |
| default=True, |
| action="store_false", |
| help="turn off scaleinv layers", |
| ) |
|
|
| parser.add_argument( |
| "--widehead", default=False, action="store_true", help="wider output head" |
| ) |
|
|
| parser.add_argument( |
| "--widehead_hr", |
| default=False, |
| action="store_true", |
| help="wider output head", |
| ) |
| parser.add_argument( |
| "--ignore_index", |
| type=int, |
| default=-1, |
| help="numeric value of ignore label in gt", |
| ) |
| |
| parser.add_argument( |
| "--label_src", |
| type=str, |
| default="default", |
| help="how to get the labels", |
| ) |
| |
| parser.add_argument( |
| "--arch_option", |
| type=int, |
| default=0, |
| help="which kind of architecture to be used", |
| ) |
|
|
| parser.add_argument( |
| "--block_depth", |
| type=int, |
| default=0, |
| help="how many blocks should be used", |
| ) |
|
|
| parser.add_argument( |
| "--activation", |
| choices=['lrelu', 'tanh'], |
| default="lrelu", |
| help="use which activation to activate the block", |
| ) |
|
|
| self.parser = parser |
|
|
| def parse(self): |
| args = self.parser.parse_args(args=[]) |
| args.cuda = not args.no_cuda and torch.cuda.is_available() |
| print(args) |
| return args |
|
|
| args = Options().parse() |
|
|
| torch.manual_seed(args.seed) |
| args.test_batch_size = 1 |
| alpha=0.5 |
| |
| args.scale_inv = False |
| args.widehead = True |
| args.dataset = 'ade20k' |
| args.backbone = 'clip_vitl16_384' |
| args.weights = 'checkpoints/demo_e200.ckpt' |
| args.ignore_index = 255 |
|
|
| module = LSegModule.load_from_checkpoint( |
| checkpoint_path=args.weights, |
| data_path=args.data_path, |
| dataset=args.dataset, |
| backbone=args.backbone, |
| aux=args.aux, |
| num_features=256, |
| aux_weight=0, |
| se_loss=False, |
| se_weight=0, |
| base_lr=0, |
| batch_size=1, |
| max_epochs=0, |
| ignore_index=args.ignore_index, |
| dropout=0.0, |
| scale_inv=args.scale_inv, |
| augment=False, |
| no_batchnorm=False, |
| widehead=args.widehead, |
| widehead_hr=args.widehead_hr, |
| map_locatin="cpu", |
| arch_option=0, |
| block_depth=0, |
| activation='lrelu', |
| ) |
|
|
| input_transform = module.val_transform |
|
|
| |
| loader_kwargs = ( |
| {"num_workers": args.workers, "pin_memory": True} if args.cuda else {} |
| ) |
|
|
| |
| if isinstance(module.net, BaseNet): |
| model = module.net |
| else: |
| model = module |
| |
| model = model.eval() |
| model = model.cpu() |
| scales = ( |
| [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] |
| if args.dataset == "citys" |
| else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] |
| ) |
|
|
| model.mean = [0.5, 0.5, 0.5] |
| model.std = [0.5, 0.5, 0.5] |
| evaluator = LSeg_MultiEvalModule( |
| model, scales=scales, flip=True |
| ).cuda() |
| evaluator.eval() |
| |
| transform = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| transforms.Resize([360,480]), |
| ] |
| ) |
|
|
| return evaluator, transform |
|
|
| """ |
| # LSeg Demo |
| """ |
| lseg_model, lseg_transform = load_model() |
| uploaded_file = st.file_uploader("Choose an image...") |
| input_labels = st.text_input("Input labels", value="dog, grass, other") |
| st.write("The labels are", input_labels) |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file) |
| pimage = lseg_transform(np.array(image)).unsqueeze(0) |
|
|
| labels = [] |
| for label in input_labels.split(","): |
| labels.append(label.strip()) |
| |
| with torch.no_grad(): |
| outputs = lseg_model.parallel_forward(pimage, labels) |
| |
| predicts = [ |
| torch.max(output, 1)[1].cpu().numpy() |
| for output in outputs |
| ] |
| |
| image = pimage[0].permute(1,2,0) |
| image = image * 0.5 + 0.5 |
| image = Image.fromarray(np.uint8(255*image)).convert("RGBA") |
| |
| pred = predicts[0] |
| new_palette = get_new_pallete(len(labels)) |
| mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels) |
| seg = mask.convert("RGBA") |
|
|
| fig = plt.figure() |
| plt.subplot(121) |
| plt.imshow(image) |
| plt.axis('off') |
|
|
| plt.subplot(122) |
| plt.imshow(seg) |
| plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5}) |
| plt.axis('off') |
| |
| plt.tight_layout() |
|
|
| |
| st.pyplot(fig) |
| |
| |
|
|