objectdetection / src /streamlit_app.py
sandbox338's picture
Update src/streamlit_app.py
d669457 verified
raw
history blame
7.59 kB
import streamlit as st
import torch
import numpy as np
import cv2
import os
import tempfile
from PIL import Image
import traceback
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Setup page config
st.set_page_config(
page_title="Object Detection with Detectron2",
page_icon="πŸ”",
layout="wide"
)
# Print environment info for debugging
if "debug" in st.experimental_get_query_params():
st.write("Environment variables:", dict(os.environ))
st.write("Current working directory:", os.getcwd())
st.write("Directory contents:", os.listdir())
st.write("Temp directory:", tempfile.gettempdir())
st.write("PyTorch CUDA available:", torch.cuda.is_available())
# Load the Detectron2 model
@st.cache_resource
def load_model():
try:
# Set the configuration for the model
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for detection confidence
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
# Use GPU if available, otherwise use CPU
if torch.cuda.is_available():
st.sidebar.success("GPU is available! Using CUDA.")
cfg.MODEL.DEVICE = "cuda"
else:
st.sidebar.info("GPU not available. Using CPU.")
cfg.MODEL.DEVICE = "cpu"
# Initialize the predictor
predictor = DefaultPredictor(cfg)
return predictor
except Exception as e:
st.error(f"Error loading model: {e}")
st.error(traceback.format_exc())
return None
# Function for image prediction
def predict_fn(image, predictor):
try:
# Convert the PIL image to a format the model can use
image_array = np.array(image.convert("RGB"))
# Make predictions
outputs = predictor(image_array)
# Get the predicted classes and bounding boxes
instances = outputs["instances"].to("cpu")
pred_classes = instances.pred_classes.numpy()
pred_boxes = instances.pred_boxes.tensor.numpy()
scores = instances.scores.numpy()
# Get class names from metadata
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
class_names = metadata.thing_classes
return pred_classes, pred_boxes, scores, image_array, class_names
except Exception as e:
st.error(f"Error during prediction: {e}")
st.error(traceback.format_exc())
return None, None, None, np.array(image), None
# Function to visualize predictions
def visualize_predictions(image_array, instances, class_names):
try:
# Create a visualizer object
metadata = MetadataCatalog.get("coco_2017_val")
if class_names is not None:
metadata.thing_classes = class_names
v = Visualizer(image_array, metadata=metadata, scale=1.2)
# Draw the predictions on the image
vis_output = v.draw_instance_predictions(instances.to("cpu"))
# Return the visualized image
return vis_output.get_image()
except Exception as e:
st.error(f"Error during visualization: {e}")
st.error(traceback.format_exc())
return image_array
# Main application
def main():
st.title("πŸ” Object Detection with Detectron2")
st.markdown("""
Upload an image to detect objects using Facebook AI Research's Detectron2 model.
This demo uses the Faster R-CNN model with ResNet-50-FPN backbone.
""")
# Create a sidebar
st.sidebar.title("Settings")
confidence_threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
# Load the model
with st.spinner("Loading model... This might take a minute on first run."):
predictor = load_model()
if predictor is None:
st.error("Failed to load model. Please check the error messages above.")
return
# Update model confidence threshold
predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.columns(2)
if uploaded_file is not None:
# Display original image
image = Image.open(uploaded_file)
with col1:
st.subheader("Original Image")
st.image(image, use_column_width=True)
# Process the image
with st.spinner("Detecting objects..."):
try:
# Make predictions
pred_classes, pred_boxes, scores, image_array, class_names = predict_fn(image, predictor)
if pred_classes is not None and len(pred_classes) > 0:
# Reconstruct instances for visualization
from detectron2.structures import Instances
h, w = image_array.shape[:2]
instances = Instances((h, w))
instances.pred_boxes = torch.tensor(pred_boxes)
instances.scores = torch.tensor(scores)
instances.pred_classes = torch.tensor(pred_classes)
# Visualize predictions
result_image = visualize_predictions(image_array, instances, class_names)
# Display results
with col2:
st.subheader("Detected Objects")
st.image(result_image, use_column_width=True)
# Display detection details
st.subheader("Detection Results")
# Create a table for results
results_data = []
for i, (cls, box, score) in enumerate(zip(pred_classes, pred_boxes, scores)):
class_name = class_names[cls] if class_names is not None else f"Class {cls}"
results_data.append({
"Object": class_name,
"Confidence": f"{score:.2f}",
"Bounding Box": f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]"
})
# Display results as table
if results_data:
st.table(results_data)
else:
st.info("No objects detected.")
else:
with col2:
st.info("No objects detected in the image.")
except Exception as e:
st.error(f"Error processing image: {e}")
st.error(traceback.format_exc())
else:
# Show a sample image option
if st.button("Use a sample image"):
# You could include a sample image in your repository
# and load it here to demonstrate the functionality
st.info("Sample image option selected - this would load a demo image if implemented")
# Footer
st.markdown("---")
st.markdown("Built with Streamlit and Facebook AI Research's Detectron2")
if __name__ == "__main__":
main()