DhominickJ commited on
Commit
e6a1f8f
·
1 Parent(s): b8c6c95

Initial implementation of MosqScope

Browse files
Files changed (1) hide show
  1. app.py +79 -30
app.py CHANGED
@@ -5,8 +5,12 @@ import av
5
  import numpy as np
6
  import cv2
7
  import streamlit as st
8
- from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode
9
  from huggingface_hub import hf_hub_download
 
 
 
 
10
 
11
  # Define dataset classes
12
  classes = ['dengue-regions', 'wet_surface']
@@ -15,46 +19,91 @@ num_classes = len(classes) + 1 # Including background
15
  # Load the SSD Model
16
  @st.cache_resource
17
  def load_model():
18
- if 'model' not in st.session_state:
19
- model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
20
- model = ssd300_vgg16(pretrained=True) # Multi-box Algorithm
21
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
22
- # model.load_state_dict(torch.load("./mosquito_model.pth", map_location=torch.device('cpu')))
23
- model.eval()
24
- st.session_state.model = model
25
 
26
- model = load_model()
 
 
 
 
27
 
28
  # Define Video Processor for WebRTC
29
  class SSDVideoProcessor(VideoProcessorBase):
30
  def __init__(self):
31
  self.model = model
32
  self.transform = transforms.Compose([
 
33
  transforms.Resize((300, 300)),
34
- transforms.ToTensor()
 
35
  ])
36
 
37
  def recv(self, frame):
 
 
 
 
38
  img = frame.to_ndarray(format="bgr24")
39
- image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40
- image = self.transform(image).unsqueeze(0)
41
-
42
- with torch.no_grad():
43
- output = self.model(image)[0]
44
-
45
- # Draw detections
46
- for box, label in zip(output["boxes"].cpu().numpy(), output["labels"].cpu().numpy()):
47
- x_min, y_min, x_max, y_max = map(int, box)
48
- cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
49
- cv2.putText(img, classes[label - 1], (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
50
-
51
- return av.VideoFrame.from_ndarray(img, format="bgr24")
52
-
53
- # Start WebRTC Streaming
54
- st.title("Real-Time SSD Object Detection with WebRTC")
55
- webrtc_streamer(
56
- key="ssd-detection",
57
- mode=WebRtcMode.SENDRECV,
58
- video_processor_factory=SSDVideoProcessor,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  media_stream_constraints={"video": True, "audio": False},
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import cv2
7
  import streamlit as st
8
+ from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, WebRtcMode, RTCConfiguration
9
  from huggingface_hub import hf_hub_download
10
+ import logging
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.DEBUG)
14
 
15
  # Define dataset classes
16
  classes = ['dengue-regions', 'wet_surface']
 
19
  # Load the SSD Model
20
  @st.cache_resource
21
  def load_model():
22
+ model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
23
+ model = ssd300_vgg16(pretrained=False) # Don't load ImageNet weights
24
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
25
+ model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
26
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
27
+ model.eval()
28
+ return model
29
 
30
+ try:
31
+ model = load_model()
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {e}")
34
+ model = None
35
 
36
  # Define Video Processor for WebRTC
37
  class SSDVideoProcessor(VideoProcessorBase):
38
  def __init__(self):
39
  self.model = model
40
  self.transform = transforms.Compose([
41
+ transforms.ToPILImage(),
42
  transforms.Resize((300, 300)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
  ])
46
 
47
  def recv(self, frame):
48
+ if self.model is None:
49
+ # Just return the frame if model isn't loaded
50
+ return frame
51
+
52
  img = frame.to_ndarray(format="bgr24")
53
+ # Make a copy for drawing
54
+ display_img = img.copy()
55
+
56
+ # Transform for model
57
+ image_tensor = self.transform(img).unsqueeze(0)
58
+
59
+ try:
60
+ with torch.no_grad():
61
+ output = self.model(image_tensor)[0]
62
+
63
+ # Scale coordinates to original image dimensions
64
+ h, w = img.shape[:2]
65
+ scale_x, scale_y = w / 300, h / 300
66
+
67
+ # Draw detections
68
+ for box, label, score in zip(output["boxes"], output["labels"], output["scores"]):
69
+ if score > 0.5: # Only show confident detections
70
+ x_min, y_min, x_max, y_max = map(int, box.cpu().numpy())
71
+ # Scale coordinates back to original image
72
+ x_min, x_max = int(x_min * scale_x), int(x_max * scale_x)
73
+ y_min, y_max = int(y_min * scale_y), int(y_max * scale_y)
74
+
75
+ cv2.rectangle(display_img, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
76
+ label_name = classes[label.item() - 1]
77
+ cv2.putText(display_img, f"{label_name} {score:.2f}",
78
+ (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX,
79
+ 0.5, (0, 0, 255), 2)
80
+ except Exception as e:
81
+ logging.error(f"Error in inference: {e}")
82
+ # Add error message to frame
83
+ cv2.putText(display_img, f"Error: {str(e)}", (10, 30),
84
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
85
+
86
+ return av.VideoFrame.from_ndarray(display_img, format="bgr24")
87
+
88
+ # Streamlit UI
89
+ st.title("Mosquito Detection with WebRTC")
90
+ st.write("This app uses a SSD model to detect mosquito breeding sites in real-time.")
91
+
92
+ # Configure WebRTC with proper STUN/TURN servers
93
+ rtc_config = RTCConfiguration(
94
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
95
  media_stream_constraints={"video": True, "audio": False},
96
  )
97
+
98
+ # Start WebRTC Streaming with proper error handling
99
+ try:
100
+ webrtc_ctx = webrtc_streamer(
101
+ key="ssd-detection",
102
+ mode=WebRtcMode.SENDRECV,
103
+ rtc_configuration=rtc_config,
104
+ video_processor_factory=SSDVideoProcessor,
105
+ async_processing=True,
106
+ )
107
+ except Exception as e:
108
+ st.error(f"WebRTC Error: {e}")
109
+ st.info("Please try refreshing the page or using a different browser.")