Spaces:
Build error
Build error
| import streamlit as st | |
| from ultralytics import YOLO | |
| import cv2 | |
| import time | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import tempfile | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| def get_direction(old_center, new_center, min_movement=10): | |
| if old_center is None or new_center is None: | |
| return "stationary" | |
| dx = new_center[0] - old_center[0] | |
| dy = new_center[1] - old_center[1] | |
| if abs(dx) < min_movement and abs(dy) < min_movement: | |
| return "stationary" | |
| if abs(dx) > abs(dy): | |
| return "right" if dx > 0 else "left" | |
| else: | |
| return "down" if dy > 0 else "up" | |
| class ObjectTracker: | |
| def __init__(self): | |
| self.tracked_objects = {} | |
| self.object_count = {} | |
| def update(self, detections): | |
| current_objects = {} | |
| results = [] | |
| for detection in detections: | |
| x1, y1, x2, y2 = detection[0:4] | |
| center = ((x1 + x2) // 2, (y1 + y2) // 2) | |
| class_id = detection[5] | |
| object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}" | |
| min_dist = float('inf') | |
| closest_id = None | |
| for prev_id, prev_data in self.tracked_objects.items(): | |
| if prev_id.split('_')[0] == str(class_id): | |
| dist = np.sqrt((center[0] - prev_data['center'][0])**2 + | |
| (center[1] - prev_data['center'][1])**2) | |
| if dist < min_dist and dist < 100: | |
| min_dist = dist | |
| closest_id = prev_id | |
| if closest_id: | |
| object_id = closest_id | |
| else: | |
| if class_id not in self.object_count: | |
| self.object_count[class_id] = [] | |
| self.object_count[class_id].append(object_id) | |
| prev_center = self.tracked_objects.get(object_id, {}).get('center', None) | |
| direction = get_direction(prev_center, center) | |
| current_objects[object_id] = { | |
| 'center': center, | |
| 'direction': direction, | |
| 'detection': detection | |
| } | |
| results.append((detection, object_id, direction)) | |
| self.tracked_objects = current_objects | |
| return results | |
| def main(): | |
| st.title("Real-time Object Detection with Direction") | |
| # File uploader for video | |
| uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov']) | |
| # Add start button | |
| start_detection = st.button("Start Detection") | |
| # Add stop button | |
| stop_detection = st.button("Stop Detection") | |
| if uploaded_file is not None and start_detection: | |
| # Create a session state to track if detection is running | |
| if 'running' not in st.session_state: | |
| st.session_state.running = True | |
| # Save uploaded file temporarily | |
| tfile = tempfile.NamedTemporaryFile(delete=False) | |
| tfile.write(uploaded_file.read()) | |
| # Load model | |
| with st.spinner('Loading model...'): | |
| model = YOLO('yolov8x.pt',verbose=False) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| tracker = ObjectTracker() | |
| cap = cv2.VideoCapture(tfile.name) | |
| direction_colors = { | |
| "left": (255, 0, 0), | |
| "right": (0, 255, 0), | |
| "up": (0, 255, 255), | |
| "down": (0, 0, 255), | |
| "stationary": (128, 128, 128) | |
| } | |
| # Create placeholder for video frame | |
| frame_placeholder = st.empty() | |
| # Create placeholder for detection info | |
| info_placeholder = st.empty() | |
| st.success("Detection Started!") | |
| while cap.isOpened() and st.session_state.running: | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| # Run detection | |
| results = model(frame, | |
| conf=0.25, | |
| iou=0.45, | |
| max_det=20, | |
| verbose=False)[0] | |
| detections = [] | |
| for box in results.boxes.data: | |
| x1, y1, x2, y2, conf, cls = box.tolist() | |
| detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)]) | |
| tracked_objects = tracker.update(detections) | |
| # Dictionary to store detection counts | |
| detection_counts = {} | |
| for detection, obj_id, direction in tracked_objects: | |
| x1, y1, x2, y2, conf, cls = detection | |
| color = direction_colors.get(direction, (128, 128, 128)) | |
| cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) | |
| label = f"{model.names[int(cls)]} {direction} {conf:.2f}" | |
| # Increased font size and thickness | |
| font_scale = 1.2 | |
| thickness = 3 | |
| text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0] | |
| # Increased padding for label background | |
| padding_y = 15 | |
| cv2.rectangle(frame, | |
| (int(x1), int(y1) - text_size[1] - padding_y), | |
| (int(x1) + text_size[0], int(y1)), | |
| color, -1) | |
| cv2.putText(frame, label, | |
| (int(x1), int(y1) - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale, | |
| (255, 255, 255), | |
| thickness) | |
| # Count detections by class | |
| class_name = model.names[int(cls)] | |
| detection_counts[class_name] = detection_counts.get(class_name, 0) + 1 | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Update frame | |
| frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True) | |
| # Update detection info | |
| info_text = "Detected Objects:\n" | |
| for class_name, count in detection_counts.items(): | |
| info_text += f"{class_name}: {count}\n" | |
| info_placeholder.text(info_text) | |
| # Check if stop button is pressed | |
| if stop_detection: | |
| st.session_state.running = False | |
| break | |
| cap.release() | |
| st.session_state.running = False | |
| st.warning("Detection Stopped") | |
| elif uploaded_file is None and start_detection: | |
| st.error("Please upload a video file first!") | |
| if __name__ == "__main__": | |
| main() |