sandbox338 commited on
Commit
d669457
·
verified ·
1 Parent(s): b0c3b32

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +171 -60
src/streamlit_app.py CHANGED
@@ -2,83 +2,194 @@ import streamlit as st
2
  import torch
3
  import numpy as np
4
  import cv2
 
 
5
  from PIL import Image
 
6
  from detectron2.engine import DefaultPredictor
7
  from detectron2.config import get_cfg
8
  from detectron2 import model_zoo
9
  from detectron2.utils.visualizer import Visualizer
10
  from detectron2.data import MetadataCatalog
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Load the Detectron2 model
13
  @st.cache_resource
14
  def load_model():
15
- # Set the configuration for the model (COCO pre-trained model for detection)
16
- cfg = get_cfg()
17
- cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
18
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set the threshold for prediction
19
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # Model weights from the model zoo
20
- cfg.MODEL.DEVICE = "cpu" # Change to "cuda" if using GPU
21
-
22
- # Initialize the predictor
23
- predictor = DefaultPredictor(cfg)
24
- return predictor
25
-
26
- # Load the model
27
- predictor = load_model()
 
 
 
 
 
 
 
 
 
28
 
29
  # Function for image prediction
30
- def predict_fn(image):
31
- # Convert the PIL image to a format the model can use
32
- # Convert to RGB (Detectron2 uses BGR format internally)
33
- image = np.array(image.convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Make predictions
36
- outputs = predictor(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Get the predicted classes and bounding boxes
39
- instances = outputs["instances"].to("cpu")
40
- pred_classes = instances.pred_classes.numpy()
41
- pred_boxes = instances.pred_boxes.tensor.numpy()
42
-
43
- return pred_classes, pred_boxes, image
44
-
45
- # Function to display image with bounding boxes
46
- def visualize_predictions(image, pred_classes, pred_boxes):
47
- # Create a visualizer object
48
- v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("coco_2017_val"), scale=1.2)
49
- v = v.draw_instance_predictions(pred_classes)
50
 
51
- # Draw the bounding boxes on the image
52
- for box in pred_boxes:
53
- start_point = tuple(map(int, box[:2]))
54
- end_point = tuple(map(int, box[2:]))
55
- color = (0, 255, 0) # Green color
56
- thickness = 2
57
- image = cv2.rectangle(image, start_point, end_point, color, thickness)
58
-
59
- return image
60
-
61
- # Streamlit UI
62
- st.title("Object Detection with Detectron2")
63
-
64
- # Upload image
65
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
66
-
67
- if uploaded_image is not None:
68
- # Open the uploaded image
69
- image = Image.open(uploaded_image)
70
 
71
- # Make predictions
72
- pred_classes, pred_boxes, image_array = predict_fn(image)
73
 
74
- # Visualize predictions on the image
75
- image_with_boxes = visualize_predictions(image_array, pred_classes, pred_boxes)
76
 
77
- # Convert the image back to RGB format for display in Streamlit
78
- image_with_boxes = Image.fromarray(image_with_boxes[:, :, ::-1]) # Convert BGR to RGB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Display the processed image with bounding boxes
81
- st.image(image_with_boxes, caption="Processed Image", use_column_width=True)
 
82
 
83
- # Display the classes detected
84
- st.write("Predicted Classes:", pred_classes)
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ import os
6
+ import tempfile
7
  from PIL import Image
8
+ import traceback
9
  from detectron2.engine import DefaultPredictor
10
  from detectron2.config import get_cfg
11
  from detectron2 import model_zoo
12
  from detectron2.utils.visualizer import Visualizer
13
  from detectron2.data import MetadataCatalog
14
 
15
+ # Setup page config
16
+ st.set_page_config(
17
+ page_title="Object Detection with Detectron2",
18
+ page_icon="🔍",
19
+ layout="wide"
20
+ )
21
+
22
+ # Print environment info for debugging
23
+ if "debug" in st.experimental_get_query_params():
24
+ st.write("Environment variables:", dict(os.environ))
25
+ st.write("Current working directory:", os.getcwd())
26
+ st.write("Directory contents:", os.listdir())
27
+ st.write("Temp directory:", tempfile.gettempdir())
28
+ st.write("PyTorch CUDA available:", torch.cuda.is_available())
29
+
30
  # Load the Detectron2 model
31
  @st.cache_resource
32
  def load_model():
33
+ try:
34
+ # Set the configuration for the model
35
+ cfg = get_cfg()
36
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
37
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for detection confidence
38
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
39
+
40
+ # Use GPU if available, otherwise use CPU
41
+ if torch.cuda.is_available():
42
+ st.sidebar.success("GPU is available! Using CUDA.")
43
+ cfg.MODEL.DEVICE = "cuda"
44
+ else:
45
+ st.sidebar.info("GPU not available. Using CPU.")
46
+ cfg.MODEL.DEVICE = "cpu"
47
+
48
+ # Initialize the predictor
49
+ predictor = DefaultPredictor(cfg)
50
+ return predictor
51
+ except Exception as e:
52
+ st.error(f"Error loading model: {e}")
53
+ st.error(traceback.format_exc())
54
+ return None
55
 
56
  # Function for image prediction
57
+ def predict_fn(image, predictor):
58
+ try:
59
+ # Convert the PIL image to a format the model can use
60
+ image_array = np.array(image.convert("RGB"))
61
+
62
+ # Make predictions
63
+ outputs = predictor(image_array)
64
+
65
+ # Get the predicted classes and bounding boxes
66
+ instances = outputs["instances"].to("cpu")
67
+ pred_classes = instances.pred_classes.numpy()
68
+ pred_boxes = instances.pred_boxes.tensor.numpy()
69
+ scores = instances.scores.numpy()
70
+
71
+ # Get class names from metadata
72
+ metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
73
+ class_names = metadata.thing_classes
74
+
75
+ return pred_classes, pred_boxes, scores, image_array, class_names
76
+ except Exception as e:
77
+ st.error(f"Error during prediction: {e}")
78
+ st.error(traceback.format_exc())
79
+ return None, None, None, np.array(image), None
80
 
81
+ # Function to visualize predictions
82
+ def visualize_predictions(image_array, instances, class_names):
83
+ try:
84
+ # Create a visualizer object
85
+ metadata = MetadataCatalog.get("coco_2017_val")
86
+ if class_names is not None:
87
+ metadata.thing_classes = class_names
88
+
89
+ v = Visualizer(image_array, metadata=metadata, scale=1.2)
90
+
91
+ # Draw the predictions on the image
92
+ vis_output = v.draw_instance_predictions(instances.to("cpu"))
93
+
94
+ # Return the visualized image
95
+ return vis_output.get_image()
96
+ except Exception as e:
97
+ st.error(f"Error during visualization: {e}")
98
+ st.error(traceback.format_exc())
99
+ return image_array
100
 
101
+ # Main application
102
+ def main():
103
+ st.title("🔍 Object Detection with Detectron2")
104
+ st.markdown("""
105
+ Upload an image to detect objects using Facebook AI Research's Detectron2 model.
106
+ This demo uses the Faster R-CNN model with ResNet-50-FPN backbone.
107
+ """)
 
 
 
 
 
108
 
109
+ # Create a sidebar
110
+ st.sidebar.title("Settings")
111
+ confidence_threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
112
+
113
+ # Load the model
114
+ with st.spinner("Loading model... This might take a minute on first run."):
115
+ predictor = load_model()
116
+
117
+ if predictor is None:
118
+ st.error("Failed to load model. Please check the error messages above.")
119
+ return
120
+
121
+ # Update model confidence threshold
122
+ predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
 
 
 
 
 
123
 
124
+ # Upload image
125
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
126
 
127
+ col1, col2 = st.columns(2)
 
128
 
129
+ if uploaded_file is not None:
130
+ # Display original image
131
+ image = Image.open(uploaded_file)
132
+ with col1:
133
+ st.subheader("Original Image")
134
+ st.image(image, use_column_width=True)
135
+
136
+ # Process the image
137
+ with st.spinner("Detecting objects..."):
138
+ try:
139
+ # Make predictions
140
+ pred_classes, pred_boxes, scores, image_array, class_names = predict_fn(image, predictor)
141
+
142
+ if pred_classes is not None and len(pred_classes) > 0:
143
+ # Reconstruct instances for visualization
144
+ from detectron2.structures import Instances
145
+ h, w = image_array.shape[:2]
146
+ instances = Instances((h, w))
147
+ instances.pred_boxes = torch.tensor(pred_boxes)
148
+ instances.scores = torch.tensor(scores)
149
+ instances.pred_classes = torch.tensor(pred_classes)
150
+
151
+ # Visualize predictions
152
+ result_image = visualize_predictions(image_array, instances, class_names)
153
+
154
+ # Display results
155
+ with col2:
156
+ st.subheader("Detected Objects")
157
+ st.image(result_image, use_column_width=True)
158
+
159
+ # Display detection details
160
+ st.subheader("Detection Results")
161
+
162
+ # Create a table for results
163
+ results_data = []
164
+ for i, (cls, box, score) in enumerate(zip(pred_classes, pred_boxes, scores)):
165
+ class_name = class_names[cls] if class_names is not None else f"Class {cls}"
166
+ results_data.append({
167
+ "Object": class_name,
168
+ "Confidence": f"{score:.2f}",
169
+ "Bounding Box": f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]"
170
+ })
171
+
172
+ # Display results as table
173
+ if results_data:
174
+ st.table(results_data)
175
+ else:
176
+ st.info("No objects detected.")
177
+ else:
178
+ with col2:
179
+ st.info("No objects detected in the image.")
180
+ except Exception as e:
181
+ st.error(f"Error processing image: {e}")
182
+ st.error(traceback.format_exc())
183
+ else:
184
+ # Show a sample image option
185
+ if st.button("Use a sample image"):
186
+ # You could include a sample image in your repository
187
+ # and load it here to demonstrate the functionality
188
+ st.info("Sample image option selected - this would load a demo image if implemented")
189
 
190
+ # Footer
191
+ st.markdown("---")
192
+ st.markdown("Built with Streamlit and Facebook AI Research's Detectron2")
193
 
194
+ if __name__ == "__main__":
195
+ main()