File size: 4,962 Bytes
ab839c8
 
 
 
 
 
d21b5aa
ab839c8
 
 
9882922
ab839c8
 
 
 
 
1d4a16a
 
 
 
 
 
a02e08d
 
 
 
 
 
93021e6
 
d1760b8
93021e6
a02e08d
2f8146d
a24460f
4a8a8fc
a02e08d
 
 
2f8146d
a02e08d
efd7b65
bf7e36c
 
d973f96
bf7e36c
 
 
a02e08d
 
bf7e36c
a02e08d
 
 
2f8146d
 
8ab219f
0f8a0fd
a6374a2
 
0f8a0fd
 
 
8ab219f
06c7151
2f8146d
 
 
73fe025
a02e08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b7875e
 
 
 
 
 
 
 
 
ab839c8
6b7875e
ab839c8
a02e08d
 
 
39936ff
a02e08d
 
39936ff
a02e08d
 
 
 
 
 
 
 
 
 
 
 
ab839c8
a02e08d
ab839c8
a02e08d
ab839c8
a02e08d
d21b5aa
a02e08d
d21b5aa
a02e08d
d21b5aa
a02e08d
ab839c8
a02e08d
ab839c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
try:
    import detectron2
except:
    import os 
    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

import cv2
import gradio as gr
import requests
import numpy as np
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

from matplotlib.pyplot import axis
from torch import nn
import requests

import torch

# Predefined Detectron2 models
models = [
    {
        "name": "Instance Segmentation",
        "config_file": "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
    },
    {
        "name": "Panoptic Segmentation",
        "config_file": "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml",
    },
    {
        "name": "Custom Model",
        "config_file": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",  
        "model_path": "https://huggingface.co/spaces/eyepop-ai/segmentation/resolve/main/model_final.pth",  
    },
]

def setup_model(config_file, model_path=None):
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(config_file))

    if model_path:
        cfg.MODEL.WEIGHTS = model_path
    else:
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)

    if not torch.cuda.is_available():
        cfg.MODEL.DEVICE = "cpu"

    return cfg

for model in models:
    if model["name"] == "Custom Model":
        model["cfg"] = setup_model(model["config_file"], model["model_path"])
        # model["metadata"] = MetadataCatalog.get("teng-valid")
        model["metadata"] = MetadataCatalog.get("teng-train").set(
            # json_file="./train/_annotations.coco.json",
            # image_root="./train",
            evaluator_type="coco",
            thing_classes=['cell', 'object'],
            thing_colors=[(0, 255, 0), (0, 255, 0)]
        )
        print(model["metadata"])
    else:
        model["cfg"] = setup_model(model["config_file"])
        model["metadata"] = MetadataCatalog.get(model["cfg"].DATASETS.TRAIN[0])
        print(model["metadata"])

def inference(image_url, image, min_score, model_name):
    model = next((m for m in models if m["name"] == model_name), None)
    if not model:
        raise ValueError("Model not found")

    if image_url:
        r = requests.get(image_url)
        if r:
            im = np.frombuffer(r.content, dtype="uint8")
            im = cv2.imdecode(im, cv2.IMREAD_COLOR)
    else:
        # Model expects BGR!
        im = image[:,:,::-1]

    model["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
    predictor = DefaultPredictor(model["cfg"])
    outputs = predictor(im)

    if model_name == "Panoptic Segmentation":
        panoptic_seg, segments_info = outputs["panoptic_seg"]
        v = Visualizer(im[:, :, ::-1], model["metadata"], scale=1.2)
        out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
        processed_image = out.get_image()[:, :, ::-1]
    else:
        v = Visualizer(im, model["metadata"], scale=1.2)
        out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
        processed_image = out.get_image()

    return processed_image

title = "# Segmentation Model Demo"
description = """
This demo introduces an interactive playground for pretrained Detectron2 model.
Currently, two models are supported that were trained on COCO and custom datasets:
* [Instance Segmentation](https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml): Identifies, outlines individual object instances.
* [Panoptic Segmentation](https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml): Unifies instance and semantic segmentation.
* [Custom Model](https://huggingface.co/spaces/eyepop-ai/segmentation/blob/main/model_final.pth): Identifies, outlines rounded objects in petri dishes.
"""
footer = "Made by eyepop.ai with ❤️."

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Tab("From URL"):
        url_input = gr.Textbox(label="Image URL", placeholder="https://images.unsplash.com/photo-1701226362119-cc86312846af?q=80&w=1587&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D")
    
    with gr.Tab("From Image"):
        image_input = gr.Image(type="numpy", label="Input Image")

    min_score = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score")

    model_name = gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model")

    output_image = gr.Image(type="pil", label="Output")

    inference_button = gr.Button("Submit")
    
    inference_button.click(fn=inference, inputs=[url_input, image_input, min_score, model_name], outputs=output_image)

    gr.Markdown(footer)

demo.launch()