DhominickJ commited on
Commit
64d2956
·
1 Parent(s): 98757fa

Initial implementation of MosqScope

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -7,27 +7,19 @@ import streamlit as st
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
10
- # Configure Streamlit UI
11
- st.title("Mosquito Detection from Camera Capture")
12
- st.write("Take a picture to detect mosquito breeding sites using SSD.")
13
-
14
  # Define dataset classes
15
  classes = ['dengue-regions', 'wet_surface']
16
-
17
- # Load the SSD Model
18
- @st.cache_resource
19
- def load_model():
20
- try:
21
- model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
22
- model = ssd300_vgg16(pretrained=False)
23
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
24
- model.eval()
25
- return model
26
- except Exception as e:
27
- st.error(f"Error loading model: {str(e)}")
28
- return None
29
-
30
- model = load_model()
31
 
32
  # Capture Image from Camera
33
  captured_image = st.camera_input("Take a picture")
@@ -43,28 +35,21 @@ if captured_image is not None:
43
  ])
44
  image_tensor = transform(image).unsqueeze(0)
45
 
46
- if model is not None:
47
- with torch.no_grad():
48
- detections = model(image_tensor)
49
-
50
- boxes = detections[0]['boxes'].cpu().numpy()
51
- scores = detections[0]['scores'].cpu().numpy()
52
- labels = detections[0]['labels'].cpu().numpy()
53
-
54
- # Convert image to OpenCV format
55
- image_np = np.array(image)
56
- image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
57
-
58
- # Draw detections
59
- for box, label, score in zip(boxes, labels, scores):
60
- if score > 0.5: # Confidence threshold
61
- x_min, y_min, x_max, y_max = map(int, box)
62
- cv2.rectangle(image_np, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
63
- label_name = classes[label - 1]
64
- cv2.putText(image_np, f"{label_name} {score:.2f}", (x_min, y_min - 5),
65
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
66
-
67
- # Convert image back to RGB for Streamlit display
68
- st.image(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB), caption="Detected Objects", use_column_width=True)
69
- else:
70
- st.warning("Model not loaded. Unable to process image.")
 
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
 
 
 
 
10
  # Define dataset classes
11
  classes = ['dengue-regions', 'wet_surface']
12
+ num_classes = len(classes) + 1 # Including background
13
+
14
+ # Load Model
15
+ st.title("Real-Time SSD Object Detection")
16
+ if 'model' not in st.session_state:
17
+ model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
18
+ model = ssd300_vgg16(pretrained=True) # Multi-box Algorithm
19
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
20
+ # model.load_state_dict(torch.load("./mosquito_model.pth", map_location=torch.device('cpu')))
21
+ model.eval()
22
+ st.session_state.model = model
 
 
 
 
23
 
24
  # Capture Image from Camera
25
  captured_image = st.camera_input("Take a picture")
 
35
  ])
36
  image_tensor = transform(image).unsqueeze(0)
37
 
38
+ # Convert frame for model
39
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
40
+ image_tensor = transform(image).unsqueeze(0)
41
+
42
+ with torch.no_grad():
43
+ output = st.session_state.model(image_tensor)[0]
44
+
45
+ # Draw detections
46
+ for box, label in zip(output["boxes"].cpu().numpy(), output["labels"].cpu().numpy()):
47
+ x_min, y_min, x_max, y_max = map(int, box)
48
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
49
+ cv2.putText(frame, classes[label - 1], (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
50
+
51
+ # Display frame
52
+ stframe.image(frame, channels="BGR")
53
+
54
+ cap.release()
55
+ cv2.destroyAllWindows()