File size: 1,803 Bytes
8865f54
e92a879
3d2c66c
 
 
 
 
 
 
 
 
 
e92a879
3d2c66c
 
 
af9f2a8
3d2c66c
 
 
e92a879
3d2c66c
e92a879
3d2c66c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
from PIL import Image
import numpy as np
import torch
import asyncio
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog

# Fix event loop issue
try:
    asyncio.get_running_loop()
except RuntimeError:
    asyncio.set_event_loop(asyncio.new_event_loop())

# Title and uploader
st.title("Detectron2 Object Detection")
st.write("Upload an image to perform object detection")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

@st.cache_resource
def load_model():
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
    predictor = DefaultPredictor(cfg)
    return predictor

def predict_fn(predictor, image):
    image_array = np.array(image)[:, :, :3]
    outputs = predictor(image_array)
    return outputs["instances"], image_array

def visualize_predictions(image, instances):
    v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("coco_2017_val"), scale=1.2)
    v = v.draw_instance_predictions(instances)
    result = v.get_image()
    return result[:, :, ::-1]

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_column_width=True)
    st.write("Processing...")

    predictor = load_model()
    instances, image_array = predict_fn(predictor, image)
    result_image = visualize_predictions(image_array, instances)

    st.image(result_image, caption="Detected Objects", use_column_width=True)