Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from shapely.geometry import Polygon, box as shapely_box | |
| import gradio as gr | |
| from PIL import Image | |
| import tempfile | |
| import spaces | |
| def dummy(): | |
| pass | |
| # Utility functions | |
| def extract_class_0_coordinates(filename): | |
| class_0_coordinates = [] | |
| with open(filename, 'r') as file: | |
| for line in file: | |
| parts = line.strip().split() | |
| if len(parts) == 0: | |
| continue | |
| if parts[0] == '0': | |
| coordinates = [float(x) for x in parts[1:]] | |
| class_0_coordinates.extend(coordinates) | |
| return class_0_coordinates | |
| def parse_yolo_box(box_string): | |
| values = list(map(float, box_string.split())) | |
| if len(values) < 5: | |
| raise ValueError(f"Expected at least 5 values, got {len(values)}") | |
| return values[0], values[1], values[2], values[3], values[4] | |
| def read_yolo_boxes(file_path): | |
| boxes = [] | |
| with open(file_path, 'r') as f: | |
| for line in f: | |
| parts = line.strip().split() | |
| class_name = COCO_CLASSES[int(parts[0])] | |
| x, y, w, h = map(float, parts[1:5]) | |
| boxes.append((class_name, x, y, w, h)) | |
| return boxes | |
| def yolo_to_pixel_coord(x, y, img_width, img_height): | |
| return int(x * img_width), int(y * img_height) | |
| def yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height): | |
| x1 = int((x_center - width / 2) * img_width) | |
| y1 = int((y_center - height / 2) * img_height) | |
| x2 = int((x_center + width / 2) * img_width) | |
| y2 = int((y_center + height / 2) * img_height) | |
| return x1, y1, x2, y2 | |
| def box_segment_relationship(yolo_box, segment, img_width, img_height, threshold): | |
| class_id, x_center, y_center, width, height = yolo_box | |
| x1, y1, x2, y2 = yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height) | |
| pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) | |
| segment_polygon = Polygon(zip(pixel_segment[::2], pixel_segment[1::2])) | |
| box_polygon = shapely_box(x1, y1, x2, y2) | |
| if box_polygon.intersects(segment_polygon): | |
| return "intersecting" | |
| elif box_polygon.distance(segment_polygon) <= threshold: | |
| return "obstructed" | |
| else: | |
| return "not touching" | |
| def convert_segment_to_pixel(segment, img_width, img_height): | |
| pixel_segment = [] | |
| for i in range(0, len(segment), 2): | |
| x, y = yolo_to_pixel_coord(segment[i], segment[i+1], img_width, img_height) | |
| pixel_segment.extend([x, y]) | |
| return pixel_segment | |
| def plot_boxes_and_segment(image, yolo_boxes, segment, img_width, img_height, threshold): | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| ax.imshow(image) | |
| pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) | |
| ax.plot(pixel_segment[::2] + [pixel_segment[0]], pixel_segment[1::2] + [pixel_segment[1]], 'g-', linewidth=2, label='Rail Zone') | |
| colors = {'intersecting': 'r', 'obstructed': 'y', 'not touching': 'b'} | |
| labels = {'intersecting': 'Intersecting Box', 'obstructed': 'Obstructed Box', 'not touching': 'Non-interacting Box'} | |
| for yolo_box in yolo_boxes: | |
| class_id, x_center, y_center, width, height = yolo_box | |
| x1, y1, x2, y2 = yolo_to_pixel_coords(x_center, y_center, width, height, img_width, img_height) | |
| relationship = box_segment_relationship(yolo_box, segment, img_width, img_height, threshold) | |
| color = colors[relationship] | |
| label = labels[relationship] | |
| ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor=color, linewidth=2, label=label)) | |
| ax.legend() | |
| ax.axis('off') | |
| plt.tight_layout() | |
| return fig | |
| # COCO classes | |
| COCO_CLASSES = [ | |
| 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |
| 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
| 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |
| 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |
| 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |
| 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |
| 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
| 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |
| 'hair drier', 'toothbrush' | |
| ] | |
| # Detection functions | |
| def detect_rail(image): | |
| # Convert PIL image to numpy array | |
| image = np.array(image) | |
| # Check if the image is RGB (3 channels) | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| # Convert RGB to BGR (OpenCV format) | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| else: | |
| # If not RGB, just use the image as is (assuming it's already in a format OpenCV can handle) | |
| image_bgr = image | |
| temp_image_path = "temp_image_rail.jpg" | |
| cv2.imwrite(temp_image_path, image_bgr) | |
| os.system(f"python segment/predict.py --source {temp_image_path} --img 640 --device cpu --weights models/segment/best-2.pt --name yolov9_c_640_detect --exist-ok --save-txt") | |
| label_path = 'runs/predict-seg/yolov9_c_640_detect/labels/temp_image_rail.txt' | |
| segment = extract_class_0_coordinates(label_path) | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| ax.imshow(image) # Use the original image for display | |
| img_height, img_width = image.shape[:2] | |
| pixel_segment = convert_segment_to_pixel(segment, img_width, img_height) | |
| ax.plot(pixel_segment[::2] + [pixel_segment[0]], pixel_segment[1::2] + [pixel_segment[1]], 'g-', linewidth=2, label='Rail Zone') | |
| ax.legend() | |
| ax.axis('off') | |
| plt.tight_layout() | |
| os.remove(temp_image_path) | |
| os.remove(label_path) | |
| return fig, segment, "Rail detection completed. You can now upload an image or video for object detection." | |
| def detect_objects(image, rail_segment): | |
| # Convert PIL image to numpy array | |
| image = np.array(image) | |
| # Check if the image is RGB (3 channels) | |
| if len(image.shape) == 3 and image.shape[2] == 3: | |
| # Convert RGB to BGR (OpenCV format) | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| else: | |
| # If not RGB, just use the image as is (assuming it's already in a format OpenCV can handle) | |
| image_bgr = image | |
| img_height, img_width = image.shape[:2] | |
| temp_image_path = "temp_image_objects.jpg" | |
| cv2.imwrite(temp_image_path, image_bgr) | |
| os.system(f"python detect.py --source {temp_image_path} --img 640 --device cpu --weights models/detect/yolov9-s-converted.pt --name yolov9_c_640_detect --exist-ok --save-txt") | |
| label_path = 'runs/detect/yolov9_c_640_detect/labels/temp_image_objects.txt' | |
| yolo_boxes = read_yolo_boxes(label_path) | |
| threshold = 10 # Set threshold (in pixels) | |
| fig = plot_boxes_and_segment(image, yolo_boxes, rail_segment, img_width, img_height, threshold) | |
| results = [] | |
| for class_name, x, y, w, h in yolo_boxes: | |
| result = box_segment_relationship((0, x, y, w, h), rail_segment, img_width, img_height, threshold) | |
| results.append(f"{class_name} at ({x:.2f}, {y:.2f}) is {result} the segment.") | |
| os.remove(temp_image_path) | |
| os.remove(label_path) | |
| return fig, "\n".join(results), yolo_boxes | |
| def process_video(video_path, rail_segment, frame_skip=15): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, "Error: Could not open video file." | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_output.name, fourcc, fps // frame_skip, (width, height)) | |
| frame_count = 0 | |
| processed_count = 0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| threshold = 10 # Set threshold (in pixels) for obstruction detection | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_count += 1 | |
| if frame_count % frame_skip != 0: | |
| continue | |
| processed_count += 1 | |
| # Convert frame to PIL Image for compatibility with detect_objects | |
| pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Detect objects in the frame | |
| _, _, yolo_boxes = detect_objects(pil_frame, rail_segment) | |
| # Draw rail segment | |
| pixel_segment = convert_segment_to_pixel(rail_segment, width, height) | |
| pts = np.array(list(zip(pixel_segment[::2], pixel_segment[1::2])), np.int32) | |
| pts = pts.reshape((-1, 1, 2)) | |
| cv2.polylines(frame, [pts], True, (0, 0, 255), 2) | |
| # Check for obstructions and draw bounding boxes | |
| for box in yolo_boxes: | |
| class_name, x, y, w, h = box | |
| relationship = box_segment_relationship((0, x, y, w, h), rail_segment, width, height, threshold) | |
| x1, y1, x2, y2 = yolo_to_pixel_coords(x, y, w, h, width, height) | |
| if relationship == "intersecting": | |
| color = (0, 0, 255) # Red for intersecting | |
| elif relationship == "obstructed": | |
| color = (0, 255, 255) # Yellow for obstructed | |
| else: | |
| color = (0, 255, 0) # Green for not touching | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText(frame, f"{class_name} ({relationship})", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
| out.write(frame) | |
| print(f"Processed frame {frame_count}/{total_frames} (Frame {processed_count})") | |
| cap.release() | |
| out.release() | |
| if processed_count == 0: | |
| return None, "Error: No frames were processed." | |
| return temp_output.name, f"Video processing completed. Processed {processed_count} out of {total_frames} frames." | |
| # Gradio interface | |
| class TwoStepDetection: | |
| def __init__(self): | |
| self.rail_segment = None | |
| def rail_detection(self, rail_input): | |
| if rail_input is None: | |
| return None, "Please upload an image for rail detection." | |
| rail_fig, self.rail_segment, message = detect_rail(rail_input) | |
| return rail_fig, message | |
| def object_detection(self, object_input, video_input, frame_skip=15): | |
| if self.rail_segment is None: | |
| return None, None, "Please complete rail detection first." | |
| if object_input is None and video_input is None: | |
| return None, None, "Please upload an image or video for object detection." | |
| if object_input is not None: # Image input | |
| object_fig, object_results, _ = detect_objects(object_input, self.rail_segment) | |
| return object_fig, None, object_results | |
| elif video_input is not None: # Video input | |
| video_output, processing_message = process_video(video_input, self.rail_segment, frame_skip) | |
| if video_output is None: | |
| return None, None, processing_message | |
| # Analyze the processed video for obstruction summary | |
| cap = cv2.VideoCapture(video_output) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| obstructed_frames = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert frame to PIL Image for compatibility with detect_objects | |
| pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Detect objects in the frame | |
| _, _, yolo_boxes = detect_objects(pil_frame, self.rail_segment) | |
| # Check for obstructions | |
| for box in yolo_boxes: | |
| _, x, y, w, h = box | |
| relationship = box_segment_relationship((0, x, y, w, h), self.rail_segment, frame.shape[1], frame.shape[0], 10) | |
| if relationship in ["intersecting", "obstructed"]: | |
| obstructed_frames += 1 | |
| break # Count the frame as obstructed if at least one object is obstructing | |
| cap.release() | |
| obstruction_percentage = (obstructed_frames / total_frames) * 100 | |
| summary = f"{processing_message}\n\nObstruction Summary:\n" | |
| summary += f"Total frames: {total_frames}\n" | |
| summary += f"Frames with obstructions: {obstructed_frames}\n" | |
| summary += f"Percentage of frames with obstructions: {obstruction_percentage:.2f}%" | |
| return None, video_output, summary | |
| # Create Gradio interface | |
| detector = TwoStepDetection() | |
| with gr.Blocks(title="Two-Step Train Obstruction Detection") as iface: | |
| gr.Markdown("# Two-Step Train Obstruction Detection") | |
| gr.Markdown("Step 1: Upload an image to detect the rail. Step 2: Upload an image or video with objects to detect obstructions.") | |
| with gr.Tab("Step 1: Rail Detection"): | |
| rail_input = gr.Image(type="numpy", label="Upload image for rail detection") | |
| rail_output = gr.Plot(label="Rail Detection Result") | |
| rail_message = gr.Textbox(label="Rail Detection Message") | |
| rail_button = gr.Button("Detect Rail") | |
| with gr.Tab("Step 2: Object Detection"): | |
| object_input = gr.Image(type="numpy", label="Upload image for object detection") | |
| video_input = gr.Video(label="Or upload video for object detection") | |
| frame_skip = gr.Slider(minimum=1, maximum=100, step=1, value=15, label="Frame Skip Rate (for video)") | |
| object_output = gr.Plot(label="Object Detection Result (Image)") | |
| video_output = gr.Video(label="Object Detection Result (Video)") | |
| object_message = gr.Textbox(label="Object Detection Results") | |
| object_button = gr.Button("Detect Objects") | |
| rail_button.click(detector.rail_detection, inputs=rail_input, outputs=[rail_output, rail_message]) | |
| object_button.click(detector.object_detection, inputs=[object_input, video_input, frame_skip], outputs=[object_output, video_output, object_message]) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| iface.launch() |