segmentation / app.py
mangaruu's picture
Update app.py
a6374a2
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import cv2
import gradio as gr
import requests
import numpy as np
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from matplotlib.pyplot import axis
from torch import nn
import requests
import torch
# Predefined Detectron2 models
models = [
{
"name": "Instance Segmentation",
"config_file": "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
},
{
"name": "Panoptic Segmentation",
"config_file": "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml",
},
{
"name": "Custom Model",
"config_file": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
"model_path": "https://huggingface.co/spaces/eyepop-ai/segmentation/resolve/main/model_final.pth",
},
]
def setup_model(config_file, model_path=None):
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(config_file))
if model_path:
cfg.MODEL.WEIGHTS = model_path
else:
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_file)
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
return cfg
for model in models:
if model["name"] == "Custom Model":
model["cfg"] = setup_model(model["config_file"], model["model_path"])
# model["metadata"] = MetadataCatalog.get("teng-valid")
model["metadata"] = MetadataCatalog.get("teng-train").set(
# json_file="./train/_annotations.coco.json",
# image_root="./train",
evaluator_type="coco",
thing_classes=['cell', 'object'],
thing_colors=[(0, 255, 0), (0, 255, 0)]
)
print(model["metadata"])
else:
model["cfg"] = setup_model(model["config_file"])
model["metadata"] = MetadataCatalog.get(model["cfg"].DATASETS.TRAIN[0])
print(model["metadata"])
def inference(image_url, image, min_score, model_name):
model = next((m for m in models if m["name"] == model_name), None)
if not model:
raise ValueError("Model not found")
if image_url:
r = requests.get(image_url)
if r:
im = np.frombuffer(r.content, dtype="uint8")
im = cv2.imdecode(im, cv2.IMREAD_COLOR)
else:
# Model expects BGR!
im = image[:,:,::-1]
model["cfg"].MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
predictor = DefaultPredictor(model["cfg"])
outputs = predictor(im)
if model_name == "Panoptic Segmentation":
panoptic_seg, segments_info = outputs["panoptic_seg"]
v = Visualizer(im[:, :, ::-1], model["metadata"], scale=1.2)
out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
processed_image = out.get_image()[:, :, ::-1]
else:
v = Visualizer(im, model["metadata"], scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
processed_image = out.get_image()
return processed_image
title = "# Segmentation Model Demo"
description = """
This demo introduces an interactive playground for pretrained Detectron2 model.
Currently, two models are supported that were trained on COCO and custom datasets:
* [Instance Segmentation](https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml): Identifies, outlines individual object instances.
* [Panoptic Segmentation](https://github.com/facebookresearch/detectron2/blob/main/configs/COCO-PanopticSegmentation/panoptic_fpn_R_50_1x.yaml): Unifies instance and semantic segmentation.
* [Custom Model](https://huggingface.co/spaces/eyepop-ai/segmentation/blob/main/model_final.pth): Identifies, outlines rounded objects in petri dishes.
"""
footer = "Made by eyepop.ai with ❤️."
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tab("From URL"):
url_input = gr.Textbox(label="Image URL", placeholder="https://images.unsplash.com/photo-1701226362119-cc86312846af?q=80&w=1587&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D")
with gr.Tab("From Image"):
image_input = gr.Image(type="numpy", label="Input Image")
min_score = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Minimum score")
model_name = gr.Radio(choices=[model["name"] for model in models], value=models[0]["name"], label="Select Detectron2 model")
output_image = gr.Image(type="pil", label="Output")
inference_button = gr.Button("Submit")
inference_button.click(fn=inference, inputs=[url_input, image_input, min_score, model_name], outputs=output_image)
gr.Markdown(footer)
demo.launch()