DhominickJ commited on
Commit
b2dec6f
·
1 Parent(s): 64d2956

Initial implementation of MosqScope

Browse files
Files changed (1) hide show
  1. app.py +45 -30
app.py CHANGED
@@ -7,19 +7,27 @@ import streamlit as st
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
 
 
 
 
10
  # Define dataset classes
11
  classes = ['dengue-regions', 'wet_surface']
12
- num_classes = len(classes) + 1 # Including background
13
-
14
- # Load Model
15
- st.title("Real-Time SSD Object Detection")
16
- if 'model' not in st.session_state:
17
- model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
18
- model = ssd300_vgg16(pretrained=True) # Multi-box Algorithm
19
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
20
- # model.load_state_dict(torch.load("./mosquito_model.pth", map_location=torch.device('cpu')))
21
- model.eval()
22
- st.session_state.model = model
 
 
 
 
23
 
24
  # Capture Image from Camera
25
  captured_image = st.camera_input("Take a picture")
@@ -30,26 +38,33 @@ if captured_image is not None:
30
 
31
  # Transform the image for SSD model
32
  transform = transforms.Compose([
33
- transforms.Resize((300, 300)),
34
  transforms.ToTensor()
35
  ])
36
  image_tensor = transform(image).unsqueeze(0)
37
 
38
- # Convert frame for model
39
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
40
- image_tensor = transform(image).unsqueeze(0)
41
-
42
- with torch.no_grad():
43
- output = st.session_state.model(image_tensor)[0]
44
-
45
- # Draw detections
46
- for box, label in zip(output["boxes"].cpu().numpy(), output["labels"].cpu().numpy()):
47
- x_min, y_min, x_max, y_max = map(int, box)
48
- cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
49
- cv2.putText(frame, classes[label - 1], (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
50
-
51
- # Display frame
52
- stframe.image(frame, channels="BGR")
53
-
54
- cap.release()
55
- cv2.destroyAllWindows()
 
 
 
 
 
 
 
 
7
  from PIL import Image
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # Configure Streamlit UI
11
+ st.title("Mosquito Detection from Camera Capture")
12
+ st.write("Take a picture to detect mosquito breeding sites using SSD.")
13
+
14
  # Define dataset classes
15
  classes = ['dengue-regions', 'wet_surface']
16
+
17
+ # Load the SSD Model
18
+ @st.cache_resource
19
+ def load_model():
20
+ try:
21
+ model_path = hf_hub_download(repo_id="DhominickJ/MosqScope", filename="mosquito_model.pth")
22
+ model = ssd300_vgg16(pretrained=False)
23
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
24
+ model.eval()
25
+ return model
26
+ except Exception as e:
27
+ st.error(f"Error loading model: {str(e)}")
28
+ return None
29
+
30
+ model = load_model()
31
 
32
  # Capture Image from Camera
33
  captured_image = st.camera_input("Take a picture")
 
38
 
39
  # Transform the image for SSD model
40
  transform = transforms.Compose([
41
+ transforms.Resize((800, 800)),
42
  transforms.ToTensor()
43
  ])
44
  image_tensor = transform(image).unsqueeze(0)
45
 
46
+ if model is not None:
47
+ with torch.no_grad():
48
+ detections = model(image_tensor)
49
+
50
+ boxes = detections[0]['boxes'].cpu().numpy()
51
+ scores = detections[0]['scores'].cpu().numpy()
52
+ labels = detections[0]['labels'].cpu().numpy()
53
+
54
+ # Convert image to OpenCV format
55
+ image_np = np.array(image)
56
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
57
+
58
+ # Draw detections
59
+ for box, label, score in zip(boxes, labels, scores):
60
+ if score > 0.5: # Confidence threshold
61
+ x_min, y_min, x_max, y_max = map(int, box)
62
+ cv2.rectangle(image_np, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
63
+ label_name = classes[label - 1]
64
+ cv2.putText(image_np, f"{label_name} {score:.2f}", (x_min, y_min - 5),
65
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
66
+
67
+ # Convert image back to RGB for Streamlit display
68
+ st.image(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB), caption="Detected Objects", use_column_width=True)
69
+ else:
70
+ st.warning("Model not loaded. Unable to process image.")