Abs6187 commited on
Commit
3705172
·
verified ·
1 Parent(s): b7bf529

Create detector.py

Browse files
Files changed (1) hide show
  1. detector.py +405 -0
detector.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ultralytics import YOLO
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import numpy as np
5
+ import pandas as pd
6
+ import os
7
+ import cv2
8
+ import time
9
+ import zipfile
10
+ import io
11
+ from datetime import datetime
12
+
13
+ # ===== Optional OCR imports =====
14
+ try:
15
+ from license_plate_ocr import extract_license_plate_text
16
+ OCR_AVAILABLE = True
17
+ print("Basic OCR module loaded successfully")
18
+ except ImportError as e:
19
+ print(f"Basic OCR module not available: {e}")
20
+ OCR_AVAILABLE = False
21
+
22
+ try:
23
+ from advanced_ocr import (
24
+ extract_license_plate_text_advanced,
25
+ get_available_models,
26
+ set_ocr_model,
27
+ )
28
+ ADVANCED_OCR_AVAILABLE = True
29
+ print("Advanced OCR module loaded successfully")
30
+ except ImportError as e:
31
+ print(f"Advanced OCR module not available: {e}")
32
+ ADVANCED_OCR_AVAILABLE = False
33
+
34
+ # ===== Model & class names =====
35
+ model = YOLO("best.pt") # make sure best.pt is present
36
+ class_names = {0: "With Helmet", 1: "Without Helmet", 2: "License Plate"}
37
+
38
+
39
+ def crop_license_plates(image, detections, extract_text=False, selected_ocr_model="auto"):
40
+ """Crop license plates and (optionally) run OCR on the crops."""
41
+ cropped_plates = []
42
+
43
+ try:
44
+ if isinstance(image, str):
45
+ if not os.path.exists(image):
46
+ print(f"Error: Image file not found: {image}")
47
+ return cropped_plates
48
+ image = Image.open(image)
49
+ elif isinstance(image, np.ndarray):
50
+ image = Image.fromarray(image)
51
+ elif not isinstance(image, Image.Image):
52
+ print(f"Error: Unsupported image type: {type(image)}")
53
+ return cropped_plates
54
+
55
+ if image.size[0] == 0 or image.size[1] == 0:
56
+ print("Error: Image has zero dimensions")
57
+ return cropped_plates
58
+ except Exception as e:
59
+ print(f"Error loading image: {e}")
60
+ return cropped_plates
61
+
62
+ for i, detection in enumerate(detections):
63
+ try:
64
+ if detection["Object"] != "License Plate":
65
+ continue
66
+
67
+ pos_str = detection["Position"].strip("()")
68
+ if "," not in pos_str:
69
+ print(
70
+ f"Error: Invalid position format for detection {i}: {detection['Position']}"
71
+ )
72
+ continue
73
+
74
+ x1, y1 = map(int, pos_str.split(", "))
75
+
76
+ dims_str = detection["Dimensions"]
77
+ if "x" not in dims_str:
78
+ print(
79
+ f"Error: Invalid dimensions format for detection {i}: {detection['Dimensions']}"
80
+ )
81
+ continue
82
+
83
+ width, height = map(int, dims_str.split("x"))
84
+
85
+ if width <= 0 or height <= 0:
86
+ print(f"Error: Invalid dimensions for detection {i}: {width}x{height}")
87
+ continue
88
+
89
+ x2, y2 = x1 + width, y1 + height
90
+
91
+ if x1 < 0 or y1 < 0 or x2 > image.width or y2 > image.height:
92
+ print(
93
+ f"Warning: Bounding box extends beyond image boundaries for detection {i}"
94
+ )
95
+ x1 = max(0, x1)
96
+ y1 = max(0, y1)
97
+ x2 = min(image.width, x2)
98
+ y2 = min(image.height, y2)
99
+
100
+ if x2 <= x1 or y2 <= y1:
101
+ print(
102
+ f"Error: Invalid crop coordinates for detection {i}: ({x1},{y1}) to ({x2},{y2})"
103
+ )
104
+ continue
105
+
106
+ cropped_plate = image.crop((x1, y1, x2, y2))
107
+
108
+ if cropped_plate.size[0] == 0 or cropped_plate.size[1] == 0:
109
+ print(
110
+ f"Error: Cropped image has zero dimensions for detection {i}"
111
+ )
112
+ continue
113
+
114
+ plate_data = {
115
+ "image": cropped_plate,
116
+ "confidence": detection["Confidence"],
117
+ "position": detection["Position"],
118
+ "crop_coords": f"({x1},{y1}) to ({x2},{y2})",
119
+ "text": "Processing...",
120
+ }
121
+
122
+ if extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
123
+ try:
124
+ print(
125
+ f"Extracting text from license plate {i+1} using {selected_ocr_model}..."
126
+ )
127
+
128
+ if ADVANCED_OCR_AVAILABLE and selected_ocr_model != "basic":
129
+ if selected_ocr_model != "auto":
130
+ set_ocr_model(selected_ocr_model)
131
+ plate_text = extract_license_plate_text_advanced(
132
+ cropped_plate,
133
+ None if selected_ocr_model == "auto" else selected_ocr_model,
134
+ )
135
+ else:
136
+ plate_text = extract_license_plate_text(cropped_plate)
137
+
138
+ if (
139
+ plate_text
140
+ and plate_text.strip()
141
+ and not plate_text.startswith("Error")
142
+ ):
143
+ plate_data["text"] = plate_text.strip()
144
+ print(f"Extracted text: {plate_text.strip()}")
145
+ else:
146
+ plate_data["text"] = "No text detected"
147
+ print(f"No text found in plate {i+1}")
148
+ except Exception as e:
149
+ print(f"OCR extraction failed for plate {i+1}: {e}")
150
+ plate_data["text"] = f"OCR Failed: {str(e)}"
151
+ elif extract_text and not (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
152
+ plate_data["text"] = "OCR not available"
153
+ else:
154
+ plate_data["text"] = "OCR disabled"
155
+
156
+ cropped_plates.append(plate_data)
157
+
158
+ except ValueError as e:
159
+ print(f"Error parsing coordinates for detection {i}: {e}")
160
+ continue
161
+ except Exception as e:
162
+ print(f"Error cropping license plate {i}: {e}")
163
+ continue
164
+
165
+ return cropped_plates
166
+
167
+
168
+ def create_download_files(annotated_image, cropped_plates, detections):
169
+ try:
170
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
171
+ os.makedirs("temp", exist_ok=True)
172
+
173
+ annotated_path = f"temp/annotated_image_{timestamp}.jpg"
174
+ try:
175
+ annotated_image.save(annotated_path, quality=95)
176
+ except Exception as e:
177
+ print(f"Error saving annotated image: {e}")
178
+ return None, None, []
179
+
180
+ plate_paths = []
181
+ for i, plate_data in enumerate(cropped_plates):
182
+ try:
183
+ plate_path = f"temp/license_plate_{i+1}_{timestamp}.jpg"
184
+ plate_data["image"].save(plate_path, quality=95)
185
+ plate_paths.append(plate_path)
186
+ except Exception as e:
187
+ print(f"Error saving license plate {i+1}: {e}")
188
+ continue
189
+
190
+ report_data = []
191
+ for detection in detections:
192
+ report_data.append(detection)
193
+
194
+ for i, plate_data in enumerate(cropped_plates):
195
+ report_data.append(
196
+ {
197
+ "Object": f"License Plate {i+1} - Text",
198
+ "Confidence": plate_data["confidence"],
199
+ "Position": plate_data["position"],
200
+ "Dimensions": "Extracted Text",
201
+ "Text": plate_data.get("text", "N/A"),
202
+ }
203
+ )
204
+
205
+ report_path = f"temp/detection_report_{timestamp}.csv"
206
+ if report_data:
207
+ try:
208
+ df = pd.DataFrame(report_data)
209
+ df.to_csv(report_path, index=False)
210
+ except Exception as e:
211
+ print(f"Error creating detection report: {e}")
212
+ report_path = None
213
+
214
+ zip_path = f"temp/detection_results_{timestamp}.zip"
215
+ try:
216
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
217
+ if os.path.exists(annotated_path):
218
+ zipf.write(annotated_path, f"annotated_image_{timestamp}.jpg")
219
+ for plate_path in plate_paths:
220
+ if os.path.exists(plate_path):
221
+ zipf.write(plate_path, os.path.basename(plate_path))
222
+ if report_path and os.path.exists(report_path):
223
+ zipf.write(report_path, f"detection_report_{timestamp}.csv")
224
+ except Exception as e:
225
+ print(f"Error creating ZIP file: {e}")
226
+ return None, annotated_path, plate_paths
227
+
228
+ return zip_path, annotated_path, plate_paths
229
+
230
+ except Exception as e:
231
+ print(f"Error in create_download_files: {e}")
232
+ return None, None, []
233
+
234
+
235
+ def yolov8_detect(
236
+ image=None,
237
+ image_size=640,
238
+ conf_threshold=0.4,
239
+ iou_threshold=0.5,
240
+ show_stats=True,
241
+ show_confidence=True,
242
+ crop_plates=True,
243
+ extract_text=False,
244
+ ocr_on_no_helmet=False,
245
+ selected_ocr_model="auto",
246
+ ):
247
+ """Main detection function."""
248
+ if image_size is None:
249
+ image_size = 640
250
+ if not isinstance(image_size, int):
251
+ image_size = int(image_size)
252
+
253
+ imgsz = [image_size, image_size]
254
+ results = model.predict(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz)
255
+
256
+ annotated_image = results[0].plot()
257
+ if isinstance(annotated_image, np.ndarray):
258
+ annotated_image = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
259
+
260
+ boxes = results[0].boxes
261
+ detections = []
262
+ if boxes is not None and len(boxes) > 0:
263
+ for i, (box, cls, conf) in enumerate(zip(boxes.xyxy, boxes.cls, boxes.conf)):
264
+ x1, y1, x2, y2 = box.tolist()
265
+ class_id = int(cls)
266
+ confidence = float(conf)
267
+ label = class_names.get(class_id, f"Class {class_id}")
268
+ detections.append(
269
+ {
270
+ "Object": label,
271
+ "Confidence": f"{confidence:.2f}",
272
+ "Position": f"({int(x1)}, {int(y1)})",
273
+ "Dimensions": f"{int(x2 - x1)}x{int(y2 - y1)}",
274
+ }
275
+ )
276
+
277
+ cropped_plates = []
278
+ license_plate_gallery = []
279
+ plate_texts = []
280
+ download_files = None
281
+
282
+ has_no_helmet = any(d["Object"] == "Without Helmet" for d in detections)
283
+ should_extract_text = extract_text or (ocr_on_no_helmet and has_no_helmet)
284
+ ocr_available = OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE
285
+
286
+ if crop_plates and detections:
287
+ try:
288
+ license_plate_count = len([d for d in detections if d["Object"] == "License Plate"])
289
+ print(f"Processing {license_plate_count} license plates...")
290
+
291
+ if ocr_on_no_helmet and has_no_helmet:
292
+ print("⚠️ No helmet detected - OCR will be performed on license plates")
293
+
294
+ cropped_plates = crop_license_plates(
295
+ image, detections, should_extract_text, selected_ocr_model
296
+ )
297
+ print(f"Successfully cropped {len(cropped_plates)} license plates")
298
+
299
+ license_plate_gallery = [plate_data["image"] for plate_data in cropped_plates]
300
+
301
+ if should_extract_text and ocr_available:
302
+ print("Extracting text from license plates...")
303
+ plate_texts = []
304
+ for i, plate_data in enumerate(cropped_plates):
305
+ text = plate_data.get("text", "No text detected")
306
+ print(f"Plate {i+1} text: {text}")
307
+ if ocr_on_no_helmet and has_no_helmet:
308
+ plate_texts.append(f"🚨 No Helmet Violation - Plate {i+1}: {text}")
309
+ else:
310
+ plate_texts.append(f"Plate {i+1}: {text}")
311
+ elif should_extract_text and not ocr_available:
312
+ plate_texts = [
313
+ "OCR not available - install requirements: pip install transformers easyocr"
314
+ ]
315
+ elif not should_extract_text:
316
+ plate_texts = [
317
+ f"Plate {i+1}: Text extraction disabled" for i in range(len(cropped_plates))
318
+ ]
319
+
320
+ if cropped_plates or detections:
321
+ download_files, _, _ = create_download_files(
322
+ annotated_image, cropped_plates, detections
323
+ )
324
+ if download_files is None:
325
+ print("Warning: Could not create download files")
326
+ except Exception as e:
327
+ print(f"Error in license plate processing: {e}")
328
+ cropped_plates = []
329
+ license_plate_gallery = []
330
+ plate_texts = ["Error processing license plates"]
331
+ download_files = None
332
+
333
+ stats_text = ""
334
+ if show_stats and detections:
335
+ df = pd.DataFrame(detections)
336
+ counts = df["Object"].value_counts().to_dict()
337
+ stats_text = "Detection Summary:\n"
338
+ for obj, count in counts.items():
339
+ stats_text += f"- {obj}: {count}\n"
340
+
341
+ if cropped_plates:
342
+ stats_text += f"\nLicense Plates Cropped: {len(cropped_plates)}\n"
343
+ if has_no_helmet:
344
+ stats_text += "⚠️ HELMET VIOLATION DETECTED!\n"
345
+ if should_extract_text and (OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE):
346
+ stats_text += "Extracted Text:\n"
347
+ for i, plate_data in enumerate(cropped_plates):
348
+ text = plate_data.get("text", "No text")
349
+ if has_no_helmet and ocr_on_no_helmet:
350
+ stats_text += f"🚨 Violation - Plate {i+1}: {text}\n"
351
+ else:
352
+ stats_text += f"- Plate {i+1}: {text}\n"
353
+
354
+ detection_table = (
355
+ pd.DataFrame(detections)
356
+ if detections
357
+ else pd.DataFrame(columns=["Object", "Confidence", "Position", "Dimensions"])
358
+ )
359
+ plate_text_output = (
360
+ "\n".join(plate_texts)
361
+ if plate_texts
362
+ else "No license plates detected or OCR disabled"
363
+ )
364
+
365
+ return (
366
+ annotated_image,
367
+ detection_table,
368
+ stats_text,
369
+ license_plate_gallery,
370
+ download_files,
371
+ plate_text_output,
372
+ )
373
+
374
+
375
+ def download_sample_images():
376
+ """Download sample images for testing."""
377
+ torch.hub.download_url_to_file(
378
+ "https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-1.jpg?raw=true",
379
+ "sample_1.jpg",
380
+ )
381
+ torch.hub.download_url_to_file(
382
+ "https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-2.jpg?raw=true",
383
+ "sample_2.jpg",
384
+ )
385
+ torch.hub.download_url_to_file(
386
+ "https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-3.jpg?raw=true",
387
+ "sample_3.jpg",
388
+ )
389
+ torch.hub.download_url_to_file(
390
+ "https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-4.jpg?raw=true",
391
+ "sample_4.jpg",
392
+ )
393
+ torch.hub.download_url_to_file(
394
+ "https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-5.jpg?raw=true",
395
+ "sample_5.jpg",
396
+ )
397
+
398
+
399
+ def get_ocr_status():
400
+ """Return OCR availability status."""
401
+ return {
402
+ "basic_available": OCR_AVAILABLE,
403
+ "advanced_available": ADVANCED_OCR_AVAILABLE,
404
+ "any_available": OCR_AVAILABLE or ADVANCED_OCR_AVAILABLE
405
+ }