sandbox338 commited on
Commit
e92a879
Β·
verified Β·
1 Parent(s): 52b9e23

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +90 -83
src/streamlit_app.py CHANGED
@@ -1,136 +1,143 @@
1
  import streamlit as st
2
  import numpy as np
 
3
  import cv2
4
  import os
5
  import sys
6
- from PIL import Image
7
- import traceback
8
 
9
- # Configure the app
10
  st.set_page_config(
11
- page_title="Object Detection App",
12
  page_icon="πŸ”",
13
  layout="wide"
14
  )
15
 
16
- # Display environment info if needed for debugging
17
- if "debug" in st.experimental_get_query_params():
18
- st.write("Python version:", sys.version)
19
- st.write("Environment variables:", dict(os.environ))
20
- st.write("Current working directory:", os.getcwd())
21
- st.write("Directory contents:", os.listdir())
22
-
23
- # Create a sidebar
24
- st.sidebar.title("Object Detection App")
25
- st.sidebar.markdown("""
26
- This app uses Detectron2 to detect objects in images.
27
  """)
28
 
29
- # Display loading message
30
- with st.spinner("Loading dependencies..."):
31
- try:
32
- import torch
33
- from detectron2.engine import DefaultPredictor
34
- from detectron2.config import get_cfg
35
- from detectron2 import model_zoo
36
- from detectron2.utils.visualizer import Visualizer
37
- from detectron2.data import MetadataCatalog
38
-
39
- st.sidebar.success("βœ… Dependencies loaded successfully!")
40
- except Exception as e:
41
- st.error(f"Failed to load dependencies: {e}")
42
- st.error(traceback.format_exc())
43
- st.stop()
 
 
 
 
 
 
 
 
44
 
45
  # Load the model
46
  @st.cache_resource
47
- def load_model():
 
48
  try:
49
- # Configure the model
50
  cfg = get_cfg()
51
  cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
52
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
53
  cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
 
54
 
55
- # Use CPU for inference (more reliable in container environment)
56
- cfg.MODEL.DEVICE = "cpu"
57
-
58
- # Initialize predictor
59
  predictor = DefaultPredictor(cfg)
60
  return predictor, cfg
61
  except Exception as e:
62
- st.error(f"Error loading model: {e}")
63
- st.error(traceback.format_exc())
64
  return None, None
65
 
66
- # Main function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def main():
68
- st.title("πŸ” Object Detection")
69
- st.markdown("""
70
- Upload an image to detect objects using Detectron2's Faster R-CNN model.
71
- """)
72
 
73
- # Load model
74
- with st.spinner("Loading model..."):
75
- predictor, cfg = load_model()
76
-
77
  if predictor is None:
78
- st.error("Failed to load the model. Check the error messages.")
79
  return
80
 
81
- # File uploader
82
- uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
83
 
 
 
 
 
84
  if uploaded_file is not None:
85
  try:
86
- # Read and display the image
87
  image = Image.open(uploaded_file)
88
- st.image(image, caption="Uploaded Image", use_column_width=True)
89
-
90
- # Convert to numpy array
91
- image_array = np.array(image.convert("RGB"))
92
 
93
- # Perform inference
94
  with st.spinner("Detecting objects..."):
95
- outputs = predictor(image_array)
96
-
97
- # Get instances
98
- instances = outputs["instances"].to("cpu")
99
 
100
- # Create visualizer
101
- v = Visualizer(image_array,
102
- metadata=MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val"),
103
- scale=1.2)
104
-
105
- # Draw predictions
106
- result = v.draw_instance_predictions(instances)
107
- result_image = result.get_image()
108
-
109
- # Display result
110
- st.image(result_image, caption="Detection Result", use_column_width=True)
111
-
112
- # Show detection information
113
  if len(instances) > 0:
 
 
 
 
 
 
 
114
  st.subheader(f"Detected {len(instances)} objects")
115
 
116
- # Get class names
117
- metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
118
- class_names = metadata.thing_classes
119
 
120
- # Show detections
121
  for i in range(len(instances)):
122
  score = instances.scores[i].item()
123
  class_id = instances.pred_classes[i].item()
124
- class_name = class_names[class_id]
125
- box = instances.pred_boxes[i].tensor.numpy()[0]
126
 
127
- st.write(f"**{class_name}**: {score:.2f} confidence")
 
 
 
 
 
128
  else:
129
  st.info("No objects detected in this image.")
130
 
131
  except Exception as e:
132
- st.error(f"Error processing image: {e}")
133
- st.error(traceback.format_exc())
134
 
135
  if __name__ == "__main__":
136
  main()
 
1
  import streamlit as st
2
  import numpy as np
3
+ from PIL import Image
4
  import cv2
5
  import os
6
  import sys
 
 
7
 
8
+ # Set page configuration
9
  st.set_page_config(
10
+ page_title="Object Detection",
11
  page_icon="πŸ”",
12
  layout="wide"
13
  )
14
 
15
+ # Display app header
16
+ st.title("πŸ” Object Detection with Detectron2")
17
+ st.markdown("""
18
+ Upload an image to detect objects using Facebook AI Research's Detectron2 model.
 
 
 
 
 
 
 
19
  """)
20
 
21
+ # Setup sidebar
22
+ with st.sidebar:
23
+ st.header("About")
24
+ st.markdown("This app uses Detectron2 to detect objects in images.")
25
+
26
+ # Show environment info for debugging
27
+ if st.checkbox("Show Environment Info", False):
28
+ st.write("Python version:", sys.version)
29
+ st.write("Working directory:", os.getcwd())
30
+ st.write("Directory contents:", os.listdir())
31
+
32
+ # Import Detectron2 with error handling
33
+ try:
34
+ import torch
35
+ from detectron2 import model_zoo
36
+ from detectron2.engine import DefaultPredictor
37
+ from detectron2.config import get_cfg
38
+ from detectron2.utils.visualizer import Visualizer
39
+ from detectron2.data import MetadataCatalog
40
+ except Exception as e:
41
+ st.error(f"Failed to import required libraries: {str(e)}")
42
+ st.error("Please check that all dependencies are correctly installed.")
43
+ st.stop()
44
 
45
  # Load the model
46
  @st.cache_resource
47
+ def load_detectron_model():
48
+ """Load the Detectron2 model with caching"""
49
  try:
50
+ # Set up configuration
51
  cfg = get_cfg()
52
  cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
53
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for object detection
54
  cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")
55
+ cfg.MODEL.DEVICE = "cpu" # Use CPU
56
 
57
+ # Create predictor
 
 
 
58
  predictor = DefaultPredictor(cfg)
59
  return predictor, cfg
60
  except Exception as e:
61
+ st.error(f"Error loading model: {str(e)}")
 
62
  return None, None
63
 
64
+ # Process the image
65
+ def process_image(image, predictor):
66
+ """Run object detection on the image"""
67
+ # Convert PIL Image to OpenCV format (RGB to BGR)
68
+ img = np.array(image.convert("RGB"))
69
+
70
+ # Run inference
71
+ outputs = predictor(img)
72
+
73
+ # Get the instances
74
+ instances = outputs["instances"].to("cpu")
75
+
76
+ return img, instances
77
+
78
+ # Visualize the results
79
+ def visualize_results(img, instances, metadata):
80
+ """Create visualization of detection results"""
81
+ v = Visualizer(img, metadata=metadata, scale=1.2)
82
+ out = v.draw_instance_predictions(instances)
83
+ return out.get_image()
84
+
85
+ # Main app logic
86
  def main():
87
+ # Load model with a spinner
88
+ with st.spinner("Loading model... (this may take a moment on first run)"):
89
+ predictor, cfg = load_detectron_model()
 
90
 
 
 
 
 
91
  if predictor is None:
92
+ st.error("Failed to load the detection model.")
93
  return
94
 
95
+ # Get metadata
96
+ metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if len(cfg.DATASETS.TRAIN) else "coco_2017_val")
97
 
98
+ # Create file uploader
99
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
100
+
101
+ # Process the uploaded image
102
  if uploaded_file is not None:
103
  try:
104
+ # Display the original image
105
  image = Image.open(uploaded_file)
106
+ st.image(image, caption="Uploaded Image", width=400)
 
 
 
107
 
108
+ # Process the image
109
  with st.spinner("Detecting objects..."):
110
+ img, instances = process_image(image, predictor)
 
 
 
111
 
112
+ # Check if any objects were detected
 
 
 
 
 
 
 
 
 
 
 
 
113
  if len(instances) > 0:
114
+ # Visualize the results
115
+ result_img = visualize_results(img, instances, metadata)
116
+
117
+ # Display the result
118
+ st.image(result_img, caption="Detection Result", width=800)
119
+
120
+ # Show detection details
121
  st.subheader(f"Detected {len(instances)} objects")
122
 
123
+ class_names = metadata.get("thing_classes", None)
 
 
124
 
125
+ # Display each detection
126
  for i in range(len(instances)):
127
  score = instances.scores[i].item()
128
  class_id = instances.pred_classes[i].item()
 
 
129
 
130
+ if class_names:
131
+ label = class_names[class_id]
132
+ else:
133
+ label = f"Class {class_id}"
134
+
135
+ st.write(f"**{label}** (Confidence: {score:.2f})")
136
  else:
137
  st.info("No objects detected in this image.")
138
 
139
  except Exception as e:
140
+ st.error(f"Error processing image: {str(e)}")
 
141
 
142
  if __name__ == "__main__":
143
  main()