hardiksharma6555 commited on
Commit
4dd531f
·
verified ·
1 Parent(s): 1c2f1f1

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +472 -0
model_handler.py CHANGED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ═══════════════════════════════════════════════════════════════════
2
+ # model_handler.py - Model Loading, Inference, and Tracking
3
+ # ═══════════════════════════════════════════════════════════════════
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from ultralytics import YOLO
10
+ from pathlib import Path
11
+ import tempfile
12
+ import os
13
+ from datetime import timedelta
14
+ from collections import defaultdict
15
+ import pandas as pd
16
+
17
+ # ═══════════════════════════════════════════════════════════════════
18
+ # CONFIGURATION
19
+ # ═══════════════════════════════════════════════════════════════════
20
+
21
+ CONFIDENCE_THRESHOLD = 0.5
22
+ VIDEO_FPS = 30
23
+
24
+ # ═══════════════════════════════════════════════════════════════════
25
+ # MODEL LOADER
26
+ # ═══════════════════════════════════════════════════════════════════
27
+
28
+ class ModelLoader:
29
+ """Handle model loading with fallback options"""
30
+
31
+ @staticmethod
32
+ def load_model():
33
+ """Try to load model with fallback options"""
34
+ print("🔄 Loading pothole detection model...")
35
+
36
+ model = None
37
+ model_path = None
38
+
39
+ # Try custom model first
40
+ if Path("best.pt").exists():
41
+ try:
42
+ print(" Attempting to load custom model: best.pt")
43
+ model = YOLO("best.pt")
44
+ model_path = "best.pt"
45
+ print("✅ Custom model loaded successfully!")
46
+ return model, model_path
47
+ except Exception as e:
48
+ print(f" ⚠️ Failed to load best.pt: {e}")
49
+
50
+ # Fallback to official YOLOv11
51
+ try:
52
+ print(" Downloading official YOLOv11n-seg model...")
53
+ model = YOLO("yolov11n-seg.pt")
54
+ model_path = "yolov11n-seg.pt"
55
+ print("✅ Official YOLOv11n-seg model loaded!")
56
+ return model, model_path
57
+ except Exception as e:
58
+ print(f" ⚠️ Failed to load YOLOv11: {e}")
59
+
60
+ # Last resort: YOLOv8
61
+ try:
62
+ print(" Downloading official YOLOv8n-seg model...")
63
+ model = YOLO("yolov8n-seg.pt")
64
+ model_path = "yolov8n-seg.pt"
65
+ print("✅ Official YOLOv8n-seg model loaded!")
66
+ return model, model_path
67
+ except Exception as e:
68
+ raise RuntimeError(f"❌ Could not load any model: {e}")
69
+
70
+ if model is None:
71
+ raise RuntimeError("❌ No model could be loaded!")
72
+
73
+ # ═══════════════════════════════════════════════════════════════════
74
+ # POTHOLE TRACKER
75
+ # ═══════════════════════════════════════════════════════════════════
76
+
77
+ class PotholeTracker:
78
+ """Track potholes across video frames"""
79
+
80
+ def __init__(self, max_distance=100):
81
+ self.tracked_potholes = {}
82
+ self.next_id = 1
83
+ self.max_distance = max_distance
84
+ self.pothole_history = defaultdict(list)
85
+
86
+ def calculate_distance(self, centroid1, centroid2):
87
+ """Calculate Euclidean distance between two centroids"""
88
+ return np.sqrt((centroid1[0] - centroid2[0])**2 + (centroid1[1] - centroid2[1])**2)
89
+
90
+ def update(self, detections, frame_num, timestamp):
91
+ """Update tracker with new detections"""
92
+ if not detections:
93
+ return []
94
+
95
+ # If no tracked potholes yet, assign new IDs
96
+ if not self.tracked_potholes:
97
+ for det in detections:
98
+ det['track_id'] = self.next_id
99
+ self.tracked_potholes[self.next_id] = det['centroid']
100
+ self.pothole_history[self.next_id].append({
101
+ 'frame': frame_num,
102
+ 'timestamp': timestamp,
103
+ 'measurements': det
104
+ })
105
+ self.next_id += 1
106
+ return detections
107
+
108
+ # Match detections to tracked potholes
109
+ current_centroids = [det['centroid'] for det in detections]
110
+ tracked_ids = list(self.tracked_potholes.keys())
111
+ tracked_centroids = [self.tracked_potholes[tid] for tid in tracked_ids]
112
+
113
+ unmatched_detections = list(range(len(detections)))
114
+ unmatched_tracks = list(range(len(tracked_ids)))
115
+
116
+ # Simple nearest neighbor matching
117
+ for det_idx in range(len(detections)):
118
+ min_dist = float('inf')
119
+ min_track_idx = -1
120
+
121
+ for track_idx in unmatched_tracks:
122
+ dist = self.calculate_distance(
123
+ current_centroids[det_idx],
124
+ tracked_centroids[track_idx]
125
+ )
126
+
127
+ if dist < min_dist and dist < self.max_distance:
128
+ min_dist = dist
129
+ min_track_idx = track_idx
130
+
131
+ if min_track_idx != -1:
132
+ # Match found
133
+ track_id = tracked_ids[min_track_idx]
134
+ detections[det_idx]['track_id'] = track_id
135
+ self.tracked_potholes[track_id] = current_centroids[det_idx]
136
+ self.pothole_history[track_id].append({
137
+ 'frame': frame_num,
138
+ 'timestamp': timestamp,
139
+ 'measurements': detections[det_idx]
140
+ })
141
+ unmatched_detections.remove(det_idx)
142
+ unmatched_tracks.remove(min_track_idx)
143
+
144
+ # Assign new IDs to unmatched detections
145
+ for det_idx in unmatched_detections:
146
+ detections[det_idx]['track_id'] = self.next_id
147
+ self.tracked_potholes[self.next_id] = current_centroids[det_idx]
148
+ self.pothole_history[self.next_id].append({
149
+ 'frame': frame_num,
150
+ 'timestamp': timestamp,
151
+ 'measurements': detections[det_idx]
152
+ })
153
+ self.next_id += 1
154
+
155
+ return detections
156
+
157
+ def get_statistics(self):
158
+ """Get comprehensive statistics for all tracked potholes"""
159
+ stats = {
160
+ 'total_potholes': len(self.pothole_history),
161
+ 'potholes': []
162
+ }
163
+
164
+ for track_id, history in self.pothole_history.items():
165
+ # Get max values across all frames for this pothole
166
+ max_depth = max(h['measurements']['max_depth_cm'] for h in history)
167
+ max_area = max(h['measurements']['area_m2'] for h in history)
168
+ max_volume = max(h['measurements']['volume_liters'] for h in history)
169
+
170
+ # Average measurements
171
+ avg_depth = np.mean([h['measurements']['max_depth_cm'] for h in history])
172
+ avg_area = np.mean([h['measurements']['area_m2'] for h in history])
173
+
174
+ # First and last appearance
175
+ first_frame = history[0]['frame']
176
+ last_frame = history[-1]['frame']
177
+ first_timestamp = history[0]['timestamp']
178
+ last_timestamp = history[-1]['timestamp']
179
+
180
+ # Most severe classification
181
+ severities = [h['measurements']['severity'] for h in history]
182
+ severity_order = {'LOW': 0, 'MEDIUM': 1, 'HIGH': 2, 'CRITICAL': 3}
183
+ max_severity = max(severities, key=lambda s: severity_order.get(s, 0))
184
+
185
+ stats['potholes'].append({
186
+ 'track_id': track_id,
187
+ 'frames_detected': len(history),
188
+ 'first_frame': first_frame,
189
+ 'last_frame': last_frame,
190
+ 'first_timestamp': first_timestamp,
191
+ 'last_timestamp': last_timestamp,
192
+ 'max_depth_cm': max_depth,
193
+ 'avg_depth_cm': avg_depth,
194
+ 'max_area_m2': max_area,
195
+ 'avg_area_m2': avg_area,
196
+ 'max_volume_liters': max_volume,
197
+ 'severity': max_severity,
198
+ 'history': history
199
+ })
200
+
201
+ return stats
202
+
203
+ # ═══════════════════════════════════════════════════════════════════
204
+ # INFERENCE HANDLER
205
+ # ═══════════════════════════════════════════════════════════════════
206
+
207
+ class InferenceHandler:
208
+ """Handle image and video inference"""
209
+
210
+ def __init__(self, model, measurer):
211
+ self.model = model
212
+ self.measurer = measurer
213
+
214
+ def detect_image(self, image, confidence_threshold=0.5):
215
+ """Run detection on a single image"""
216
+ # Convert PIL to numpy if needed
217
+ if isinstance(image, Image.Image):
218
+ image = np.array(image)
219
+
220
+ # Ensure RGB format
221
+ if len(image.shape) == 2:
222
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
223
+ elif image.shape[2] == 4:
224
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
225
+
226
+ h, w = image.shape[:2]
227
+
228
+ # Save to temporary file
229
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
230
+ tmp_path = tmp_file.name
231
+ cv2.imwrite(tmp_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
232
+
233
+ try:
234
+ # Run prediction
235
+ results = self.model(tmp_path, conf=confidence_threshold, verbose=False)[0]
236
+
237
+ # Check if any detections
238
+ if results.boxes is None or len(results.boxes) == 0:
239
+ return image, []
240
+
241
+ # Extract results
242
+ boxes = results.boxes.xyxy.cpu().numpy()
243
+ confidences = results.boxes.conf.cpu().numpy()
244
+ masks = results.masks.data.cpu().numpy() if results.masks is not None else None
245
+
246
+ # Create annotated image
247
+ annotated_img = image.copy()
248
+ all_measurements = []
249
+
250
+ # Process each detection
251
+ for idx, (box, conf) in enumerate(zip(boxes, confidences)):
252
+ x1, y1, x2, y2 = box.astype(int)
253
+
254
+ # Draw bounding box
255
+ cv2.rectangle(annotated_img, (x1, y1), (x2, y2), (255, 0, 0), 3)
256
+
257
+ # Process mask if available
258
+ if masks is not None and idx < len(masks):
259
+ mask = masks[idx]
260
+ mask_resized = cv2.resize(mask, (w, h))
261
+ mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
262
+
263
+ # Create colored overlay
264
+ overlay = annotated_img.copy()
265
+ overlay[mask_binary > 0] = [255, 50, 50]
266
+ annotated_img = cv2.addWeighted(annotated_img, 0.6, overlay, 0.4, 0)
267
+
268
+ # Draw contour
269
+ contours, _ = cv2.findContours(
270
+ mask_binary,
271
+ cv2.RETR_EXTERNAL,
272
+ cv2.CHAIN_APPROX_SIMPLE
273
+ )
274
+ cv2.drawContours(annotated_img, contours, -1, (0, 255, 0), 2)
275
+
276
+ # Calculate measurements
277
+ measurements = self.measurer.calculate_measurements(mask_binary)
278
+
279
+ if measurements:
280
+ measurements['pothole_id'] = idx + 1
281
+ measurements['confidence'] = float(conf)
282
+ all_measurements.append(measurements)
283
+
284
+ # Add text annotation
285
+ text = f"#{idx+1} {measurements['severity_color']} {measurements['severity']}"
286
+ text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
287
+
288
+ cv2.rectangle(
289
+ annotated_img,
290
+ (x1, y1 - text_size[1] - 10),
291
+ (x1 + text_size[0] + 10, y1),
292
+ (0, 0, 0),
293
+ -1
294
+ )
295
+
296
+ cv2.putText(
297
+ annotated_img,
298
+ text,
299
+ (x1 + 5, y1 - 5),
300
+ cv2.FONT_HERSHEY_SIMPLEX,
301
+ 0.7,
302
+ (255, 255, 255),
303
+ 2
304
+ )
305
+
306
+ return annotated_img, all_measurements
307
+
308
+ finally:
309
+ if os.path.exists(tmp_path):
310
+ os.unlink(tmp_path)
311
+
312
+ def detect_video(self, video_path, confidence_threshold=0.5, progress_callback=None):
313
+ """Run detection on video"""
314
+ if video_path is None:
315
+ return None, None, None, None
316
+
317
+ # Open video
318
+ cap = cv2.VideoCapture(video_path)
319
+
320
+ if not cap.isOpened():
321
+ return None, None, None, None
322
+
323
+ # Get video properties
324
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
325
+ if fps == 0:
326
+ fps = VIDEO_FPS
327
+
328
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
329
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
330
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
331
+
332
+ # Create output video
333
+ output_path = tempfile.mktemp(suffix='.mp4')
334
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
335
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
336
+
337
+ # Initialize tracker
338
+ tracker = PotholeTracker(max_distance=150)
339
+ csv_data = []
340
+ frame_num = 0
341
+
342
+ if progress_callback:
343
+ progress_callback(0, desc="Starting video processing...")
344
+
345
+ while True:
346
+ ret, frame = cap.read()
347
+ if not ret:
348
+ break
349
+
350
+ # Calculate timestamp
351
+ timestamp = frame_num / fps
352
+ timestamp_str = str(timedelta(seconds=int(timestamp)))
353
+
354
+ # Save frame temporarily
355
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
356
+ tmp_path = tmp_file.name
357
+ cv2.imwrite(tmp_path, frame)
358
+
359
+ try:
360
+ # Run prediction
361
+ results = self.model(tmp_path, conf=confidence_threshold, verbose=False)[0]
362
+ detections = []
363
+
364
+ # Process detections
365
+ if results.boxes is not None and len(results.boxes) > 0:
366
+ boxes = results.boxes.xyxy.cpu().numpy()
367
+ confidences = results.boxes.conf.cpu().numpy()
368
+ masks = results.masks.data.cpu().numpy() if results.masks is not None else None
369
+
370
+ for idx, (box, conf) in enumerate(zip(boxes, confidences)):
371
+ if masks is not None and idx < len(masks):
372
+ mask = masks[idx]
373
+ mask_resized = cv2.resize(mask, (width, height))
374
+ mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
375
+
376
+ measurements = self.measurer.calculate_measurements(mask_binary)
377
+
378
+ if measurements:
379
+ measurements['confidence'] = float(conf)
380
+ detections.append(measurements)
381
+
382
+ # Draw on frame
383
+ overlay = frame.copy()
384
+ overlay[mask_binary > 0] = [50, 50, 255]
385
+ frame = cv2.addWeighted(frame, 0.6, overlay, 0.4, 0)
386
+
387
+ contours, _ = cv2.findContours(
388
+ mask_binary,
389
+ cv2.RETR_EXTERNAL,
390
+ cv2.CHAIN_APPROX_SIMPLE
391
+ )
392
+ cv2.drawContours(frame, contours, -1, (0, 255, 0), 2)
393
+
394
+ # Update tracker
395
+ tracked_detections = tracker.update(detections, frame_num, timestamp_str)
396
+
397
+ # Annotate frame
398
+ for det in tracked_detections:
399
+ x, y, w, h = det['bbox']
400
+ cx, cy = det['centroid']
401
+ track_id = det['track_id']
402
+
403
+ cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
404
+ cv2.circle(frame, (cx, cy), 5, (0, 255, 255), -1)
405
+
406
+ label = f"ID:{track_id} {det['severity']}"
407
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
408
+ cv2.rectangle(
409
+ frame,
410
+ (x, y - text_size[1] - 10),
411
+ (x + text_size[0] + 10, y),
412
+ (0, 0, 0),
413
+ -1
414
+ )
415
+
416
+ cv2.putText(frame, label, (x + 5, y - 5),
417
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
418
+
419
+ # Store CSV data
420
+ csv_data.append({
421
+ 'Frame': frame_num,
422
+ 'Timestamp': timestamp_str,
423
+ 'Track_ID': track_id,
424
+ 'Centroid_X': cx,
425
+ 'Centroid_Y': cy,
426
+ 'BBox_X': x,
427
+ 'BBox_Y': y,
428
+ 'BBox_Width': w,
429
+ 'BBox_Height': h,
430
+ 'Depth_cm': det['max_depth_cm'],
431
+ 'Area_m2': det['area_m2'],
432
+ 'Volume_L': det['volume_liters'],
433
+ 'Severity': det['severity'],
434
+ 'Confidence': det['confidence']
435
+ })
436
+
437
+ # Add frame info
438
+ info_text = f"Frame: {frame_num}/{total_frames} | Time: {timestamp_str} | Potholes: {len(tracked_detections)}"
439
+ cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
440
+ cv2.putText(frame, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1)
441
+
442
+ out.write(frame)
443
+
444
+ finally:
445
+ if os.path.exists(tmp_path):
446
+ os.unlink(tmp_path)
447
+
448
+ frame_num += 1
449
+
450
+ # Update progress
451
+ if frame_num % 10 == 0 and progress_callback:
452
+ progress_callback(frame_num / total_frames,
453
+ desc=f"Processing frame {frame_num}/{total_frames}")
454
+
455
+ cap.release()
456
+ out.release()
457
+
458
+ # Get statistics
459
+ stats = tracker.get_statistics()
460
+
461
+ # Save CSV
462
+ csv_path = tempfile.mktemp(suffix='.csv')
463
+ if csv_data:
464
+ df = pd.DataFrame(csv_data)
465
+ df.to_csv(csv_path, index=False)
466
+ else:
467
+ csv_path = None
468
+
469
+ if progress_callback:
470
+ progress_callback(1.0, desc="Video processing complete!")
471
+
472
+ return output_path, stats, total_frames, fps, csv_path