| |
| |
| import argparse |
| import glob |
| import multiprocessing as mp |
| import os |
| |
| os.system('pip install git+https://github.com/facebookresearch/detectron2.git') |
|
|
| |
| import sys |
| sys.path.insert(1, os.path.join(sys.path[0], '..')) |
| |
|
|
| import tempfile |
| import time |
| import warnings |
|
|
| import cv2 |
| import numpy as np |
| import tqdm |
|
|
| from detectron2.config import get_cfg |
| from detectron2.data.detection_utils import read_image |
| from detectron2.projects.deeplab import add_deeplab_config |
| from detectron2.utils.logger import setup_logger |
|
|
| from cat_seg import add_cat_seg_config |
| from demo.predictor import VisualizationDemo |
| import gradio as gr |
| import torch |
| from matplotlib.backends.backend_agg import FigureCanvasAgg as fc |
|
|
| |
| WINDOW_NAME = "MaskFormer demo" |
|
|
|
|
| def setup_cfg(args): |
| |
| cfg = get_cfg() |
| add_deeplab_config(cfg) |
| add_cat_seg_config(cfg) |
| cfg.merge_from_file(args.config_file) |
| cfg.merge_from_list(args.opts) |
| if torch.cuda.is_available(): |
| cfg.MODEL.DEVICE = "cuda" |
| cfg.freeze() |
| return cfg |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") |
| parser.add_argument( |
| "--config-file", |
| default="configs/vitl_swinb_384.yaml", |
| metavar="FILE", |
| help="path to config file", |
| ) |
| parser.add_argument( |
| "--input", |
| nargs="+", |
| help="A list of space separated input images; " |
| "or a single glob pattern such as 'directory/*.jpg'", |
| ) |
| parser.add_argument( |
| "--opts", |
| help="Modify config options using the command-line 'KEY VALUE' pairs", |
| default=( |
| ["MODEL.WEIGHTS", "model_final.pth", |
| "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json", |
| "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json", |
| "TEST.SLIDING_WINDOW", "True", |
| "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]", |
| "MODEL.PROMPT_ENSEMBLE_TYPE", "single", |
| "MODEL.DEVICE", "cpu"]), |
| nargs=argparse.REMAINDER, |
| ) |
| return parser |
|
|
| def save_masks(preds, text): |
| preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() |
| for i, t in enumerate(text): |
| dir = f"mask_{t}.png" |
| mask = preds == i |
| cv2.imwrite(dir, mask * 255) |
|
|
| def predict(image, text): |
| predictions, visualized_output = demo.run_on_image(image, text) |
| |
| canvas = fc(visualized_output.fig) |
| canvas.draw() |
| out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,)) |
|
|
| return out[..., ::-1] |
|
|
| if __name__ == "__main__": |
| args = get_parser().parse_args() |
| cfg = setup_cfg(args) |
| global demo |
| demo = VisualizationDemo(cfg) |
|
|
| iface = gr.Interface( |
| fn=predict, |
| inputs=[gr.Image(), gr.Textbox(placeholder='cat, person, background')], |
| outputs="image", |
| examples=[['assets/nyancat.png', 'cat, pop tart, rainbow, background']], |
| description="""## CAT-Seg Demo |
| Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation. |
| |
| Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model. |
| |
| To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""") |
| iface.launch() |
|
|