Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import requests | |
| import json | |
| import base64 | |
| from PIL import Image | |
| import io | |
| import os | |
| from dotenv import load_dotenv | |
| from collections import defaultdict | |
| import time | |
| # Load environment variables | |
| load_dotenv() | |
| # Define API endpoint from environment variable | |
| API_URL = os.getenv("API_URL", "http://122.155.170.240:81") | |
| print(f"Using API URL: {API_URL}") | |
| DEFAULT_CONFIDENCE = float(os.getenv("DEFAULT_CONFIDENCE_THRESHOLD", "0.25")) | |
| def calculate_iou(box1, box2): | |
| """Calculate Intersection over Union (IoU) between two bounding boxes""" | |
| x1 = max(box1[0], box2[0]) | |
| y1 = max(box1[1], box2[1]) | |
| x2 = min(box1[2], box2[2]) | |
| y2 = min(box1[3], box2[3]) | |
| intersection = max(0, x2 - x1) * max(0, y2 - y1) | |
| area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0 | |
| def calculate_bbox_similarity(bbox1, bbox2): | |
| """Calculate similarity between two bounding boxes using IoU and center distance""" | |
| try: | |
| # Calculate IoU | |
| iou = calculate_iou(bbox1, bbox2) | |
| # Calculate center distance | |
| center1 = get_box_center(bbox1) | |
| center2 = get_box_center(bbox2) | |
| if center1 is None or center2 is None: | |
| return 0.0 | |
| distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) | |
| # Normalize distance based on bbox size | |
| bbox_size = max(bbox1[2] - bbox1[0], bbox1[3] - bbox1[1]) | |
| normalized_distance = distance / max(bbox_size, 1) | |
| # Combine IoU and distance for final similarity score | |
| similarity = iou * 0.7 + max(0, 1 - normalized_distance * 0.3) * 0.3 | |
| return similarity | |
| except Exception as e: | |
| return 0.0 | |
| def get_box_center(bbox): | |
| """Calculate center point of bounding box""" | |
| try: | |
| # Handle different bbox formats (x,y,w,h) or (x1,y1,x2,y2) | |
| if len(bbox) == 4: | |
| if bbox[2] < bbox[0] or bbox[3] < bbox[1]: # If it's x1,y1,x2,y2 format | |
| x = (bbox[0] + bbox[2]) / 2 | |
| y = (bbox[1] + bbox[3]) / 2 | |
| else: # If it's x,y,w,h format | |
| x = bbox[0] + bbox[2]/2 | |
| y = bbox[1] + bbox[3]/2 | |
| else: | |
| return None | |
| return (x, y) | |
| except Exception as e: | |
| return None | |
| def calculate_movement(prev_center, curr_center, min_movement=10): | |
| """Calculate if there's significant movement between frames""" | |
| try: | |
| if prev_center is None or curr_center is None: | |
| return False | |
| dx = curr_center[0] - prev_center[0] | |
| dy = curr_center[1] - prev_center[1] | |
| distance = np.sqrt(dx*dx + dy*dy) | |
| return distance > min_movement | |
| except Exception as e: | |
| return False | |
| class TrackedObject: | |
| def __init__(self, obj_id, obj_class, bbox): | |
| self.id = obj_id | |
| self.class_name = obj_class | |
| self.trajectory = [] # List of center points | |
| self.bboxes = [] # List of bounding boxes | |
| self.counted = False | |
| self.last_seen = 0 # Frame number when last seen | |
| self.first_seen = 0 # Frame number when first seen | |
| self.frames_in_red_zone = 0 # Number of consecutive frames in red zone | |
| self.warning_triggered = False # Whether warning has been triggered | |
| self.red_zone_entry_frame = None # Frame when object entered red zone | |
| self.similarity_scores = [] # Track similarity scores over time | |
| self.add_detection(bbox) | |
| def add_detection(self, bbox): | |
| try: | |
| center = get_box_center(bbox) | |
| if center is not None: | |
| self.trajectory.append(center) | |
| self.bboxes.append(bbox) | |
| # Keep only recent history to prevent memory issues | |
| if len(self.trajectory) > 50: | |
| self.trajectory = self.trajectory[-25:] | |
| self.bboxes = self.bboxes[-25:] | |
| except Exception as e: | |
| pass | |
| def has_movement(self, min_movement=10): | |
| try: | |
| if len(self.trajectory) < 2: | |
| return False | |
| return calculate_movement(self.trajectory[-2], self.trajectory[-1], min_movement) | |
| except Exception as e: | |
| return False | |
| def update_red_zone_status(self, is_in_red_zone, frame_number): | |
| """Update red zone status and handle warnings""" | |
| if is_in_red_zone: | |
| if self.red_zone_entry_frame is None: | |
| self.red_zone_entry_frame = frame_number | |
| self.frames_in_red_zone += 1 | |
| # Check if warning should be triggered | |
| if self.frames_in_red_zone > 3 and not self.warning_triggered: | |
| self.warning_triggered = True | |
| return True # Return True to indicate warning should be shown | |
| else: | |
| # Object left red zone, reset counters | |
| self.frames_in_red_zone = 0 | |
| self.red_zone_entry_frame = None | |
| self.warning_triggered = False | |
| return False | |
| def get_similarity_with(self, other_bbox, similarity_threshold=0.5): | |
| """Calculate similarity with another bounding box""" | |
| if len(self.bboxes) == 0: | |
| return 0.0 | |
| current_bbox = self.bboxes[-1] | |
| return calculate_bbox_similarity(current_bbox, other_bbox) | |
| def is_similar_object(obj1, obj2, similarity_threshold=0.6): | |
| """Check if two objects are similar based on class, position and bounding box similarity""" | |
| try: | |
| if obj1['class'] != obj2['class']: | |
| return False | |
| box1 = obj1['bbox'] | |
| box2 = obj2['bbox'] | |
| # Convert to x1,y1,x2,y2 format if needed | |
| if len(box1) == 4 and len(box2) == 4: | |
| if box1[2] < box1[0] or box1[3] < box1[1]: # Already in x1,y1,x2,y2 | |
| bbox1 = box1 | |
| else: # Convert from x,y,w,h to x1,y1,x2,y2 | |
| bbox1 = [box1[0], box1[1], box1[0] + box1[2], box1[1] + box1[3]] | |
| if box2[2] < box2[0] or box2[3] < box2[1]: # Already in x1,y1,x2,y2 | |
| bbox2 = box2 | |
| else: # Convert from x,y,w,h to x1,y1,x2,y2 | |
| bbox2 = [box2[0], box2[1], box2[0] + box2[2], box2[1] + box2[3]] | |
| similarity = calculate_bbox_similarity(bbox1, bbox2) | |
| return similarity > similarity_threshold | |
| return False | |
| except Exception as e: | |
| return False | |
| # Global state for protection area and previous detections | |
| class State: | |
| def __init__(self): | |
| self.protection_points = [] # Store clicked points | |
| self.detected_segments = [] | |
| self.segment_image = None | |
| self.selected_segments = [] | |
| self.previous_detections = None | |
| self.cached_protection_area = None | |
| self.current_image = None # Store current image for drawing | |
| self.original_dims = None # Store original image dimensions | |
| self.display_dims = None # Store display dimensions | |
| self.tracked_objects = {} # Dictionary of tracked objects | |
| self.next_obj_id = 0 # Counter for generating unique object IDs | |
| self.object_count = defaultdict(int) # Count by class | |
| self.frame_count = 0 # Count processed frames | |
| self.red_zone_passed_objects = defaultdict(int) # Objects that passed through red zone | |
| self.red_zone_warnings = [] # Store warning messages | |
| self.time_window = 10 # Configurable time window for similarity comparison | |
| self.similarity_threshold = 0.6 # Configurable similarity threshold | |
| def reset_tracking(self): | |
| """Reset all tracking data""" | |
| self.tracked_objects = {} | |
| self.next_obj_id = 0 | |
| self.object_count = defaultdict(int) | |
| self.frame_count = 0 | |
| self.red_zone_passed_objects = defaultdict(int) | |
| self.red_zone_warnings = [] | |
| state = State() | |
| def image_to_bytes(image): | |
| """Convert PIL Image to bytes for API request""" | |
| # Log original image size | |
| original_width, original_height = image.size | |
| print(f"Original image dimensions: {original_width}x{original_height}") | |
| # Convert image to bytes without resizing | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| print(f"Sending image with original dimensions: {original_width}x{original_height}") | |
| return img_byte_arr.getvalue() | |
| def base64_to_image(base64_str): | |
| """Convert base64 string to OpenCV image""" | |
| img_data = base64.b64decode(base64_str) | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| return cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| def opencv_to_pil(opencv_image): | |
| """Convert OpenCV image to PIL format""" | |
| # Convert from BGR to RGB for PIL | |
| rgb_image = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(rgb_image) | |
| def scale_point_to_original(x, y): | |
| """Scale display coordinates back to original image coordinates""" | |
| if state.original_dims is None or state.display_dims is None: | |
| return x, y | |
| orig_w, orig_h = state.original_dims | |
| disp_w, disp_h = state.display_dims | |
| # Calculate scaling factors | |
| scale_x = orig_w / disp_w | |
| scale_y = orig_h / disp_h | |
| # Scale the coordinates | |
| orig_x = int(x * scale_x) | |
| orig_y = int(y * scale_y) | |
| return orig_x, orig_y | |
| def scale_points_to_display(points): | |
| """Scale points from original image coordinates to display coordinates""" | |
| if state.original_dims is None or state.display_dims is None: | |
| return points | |
| orig_w, orig_h = state.original_dims | |
| disp_w, disp_h = state.display_dims | |
| # Calculate scaling factors | |
| scale_x = disp_w / orig_w | |
| scale_y = disp_h / orig_h | |
| # Scale all points | |
| display_points = [] | |
| for point in points: | |
| x = int(point[0] * scale_x) | |
| y = int(point[1] * scale_y) | |
| display_points.append([x, y]) | |
| return display_points | |
| def draw_protection_area(image): | |
| """Draw protection area points and lines on the image""" | |
| img = image.copy() | |
| points = state.protection_points | |
| # Draw existing points and lines | |
| if len(points) > 0: | |
| # Convert points to numpy array | |
| points_array = np.array(points, dtype=np.int32) | |
| # Draw lines between points | |
| if len(points) > 1: | |
| cv2.polylines(img, [points_array], | |
| True if len(points) == 4 else False, | |
| (0, 255, 0), 2) | |
| # Draw points with numbers | |
| for i, point in enumerate(points): | |
| cv2.circle(img, tuple(point), 5, (0, 0, 255), -1) | |
| cv2.putText(img, str(i+1), | |
| (point[0]+10, point[1]+10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
| # Fill polygon with semi-transparent color if we have at least 3 points | |
| if len(points) >= 3: | |
| overlay = img.copy() | |
| cv2.fillPoly(overlay, [points_array], (0, 255, 0)) | |
| cv2.addWeighted(overlay, 0.3, img, 0.7, 0, img) | |
| return img | |
| def update_preview(video): | |
| if video is None: | |
| return None, [], gr.update(visible=False) | |
| cap = cv2.VideoCapture(video) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| # Reset state | |
| state.protection_points = [] | |
| state.detected_segments = [] | |
| state.segment_image = None | |
| state.selected_segments = [] | |
| state.previous_detections = None | |
| state.cached_protection_area = None | |
| # Store original frame and its dimensions | |
| state.current_image = frame.copy() # Store the original frame | |
| state.original_dims = (frame.shape[1], frame.shape[0]) # (width, height) | |
| state.display_dims = state.original_dims # Set display dims same as original | |
| # Convert to RGB without resizing | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| return frame_rgb, gr.update(choices=[], value=[], visible=False) | |
| return None, gr.update(choices=[], value=[], visible=False) | |
| def handle_image_click(evt: gr.SelectData, img): | |
| """Handle mouse clicks on the image""" | |
| if len(state.protection_points) >= 4: | |
| # Reset points if we already have 4 | |
| state.protection_points = [] | |
| if state.current_image is None: | |
| return img, "Error: No image loaded" | |
| # Get click coordinates from the event - these are now in original scale | |
| click_x, click_y = evt.index[0], evt.index[1] | |
| # Add point directly (no scaling needed as we're working with original coordinates) | |
| state.protection_points.append([click_x, click_y]) | |
| # Create a copy of the current image for display | |
| display_img = state.current_image.copy() | |
| # Draw points and lines | |
| for i, point in enumerate(state.protection_points): | |
| # Draw point | |
| cv2.circle(display_img, (point[0], point[1]), 5, (0, 0, 255), -1) | |
| cv2.putText(display_img, str(i+1), | |
| (point[0] + 10, point[1] + 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
| # Draw lines between points | |
| if len(state.protection_points) > 1: | |
| points_array = np.array(state.protection_points, dtype=np.int32) | |
| # Draw lines | |
| cv2.polylines(display_img, [points_array], | |
| True if len(state.protection_points) == 4 else False, | |
| (0, 255, 0), 2) | |
| # Fill polygon with semi-transparent color if we have at least 3 points | |
| if len(state.protection_points) >= 3: | |
| overlay = display_img.copy() | |
| cv2.fillPoly(overlay, [points_array], (0, 255, 0)) | |
| cv2.addWeighted(overlay, 0.3, display_img, 0.7, 0, display_img) | |
| # Convert to RGB for display | |
| display_img_rgb = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB) | |
| # Return the image and status | |
| return display_img_rgb, f"Selected {len(state.protection_points)} points\nCoordinates: {state.protection_points}" | |
| def reset_points(): | |
| """Reset protection points""" | |
| state.protection_points = [] | |
| if state.current_image is not None: | |
| # Convert original image to RGB for display | |
| display_img_rgb = cv2.cvtColor(state.current_image.copy(), cv2.COLOR_BGR2RGB) | |
| return display_img_rgb, "Points reset" | |
| return None, "Points reset" | |
| def detect_rail_segments(image): | |
| """Detect rail segments using the API""" | |
| try: | |
| # Log original image dimensions | |
| width, height = image.size | |
| print(f"Detecting rail segments on image with dimensions: {width}x{height}") | |
| files = {"file": image_to_bytes(image)} | |
| response = requests.post( | |
| f"{API_URL}/detect/rail-segment", | |
| files=files, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if "segments" in result: | |
| return result["segments"], base64_to_image(result["image_base64"]) | |
| else: | |
| return [], None | |
| else: | |
| print(f"API error: {response.status_code} - Image size was {width}x{height}") | |
| return [], None | |
| except Exception as e: | |
| print(f"Error in detect_rail_segments: {str(e)}") | |
| return [], None | |
| def extract_protection_area(first_frame): | |
| """Extract and cache protection area points using rail segment detection""" | |
| try: | |
| # Log original frame dimensions | |
| height, width = first_frame.shape[:2] | |
| print(f"Extracting protection area from frame with dimensions: {width}x{height}") | |
| # Convert frame to PIL Image without resizing | |
| first_frame_pil = Image.fromarray(cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)) | |
| # Verify PIL image dimensions | |
| pil_width, pil_height = first_frame_pil.size | |
| print(f"PIL Image dimensions before API call: {pil_width}x{pil_height}") | |
| # Detect rail segments | |
| segments, segment_img = detect_rail_segments(first_frame_pil) | |
| if segments and len(segments) > 0: | |
| # Verify segment image dimensions | |
| if segment_img is not None: | |
| seg_height, seg_width = segment_img.shape[:2] | |
| print(f"Received segment image dimensions: {seg_width}x{seg_height}") | |
| # Only resize if dimensions don't match | |
| if (seg_width, seg_height) != (width, height): | |
| print(f"Resizing segment image from {seg_width}x{seg_height} to {width}x{height}") | |
| segment_img = cv2.resize(segment_img, (width, height), interpolation=cv2.INTER_LANCZOS4) | |
| # Store segments and image | |
| state.detected_segments = segments | |
| state.segment_image = segment_img | |
| # Create segment choices with more detailed information | |
| segment_choices = [] | |
| for i, segment in enumerate(segments): | |
| # Extract mask dimensions for verification | |
| mask_points = segment.get('mask', []) | |
| if mask_points: | |
| mask_x = [p[0] for p in mask_points] | |
| mask_y = [p[1] for p in mask_points] | |
| mask_width = max(mask_x) - min(mask_x) | |
| mask_height = max(mask_y) - min(mask_y) | |
| print(f"Segment {i+1} mask dimensions: {mask_width}x{mask_height}") | |
| choice_text = f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" | |
| segment_choices.append(choice_text) | |
| state.selected_segments = segment_choices # Select all segments by default | |
| # Use the first segment's mask as protection area | |
| segment = segments[0] | |
| if 'mask' in segment and segment['mask']: | |
| mask_points = segment['mask'] | |
| # Convert to list of [x,y] points and ensure integer values | |
| mask_points = [[int(float(x)), int(float(y))] for x, y in mask_points] | |
| if len(mask_points) >= 3: # Need at least 3 points for a valid polygon | |
| state.cached_protection_area = mask_points | |
| # Convert segment image to RGB for display without resizing | |
| if segment_img is not None: | |
| display_img = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) | |
| return True, "Protection area extracted successfully", display_img | |
| return False, "Invalid mask points in segment", None | |
| return False, "No valid rail segments detected", None | |
| except Exception as e: | |
| print(f"Error in extract_protection_area: {str(e)}") | |
| return False, f"Error extracting protection area: {str(e)}", None | |
| def get_segment_index(choice_text): | |
| """Extract segment index from choice text""" | |
| try: | |
| # Extract index from "Segment X (Confidence: Y)" format | |
| return int(choice_text.split()[1]) - 1 | |
| except: | |
| return -1 | |
| def update_object_tracking(objects_in_area): | |
| """Update object tracking with new detections""" | |
| try: | |
| current_tracked = set() # Keep track of objects seen in this frame | |
| current_warnings = [] # Collect warnings for this frame | |
| # Match new detections with existing tracked objects | |
| for obj in objects_in_area: | |
| try: | |
| if 'bbox' not in obj or 'class' not in obj: | |
| continue | |
| bbox = obj['bbox'] | |
| obj_class = obj['class'] | |
| is_in_red_zone = obj.get('in_protection_area', False) | |
| matched = False | |
| best_match_id = None | |
| best_similarity = 0.0 | |
| # Try to match with existing tracked objects using similarity method | |
| for obj_id, tracked in state.tracked_objects.items(): | |
| if tracked.class_name == obj_class: | |
| # Check if object was seen recently (within time window) | |
| if state.frame_count - tracked.last_seen <= state.time_window: | |
| similarity = tracked.get_similarity_with(bbox) | |
| # Use the best match above threshold | |
| if similarity > state.similarity_threshold and similarity > best_similarity: | |
| best_similarity = similarity | |
| best_match_id = obj_id | |
| # If good match found, update existing object | |
| if best_match_id is not None: | |
| tracked = state.tracked_objects[best_match_id] | |
| tracked.add_detection(bbox) | |
| tracked.last_seen = state.frame_count | |
| current_tracked.add(best_match_id) | |
| matched = True | |
| # Check red zone status and warnings | |
| warning_triggered = tracked.update_red_zone_status(is_in_red_zone, state.frame_count) | |
| if warning_triggered: | |
| warning_msg = f"β οΈ WARNING: {tracked.class_name} (ID: {tracked.id}) has been in red zone for {tracked.frames_in_red_zone} frames!" | |
| current_warnings.append(warning_msg) | |
| state.red_zone_warnings.append({ | |
| 'frame': state.frame_count, | |
| 'object_id': tracked.id, | |
| 'class': tracked.class_name, | |
| 'frames_in_zone': tracked.frames_in_red_zone, | |
| 'message': warning_msg | |
| }) | |
| # Check if object should be counted (only count objects that actually move through the zone) | |
| if not tracked.counted and tracked.has_movement() and is_in_red_zone: | |
| # Additional check: object should have been tracked for at least a few frames | |
| if len(tracked.trajectory) >= 3: | |
| tracked.counted = True | |
| state.red_zone_passed_objects[obj_class] += 1 | |
| # If no match found, create new tracked object | |
| if not matched: | |
| new_obj = TrackedObject(state.next_obj_id, obj_class, bbox) | |
| new_obj.last_seen = state.frame_count | |
| new_obj.first_seen = state.frame_count | |
| state.tracked_objects[state.next_obj_id] = new_obj | |
| current_tracked.add(state.next_obj_id) | |
| state.next_obj_id += 1 | |
| # Check red zone status for new object | |
| new_obj.update_red_zone_status(is_in_red_zone, state.frame_count) | |
| except Exception as e: | |
| continue | |
| # Update objects not seen in current frame | |
| for obj_id, tracked in state.tracked_objects.items(): | |
| if obj_id not in current_tracked: | |
| # Object not seen in current frame, update red zone status | |
| tracked.update_red_zone_status(False, state.frame_count) | |
| # Remove objects that haven't been seen for a while | |
| if state.frame_count > state.time_window: | |
| to_remove = [] | |
| for obj_id, tracked in state.tracked_objects.items(): | |
| if state.frame_count - tracked.last_seen > state.time_window * 2: # Remove after 2x time window | |
| to_remove.append(obj_id) | |
| for obj_id in to_remove: | |
| del state.tracked_objects[obj_id] | |
| # Store current warnings | |
| if current_warnings: | |
| print(f"Frame {state.frame_count} Warnings: {current_warnings}") | |
| except Exception as e: | |
| print(f"Error in update_object_tracking: {str(e)}") | |
| def get_red_zone_summary(): | |
| """Generate summary of objects that passed through red zone""" | |
| summary = [] | |
| if state.red_zone_passed_objects: | |
| summary.append("π΄ RED ZONE PASSAGE SUMMARY:") | |
| total_objects = sum(state.red_zone_passed_objects.values()) | |
| summary.append(f"Total objects passed: {total_objects}") | |
| for obj_class, count in sorted(state.red_zone_passed_objects.items()): | |
| summary.append(f" β’ {obj_class}: {count}") | |
| # Add current objects in red zone | |
| current_in_zone = [] | |
| for obj_id, tracked in state.tracked_objects.items(): | |
| if tracked.frames_in_red_zone > 0: | |
| current_in_zone.append(f"{tracked.class_name} (ID: {tracked.id}, {tracked.frames_in_red_zone} frames)") | |
| if current_in_zone: | |
| summary.append("\nπ¨ CURRENTLY IN RED ZONE:") | |
| for obj_info in current_in_zone: | |
| summary.append(f" β’ {obj_info}") | |
| # Add recent warnings | |
| recent_warnings = [w for w in state.red_zone_warnings if state.frame_count - w['frame'] <= 5] | |
| if recent_warnings: | |
| summary.append("\nβ οΈ RECENT WARNINGS:") | |
| for warning in recent_warnings[-3:]: # Show last 3 warnings | |
| summary.append(f" β’ Frame {warning['frame']}: {warning['message']}") | |
| return "\n".join(summary) if summary else "No objects detected in red zone yet." | |
| def process_frame(frame, confidence): | |
| """Process a video frame using cached protection area""" | |
| try: | |
| protection_area = [] | |
| if state.selected_segments and state.detected_segments: | |
| for choice in state.selected_segments: | |
| idx = get_segment_index(choice) | |
| if 0 <= idx < len(state.detected_segments): | |
| segment = state.detected_segments[idx] | |
| if 'mask' in segment and segment['mask']: | |
| protection_area = segment['mask'] | |
| break | |
| elif len(state.protection_points) >= 3: | |
| protection_area = state.protection_points | |
| if not protection_area: | |
| return None, "Protection area not set. Please extract protection area first." | |
| # Ensure frame is valid | |
| if frame is None or frame.size == 0: | |
| return None, "Invalid frame" | |
| success, buffer = cv2.imencode('.png', frame) | |
| if not success: | |
| return None, "Failed to encode frame" | |
| files = { | |
| "file": ("frame.png", buffer.tobytes(), "image/png") | |
| } | |
| protection_area_json = json.dumps(protection_area) | |
| data = { | |
| "protection_area": protection_area_json, | |
| "confidence_threshold": str(confidence) | |
| } | |
| if state.previous_detections: | |
| data["previous_detections"] = json.dumps(state.previous_detections) | |
| try: | |
| response = requests.post( | |
| f"{API_URL}/detect/objects-and-redlight", | |
| files=files, | |
| data=data, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if not result.get("success"): | |
| return None, f"API Error: {result.get('detail', 'Unknown error')}" | |
| result_data = result.get("result", {}) | |
| if not result_data: | |
| return None, "No result data received" | |
| red_light_info = result_data.get("red_light", {}) | |
| red_light_detected = red_light_info.get("detected", False) | |
| red_light_prob = red_light_info.get("probability", 0) | |
| img_base64 = result_data.get("image_base64") | |
| if not img_base64: | |
| return None, "No image data received from API" | |
| try: | |
| if ',' in img_base64: | |
| img_base64 = img_base64.split(',')[1] | |
| img_data = base64.b64decode(img_base64) | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| processed_img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if processed_img is None or processed_img.size == 0: | |
| return None, "Failed to decode image from API response" | |
| objects_in_area = [obj for obj in result_data.get("objects", []) | |
| if obj.get("in_protection_area", False) and | |
| 'bbox' in obj and 'class' in obj] | |
| # Update object tracking | |
| state.frame_count += 1 | |
| update_object_tracking(objects_in_area) | |
| # Cache detections for next frame | |
| state.previous_detections = objects_in_area | |
| processed_img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB) | |
| status = [] | |
| status.append(f"Red Light: {'YES' if red_light_detected else 'NO'} ({red_light_prob:.2f})") | |
| # Add enhanced red zone summary | |
| red_zone_summary = get_red_zone_summary() | |
| status.append(f"\n{red_zone_summary}") | |
| if objects_in_area: | |
| status.append("\nπ CURRENT FRAME DETECTIONS:") | |
| for obj in objects_in_area: | |
| status.append(f" β’ {obj['class']} (confidence: {obj['confidence']:.2f})") | |
| # Add tracking statistics | |
| active_objects = len([obj for obj in state.tracked_objects.values() | |
| if state.frame_count - obj.last_seen <= 3]) | |
| status.append(f"\nπ TRACKING STATS:") | |
| status.append(f" β’ Active tracked objects: {active_objects}") | |
| status.append(f" β’ Frame: {state.frame_count}") | |
| status.append(f" β’ Time window: {state.time_window} frames") | |
| status.append(f" β’ Similarity threshold: {state.similarity_threshold:.2f}") | |
| return processed_img_rgb, "\n".join(status) | |
| except Exception as e: | |
| return None, f"Error processing detection results: {str(e)}" | |
| else: | |
| error_detail = f"API Error: {response.status_code}" | |
| try: | |
| error_json = response.json() | |
| if 'detail' in error_json: | |
| error_detail += f" - {error_json['detail']}" | |
| except: | |
| error_detail += f" - {response.text}" | |
| return None, error_detail | |
| except requests.exceptions.Timeout: | |
| return None, "API request timed out" | |
| except requests.exceptions.ConnectionError: | |
| return None, "Could not connect to API server" | |
| except Exception as e: | |
| return None, f"API request failed: {str(e)}" | |
| except Exception as e: | |
| return None, f"Error processing frame: {str(e)}" | |
| def process_video(video, confidence=DEFAULT_CONFIDENCE, target_fps=1): | |
| """Stream processed frames in real-time using cached protection area""" | |
| detection_results = [] | |
| cap = cv2.VideoCapture(video) | |
| if not cap.isOpened(): | |
| yield None, "Error: Could not open video file" | |
| return | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = max(1, int(fps / target_fps)) | |
| frame_number = 0 | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_number += 1 | |
| if frame_number % frame_interval != 0: | |
| continue | |
| # Process frame and get results | |
| processed_frame, result = process_frame(frame, confidence) | |
| if processed_frame is not None: | |
| # Frame is already in RGB format from process_frame | |
| current_status = f"Processing frame {frame_number}/{total_frames}\n{result}" | |
| yield processed_frame, current_status | |
| else: | |
| current_status = f"Frame {frame_number}: {result}" | |
| yield None, current_status | |
| # Release resources | |
| cap.release() | |
| # Generate final summary | |
| final_summary = generate_final_summary() | |
| yield None, final_summary | |
| except Exception as e: | |
| yield None, f"Error processing video: {str(e)}" | |
| finally: | |
| cap.release() | |
| def generate_final_summary(): | |
| """Generate comprehensive final summary of video processing""" | |
| summary_lines = [] | |
| summary_lines.append("π¬ VIDEO PROCESSING COMPLETE") | |
| summary_lines.append("=" * 50) | |
| # Processing statistics | |
| summary_lines.append(f"π PROCESSING STATISTICS:") | |
| summary_lines.append(f" β’ Total frames processed: {state.frame_count}") | |
| summary_lines.append(f" β’ Time window used: {state.time_window} frames") | |
| summary_lines.append(f" β’ Similarity threshold: {state.similarity_threshold:.2f}") | |
| # Red zone passage summary | |
| if state.red_zone_passed_objects: | |
| summary_lines.append(f"\nπ΄ RED ZONE PASSAGE SUMMARY:") | |
| total_passed = sum(state.red_zone_passed_objects.values()) | |
| summary_lines.append(f" β’ Total objects passed through red zone: {total_passed}") | |
| for obj_class, count in sorted(state.red_zone_passed_objects.items()): | |
| summary_lines.append(f" - {obj_class}: {count}") | |
| else: | |
| summary_lines.append(f"\nπ΄ RED ZONE PASSAGE SUMMARY:") | |
| summary_lines.append(f" β’ No objects detected passing through red zone") | |
| # Warning summary | |
| if state.red_zone_warnings: | |
| summary_lines.append(f"\nβ οΈ WARNING SUMMARY:") | |
| summary_lines.append(f" β’ Total warnings generated: {len(state.red_zone_warnings)}") | |
| # Group warnings by object class | |
| warning_by_class = defaultdict(int) | |
| for warning in state.red_zone_warnings: | |
| warning_by_class[warning['class']] += 1 | |
| for obj_class, count in sorted(warning_by_class.items()): | |
| summary_lines.append(f" - {obj_class}: {count} warnings") | |
| # Show last few warnings | |
| if len(state.red_zone_warnings) > 0: | |
| summary_lines.append(f"\n π Recent warnings:") | |
| for warning in state.red_zone_warnings[-5:]: # Last 5 warnings | |
| summary_lines.append(f" - Frame {warning['frame']}: {warning['class']} (ID: {warning['object_id']}) - {warning['frames_in_zone']} frames in zone") | |
| else: | |
| summary_lines.append(f"\nβ οΈ WARNING SUMMARY:") | |
| summary_lines.append(f" β’ No warnings generated (no objects stayed in red zone > 3 frames)") | |
| # Active tracking summary | |
| total_tracked = len(state.tracked_objects) | |
| if total_tracked > 0: | |
| summary_lines.append(f"\nπ OBJECT TRACKING SUMMARY:") | |
| summary_lines.append(f" β’ Total unique objects tracked: {total_tracked}") | |
| # Group by class | |
| objects_by_class = defaultdict(int) | |
| for obj in state.tracked_objects.values(): | |
| objects_by_class[obj.class_name] += 1 | |
| for obj_class, count in sorted(objects_by_class.items()): | |
| summary_lines.append(f" - {obj_class}: {count}") | |
| summary_lines.append("\nβ Processing completed successfully!") | |
| return "\n".join(summary_lines) | |
| def extract_area_from_video(video): | |
| if video is None: | |
| return None, "Please upload a video", gr.update(choices=[], value=[], visible=False) | |
| cap = cv2.VideoCapture(video) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return None, "Could not read video frame", gr.update(choices=[], value=[], visible=False) | |
| success, message, segment_img = extract_protection_area(frame) | |
| if success and segment_img is not None: | |
| # Convert segment image to RGB for display | |
| segment_img_rgb = cv2.cvtColor(segment_img, cv2.COLOR_BGR2RGB) | |
| # Create segment choices | |
| segment_choices = [f"Segment {i+1} (Confidence: {segment['confidence']:.2f})" | |
| for i, segment in enumerate(state.detected_segments)] | |
| return segment_img_rgb, message, gr.update(choices=segment_choices, value=segment_choices, visible=True) | |
| return None, message, gr.update(choices=[], value=[], visible=False) | |
| def update_selected_segments(selected): | |
| if selected is None: | |
| selected = [] | |
| state.selected_segments = selected | |
| return gr.update() | |
| def process_video_wrapper(video, confidence=DEFAULT_CONFIDENCE, target_fps=1, time_window=10, similarity_threshold=0.6): | |
| """Wrapper around process_video to handle full-size video processing""" | |
| if video is None: | |
| yield None, "Please upload a video" | |
| return | |
| # Reset tracking state and update parameters | |
| state.reset_tracking() | |
| state.time_window = time_window | |
| state.similarity_threshold = similarity_threshold | |
| protection_area = [] | |
| if state.selected_segments and state.detected_segments: | |
| for choice in state.selected_segments: | |
| idx = get_segment_index(choice) | |
| if 0 <= idx < len(state.detected_segments): | |
| segment = state.detected_segments[idx] | |
| if 'mask' in segment and segment['mask']: | |
| protection_area = segment['mask'] | |
| break | |
| elif len(state.protection_points) >= 3: | |
| protection_area = state.protection_points | |
| if not protection_area: | |
| yield None, "Please extract protection area first" | |
| return | |
| try: | |
| yield None, f"π Starting video processing...\nβοΈ Time window: {time_window} frames\nβοΈ Similarity threshold: {similarity_threshold:.2f}" | |
| for frame, status in process_video(video, confidence, target_fps): | |
| yield frame, status | |
| except Exception as e: | |
| yield None, f"Error processing video: {str(e)}" | |
| # Update the Gradio interface | |
| with gr.Blocks(title="Enhanced Rail Traffic Monitor") as demo: | |
| gr.Markdown(""" | |
| # Enhanced Rail Traffic Monitoring System | |
| ## Features: | |
| - **Smart Object Tracking**: Uses similarity method to track objects across frames | |
| - **Red Zone Monitoring**: Counts objects passing through the red zone | |
| - **Warning System**: Alerts when objects stay in red zone for more than 3 frames | |
| - **Configurable Parameters**: Adjust time window and similarity threshold | |
| ## Setup Instructions: | |
| **Method 1 (Manual Protection Area):** | |
| 1. Click 4 points on the image to define protection area | |
| 2. Click "Reset Points" to start over | |
| **Method 2 (Automatic Detection):** | |
| 1. Click "Extract Protection Area" to automatically detect rail segments | |
| **Processing:** | |
| 3. Adjust detection confidence, processing frame rate, time window, and similarity threshold | |
| 4. Click "Process Video" to analyze | |
| The system will show real-time results including: | |
| - Objects currently in red zone | |
| - Total count of objects that passed through | |
| - Warnings for objects staying too long in red zone | |
| - Tracking statistics | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video( | |
| label="Input Video" | |
| ) | |
| with gr.Row(): | |
| confidence = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_CONFIDENCE, | |
| label="Detection Confidence Threshold", | |
| info="Minimum confidence for object detection" | |
| ) | |
| fps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=1, | |
| step=1, | |
| label="Processing Frame Rate (FPS)", | |
| info="Frames per second to process" | |
| ) | |
| with gr.Row(): | |
| time_window_slider = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| label="Time Window (frames)", | |
| info="Number of frames to consider for object similarity" | |
| ) | |
| similarity_threshold_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.6, | |
| step=0.05, | |
| label="Similarity Threshold", | |
| info="Threshold for considering objects as the same (higher = stricter)" | |
| ) | |
| with gr.Column(): | |
| preview_image = gr.Image( | |
| label="Click to Select Protection Area (Original Size)", | |
| interactive=True, | |
| show_label=True | |
| ) | |
| # Add segment selection dropdown | |
| segment_dropdown = gr.Dropdown( | |
| label="Selected Segments", | |
| choices=[], | |
| multiselect=True, | |
| interactive=True, | |
| visible=False, | |
| value=[] | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button("Reset Points") | |
| extract_btn = gr.Button("Extract Protection Area") | |
| process_btn = gr.Button("π Process Video") | |
| with gr.Row(): | |
| video_output = gr.Image( | |
| label="Live Processing Output", | |
| streaming=True, | |
| interactive=False, | |
| show_label=True, | |
| container=True, | |
| show_download_button=True | |
| ) | |
| text_output = gr.Textbox( | |
| label="Detection Results & Red Zone Summary", | |
| lines=15, | |
| max_lines=20, | |
| show_copy_button=True | |
| ) | |
| # Handle video upload to populate preview | |
| video_input.change( | |
| fn=update_preview, | |
| inputs=[video_input], | |
| outputs=[preview_image, segment_dropdown] | |
| ) | |
| extract_btn.click( | |
| fn=extract_area_from_video, | |
| inputs=[video_input], | |
| outputs=[preview_image, text_output, segment_dropdown] | |
| ) | |
| segment_dropdown.change( | |
| fn=update_selected_segments, | |
| inputs=[segment_dropdown], | |
| outputs=[segment_dropdown] | |
| ) | |
| process_btn.click( | |
| fn=process_video_wrapper, | |
| inputs=[video_input, confidence, fps_slider, time_window_slider, similarity_threshold_slider], | |
| outputs=[video_output, text_output] | |
| ) | |
| # Add click event handler | |
| preview_image.select( | |
| fn=handle_image_click, | |
| inputs=[preview_image], | |
| outputs=[preview_image, text_output] | |
| ) | |
| # Add reset button handler | |
| reset_btn.click( | |
| fn=reset_points, | |
| inputs=[], | |
| outputs=[preview_image, text_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |