import os import cv2 import numpy as np import matplotlib.pyplot as plt from shapely.geometry import Polygon, box as shapely_box import gradio as gr from PIL import Image import time import spaces @spaces.GPU # Utility functions def extract_class_0_coordinates(filename): class_0_coordinates = [] with open(filename, 'r') as file: for line in file: parts = line.strip().split() if len(parts) == 0: continue if parts[0] == '0': coordinates = [float(x) for x in parts[1:]] class_0_coordinates.extend(coordinates) return class_0_coordinates def read_yolo_boxes(file_path): boxes = [] with open(file_path, 'r') as f: for line in f: parts = line.strip().split() class_name = COCO_CLASSES[int(parts[0])] x, y, w, h = map(float, parts[1:5]) boxes.append((class_name, x, y, w, h)) return boxes def yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height): x1 = int((x_center - width / 2) * img_width) y1 = int((y_center - height / 2) * img_height) x2 = int((x_center + width / 2) * img_width) y2 = int((y_center + height / 2) * img_height) return x1, y1, x2, y2 def convert_segment_to_pixel(segment, img_width, img_height): return [(int(x * img_width), int(y * img_height)) for x, y in zip(segment[::2], segment[1::2])] def box_segment_relationship(yolo_box, segment, img_width, img_height, threshold): class_id, x_center, y_center, width, height = yolo_box x1, y1, x2, y2 = yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height) pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) segment_polygon = Polygon(pixel_segment) box_polygon = shapely_box(x1, y1, x2, y2) if box_polygon.intersects(segment_polygon): return "intersecting" elif box_polygon.distance(segment_polygon) <= threshold: return "obstructed" else: return "not touching" # COCO classes COCO_CLASSES = [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] # Detection functions def detect_rail(image): # Convert PIL image to numpy array image = np.array(image) # Check if the image is RGB (3 channels) if len(image.shape) == 3 and image.shape[2] == 3: # Convert RGB to BGR (OpenCV format) image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) else: # If not RGB, just use the image as is (assuming it's already in a format OpenCV can handle) image_bgr = image temp_image_path = "temp_image_rail.jpg" cv2.imwrite(temp_image_path, image_bgr) os.system(f"python segment/predict.py --source {temp_image_path} --img 640 --device cpu --weights models/segment/best-2.pt --name yolov9_c_640_detect --exist-ok --save-txt") label_path = 'runs/predict-seg/yolov9_c_640_detect/labels/temp_image_rail.txt' segment = extract_class_0_coordinates(label_path) fig, ax = plt.subplots(figsize=(12, 8)) ax.imshow(image) # Use the original image for display img_height, img_width = image.shape[:2] pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) ax.plot([x for x, _ in pixel_segment] + [pixel_segment[0][0]], [y for _, y in pixel_segment] + [pixel_segment[0][1]], 'g-', linewidth=2, label='Rail Zone') ax.legend() ax.axis('off') plt.tight_layout() os.remove(temp_image_path) os.remove(label_path) return fig, segment, "Rail detection completed. You can now upload a video for object detection." def create_sample_video(output_path, duration=10, fps=30, width=640, height=480): fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for _ in range(duration * fps): frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) out.write(frame) out.release() return output_path def process_video(video_path, rail_segment, frame_skip=15): if not os.path.exists(video_path): return None, f"Error: Video file not found at {video_path}" cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, "Error: Could not open video file." fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Create output directory if it doesn't exist output_dir = 'output_videos' os.makedirs(output_dir, exist_ok=True) # Generate a unique filename based on timestamp timestamp = int(time.time()) output_filename = f'processed_video_{timestamp}.mp4' output_path = os.path.join(output_dir, output_filename) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps // frame_skip, (width, height)) frame_count = 0 processed_count = 0 threshold = 10 # Set threshold (in pixels) for obstruction detection obstructed_frames = 0 all_detections = [] # First pass: Detect objects in all frames while True: ret, frame = cap.read() if not ret: break frame_count += 1 if frame_count % frame_skip != 0: continue processed_count += 1 # Save the frame as a temporary image temp_frame_path = f"temp_frame_{processed_count:04d}.jpg" cv2.imwrite(temp_frame_path, frame) # Run object detection on the frame os.system(f"python detect.py --source {temp_frame_path} --img 640 --device cpu --weights models/detect/yolov9-s-converted.pt --name yolov9_c_640_detect --exist-ok --save-txt") # Read detection results label_path = f'runs/detect/yolov9_c_640_detect/labels/temp_frame_{processed_count:04d}.txt' yolo_boxes = read_yolo_boxes(label_path) all_detections.append(yolo_boxes) os.remove(temp_frame_path) os.remove(label_path) print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})") cap.set(cv2.CAP_PROP_POS_FRAMES, 0) frame_count = 0 processed_count = 0 # Second pass: Check for obstructions and create output video while True: ret, frame = cap.read() if not ret: break frame_count += 1 if frame_count % frame_skip != 0: continue processed_count += 1 # Draw rail segment pixel_segment = convert_segment_to_pixel(rail_segment, width, height) cv2.polylines(frame, [np.array(pixel_segment)], True, (0, 255, 0), 2) # Check for obstructions and draw bounding boxes frame_obstructed = False for box in all_detections[processed_count - 1]: class_name, x, y, w, h = box relationship = box_segment_relationship((0, x, y, w, h), rail_segment, width, height, threshold) x1, y1, x2, y2 = yolo_to_pixel_coords(x, y, w, h, width, height) if relationship == "intersecting": color = (0, 0, 255) # Red for intersecting frame_obstructed = True elif relationship == "obstructed": color = (0, 255, 255) # Yellow for obstructed frame_obstructed = True else: color = (255, 0, 0) # Blue for not touching cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) cv2.putText(frame, f"{class_name} ({relationship})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) if frame_obstructed: obstructed_frames += 1 out.write(frame) print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})") cap.release() out.release() if processed_count == 0: return None, "Error: No frames were processed." obstruction_percentage = (obstructed_frames / processed_count) * 100 summary = f"Video processing completed. Processed {processed_count} out of {total_frames} frames.\n\n" summary += f"Obstruction Summary:\n" summary += f"Total processed frames: {processed_count}\n" summary += f"Frames with obstructions: {obstructed_frames}\n" summary += f"Percentage of frames with obstructions: {obstruction_percentage:.2f}%\n" summary += f"Output video saved as: {output_path}" return output_path, summary # Gradio interface class TwoStepDetection: def __init__(self): self.rail_segment = None def rail_detection(self, rail_input): if rail_input is None: return None, "Please upload an image for rail detection." rail_fig, self.rail_segment, message = detect_rail(rail_input) return rail_fig, message def object_detection(self, video_input, frame_skip=15): if self.rail_segment is None: return None, "Please complete rail detection first." if video_input is None: # Create a sample video if none is provided sample_video_path = "sample_video.mp4" create_sample_video(sample_video_path) video_input = sample_video_path video_output, processing_message = process_video(video_input, self.rail_segment, frame_skip) if video_output is None: return None, processing_message return video_output, processing_message # Create Gradio interface detector = TwoStepDetection() with gr.Blocks(title="Two-Step Train Obstruction Detection") as iface: gr.Markdown("# Two-Step Train Obstruction Detection") gr.Markdown("Step 1: Upload an image to detect the rail. Step 2: Upload a video to detect obstructions.") with gr.Tab("Step 1: Rail Detection"): rail_input = gr.Image(type="numpy", label="Upload image for rail detection") rail_output = gr.Plot(label="Rail Detection Result") rail_message = gr.Textbox(label="Rail Detection Message") rail_button = gr.Button("Detect Rail") with gr.Tab("Step 2: Object Detection"): video_input = gr.Video(label="Upload video for object detection") frame_skip = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Frame Skip Rate") video_output = gr.Video(label="Object Detection Result") object_message = gr.Textbox(label="Object Detection Results") object_button = gr.Button("Detect Objects") rail_button.click(detector.rail_detection, inputs=rail_input, outputs=[rail_output, rail_message]) object_button.click(detector.object_detection, inputs=[video_input, frame_skip], outputs=[video_output, object_message]) # Launch the Gradio app if __name__ == "__main__": iface.launch()