sandbox338 commited on
Commit
eaf6fcc
·
verified ·
1 Parent(s): 73c1074

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +91 -150
src/streamlit_app.py CHANGED
@@ -1,195 +1,136 @@
1
  import streamlit as st
2
- import torch
3
  import numpy as np
4
  import cv2
5
  import os
6
- import tempfile
7
  from PIL import Image
8
  import traceback
9
- from detectron2.engine import DefaultPredictor
10
- from detectron2.config import get_cfg
11
- from detectron2 import model_zoo
12
- from detectron2.utils.visualizer import Visualizer
13
- from detectron2.data import MetadataCatalog
14
 
15
- # Setup page config
16
  st.set_page_config(
17
- page_title="Object Detection with Detectron2",
18
  page_icon="🔍",
19
  layout="wide"
20
  )
21
 
22
- # Print environment info for debugging
23
  if "debug" in st.experimental_get_query_params():
 
24
  st.write("Environment variables:", dict(os.environ))
25
  st.write("Current working directory:", os.getcwd())
26
  st.write("Directory contents:", os.listdir())
27
- st.write("Temp directory:", tempfile.gettempdir())
28
- st.write("PyTorch CUDA available:", torch.cuda.is_available())
29
 
30
- # Load the Detectron2 model
31
- @st.cache_resource
32
- def load_model():
33
- try:
34
- # Set the configuration for the model
35
- cfg = get_cfg()
36
- cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
37
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for detection confidence
38
- cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
39
-
40
- # Use GPU if available, otherwise use CPU
41
- if torch.cuda.is_available():
42
- st.sidebar.success("GPU is available! Using CUDA.")
43
- cfg.MODEL.DEVICE = "cuda"
44
- else:
45
- st.sidebar.info("GPU not available. Using CPU.")
46
- cfg.MODEL.DEVICE = "cpu"
47
-
48
- # Initialize the predictor
49
- predictor = DefaultPredictor(cfg)
50
- return predictor
51
- except Exception as e:
52
- st.error(f"Error loading model: {e}")
53
- st.error(traceback.format_exc())
54
- return None
55
 
56
- # Function for image prediction
57
- def predict_fn(image, predictor):
58
  try:
59
- # Convert the PIL image to a format the model can use
60
- image_array = np.array(image.convert("RGB"))
61
-
62
- # Make predictions
63
- outputs = predictor(image_array)
64
-
65
- # Get the predicted classes and bounding boxes
66
- instances = outputs["instances"].to("cpu")
67
- pred_classes = instances.pred_classes.numpy()
68
- pred_boxes = instances.pred_boxes.tensor.numpy()
69
- scores = instances.scores.numpy()
70
-
71
- # Get class names from metadata
72
- metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
73
- class_names = metadata.thing_classes
74
 
75
- return pred_classes, pred_boxes, scores, image_array, class_names
76
  except Exception as e:
77
- st.error(f"Error during prediction: {e}")
78
  st.error(traceback.format_exc())
79
- return None, None, None, np.array(image), None
80
 
81
- # Function to visualize predictions
82
- def visualize_predictions(image_array, instances, class_names):
 
83
  try:
84
- # Create a visualizer object
85
- metadata = MetadataCatalog.get("coco_2017_val")
86
- if class_names is not None:
87
- metadata.thing_classes = class_names
88
-
89
- v = Visualizer(image_array, metadata=metadata, scale=1.2)
90
 
91
- # Draw the predictions on the image
92
- vis_output = v.draw_instance_predictions(instances.to("cpu"))
93
 
94
- # Return the visualized image
95
- return vis_output.get_image()
 
96
  except Exception as e:
97
- st.error(f"Error during visualization: {e}")
98
  st.error(traceback.format_exc())
99
- return image_array
100
 
101
- # Main application
102
  def main():
103
- st.title("🔍 Object Detection with Detectron2")
104
  st.markdown("""
105
- Upload an image to detect objects using Facebook AI Research's Detectron2 model.
106
- This demo uses the Faster R-CNN model with ResNet-50-FPN backbone.
107
  """)
108
 
109
- # Create a sidebar
110
- st.sidebar.title("Settings")
111
- confidence_threshold = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.5, 0.05)
112
-
113
- # Load the model
114
- with st.spinner("Loading model... This might take a minute on first run."):
115
- predictor = load_model()
116
-
117
  if predictor is None:
118
- st.error("Failed to load model. Please check the error messages above.")
119
  return
120
 
121
- # Update model confidence threshold
122
- predictor.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
123
-
124
- # Upload image
125
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
126
-
127
- col1, col2 = st.columns(2)
128
 
129
  if uploaded_file is not None:
130
- # Display original image
131
- image = Image.open(uploaded_file)
132
- with col1:
133
- st.subheader("Original Image")
134
- st.image(image, use_column_width=True)
135
-
136
- # Process the image
137
- with st.spinner("Detecting objects..."):
138
- try:
139
- # Make predictions
140
- pred_classes, pred_boxes, scores, image_array, class_names = predict_fn(image, predictor)
141
 
142
- if pred_classes is not None and len(pred_classes) > 0:
143
- # Reconstruct instances for visualization
144
- from detectron2.structures import Instances
145
- h, w = image_array.shape[:2]
146
- instances = Instances((h, w))
147
- instances.pred_boxes = torch.tensor(pred_boxes)
148
- instances.scores = torch.tensor(scores)
149
- instances.pred_classes = torch.tensor(pred_classes)
150
-
151
- # Visualize predictions
152
- result_image = visualize_predictions(image_array, instances, class_names)
153
-
154
- # Display results
155
- with col2:
156
- st.subheader("Detected Objects")
157
- st.image(result_image, use_column_width=True)
158
-
159
- # Display detection details
160
- st.subheader("Detection Results")
161
 
162
- # Create a table for results
163
- results_data = []
164
- for i, (cls, box, score) in enumerate(zip(pred_classes, pred_boxes, scores)):
165
- class_name = class_names[cls] if class_names is not None else f"Class {cls}"
166
- results_data.append({
167
- "Object": class_name,
168
- "Confidence": f"{score:.2f}",
169
- "Bounding Box": f"[{int(box[0])}, {int(box[1])}, {int(box[2])}, {int(box[3])}]"
170
- })
171
 
172
- # Display results as table
173
- if results_data:
174
- st.table(results_data)
175
- else:
176
- st.info("No objects detected.")
 
 
 
177
  else:
178
- with col2:
179
- st.info("No objects detected in the image.")
180
- except Exception as e:
181
- st.error(f"Error processing image: {e}")
182
- st.error(traceback.format_exc())
183
- else:
184
- # Show a sample image option
185
- if st.button("Use a sample image"):
186
- # You could include a sample image in your repository
187
- # and load it here to demonstrate the functionality
188
- st.info("Sample image option selected - this would load a demo image if implemented")
189
-
190
- # Footer
191
- st.markdown("---")
192
- st.markdown("Built with Streamlit and Facebook AI Research's Detectron2")
193
 
194
  if __name__ == "__main__":
195
  main()
 
1
  import streamlit as st
 
2
  import numpy as np
3
  import cv2
4
  import os
5
+ import sys
6
  from PIL import Image
7
  import traceback
 
 
 
 
 
8
 
9
+ # Configure the app
10
  st.set_page_config(
11
+ page_title="Object Detection App",
12
  page_icon="🔍",
13
  layout="wide"
14
  )
15
 
16
+ # Display environment info if needed for debugging
17
  if "debug" in st.experimental_get_query_params():
18
+ st.write("Python version:", sys.version)
19
  st.write("Environment variables:", dict(os.environ))
20
  st.write("Current working directory:", os.getcwd())
21
  st.write("Directory contents:", os.listdir())
 
 
22
 
23
+ # Create a sidebar
24
+ st.sidebar.title("Object Detection App")
25
+ st.sidebar.markdown("""
26
+ This app uses Detectron2 to detect objects in images.
27
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Display loading message
30
+ with st.spinner("Loading dependencies..."):
31
  try:
32
+ import torch
33
+ from detectron2.engine import DefaultPredictor
34
+ from detectron2.config import get_cfg
35
+ from detectron2 import model_zoo
36
+ from detectron2.utils.visualizer import Visualizer
37
+ from detectron2.data import MetadataCatalog
 
 
 
 
 
 
 
 
 
38
 
39
+ st.sidebar.success("✅ Dependencies loaded successfully!")
40
  except Exception as e:
41
+ st.error(f"Failed to load dependencies: {e}")
42
  st.error(traceback.format_exc())
43
+ st.stop()
44
 
45
+ # Load the model
46
+ @st.cache_resource
47
+ def load_model():
48
  try:
49
+ # Configure the model
50
+ cfg = get_cfg()
51
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
52
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
53
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
 
54
 
55
+ # Use CPU for inference (more reliable in container environment)
56
+ cfg.MODEL.DEVICE = "cpu"
57
 
58
+ # Initialize predictor
59
+ predictor = DefaultPredictor(cfg)
60
+ return predictor, cfg
61
  except Exception as e:
62
+ st.error(f"Error loading model: {e}")
63
  st.error(traceback.format_exc())
64
+ return None, None
65
 
66
+ # Main function
67
  def main():
68
+ st.title("🔍 Object Detection")
69
  st.markdown("""
70
+ Upload an image to detect objects using Detectron2's Faster R-CNN model.
 
71
  """)
72
 
73
+ # Load model
74
+ with st.spinner("Loading model..."):
75
+ predictor, cfg = load_model()
76
+
 
 
 
 
77
  if predictor is None:
78
+ st.error("Failed to load the model. Check the error messages.")
79
  return
80
 
81
+ # File uploader
82
+ uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
 
 
 
 
 
83
 
84
  if uploaded_file is not None:
85
+ try:
86
+ # Read and display the image
87
+ image = Image.open(uploaded_file)
88
+ st.image(image, caption="Uploaded Image", use_column_width=True)
89
+
90
+ # Convert to numpy array
91
+ image_array = np.array(image.convert("RGB"))
92
+
93
+ # Perform inference
94
+ with st.spinner("Detecting objects..."):
95
+ outputs = predictor(image_array)
96
 
97
+ # Get instances
98
+ instances = outputs["instances"].to("cpu")
99
+
100
+ # Create visualizer
101
+ v = Visualizer(image_array,
102
+ metadata=MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val"),
103
+ scale=1.2)
104
+
105
+ # Draw predictions
106
+ result = v.draw_instance_predictions(instances)
107
+ result_image = result.get_image()
108
+
109
+ # Display result
110
+ st.image(result_image, caption="Detection Result", use_column_width=True)
111
+
112
+ # Show detection information
113
+ if len(instances) > 0:
114
+ st.subheader(f"Detected {len(instances)} objects")
 
115
 
116
+ # Get class names
117
+ metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
118
+ class_names = metadata.thing_classes
 
 
 
 
 
 
119
 
120
+ # Show detections
121
+ for i in range(len(instances)):
122
+ score = instances.scores[i].item()
123
+ class_id = instances.pred_classes[i].item()
124
+ class_name = class_names[class_id]
125
+ box = instances.pred_boxes[i].tensor.numpy()[0]
126
+
127
+ st.write(f"**{class_name}**: {score:.2f} confidence")
128
  else:
129
+ st.info("No objects detected in this image.")
130
+
131
+ except Exception as e:
132
+ st.error(f"Error processing image: {e}")
133
+ st.error(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
  main()