sandbox338 commited on
Commit
3d2c66c
Β·
verified Β·
1 Parent(s): 7d0b950

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +47 -137
src/streamlit_app.py CHANGED
@@ -1,143 +1,53 @@
1
  import streamlit as st
2
- import numpy as np
3
  from PIL import Image
4
- import cv2
5
- import os
6
- import sys
7
-
8
- # Set page configuration
9
- st.set_page_config(
10
- page_title="Object Detection",
11
- page_icon="πŸ”",
12
- layout="wide"
13
- )
14
-
15
- # Display app header
16
- st.title("πŸ” Object Detection with Detectron2")
17
- st.markdown("""
18
- Upload an image to detect objects using Facebook AI Research's Detectron2 model.
19
- """)
20
-
21
- # Setup sidebar
22
- with st.sidebar:
23
- st.header("About")
24
- st.markdown("This app uses Detectron2 to detect objects in images.")
25
-
26
- # Show environment info for debugging
27
- if st.checkbox("Show Environment Info", False):
28
- st.write("Python version:", sys.version)
29
- st.write("Working directory:", os.getcwd())
30
- st.write("Directory contents:", os.listdir())
31
-
32
- # Import Detectron2 with error handling
33
  try:
34
- import torch
35
- from detectron2 import model_zoo
36
- from detectron2.engine import DefaultPredictor
37
- from detectron2.config import get_cfg
38
- from detectron2.utils.visualizer import Visualizer
39
- from detectron2.data import MetadataCatalog
40
- except Exception as e:
41
- st.error(f"Failed to import required libraries: {str(e)}")
42
- st.error("Please check that all dependencies are correctly installed.")
43
- st.stop()
44
-
45
- # Load the model
46
- @st.cache_resource
47
- def load_detectron_model():
48
- """Load the Detectron2 model with caching"""
49
- try:
50
- # Set up configuration
51
- cfg = get_cfg()
52
- cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
53
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for object detection
54
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
55
- cfg.MODEL.DEVICE = "cpu" # Use CPU
56
-
57
- # Create predictor
58
- predictor = DefaultPredictor(cfg)
59
- return predictor, cfg
60
- except Exception as e:
61
- st.error(f"Error loading model: {str(e)}")
62
- return None, None
63
 
64
- # Process the image
65
- def process_image(image, predictor):
66
- """Run object detection on the image"""
67
- # Convert PIL Image to OpenCV format (RGB to BGR)
68
- img = np.array(image.convert("RGB"))
69
-
70
- # Run inference
71
- outputs = predictor(img)
72
-
73
- # Get the instances
74
- instances = outputs["instances"].to("cpu")
75
-
76
- return img, instances
77
 
78
- # Visualize the results
79
- def visualize_results(img, instances, metadata):
80
- """Create visualization of detection results"""
81
- v = Visualizer(img, metadata=metadata, scale=1.2)
82
- out = v.draw_instance_predictions(instances)
83
- return out.get_image()
84
 
85
- # Main app logic
86
- def main():
87
- # Load model with a spinner
88
- with st.spinner("Loading model... (this may take a moment on first run)"):
89
- predictor, cfg = load_detectron_model()
90
-
91
- if predictor is None:
92
- st.error("Failed to load the detection model.")
93
- return
94
-
95
- # Get metadata
96
- metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
97
-
98
- # Create file uploader
99
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
100
-
101
- # Process the uploaded image
102
- if uploaded_file is not None:
103
- try:
104
- # Display the original image
105
- image = Image.open(uploaded_file)
106
- st.image(image, caption="Uploaded Image", width=400)
107
-
108
- # Process the image
109
- with st.spinner("Detecting objects..."):
110
- img, instances = process_image(image, predictor)
111
-
112
- # Check if any objects were detected
113
- if len(instances) > 0:
114
- # Visualize the results
115
- result_img = visualize_results(img, instances, metadata)
116
-
117
- # Display the result
118
- st.image(result_img, caption="Detection Result", width=800)
119
-
120
- # Show detection details
121
- st.subheader(f"Detected {len(instances)} objects")
122
-
123
- class_names = metadata.get("thing_classes", None)
124
-
125
- # Display each detection
126
- for i in range(len(instances)):
127
- score = instances.scores[i].item()
128
- class_id = instances.pred_classes[i].item()
129
-
130
- if class_names:
131
- label = class_names[class_id]
132
- else:
133
- label = f"Class {class_id}"
134
-
135
- st.write(f"**{label}** (Confidence: {score:.2f})")
136
- else:
137
- st.info("No objects detected in this image.")
138
-
139
- except Exception as e:
140
- st.error(f"Error processing image: {str(e)}")
141
-
142
- if __name__ == "__main__":
143
- main()
 
1
  import streamlit as st
 
2
  from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ import asyncio
6
+ from detectron2.engine import DefaultPredictor
7
+ from detectron2.config import get_cfg
8
+ from detectron2 import model_zoo
9
+ from detectron2.utils.visualizer import Visualizer
10
+ from detectron2.data import MetadataCatalog
11
+
12
+ # Fix event loop issue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
+ asyncio.get_running_loop()
15
+ except RuntimeError:
16
+ asyncio.set_event_loop(asyncio.new_event_loop())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Title and uploader
19
+ st.title("Detectron2 Object Detection")
20
+ st.write("Upload an image to perform object detection")
 
 
 
 
 
 
 
 
 
 
21
 
22
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
23
 
24
+ @st.cache_resource
25
+ def load_model():
26
+ cfg = get_cfg()
27
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
28
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
29
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
30
+ predictor = DefaultPredictor(cfg)
31
+ return predictor
32
+
33
+ def predict_fn(predictor, image):
34
+ image_array = np.array(image)[:, :, :3]
35
+ outputs = predictor(image_array)
36
+ return outputs["instances"], image_array
37
+
38
+ def visualize_predictions(image, instances):
39
+ v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("coco_2017_val"), scale=1.2)
40
+ v = v.draw_instance_predictions(instances)
41
+ result = v.get_image()
42
+ return result[:, :, ::-1]
43
+
44
+ if uploaded_file is not None:
45
+ image = Image.open(uploaded_file)
46
+ st.image(image, caption="Uploaded Image", use_column_width=True)
47
+ st.write("Processing...")
48
+
49
+ predictor = load_model()
50
+ instances, image_array = predict_fn(predictor, image)
51
+ result_image = visualize_predictions(image_array, instances)
52
+
53
+ st.image(result_image, caption="Detected Objects", use_column_width=True)