Spaces:
Runtime error
Runtime error
| from collections import namedtuple | |
| import altair as alt | |
| import math | |
| import pandas as pd | |
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| from PIL import Image | |
| import os | |
| import torch | |
| import os | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils import data | |
| import torchvision.transforms as transform | |
| from torch.nn.parallel.scatter_gather import gather | |
| from additional_utils.models import LSeg_MultiEvalModule | |
| from modules.lseg_module import LSegModule | |
| import cv2 | |
| import math | |
| import types | |
| import functools | |
| import torchvision.transforms as torch_transforms | |
| import copy | |
| import itertools | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import clip | |
| from encoding.models.sseg import BaseNet | |
| import matplotlib as mpl | |
| import matplotlib.colors as mplc | |
| import matplotlib.figure as mplfigure | |
| import matplotlib.patches as mpatches | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| from data import get_dataset | |
| import torchvision.transforms as transforms | |
| def get_new_pallete(num_cls): | |
| n = num_cls | |
| pallete = [0]*(n*3) | |
| for j in range(0,n): | |
| lab = j | |
| pallete[j*3+0] = 0 | |
| pallete[j*3+1] = 0 | |
| pallete[j*3+2] = 0 | |
| i = 0 | |
| while (lab > 0): | |
| pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) | |
| pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) | |
| pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) | |
| i = i + 1 | |
| lab >>= 3 | |
| return pallete | |
| def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None): | |
| """Get image color pallete for visualizing masks""" | |
| # put colormap | |
| out_img = Image.fromarray(npimg.squeeze().astype('uint8')) | |
| out_img.putpalette(new_palette) | |
| if out_label_flag: | |
| assert labels is not None | |
| u_index = np.unique(npimg) | |
| patches = [] | |
| for i, index in enumerate(u_index): | |
| label = labels[index] | |
| cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0] | |
| red_patch = mpatches.Patch(color=cur_color, label=label) | |
| patches.append(red_patch) | |
| return out_img, patches | |
| def load_model(): | |
| class Options: | |
| def __init__(self): | |
| parser = argparse.ArgumentParser(description="PyTorch Segmentation") | |
| # model and dataset | |
| parser.add_argument( | |
| "--model", type=str, default="encnet", help="model name (default: encnet)" | |
| ) | |
| parser.add_argument( | |
| "--backbone", | |
| type=str, | |
| default="clip_vitl16_384", | |
| help="backbone name (default: resnet50)", | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default="ade20k", | |
| help="dataset name (default: pascal12)", | |
| ) | |
| parser.add_argument( | |
| "--workers", type=int, default=16, metavar="N", help="dataloader threads" | |
| ) | |
| parser.add_argument( | |
| "--base-size", type=int, default=520, help="base image size" | |
| ) | |
| parser.add_argument( | |
| "--crop-size", type=int, default=480, help="crop image size" | |
| ) | |
| parser.add_argument( | |
| "--train-split", | |
| type=str, | |
| default="train", | |
| help="dataset train split (default: train)", | |
| ) | |
| parser.add_argument( | |
| "--aux", action="store_true", default=False, help="Auxilary Loss" | |
| ) | |
| parser.add_argument( | |
| "--se-loss", | |
| action="store_true", | |
| default=False, | |
| help="Semantic Encoding Loss SE-loss", | |
| ) | |
| parser.add_argument( | |
| "--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=16, | |
| metavar="N", | |
| help="input batch size for \ | |
| training (default: auto)", | |
| ) | |
| parser.add_argument( | |
| "--test-batch-size", | |
| type=int, | |
| default=16, | |
| metavar="N", | |
| help="input batch size for \ | |
| testing (default: same as batch size)", | |
| ) | |
| # cuda, seed and logging | |
| parser.add_argument( | |
| "--no-cuda", | |
| action="store_true", | |
| default=False, | |
| help="disables CUDA training", | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" | |
| ) | |
| # checking point | |
| parser.add_argument( | |
| "--weights", type=str, default='', help="checkpoint to test" | |
| ) | |
| # evaluation option | |
| parser.add_argument( | |
| "--eval", action="store_true", default=False, help="evaluating mIoU" | |
| ) | |
| parser.add_argument( | |
| "--export", | |
| type=str, | |
| default=None, | |
| help="put the path to resuming file if needed", | |
| ) | |
| parser.add_argument( | |
| "--acc-bn", | |
| action="store_true", | |
| default=False, | |
| help="Re-accumulate BN statistics", | |
| ) | |
| parser.add_argument( | |
| "--test-val", | |
| action="store_true", | |
| default=False, | |
| help="generate masks on val set", | |
| ) | |
| parser.add_argument( | |
| "--no-val", | |
| action="store_true", | |
| default=False, | |
| help="skip validation during training", | |
| ) | |
| parser.add_argument( | |
| "--module", | |
| default='lseg', | |
| help="select model definition", | |
| ) | |
| # test option | |
| parser.add_argument( | |
| "--data-path", type=str, default='../datasets/', help="path to test image folder" | |
| ) | |
| parser.add_argument( | |
| "--no-scaleinv", | |
| dest="scale_inv", | |
| default=True, | |
| action="store_false", | |
| help="turn off scaleinv layers", | |
| ) | |
| parser.add_argument( | |
| "--widehead", default=False, action="store_true", help="wider output head" | |
| ) | |
| parser.add_argument( | |
| "--widehead_hr", | |
| default=False, | |
| action="store_true", | |
| help="wider output head", | |
| ) | |
| parser.add_argument( | |
| "--ignore_index", | |
| type=int, | |
| default=-1, | |
| help="numeric value of ignore label in gt", | |
| ) | |
| parser.add_argument( | |
| "--label_src", | |
| type=str, | |
| default="default", | |
| help="how to get the labels", | |
| ) | |
| parser.add_argument( | |
| "--arch_option", | |
| type=int, | |
| default=0, | |
| help="which kind of architecture to be used", | |
| ) | |
| parser.add_argument( | |
| "--block_depth", | |
| type=int, | |
| default=0, | |
| help="how many blocks should be used", | |
| ) | |
| parser.add_argument( | |
| "--activation", | |
| choices=['lrelu', 'tanh'], | |
| default="lrelu", | |
| help="use which activation to activate the block", | |
| ) | |
| self.parser = parser | |
| def parse(self): | |
| args = self.parser.parse_args(args=[]) | |
| args.cuda = not args.no_cuda and torch.cuda.is_available() | |
| print(args) | |
| return args | |
| args = Options().parse() | |
| torch.manual_seed(args.seed) | |
| args.test_batch_size = 1 | |
| alpha=0.5 | |
| args.scale_inv = False | |
| args.widehead = True | |
| args.dataset = 'ade20k' | |
| args.backbone = 'clip_vitl16_384' | |
| args.weights = 'checkpoints/demo_e200.ckpt' | |
| args.ignore_index = 255 | |
| module = LSegModule.load_from_checkpoint( | |
| checkpoint_path=args.weights, | |
| data_path=args.data_path, | |
| dataset=args.dataset, | |
| backbone=args.backbone, | |
| aux=args.aux, | |
| num_features=256, | |
| aux_weight=0, | |
| se_loss=False, | |
| se_weight=0, | |
| base_lr=0, | |
| batch_size=1, | |
| max_epochs=0, | |
| ignore_index=args.ignore_index, | |
| dropout=0.0, | |
| scale_inv=args.scale_inv, | |
| augment=False, | |
| no_batchnorm=False, | |
| widehead=args.widehead, | |
| widehead_hr=args.widehead_hr, | |
| map_locatin="cpu", | |
| arch_option=0, | |
| block_depth=0, | |
| activation='lrelu', | |
| ) | |
| input_transform = module.val_transform | |
| # dataloader | |
| loader_kwargs = ( | |
| {"num_workers": args.workers, "pin_memory": True} if args.cuda else {} | |
| ) | |
| # model | |
| if isinstance(module.net, BaseNet): | |
| model = module.net | |
| else: | |
| model = module | |
| model = model.eval() | |
| model = model.cpu() | |
| scales = ( | |
| [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] | |
| if args.dataset == "citys" | |
| else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | |
| ) | |
| model.mean = [0.5, 0.5, 0.5] | |
| model.std = [0.5, 0.5, 0.5] | |
| evaluator = LSeg_MultiEvalModule( | |
| model, scales=scales, flip=True | |
| ).cuda() | |
| evaluator.eval() | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| transforms.Resize([360,480]), | |
| ] | |
| ) | |
| return evaluator, transform | |
| """ | |
| # LSeg Demo | |
| """ | |
| lseg_model, lseg_transform = load_model() | |
| uploaded_file = st.file_uploader("Choose an image...") | |
| input_labels = st.text_input("Input labels", value="dog, grass, other") | |
| st.write("The labels are", input_labels) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file) | |
| pimage = lseg_transform(np.array(image)).unsqueeze(0) | |
| labels = [] | |
| for label in input_labels.split(","): | |
| labels.append(label.strip()) | |
| with torch.no_grad(): | |
| outputs = lseg_model.parallel_forward(pimage, labels) | |
| predicts = [ | |
| torch.max(output, 1)[1].cpu().numpy() | |
| for output in outputs | |
| ] | |
| image = pimage[0].permute(1,2,0) | |
| image = image * 0.5 + 0.5 | |
| image = Image.fromarray(np.uint8(255*image)).convert("RGBA") | |
| pred = predicts[0] | |
| new_palette = get_new_pallete(len(labels)) | |
| mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels) | |
| seg = mask.convert("RGBA") | |
| fig = plt.figure() | |
| plt.subplot(121) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.subplot(122) | |
| plt.imshow(seg) | |
| plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5}) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| #st.image([image,seg], width=700, caption=["Input image", "Segmentation"]) | |
| st.pyplot(fig) | |