objectdetection / src /streamlit_app.py
sandbox338's picture
Update src/streamlit_app.py
3d2c66c verified
import streamlit as st
from PIL import Image
import numpy as np
import torch
import asyncio
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Fix event loop issue
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.set_event_loop(asyncio.new_event_loop())
# Title and uploader
st.title("Detectron2 Object Detection")
st.write("Upload an image to perform object detection")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@st.cache_resource
def load_model():
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
return predictor
def predict_fn(predictor, image):
image_array = np.array(image)[:, :, :3]
outputs = predictor(image_array)
return outputs["instances"], image_array
def visualize_predictions(image, instances):
v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("coco_2017_val"), scale=1.2)
v = v.draw_instance_predictions(instances)
result = v.get_image()
return result[:, :, ::-1]
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
st.write("Processing...")
predictor = load_model()
instances, image_array = predict_fn(predictor, image)
result_image = visualize_predictions(image_array, instances)
st.image(result_image, caption="Detected Objects", use_column_width=True)