| import gradio as gr |
| from ultralytics import YOLO |
| import cv2 |
| from PIL import Image |
| import numpy as np |
| import tempfile |
| import os |
|
|
| |
| try: |
| model = YOLO("model.pt") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| model = None |
|
|
|
|
| def predict_image(image, conf_threshold): |
| try: |
| if image is None or model is None: |
| return None, "Model not loaded or invalid image." |
|
|
| results = model(image, imgsz=768, conf=conf_threshold) |
| result = results[0] |
|
|
| annotated_image = result.plot() |
| annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB) |
|
|
| boxes = result.boxes |
| class_names = result.names |
|
|
| if len(boxes) == 0: |
| detection_summary = "No civic issues detected in this image." |
| else: |
| detection_counts = {} |
| for box in boxes: |
| cls_id = int(box.cls.item() if hasattr(box.cls, "item") else box.cls[0]) |
| cls_name = class_names.get(cls_id, f"Class {cls_id}") |
| detection_counts[cls_name] = detection_counts.get(cls_name, 0) + 1 |
|
|
| summary_lines = ["**Detections:**"] |
| for cls_name, count in detection_counts.items(): |
| summary_lines.append(f"- {count} {cls_name}(s)") |
|
|
| detection_summary = "\n".join(summary_lines) |
|
|
| return Image.fromarray(annotated_image_rgb), detection_summary |
|
|
| except Exception as e: |
| import traceback |
| error_msg = f"ERROR during prediction: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| return None, error_msg |
|
|
|
|
| def predict_video(video_path, conf_threshold, progress=gr.Progress()): |
| try: |
| if video_path is None or model is None: |
| return None, "Model not loaded or no video provided." |
|
|
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return None, "Could not open video file." |
|
|
| |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25 |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
| |
| out_path = tempfile.mktemp(suffix=".mp4") |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) |
|
|
| all_detection_counts = {} |
| frame_idx = 0 |
|
|
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| |
| if total_frames > 0: |
| progress(frame_idx / total_frames, desc=f"Processing frame {frame_idx}/{total_frames}") |
|
|
| |
| results = model(frame, imgsz=768, conf=conf_threshold, verbose=False) |
| result = results[0] |
|
|
| |
| annotated_frame = result.plot() |
| out.write(annotated_frame) |
|
|
| |
| for box in result.boxes: |
| cls_id = int(box.cls.item() if hasattr(box.cls, "item") else box.cls[0]) |
| cls_name = result.names.get(cls_id, f"Class {cls_id}") |
| all_detection_counts[cls_name] = all_detection_counts.get(cls_name, 0) + 1 |
|
|
| frame_idx += 1 |
|
|
| cap.release() |
| out.release() |
|
|
| |
| final_path = tempfile.mktemp(suffix=".mp4") |
| os.system(f'ffmpeg -y -i "{out_path}" -vcodec libx264 -crf 23 -preset fast "{final_path}" -loglevel quiet') |
| if os.path.exists(final_path) and os.path.getsize(final_path) > 0: |
| os.remove(out_path) |
| out_path = final_path |
|
|
| |
| if not all_detection_counts: |
| summary = f"Processed {frame_idx} frames.\nNo civic issues detected in this video." |
| else: |
| summary_lines = [f"Processed {frame_idx} frames.\n\n**Total Detections Across All Frames:**"] |
| for cls_name, count in sorted(all_detection_counts.items(), key=lambda x: -x[1]): |
| summary_lines.append(f"- {count} {cls_name}(s)") |
| summary = "\n".join(summary_lines) |
|
|
| return out_path, summary |
|
|
| except Exception as e: |
| import traceback |
| error_msg = f"ERROR during video prediction: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| return None, error_msg |
|
|
|
|
| |
| with gr.Blocks(title="PotholeNet-YOLO11m-v1 π") as interface: |
| gr.Markdown("# π PotholeNet-YOLO11m-v1") |
| gr.Markdown( |
| "**Aamchi City AI Civic System** β Real-time pothole, road damage, and garbage detection for Indian urban roads." |
| ) |
| gr.Markdown( |
| "Upload an image **or video** of a road to detect infrastructure issues. " |
| "The model was trained on 23,000+ street-level images." |
| ) |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("πΌοΈ Image Detection"): |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil", label="Upload Street Image") |
| img_conf_slider = gr.Slider( |
| minimum=0.01, maximum=1.0, value=0.25, step=0.01, |
| label="Confidence Threshold" |
| ) |
| img_submit_btn = gr.Button("Detect Civic Issues", variant="primary") |
|
|
| with gr.Column(): |
| output_image = gr.Image(type="pil", label="Detection Results") |
| img_detection_text = gr.Textbox( |
| label="Detection Summary", interactive=False, lines=4 |
| ) |
|
|
| img_submit_btn.click( |
| fn=predict_image, |
| inputs=[input_image, img_conf_slider], |
| outputs=[output_image, img_detection_text], |
| ) |
|
|
| |
| with gr.TabItem("π¬ Video Detection"): |
| gr.Markdown( |
| "> β οΈ **Note:** Video processing is frame-by-frame and may take a while depending on length and hardware." |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| input_video = gr.Video( |
| label="Upload Street Video", |
| sources=["upload"], |
| format="mp4", |
| ) |
| vid_conf_slider = gr.Slider( |
| minimum=0.01, maximum=1.0, value=0.25, step=0.01, |
| label="Confidence Threshold" |
| ) |
| vid_submit_btn = gr.Button("Detect Civic Issues in Video", variant="primary") |
|
|
| with gr.Column(): |
| output_video = gr.Video(label="Annotated Video") |
| vid_detection_text = gr.Textbox( |
| label="Detection Summary", interactive=False, lines=6 |
| ) |
|
|
| vid_submit_btn.click( |
| fn=predict_video, |
| inputs=[input_video, vid_conf_slider], |
| outputs=[output_video, vid_detection_text], |
| ) |
|
|
| gr.Markdown("### Intended Use") |
| gr.Markdown( |
| "Real-time pothole detection, Automated civic issue reporting, Infrastructure health monitoring." |
| ) |
| gr.Markdown("**Developer:** Vansh Momaya") |
|
|
| if __name__ == "__main__": |
| interface.launch(server_name="0.0.0.0", server_port=7860) |