Spaces:
Running
Running
Added YOLO model.
Browse files- .gitignore +3 -1
- app.py +300 -4
- requirements.txt +2 -1
- utils/helpers.py +49 -0
- utils/onnx_inference.py +47 -0
.gitignore
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
**/__pycache__
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__
|
| 2 |
+
*.onnx
|
| 3 |
+
*.pth
|
app.py
CHANGED
|
@@ -17,6 +17,9 @@ from utils.helpers import calculate_deforestation_metrics, create_overlay
|
|
| 17 |
from utils.audio_processing import preprocess_audio
|
| 18 |
from utils.audio_model import load_audio_model, predict_audio, class_names
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
# Ensure torch classes path is initialized to avoid warnings
|
| 21 |
torch.classes.__path__ = []
|
| 22 |
|
|
@@ -31,12 +34,15 @@ st.set_page_config(
|
|
| 31 |
# Constants
|
| 32 |
DEFOREST_MODEL_INPUT_SIZE = 256
|
| 33 |
AUDIO_MODEL_PATH = "models/best_model.pth"
|
|
|
|
| 34 |
|
| 35 |
# Initialize session state for navigation
|
| 36 |
if 'current_service' not in st.session_state:
|
| 37 |
st.session_state.current_service = 'deforestation'
|
| 38 |
if 'audio_input_method' not in st.session_state:
|
| 39 |
st.session_state.audio_input_method = 'upload'
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Sidebar for navigation
|
| 42 |
with st.sidebar:
|
|
@@ -45,9 +51,15 @@ with st.sidebar:
|
|
| 45 |
|
| 46 |
selected_service = st.radio(
|
| 47 |
"Select Service:",
|
| 48 |
-
["Deforestation Detection", "Forest Audio Surveillance"]
|
| 49 |
)
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
st.markdown("---")
|
| 53 |
|
|
@@ -60,7 +72,7 @@ with st.sidebar:
|
|
| 60 |
Upload satellite or aerial images to detect areas of deforestation.
|
| 61 |
"""
|
| 62 |
)
|
| 63 |
-
|
| 64 |
st.info(
|
| 65 |
"""
|
| 66 |
**Forest Audio Surveillance**
|
|
@@ -92,6 +104,44 @@ with st.sidebar:
|
|
| 92 |
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
| 93 |
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
| 94 |
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# Load deforestation model
|
| 97 |
@st.cache_resource
|
|
@@ -104,6 +154,10 @@ def load_cached_deforestation_model():
|
|
| 104 |
def load_cached_audio_model():
|
| 105 |
return load_audio_model(AUDIO_MODEL_PATH)
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Process image for deforestation detection
|
| 108 |
def process_image(model, image):
|
| 109 |
"""Process a single image and return results"""
|
|
@@ -379,13 +433,255 @@ def show_audio_classification():
|
|
| 379 |
else:
|
| 380 |
st.write("Waiting for recording...")
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
# Main function
|
| 383 |
def main():
|
| 384 |
# Check which service is selected and render appropriate UI
|
| 385 |
if st.session_state.current_service == 'deforestation':
|
| 386 |
show_deforestation_detection()
|
| 387 |
-
|
| 388 |
show_audio_classification()
|
|
|
|
|
|
|
| 389 |
|
| 390 |
# Footer
|
| 391 |
st.markdown("---")
|
|
|
|
| 17 |
from utils.audio_processing import preprocess_audio
|
| 18 |
from utils.audio_model import load_audio_model, predict_audio, class_names
|
| 19 |
|
| 20 |
+
# Import YOLO detection modules
|
| 21 |
+
from utils.onnx_inference import YOLOv11
|
| 22 |
+
|
| 23 |
# Ensure torch classes path is initialized to avoid warnings
|
| 24 |
torch.classes.__path__ = []
|
| 25 |
|
|
|
|
| 34 |
# Constants
|
| 35 |
DEFOREST_MODEL_INPUT_SIZE = 256
|
| 36 |
AUDIO_MODEL_PATH = "models/best_model.pth"
|
| 37 |
+
YOLO_MODEL_PATH = "models/best_model.onnx"
|
| 38 |
|
| 39 |
# Initialize session state for navigation
|
| 40 |
if 'current_service' not in st.session_state:
|
| 41 |
st.session_state.current_service = 'deforestation'
|
| 42 |
if 'audio_input_method' not in st.session_state:
|
| 43 |
st.session_state.audio_input_method = 'upload'
|
| 44 |
+
if 'detection_input_method' not in st.session_state:
|
| 45 |
+
st.session_state.detection_input_method = 'image'
|
| 46 |
|
| 47 |
# Sidebar for navigation
|
| 48 |
with st.sidebar:
|
|
|
|
| 51 |
|
| 52 |
selected_service = st.radio(
|
| 53 |
"Select Service:",
|
| 54 |
+
["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
|
| 55 |
)
|
| 56 |
+
|
| 57 |
+
if selected_service == "Deforestation Detection":
|
| 58 |
+
st.session_state.current_service = 'deforestation'
|
| 59 |
+
elif selected_service == "Forest Audio Surveillance":
|
| 60 |
+
st.session_state.current_service = 'audio'
|
| 61 |
+
else:
|
| 62 |
+
st.session_state.current_service = 'detection'
|
| 63 |
|
| 64 |
st.markdown("---")
|
| 65 |
|
|
|
|
| 72 |
Upload satellite or aerial images to detect areas of deforestation.
|
| 73 |
"""
|
| 74 |
)
|
| 75 |
+
elif st.session_state.current_service == 'audio':
|
| 76 |
st.info(
|
| 77 |
"""
|
| 78 |
**Forest Audio Surveillance**
|
|
|
|
| 104 |
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
| 105 |
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
| 106 |
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
| 107 |
+
else: # Object Detection
|
| 108 |
+
st.info(
|
| 109 |
+
"""
|
| 110 |
+
**Object Detection**
|
| 111 |
+
|
| 112 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
| 113 |
+
"""
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Detection service specific controls
|
| 117 |
+
st.subheader("Detection Configuration")
|
| 118 |
+
detection_input_method = st.radio(
|
| 119 |
+
"Select Input Method:",
|
| 120 |
+
("Image", "Video", "Camera"),
|
| 121 |
+
index=0 if st.session_state.detection_input_method == 'image' else
|
| 122 |
+
(1 if st.session_state.detection_input_method == 'video' else 2)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if detection_input_method == "Image":
|
| 126 |
+
st.session_state.detection_input_method = 'image'
|
| 127 |
+
elif detection_input_method == "Video":
|
| 128 |
+
st.session_state.detection_input_method = 'video'
|
| 129 |
+
else:
|
| 130 |
+
st.session_state.detection_input_method = 'camera'
|
| 131 |
+
|
| 132 |
+
# Detection threshold controls
|
| 133 |
+
st.subheader("Detection Settings")
|
| 134 |
+
confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
|
| 135 |
+
iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
|
| 136 |
+
|
| 137 |
+
# Detection class information
|
| 138 |
+
st.markdown("**Detection Classes:**")
|
| 139 |
+
st.markdown("🚴 **Bike/Bicycle**")
|
| 140 |
+
st.markdown("🚚 **Bus/Truck**")
|
| 141 |
+
st.markdown("🚗 **Car**")
|
| 142 |
+
st.markdown("🔥 **Fire**")
|
| 143 |
+
st.markdown("👤 **Human**")
|
| 144 |
+
st.markdown("💨 **Smoke**")
|
| 145 |
|
| 146 |
# Load deforestation model
|
| 147 |
@st.cache_resource
|
|
|
|
| 154 |
def load_cached_audio_model():
|
| 155 |
return load_audio_model(AUDIO_MODEL_PATH)
|
| 156 |
|
| 157 |
+
@st.cache_resource
|
| 158 |
+
def load_cached_yolo_model():
|
| 159 |
+
return YOLOv11(YOLO_MODEL_PATH)
|
| 160 |
+
|
| 161 |
# Process image for deforestation detection
|
| 162 |
def process_image(model, image):
|
| 163 |
"""Process a single image and return results"""
|
|
|
|
| 433 |
else:
|
| 434 |
st.write("Waiting for recording...")
|
| 435 |
|
| 436 |
+
# Object Detection UI
|
| 437 |
+
def show_object_detection():
|
| 438 |
+
# App title and description
|
| 439 |
+
st.title("🔍 Forest Object Detection")
|
| 440 |
+
st.markdown(
|
| 441 |
+
"""
|
| 442 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
| 443 |
+
Choose an input method to begin detection.
|
| 444 |
+
"""
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Model info
|
| 448 |
+
st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
|
| 449 |
+
|
| 450 |
+
# Load model
|
| 451 |
+
try:
|
| 452 |
+
model = load_cached_yolo_model()
|
| 453 |
+
# Update model confidence and IoU thresholds from sidebar
|
| 454 |
+
confidence = st.session_state.get('confidence', 0.5)
|
| 455 |
+
iou_thres = st.session_state.get('iou_thres', 0.5)
|
| 456 |
+
model.conf_thres = confidence
|
| 457 |
+
model.iou_thres = iou_thres
|
| 458 |
+
except Exception as e:
|
| 459 |
+
st.error(f"Error loading model: {e}")
|
| 460 |
+
st.info(
|
| 461 |
+
"Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
|
| 462 |
+
)
|
| 463 |
+
return
|
| 464 |
+
|
| 465 |
+
# Input method based selection
|
| 466 |
+
if st.session_state.detection_input_method == 'image':
|
| 467 |
+
# Image upload
|
| 468 |
+
img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
|
| 469 |
+
if img_file is not None:
|
| 470 |
+
# Load image
|
| 471 |
+
file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
|
| 472 |
+
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
| 473 |
+
if image is not None:
|
| 474 |
+
# Display original image
|
| 475 |
+
st.subheader("Original Image")
|
| 476 |
+
st.image(
|
| 477 |
+
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
| 478 |
+
caption="Uploaded Image",
|
| 479 |
+
use_container_width=True,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Process with detection model
|
| 483 |
+
with st.spinner("Processing image..."):
|
| 484 |
+
try:
|
| 485 |
+
detections = model.detect(image)
|
| 486 |
+
result_image = model.draw_detections(image.copy(), detections)
|
| 487 |
+
|
| 488 |
+
# Display results
|
| 489 |
+
st.subheader("Detection Results")
|
| 490 |
+
st.image(
|
| 491 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
| 492 |
+
caption="Detected Objects",
|
| 493 |
+
use_container_width=True,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Display detection statistics
|
| 497 |
+
st.subheader("Detection Statistics")
|
| 498 |
+
|
| 499 |
+
# Count detections by class
|
| 500 |
+
class_counts = {}
|
| 501 |
+
for det in detections:
|
| 502 |
+
class_name = det['class']
|
| 503 |
+
if class_name in class_counts:
|
| 504 |
+
class_counts[class_name] += 1
|
| 505 |
+
else:
|
| 506 |
+
class_counts[class_name] = 1
|
| 507 |
+
|
| 508 |
+
# Display counts with emojis
|
| 509 |
+
cols = st.columns(3)
|
| 510 |
+
col_idx = 0
|
| 511 |
+
|
| 512 |
+
for class_name, count in class_counts.items():
|
| 513 |
+
emoji = "👤" if class_name == "human" else (
|
| 514 |
+
"🔥" if class_name == "fire" else (
|
| 515 |
+
"💨" if class_name == "smoke" else (
|
| 516 |
+
"🚗" if class_name == "car" else (
|
| 517 |
+
"🚴" if class_name == "bike-bicycle" else "🚚"))))
|
| 518 |
+
|
| 519 |
+
with cols[col_idx % 3]:
|
| 520 |
+
st.metric(f"{emoji} {class_name.capitalize()}", count)
|
| 521 |
+
col_idx += 1
|
| 522 |
+
|
| 523 |
+
# Check for priority threats
|
| 524 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
| 525 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
|
| 526 |
+
|
| 527 |
+
if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
|
| 528 |
+
st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
|
| 529 |
+
|
| 530 |
+
except Exception as e:
|
| 531 |
+
st.error(f"Error during detection: {e}")
|
| 532 |
+
st.exception(e)
|
| 533 |
+
|
| 534 |
+
elif st.session_state.detection_input_method == 'video':
|
| 535 |
+
# Video upload
|
| 536 |
+
video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
|
| 537 |
+
if video_file is not None:
|
| 538 |
+
# Save uploaded video to temp file
|
| 539 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
|
| 540 |
+
tfile.write(video_file.read())
|
| 541 |
+
temp_video_path = tfile.name
|
| 542 |
+
|
| 543 |
+
# Display video upload success
|
| 544 |
+
st.success("Video uploaded successfully!")
|
| 545 |
+
|
| 546 |
+
# Process video button
|
| 547 |
+
if st.button("Process Video"):
|
| 548 |
+
with st.spinner("Processing video... This may take a while."):
|
| 549 |
+
try:
|
| 550 |
+
# Open video file
|
| 551 |
+
cap = cv2.VideoCapture(temp_video_path)
|
| 552 |
+
|
| 553 |
+
# Create video writer for output
|
| 554 |
+
output_path = "output_video.mp4"
|
| 555 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 556 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 557 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 558 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 559 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 560 |
+
|
| 561 |
+
# Create placeholder for video frames
|
| 562 |
+
video_placeholder = st.empty()
|
| 563 |
+
status_text = st.empty()
|
| 564 |
+
|
| 565 |
+
# Process frames
|
| 566 |
+
frame_count = 0
|
| 567 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 568 |
+
|
| 569 |
+
while cap.isOpened():
|
| 570 |
+
ret, frame = cap.read()
|
| 571 |
+
if not ret:
|
| 572 |
+
break
|
| 573 |
+
|
| 574 |
+
# Process every 5th frame for speed
|
| 575 |
+
if frame_count % 5 == 0:
|
| 576 |
+
detections = model.detect(frame)
|
| 577 |
+
result_frame = model.draw_detections(frame.copy(), detections)
|
| 578 |
+
|
| 579 |
+
# Update preview
|
| 580 |
+
if frame_count % 15 == 0: # Update display less frequently
|
| 581 |
+
video_placeholder.image(
|
| 582 |
+
cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
|
| 583 |
+
caption="Processing Video",
|
| 584 |
+
use_container_width=True
|
| 585 |
+
)
|
| 586 |
+
progress = min(100, int((frame_count / total_frames) * 100))
|
| 587 |
+
status_text.text(f"Processing: {progress}% complete")
|
| 588 |
+
else:
|
| 589 |
+
result_frame = frame # Skip detection on some frames
|
| 590 |
+
|
| 591 |
+
# Write frame to output video
|
| 592 |
+
out.write(result_frame)
|
| 593 |
+
frame_count += 1
|
| 594 |
+
|
| 595 |
+
# Release resources
|
| 596 |
+
cap.release()
|
| 597 |
+
out.release()
|
| 598 |
+
|
| 599 |
+
# Display completion message
|
| 600 |
+
st.success("Video processing complete!")
|
| 601 |
+
|
| 602 |
+
# Provide download button for processed video
|
| 603 |
+
with open(output_path, "rb") as file:
|
| 604 |
+
st.download_button(
|
| 605 |
+
label="Download Processed Video",
|
| 606 |
+
data=file,
|
| 607 |
+
file_name="forest_surveillance_results.mp4",
|
| 608 |
+
mime="video/mp4"
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
except Exception as e:
|
| 612 |
+
st.error(f"Error processing video: {e}")
|
| 613 |
+
st.exception(e)
|
| 614 |
+
finally:
|
| 615 |
+
# Clean up temp file
|
| 616 |
+
try:
|
| 617 |
+
os.unlink(temp_video_path)
|
| 618 |
+
except:
|
| 619 |
+
pass
|
| 620 |
+
|
| 621 |
+
else: # Camera mode
|
| 622 |
+
# Live camera feed
|
| 623 |
+
st.subheader("Live Camera Detection")
|
| 624 |
+
st.info("Use your webcam to detect objects in real-time")
|
| 625 |
+
|
| 626 |
+
cam = st.camera_input("Camera Feed")
|
| 627 |
+
|
| 628 |
+
if cam:
|
| 629 |
+
# Process camera input
|
| 630 |
+
with st.spinner("Processing image..."):
|
| 631 |
+
try:
|
| 632 |
+
# Convert image
|
| 633 |
+
image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
|
| 634 |
+
|
| 635 |
+
# Run detection
|
| 636 |
+
detections = model.detect(image)
|
| 637 |
+
result_image = model.draw_detections(image.copy(), detections)
|
| 638 |
+
|
| 639 |
+
# Display results
|
| 640 |
+
st.image(
|
| 641 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
| 642 |
+
caption="Detection Results",
|
| 643 |
+
use_container_width=True
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Show detection summary
|
| 647 |
+
if detections:
|
| 648 |
+
# Count detections by class
|
| 649 |
+
class_counts = {}
|
| 650 |
+
for det in detections:
|
| 651 |
+
class_name = det['class']
|
| 652 |
+
if class_name in class_counts:
|
| 653 |
+
class_counts[class_name] += 1
|
| 654 |
+
else:
|
| 655 |
+
class_counts[class_name] = 1
|
| 656 |
+
|
| 657 |
+
# Display as metrics
|
| 658 |
+
st.subheader("Detection Summary")
|
| 659 |
+
cols = st.columns(3)
|
| 660 |
+
for i, (class_name, count) in enumerate(class_counts.items()):
|
| 661 |
+
with cols[i % 3]:
|
| 662 |
+
st.metric(class_name.capitalize(), count)
|
| 663 |
+
|
| 664 |
+
# Check for priority threats
|
| 665 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
| 666 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
|
| 667 |
+
|
| 668 |
+
if "human" in class_counts:
|
| 669 |
+
st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
|
| 670 |
+
else:
|
| 671 |
+
st.info("No objects detected in frame")
|
| 672 |
+
|
| 673 |
+
except Exception as e:
|
| 674 |
+
st.error(f"Error processing camera feed: {e}")
|
| 675 |
+
|
| 676 |
# Main function
|
| 677 |
def main():
|
| 678 |
# Check which service is selected and render appropriate UI
|
| 679 |
if st.session_state.current_service == 'deforestation':
|
| 680 |
show_deforestation_detection()
|
| 681 |
+
elif st.session_state.current_service == 'audio':
|
| 682 |
show_audio_classification()
|
| 683 |
+
else: # 'detection'
|
| 684 |
+
show_object_detection()
|
| 685 |
|
| 686 |
# Footer
|
| 687 |
st.markdown("---")
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ onnxruntime-gpu
|
|
| 13 |
onnx
|
| 14 |
librosa
|
| 15 |
soundfile
|
| 16 |
-
pydub
|
|
|
|
|
|
| 13 |
onnx
|
| 14 |
librosa
|
| 15 |
soundfile
|
| 16 |
+
pydub
|
| 17 |
+
supervision
|
utils/helpers.py
CHANGED
|
@@ -71,3 +71,52 @@ def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
|
|
| 71 |
overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
|
| 72 |
|
| 73 |
return overlay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
|
| 72 |
|
| 73 |
return overlay
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
CLASS_NAMES = ['bike-bicycle', 'bus-truck', 'car', 'fire', 'human', 'smoke']
|
| 77 |
+
COLORS = np.random.uniform(0, 255, size=(len(CLASS_NAMES), 3))
|
| 78 |
+
|
| 79 |
+
def preprocess(image, img_size=640):
|
| 80 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 81 |
+
image = cv2.resize(image, (img_size, img_size))
|
| 82 |
+
image = image.transpose((2, 0, 1)) # HWC to CHW
|
| 83 |
+
image = np.ascontiguousarray(image, dtype=np.float32) / 255.0
|
| 84 |
+
return image[np.newaxis, ...]
|
| 85 |
+
|
| 86 |
+
def postprocess(outputs, conf_thresh=0.5, iou_thresh=0.5):
|
| 87 |
+
outputs = outputs[0].transpose()
|
| 88 |
+
boxes, scores, class_ids = [], [], []
|
| 89 |
+
|
| 90 |
+
for row in outputs:
|
| 91 |
+
cls_scores = row[4:4+len(CLASS_NAMES)]
|
| 92 |
+
class_id = np.argmax(cls_scores)
|
| 93 |
+
max_score = cls_scores[class_id]
|
| 94 |
+
|
| 95 |
+
if max_score >= conf_thresh:
|
| 96 |
+
cx, cy, w, h = row[:4]
|
| 97 |
+
x = (cx - w/2).item() # Convert to Python float
|
| 98 |
+
y = (cy - h/2).item()
|
| 99 |
+
width = w.item()
|
| 100 |
+
height = h.item()
|
| 101 |
+
boxes.append([x, y, width, height])
|
| 102 |
+
scores.append(float(max_score))
|
| 103 |
+
class_ids.append(int(class_id))
|
| 104 |
+
|
| 105 |
+
if len(boxes) > 0:
|
| 106 |
+
# Convert to list of lists with native Python floats
|
| 107 |
+
boxes = [[float(x) for x in box] for box in boxes]
|
| 108 |
+
scores = [float(score) for score in scores]
|
| 109 |
+
|
| 110 |
+
indices = cv2.dnn.NMSBoxes(
|
| 111 |
+
bboxes=boxes,
|
| 112 |
+
scores=scores,
|
| 113 |
+
score_threshold=conf_thresh,
|
| 114 |
+
nms_threshold=iou_thresh
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if len(indices) > 0:
|
| 118 |
+
boxes = [boxes[i] for i in indices.flatten()]
|
| 119 |
+
scores = [scores[i] for i in indices.flatten()]
|
| 120 |
+
class_ids = [class_ids[i] for i in indices.flatten()]
|
| 121 |
+
|
| 122 |
+
return boxes, scores, class_ids
|
utils/onnx_inference.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import onnxruntime as ort
|
| 4 |
+
from .helpers import CLASS_NAMES, COLORS, preprocess, postprocess
|
| 5 |
+
|
| 6 |
+
class YOLOv11:
|
| 7 |
+
def __init__(self, onnx_path, conf_thres=0.5, iou_thres=0.5):
|
| 8 |
+
self.session = ort.InferenceSession(onnx_path)
|
| 9 |
+
self.conf_thres = conf_thres
|
| 10 |
+
self.iou_thres = iou_thres
|
| 11 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 12 |
+
self.output_name = self.session.get_outputs()[0].name
|
| 13 |
+
|
| 14 |
+
# Verify input type
|
| 15 |
+
input_type = self.session.get_inputs()[0].type
|
| 16 |
+
assert "float" in input_type, f"Model expects {input_type}"
|
| 17 |
+
|
| 18 |
+
def detect(self, image):
|
| 19 |
+
orig_h, orig_w = image.shape[:2]
|
| 20 |
+
blob = preprocess(image)
|
| 21 |
+
outputs = self.session.run([self.output_name], {self.input_name: blob})
|
| 22 |
+
boxes, scores, class_ids = postprocess(outputs, self.conf_thres, self.iou_thres)
|
| 23 |
+
|
| 24 |
+
results = []
|
| 25 |
+
for box, score, class_id in zip(boxes, scores, class_ids):
|
| 26 |
+
x, y, w, h = box
|
| 27 |
+
x1 = int(x * orig_w / 640)
|
| 28 |
+
y1 = int(y * orig_h / 640)
|
| 29 |
+
x2 = int((x + w) * orig_w / 640)
|
| 30 |
+
y2 = int((y + h) * orig_h / 640)
|
| 31 |
+
|
| 32 |
+
results.append({
|
| 33 |
+
'class': CLASS_NAMES[class_id],
|
| 34 |
+
'confidence': score,
|
| 35 |
+
'box': [x1, y1, x2, y2]
|
| 36 |
+
})
|
| 37 |
+
return results
|
| 38 |
+
|
| 39 |
+
def draw_detections(self, image, detections):
|
| 40 |
+
for det in detections:
|
| 41 |
+
x1, y1, x2, y2 = det['box']
|
| 42 |
+
color = COLORS[CLASS_NAMES.index(det['class'])]
|
| 43 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
| 44 |
+
label = f"{det['class']}: {det['confidence']:.2f}"
|
| 45 |
+
cv2.putText(image, label, (x1, y1 - 10),
|
| 46 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
| 47 |
+
return image
|