Spaces:
Sleeping
Sleeping
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # model_handler.py - Model Loading, Inference, and Tracking | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from ultralytics import YOLO | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| from datetime import timedelta | |
| from collections import defaultdict | |
| import pandas as pd | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CONFIDENCE_THRESHOLD = 0.5 | |
| VIDEO_FPS = 30 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL LOADER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModelLoader: | |
| """Handle model loading with fallback options""" | |
| def load_model(): | |
| """Try to load model with fallback options""" | |
| print("π Loading pothole detection model...") | |
| model = None | |
| model_path = None | |
| # Try custom model first | |
| if Path("best.pt").exists(): | |
| try: | |
| print(" Attempting to load custom model: best.pt") | |
| model = YOLO("best.pt") | |
| model_path = "best.pt" | |
| print("β Custom model loaded successfully!") | |
| return model, model_path | |
| except Exception as e: | |
| print(f" β οΈ Failed to load best.pt: {e}") | |
| # Fallback to official YOLOv11 | |
| try: | |
| print(" Downloading official YOLOv11n-seg model...") | |
| model = YOLO("yolov11n-seg.pt") | |
| model_path = "yolov11n-seg.pt" | |
| print("β Official YOLOv11n-seg model loaded!") | |
| return model, model_path | |
| except Exception as e: | |
| print(f" β οΈ Failed to load YOLOv11: {e}") | |
| # Last resort: YOLOv8 | |
| try: | |
| print(" Downloading official YOLOv8n-seg model...") | |
| model = YOLO("yolov8n-seg.pt") | |
| model_path = "yolov8n-seg.pt" | |
| print("β Official YOLOv8n-seg model loaded!") | |
| return model, model_path | |
| except Exception as e: | |
| raise RuntimeError(f"β Could not load any model: {e}") | |
| if model is None: | |
| raise RuntimeError("β No model could be loaded!") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # POTHOLE TRACKER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PotholeTracker: | |
| """Track potholes across video frames""" | |
| def __init__(self, max_distance=100): | |
| self.tracked_potholes = {} | |
| self.next_id = 1 | |
| self.max_distance = max_distance | |
| self.pothole_history = defaultdict(list) | |
| def calculate_distance(self, centroid1, centroid2): | |
| """Calculate Euclidean distance between two centroids""" | |
| return np.sqrt((centroid1[0] - centroid2[0])**2 + (centroid1[1] - centroid2[1])**2) | |
| def update(self, detections, frame_num, timestamp): | |
| """Update tracker with new detections""" | |
| if not detections: | |
| return [] | |
| # If no tracked potholes yet, assign new IDs | |
| if not self.tracked_potholes: | |
| for det in detections: | |
| det['track_id'] = self.next_id | |
| self.tracked_potholes[self.next_id] = det['centroid'] | |
| self.pothole_history[self.next_id].append({ | |
| 'frame': frame_num, | |
| 'timestamp': timestamp, | |
| 'measurements': det | |
| }) | |
| self.next_id += 1 | |
| return detections | |
| # Match detections to tracked potholes | |
| current_centroids = [det['centroid'] for det in detections] | |
| tracked_ids = list(self.tracked_potholes.keys()) | |
| tracked_centroids = [self.tracked_potholes[tid] for tid in tracked_ids] | |
| unmatched_detections = list(range(len(detections))) | |
| unmatched_tracks = list(range(len(tracked_ids))) | |
| # Simple nearest neighbor matching | |
| for det_idx in range(len(detections)): | |
| min_dist = float('inf') | |
| min_track_idx = -1 | |
| for track_idx in unmatched_tracks: | |
| dist = self.calculate_distance( | |
| current_centroids[det_idx], | |
| tracked_centroids[track_idx] | |
| ) | |
| if dist < min_dist and dist < self.max_distance: | |
| min_dist = dist | |
| min_track_idx = track_idx | |
| if min_track_idx != -1: | |
| # Match found | |
| track_id = tracked_ids[min_track_idx] | |
| detections[det_idx]['track_id'] = track_id | |
| self.tracked_potholes[track_id] = current_centroids[det_idx] | |
| self.pothole_history[track_id].append({ | |
| 'frame': frame_num, | |
| 'timestamp': timestamp, | |
| 'measurements': detections[det_idx] | |
| }) | |
| unmatched_detections.remove(det_idx) | |
| unmatched_tracks.remove(min_track_idx) | |
| # Assign new IDs to unmatched detections | |
| for det_idx in unmatched_detections: | |
| detections[det_idx]['track_id'] = self.next_id | |
| self.tracked_potholes[self.next_id] = current_centroids[det_idx] | |
| self.pothole_history[self.next_id].append({ | |
| 'frame': frame_num, | |
| 'timestamp': timestamp, | |
| 'measurements': detections[det_idx] | |
| }) | |
| self.next_id += 1 | |
| return detections | |
| def get_statistics(self): | |
| """Get comprehensive statistics for all tracked potholes""" | |
| stats = { | |
| 'total_potholes': len(self.pothole_history), | |
| 'potholes': [] | |
| } | |
| for track_id, history in self.pothole_history.items(): | |
| # Get max values across all frames for this pothole | |
| max_depth = max(h['measurements']['max_depth_cm'] for h in history) | |
| max_area = max(h['measurements']['area_m2'] for h in history) | |
| max_volume = max(h['measurements']['volume_liters'] for h in history) | |
| # Average measurements | |
| avg_depth = np.mean([h['measurements']['max_depth_cm'] for h in history]) | |
| avg_area = np.mean([h['measurements']['area_m2'] for h in history]) | |
| # First and last appearance | |
| first_frame = history[0]['frame'] | |
| last_frame = history[-1]['frame'] | |
| first_timestamp = history[0]['timestamp'] | |
| last_timestamp = history[-1]['timestamp'] | |
| # Most severe classification | |
| severities = [h['measurements']['severity'] for h in history] | |
| severity_order = {'LOW': 0, 'MEDIUM': 1, 'HIGH': 2, 'CRITICAL': 3} | |
| max_severity = max(severities, key=lambda s: severity_order.get(s, 0)) | |
| stats['potholes'].append({ | |
| 'track_id': track_id, | |
| 'frames_detected': len(history), | |
| 'first_frame': first_frame, | |
| 'last_frame': last_frame, | |
| 'first_timestamp': first_timestamp, | |
| 'last_timestamp': last_timestamp, | |
| 'max_depth_cm': max_depth, | |
| 'avg_depth_cm': avg_depth, | |
| 'max_area_m2': max_area, | |
| 'avg_area_m2': avg_area, | |
| 'max_volume_liters': max_volume, | |
| 'severity': max_severity, | |
| 'history': history | |
| }) | |
| return stats | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INFERENCE HANDLER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class InferenceHandler: | |
| """Handle image and video inference""" | |
| def __init__(self, model, measurer): | |
| self.model = model | |
| self.measurer = measurer | |
| def detect_image(self, image, confidence_threshold=0.5): | |
| """Run detection on a single image""" | |
| # Convert PIL to numpy if needed | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| # Ensure RGB format | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| h, w = image.shape[:2] | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| tmp_path = tmp_file.name | |
| cv2.imwrite(tmp_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| try: | |
| # Run prediction | |
| results = self.model(tmp_path, conf=confidence_threshold, verbose=False)[0] | |
| # Check if any detections | |
| if results.boxes is None or len(results.boxes) == 0: | |
| return image, [] | |
| # Extract results | |
| boxes = results.boxes.xyxy.cpu().numpy() | |
| confidences = results.boxes.conf.cpu().numpy() | |
| masks = results.masks.data.cpu().numpy() if results.masks is not None else None | |
| # Create annotated image | |
| annotated_img = image.copy() | |
| all_measurements = [] | |
| # Process each detection | |
| for idx, (box, conf) in enumerate(zip(boxes, confidences)): | |
| x1, y1, x2, y2 = box.astype(int) | |
| # Draw bounding box | |
| cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (255, 0, 0), 3) | |
| # Process mask if available | |
| if masks is not None and idx < len(masks): | |
| mask = masks[idx] | |
| mask_resized = cv2.resize(mask, (w, h)) | |
| mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255 | |
| # Create colored overlay | |
| overlay = annotated_img.copy() | |
| overlay[mask_binary > 0] = [255, 50, 50] | |
| annotated_img = cv2.addWeighted(annotated_img, 0.6, overlay, 0.4, 0) | |
| # Draw contour | |
| contours, _ = cv2.findContours( | |
| mask_binary, | |
| cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| cv2.drawContours(annotated_img, contours, -1, (0, 255, 0), 2) | |
| # Calculate measurements | |
| measurements = self.measurer.calculate_measurements(mask_binary) | |
| if measurements: | |
| measurements['pothole_id'] = idx + 1 | |
| measurements['confidence'] = float(conf) | |
| all_measurements.append(measurements) | |
| # Add text annotation | |
| text = f"#{idx+1} {measurements['severity_color']} {measurements['severity']}" | |
| text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0] | |
| cv2.rectangle( | |
| annotated_img, | |
| (x1, y1 - text_size[1] - 10), | |
| (x1 + text_size[0] + 10, y1), | |
| (0, 0, 0), | |
| -1 | |
| ) | |
| cv2.putText( | |
| annotated_img, | |
| text, | |
| (x1 + 5, y1 - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (255, 255, 255), | |
| 2 | |
| ) | |
| return annotated_img, all_measurements | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| def detect_video(self, video_path, confidence_threshold=0.5, progress_callback=None): | |
| """Run detection on video""" | |
| if video_path is None: | |
| return None, None, None, None | |
| # Open video | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, None, None, None | |
| # Get video properties | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| if fps == 0: | |
| fps = VIDEO_FPS | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Create output video | |
| output_path = tempfile.mktemp(suffix='.mp4') | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # Initialize tracker | |
| tracker = PotholeTracker(max_distance=150) | |
| csv_data = [] | |
| frame_num = 0 | |
| if progress_callback: | |
| progress_callback(0, desc="Starting video processing...") | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Calculate timestamp | |
| timestamp = frame_num / fps | |
| timestamp_str = str(timedelta(seconds=int(timestamp))) | |
| # Save frame temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
| tmp_path = tmp_file.name | |
| cv2.imwrite(tmp_path, frame) | |
| try: | |
| # Run prediction | |
| results = self.model(tmp_path, conf=confidence_threshold, verbose=False)[0] | |
| detections = [] | |
| # Process detections | |
| if results.boxes is not None and len(results.boxes) > 0: | |
| boxes = results.boxes.xyxy.cpu().numpy() | |
| confidences = results.boxes.conf.cpu().numpy() | |
| masks = results.masks.data.cpu().numpy() if results.masks is not None else None | |
| for idx, (box, conf) in enumerate(zip(boxes, confidences)): | |
| if masks is not None and idx < len(masks): | |
| mask = masks[idx] | |
| mask_resized = cv2.resize(mask, (width, height)) | |
| mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255 | |
| measurements = self.measurer.calculate_measurements(mask_binary) | |
| if measurements: | |
| measurements['confidence'] = float(conf) | |
| detections.append(measurements) | |
| # Draw on frame | |
| overlay = frame.copy() | |
| overlay[mask_binary > 0] = [50, 50, 255] | |
| frame = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0) | |
| contours, _ = cv2.findContours( | |
| mask_binary, | |
| cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| cv2.drawContours(frame, contours, -1, (0, 255, 0), 2) | |
| # Update tracker | |
| tracked_detections = tracker.update(detections, frame_num, timestamp_str) | |
| # Annotate frame | |
| for det in tracked_detections: | |
| x, y, w, h = det['bbox'] | |
| cx, cy = det['centroid'] | |
| track_id = det['track_id'] | |
| cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2) | |
| cv2.circle(frame, (cx, cy), 5, (0, 255, 255), -1) | |
| label = f"ID:{track_id} {det['severity']}" | |
| text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] | |
| cv2.rectangle( | |
| frame, | |
| (x, y - text_size[1] - 10), | |
| (x + text_size[0] + 10, y), | |
| (0, 0, 0), | |
| -1 | |
| ) | |
| cv2.putText(frame, label, (x + 5, y - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| # Store CSV data | |
| csv_data.append({ | |
| 'Frame': frame_num, | |
| 'Timestamp': timestamp_str, | |
| 'Track_ID': track_id, | |
| 'Centroid_X': cx, | |
| 'Centroid_Y': cy, | |
| 'BBox_X': x, | |
| 'BBox_Y': y, | |
| 'BBox_Width': w, | |
| 'BBox_Height': h, | |
| 'Depth_cm': det['max_depth_cm'], | |
| 'Area_m2': det['area_m2'], | |
| 'Volume_L': det['volume_liters'], | |
| 'Severity': det['severity'], | |
| 'Confidence': det['confidence'] | |
| }) | |
| # Add frame info | |
| info_text = f"Frame: {frame_num}/{total_frames} | Time: {timestamp_str} | Potholes: {len(tracked_detections)}" | |
| cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1) | |
| out.write(frame) | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| frame_num += 1 | |
| # Update progress | |
| if frame_num % 10 == 0 and progress_callback: | |
| progress_callback(frame_num / total_frames, | |
| desc=f"Processing frame {frame_num}/{total_frames}") | |
| cap.release() | |
| out.release() | |
| # Get statistics | |
| stats = tracker.get_statistics() | |
| # Save CSV | |
| csv_path = tempfile.mktemp(suffix='.csv') | |
| if csv_data: | |
| df = pd.DataFrame(csv_data) | |
| df.to_csv(csv_path, index=False) | |
| else: | |
| csv_path = None | |
| if progress_callback: | |
| progress_callback(1.0, desc="Video processing complete!") | |
| return output_path, stats, total_frames, fps, csv_path |