sandbox338 commited on
Commit
af9f2a8
·
verified ·
1 Parent(s): ba14d78

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +57 -43
src/streamlit_app.py CHANGED
@@ -6,65 +6,79 @@ from PIL import Image
6
  from detectron2.engine import DefaultPredictor
7
  from detectron2.config import get_cfg
8
  from detectron2 import model_zoo
 
 
9
 
10
- # Function to load the model
11
- def load_model(model_path):
12
- # Initialize the configuration
 
13
  cfg = get_cfg()
14
  cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
15
- cfg.MODEL.WEIGHTS = model_path # Path to your trained model
16
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Threshold for prediction
 
 
 
17
  predictor = DefaultPredictor(cfg)
18
  return predictor
19
 
20
- # Function to make predictions
21
- def predict_fn(image, model_path):
22
- # Convert image to numpy array
23
- image = np.array(image)
24
-
25
- # Load model
26
- predictor = load_model(model_path)
27
-
 
28
  # Make predictions
29
  outputs = predictor(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Extract predicted bounding boxes and class labels
32
- instances = outputs["instances"]
33
- pred_classes = instances.pred_classes.to("cpu").numpy()
34
- pred_boxes = instances.pred_boxes.tensor.to("cpu").numpy()
35
-
36
- return pred_classes, pred_boxes
 
 
 
37
 
38
- # Streamlit UI elements
39
- st.title("Wildlife Detection with Your Model")
40
 
41
- # Upload an image for testing
42
- uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
43
 
44
  if uploaded_image is not None:
45
- # Open the uploaded image using PIL
46
  image = Image.open(uploaded_image)
47
 
48
- # Display the uploaded image
49
- st.image(image, caption="Uploaded Image", use_column_width=True)
50
 
51
- # Set model path
52
- model_path = "sandbox338/wild-life-model" # Your model path on Hugging Face
53
 
54
- # Make predictions
55
- pred_classes, pred_boxes = predict_fn(image, model_path)
56
 
57
- # Display predictions
58
- st.write(f"Predicted classes: {pred_classes}")
59
- st.write(f"Predicted bounding boxes: {pred_boxes}")
60
 
61
- # Optionally, visualize predictions on the image
62
- img_array = np.array(image)
63
- for i, box in enumerate(pred_boxes):
64
- start_point = (int(box[0]), int(box[1]))
65
- end_point = (int(box[2]), int(box[3]))
66
- img_array = cv2.rectangle(img_array, start_point, end_point, (0, 255, 0), 2)
67
-
68
- st.image(img_array, caption="Predicted Image with Bounding Boxes", use_column_width=True)
69
- else:
70
- st.write("Please upload an image to detect wildlife!")
 
6
  from detectron2.engine import DefaultPredictor
7
  from detectron2.config import get_cfg
8
  from detectron2 import model_zoo
9
+ from detectron2.utils.visualizer import Visualizer
10
+ from detectron2.data import MetadataCatalog
11
 
12
+ # Load the Detectron2 model
13
+ @st.cache_resource
14
+ def load_model():
15
+ # Set the configuration for the model (COCO pre-trained model for detection)
16
  cfg = get_cfg()
17
  cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"))
18
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set the threshold for prediction
19
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") # Model weights from the model zoo
20
+ cfg.MODEL.DEVICE = "cpu" # Change to "cuda" if using GPU
21
+
22
+ # Initialize the predictor
23
  predictor = DefaultPredictor(cfg)
24
  return predictor
25
 
26
+ # Load the model
27
+ predictor = load_model()
28
+
29
+ # Function for image prediction
30
+ def predict_fn(image):
31
+ # Convert the PIL image to a format the model can use
32
+ # Convert to RGB (Detectron2 uses BGR format internally)
33
+ image = np.array(image.convert("RGB"))
34
+
35
  # Make predictions
36
  outputs = predictor(image)
37
+
38
+ # Get the predicted classes and bounding boxes
39
+ instances = outputs["instances"].to("cpu")
40
+ pred_classes = instances.pred_classes.numpy()
41
+ pred_boxes = instances.pred_boxes.tensor.numpy()
42
+
43
+ return pred_classes, pred_boxes, image
44
+
45
+ # Function to display image with bounding boxes
46
+ def visualize_predictions(image, pred_classes, pred_boxes):
47
+ # Create a visualizer object
48
+ v = Visualizer(image[:, :, ::-1], MetadataCatalog.get("coco_2017_val"), scale=1.2)
49
+ v = v.draw_instance_predictions(pred_classes)
50
 
51
+ # Draw the bounding boxes on the image
52
+ for box in pred_boxes:
53
+ start_point = tuple(map(int, box[:2]))
54
+ end_point = tuple(map(int, box[2:]))
55
+ color = (0, 255, 0) # Green color
56
+ thickness = 2
57
+ image = cv2.rectangle(image, start_point, end_point, color, thickness)
58
+
59
+ return image
60
 
61
+ # Streamlit UI
62
+ st.title("Object Detection with Detectron2")
63
 
64
+ # Upload image
65
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
66
 
67
  if uploaded_image is not None:
68
+ # Open the uploaded image
69
  image = Image.open(uploaded_image)
70
 
71
+ # Make predictions
72
+ pred_classes, pred_boxes, image_array = predict_fn(image)
73
 
74
+ # Visualize predictions on the image
75
+ image_with_boxes = visualize_predictions(image_array, pred_classes, pred_boxes)
76
 
77
+ # Convert the image back to RGB format for display in Streamlit
78
+ image_with_boxes = Image.fromarray(image_with_boxes[:, :, ::-1]) # Convert BGR to RGB
79
 
80
+ # Display the processed image with bounding boxes
81
+ st.image(image_with_boxes, caption="Processed Image", use_column_width=True)
 
82
 
83
+ # Display the classes detected
84
+ st.write("Predicted Classes:", pred_classes)