Sarvamangalak commited on
Commit
bedbb75
·
verified ·
1 Parent(s): d17b045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -130
app.py CHANGED
@@ -1,155 +1,331 @@
 
 
 
1
  import cv2
2
- import numpy as np
3
- from ultralytics import YOLO
4
- import easyocr
5
  import gradio as gr
6
- import tempfile
7
- import os
8
-
9
- # Load YOLOv8 plate detection model
10
- model = YOLO("best.pt") # <-- your trained plate model
11
-
12
- # Initialize OCR
13
- reader = easyocr.Reader(['en'], gpu=False)
14
-
15
- def preprocess_plate(plate_img):
16
- gray = cv2.cvtColor(plate_img, cv2.COLOR_BGR2GRAY)
17
- blur = cv2.GaussianBlur(gray, (5, 5), 0)
18
- thresh = cv2.adaptiveThreshold(
19
- blur, 255,
20
- cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
21
- cv2.THRESH_BINARY, 11, 2
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
- return thresh
24
-
25
- def recognize_plate(plate_img):
26
- processed = preprocess_plate(plate_img)
27
- ocr_result = reader.readtext(processed)
28
-
29
- plate_text = ""
30
- for (bbox, text, prob) in ocr_result:
31
- if prob > 0.4:
32
- plate_text += text + " "
33
-
34
- return plate_text.strip()
35
-
36
- def process_frame(frame):
37
- detected_plates = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- results = model(frame)
 
 
 
40
 
41
- for r in results:
42
- if r.boxes is None:
43
- continue
44
-
45
- boxes = r.boxes.xyxy.cpu().numpy()
46
- confs = r.boxes.conf.cpu().numpy()
47
-
48
- for box, conf in zip(boxes, confs):
49
- x1, y1, x2, y2 = map(int, box)
50
-
51
- plate_img = frame[y1:y2, x1:x2]
52
- if plate_img.size == 0:
53
- continue
54
-
55
- plate_text = recognize_plate(plate_img)
56
-
57
- detected_plates.append({
58
- "plate_text": plate_text,
59
- "confidence": float(conf)
60
- })
61
 
62
- # Draw bounding box
63
- cv2.rectangle(frame, (x1, y1), (x2, y2),
64
- (0, 255, 0), 2)
65
 
66
- # Draw plate text
67
- label = plate_text if plate_text else "Plate"
68
- cv2.putText(frame, label,
69
- (x1, y1 - 10),
70
- cv2.FONT_HERSHEY_SIMPLEX,
71
- 0.8, (255, 0, 0), 2)
72
 
73
- return frame, detected_plates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # =========================
76
- # IMAGE MODE
77
- # =========================
78
- def process_image(image):
79
- frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
80
- annotated_frame, plates = process_frame(frame)
81
- annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
82
 
83
- plate_texts = [p["plate_text"] for p in plates if p["plate_text"]]
84
- result_text = "\n".join(plate_texts) if plate_texts else "No plates detected."
85
 
86
- return annotated_frame, result_text
87
 
88
- # =========================
89
- # VIDEO MODE
90
- # =========================
91
- def process_video(video_file):
92
- cap = cv2.VideoCapture(video_file)
93
 
94
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
95
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
96
- fps = cap.get(cv2.CAP_PROP_FPS)
97
 
98
- temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
99
- out_path = temp_out.name
100
- temp_out.close()
 
 
 
 
 
101
 
102
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
103
- out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
 
 
104
 
105
- all_detected = set()
 
 
 
 
106
 
107
- while cap.isOpened():
108
- ret, frame = cap.read()
109
- if not ret:
110
- break
111
 
112
- annotated_frame, plates = process_frame(frame)
 
 
113
 
114
- for p in plates:
115
- if p["plate_text"]:
116
- all_detected.add(p["plate_text"])
117
 
118
- out.write(annotated_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- cap.release()
121
- out.release()
 
 
 
 
122
 
123
- result_text = "\n".join(all_detected) if all_detected else "No plates detected."
 
 
 
 
 
124
 
125
- return out_path, result_text
 
 
 
 
 
126
 
127
- # =========================
128
- # GRADIO UI
129
- # =========================
130
- with gr.Blocks() as demo:
131
- gr.Markdown("## Smart Traffic & EV Analytics System")
132
- gr.Markdown("Upload an image or video to detect multiple vehicle number plates.")
133
 
134
- with gr.Tabs():
135
- with gr.Tab("Image"):
136
- image_input = gr.Image(type="numpy", label="Upload Image")
137
- image_output = gr.Image(label="Detected Plates")
138
- image_text = gr.Textbox(label="Recognized Plate Numbers")
139
-
140
- image_button = gr.Button("Detect Plates")
141
- image_button.click(process_image,
142
- inputs=image_input,
143
- outputs=[image_output, image_text])
144
-
145
- with gr.Tab("Video"):
146
- video_input = gr.Video(label="Upload Video")
147
- video_output = gr.Video(label="Processed Video")
148
- video_text = gr.Textbox(label="Recognized Plate Numbers")
149
-
150
- video_button = gr.Button("Detect Plates")
151
- video_button.click(process_video,
152
- inputs=video_input,
153
- outputs=[video_output, video_text])
154
-
155
- demo.launch()
 
1
+ # app_with_video.py
2
+ import io
3
+ import os
4
  import cv2
 
 
 
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import requests, validators
8
+ import torch
9
+ import pathlib
10
+ from PIL import Image
11
+ from transformers import AutoImageProcessor, YolosForObjectDetection, DetrForObjectDetection
12
+
13
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
14
+
15
+ COLORS = [
16
+ [0.000, 0.447, 0.741],
17
+ [0.850, 0.325, 0.098],
18
+ [0.929, 0.694, 0.125],
19
+ [0.494, 0.184, 0.556],
20
+ [0.466, 0.674, 0.188],
21
+ [0.301, 0.745, 0.933]
22
+ ]
23
+
24
+ # ---------- Core Inference ----------
25
+
26
+ def make_prediction(img, processor, model):
27
+ inputs = processor(images=img, return_tensors="pt")
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ img_size = torch.tensor([tuple(reversed(img.size))])
31
+ processed_outputs = processor.post_process_object_detection(
32
+ outputs, threshold=0.0, target_sizes=img_size
33
  )
34
+ return processed_outputs[0]
35
+
36
+
37
+ def fig2img(fig):
38
+ buf = io.BytesIO()
39
+ fig.savefig(buf)
40
+ buf.seek(0)
41
+ pil_img = Image.open(buf)
42
+ basewidth = 750
43
+ wpercent = (basewidth / float(pil_img.size[0]))
44
+ hsize = int((float(pil_img.size[1]) * float(wpercent)))
45
+ img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
46
+ plt.close(fig)
47
+ return img
48
+
49
+
50
+ def classify_plate_color(crop_img):
51
+ # Convert PIL to OpenCV BGR
52
+ img = cv2.cvtColor(np.array(crop_img), cv2.COLOR_RGB2BGR)
53
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
54
+ h, s, v = cv2.split(hsv)
55
+ avg_h, avg_s, avg_v = np.mean(h), np.mean(s), np.mean(v)
56
+
57
+ # Heuristic thresholds (India-style plates)
58
+ if avg_v < 80:
59
+ return "Black Plate (Commercial)"
60
+ if avg_s < 40 and avg_v > 180:
61
+ return "White Plate (Private)"
62
+ if 15 < avg_h < 35 and avg_s > 80:
63
+ return "Yellow Plate (Commercial)"
64
+ if avg_h > 80 and avg_h < 130:
65
+ return "Blue Plate (Diplomatic)"
66
+ if avg_h > 35 and avg_h < 85:
67
+ return "Green Plate (Electric)"
68
+
69
+ return "Unknown Plate"
70
+
71
+
72
+ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
73
+ keep = output_dict["scores"] > threshold
74
+ boxes = output_dict["boxes"][keep].tolist()
75
+ scores = output_dict["scores"][keep].tolist()
76
+ labels = output_dict["labels"][keep].tolist()
77
+
78
+ if id2label is not None:
79
+ labels = [id2label[x] for x in labels]
80
+
81
+ plt.figure(figsize=(20, 20))
82
+ plt.imshow(img)
83
+ ax = plt.gca()
84
+ colors = COLORS * 100
85
+
86
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
87
+ if label == 'license-plates':
88
+ crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
89
+ plate_type = classify_plate_color(crop)
90
+
91
+ ax.add_patch(
92
+ plt.Rectangle(
93
+ (xmin, ymin), xmax - xmin, ymax - ymin,
94
+ fill=False, color=color, linewidth=4
95
+ )
96
+ )
97
+ ax.text(
98
+ xmin, ymin - 10,
99
+ f"{plate_type} | {score:0.2f}",
100
+ fontsize=12,
101
+ bbox=dict(facecolor="yellow", alpha=0.8)
102
+ )
103
+
104
+ plt.axis("off")
105
+ return fig2img(plt.gcf())(img, output_dict, threshold=0.5, id2label=None):
106
+ keep = output_dict["scores"] > threshold
107
+ boxes = output_dict["boxes"][keep].tolist()
108
+ scores = output_dict["scores"][keep].tolist()
109
+ labels = output_dict["labels"][keep].tolist()
110
+
111
+ if id2label is not None:
112
+ labels = [id2label[x] for x in labels]
113
+
114
+ plt.figure(figsize=(20, 20))
115
+ plt.imshow(img)
116
+ ax = plt.gca()
117
+ colors = COLORS * 100
118
+
119
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
120
+ if label == 'license-plates':
121
+ ax.add_patch(
122
+ plt.Rectangle(
123
+ (xmin, ymin), xmax - xmin, ymax - ymin,
124
+ fill=False, color=color, linewidth=4
125
+ )
126
+ )
127
+ ax.text(
128
+ xmin, ymin,
129
+ f"{label}: {score:0.2f}",
130
+ fontsize=12,
131
+ bbox=dict(facecolor="yellow", alpha=0.8)
132
+ )
133
+
134
+ plt.axis("off")
135
+ return fig2img(plt.gcf())
136
+
137
+
138
+ # ---------- Utilities ----------
139
+
140
+ def get_original_image(url_input):
141
+ if validators.url(url_input):
142
+ image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
143
+ return image
144
+
145
+
146
+ def load_model(model_name):
147
+ processor = AutoImageProcessor.from_pretrained(model_name)
148
+
149
+ if "yolos" in model_name:
150
+ model = YolosForObjectDetection.from_pretrained(model_name)
151
+ elif "detr" in model_name:
152
+ model = DetrForObjectDetection.from_pretrained(model_name)
153
+ else:
154
+ raise ValueError("Unsupported model")
155
+
156
+ model.eval()
157
+ return processor, model
158
+
159
+
160
+ # ---------- Image Detection ----------
161
+
162
+ def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold):
163
+ processor, model = load_model(model_name)
164
+
165
+ if validators.url(url_input):
166
+ image = get_original_image(url_input)
167
+ elif image_input is not None:
168
+ image = image_input
169
+ elif webcam_input is not None:
170
+ image = webcam_input
171
+ else:
172
+ return None
173
+
174
+ processed_outputs = make_prediction(image, processor, model)
175
+ viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
176
+
177
+ return viz_img
178
+
179
+
180
+ # ---------- Video Detection ----------
181
+
182
+ def detect_objects_video(model_name, video_input, threshold):
183
+ if video_input is None:
184
+ return None
185
 
186
+ processor, model = load_model(model_name)
187
+
188
+ cap = cv2.VideoCapture(video_input)
189
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
190
 
191
+ output_path = "/mnt/data/output_detected.mp4"
192
+ fps = cap.get(cv2.CAP_PROP_FPS)
193
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
194
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
 
 
197
 
198
+ while True:
199
+ ret, frame = cap.read()
200
+ if not ret:
201
+ break
 
 
202
 
203
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
204
+ pil_img = Image.fromarray(rgb_frame)
205
+
206
+ processed_outputs = make_prediction(pil_img, processor, model)
207
+
208
+ keep = processed_outputs["scores"] > threshold
209
+ boxes = processed_outputs["boxes"][keep].tolist()
210
+ scores = processed_outputs["scores"][keep].tolist()
211
+ labels = processed_outputs["labels"][keep].tolist()
212
+
213
+ labels = [model.config.id2label[x] for x in labels]
214
+
215
+ for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels):
216
+ if label == 'license-plates':
217
+ cv2.rectangle(
218
+ frame,
219
+ (int(xmin), int(ymin)),
220
+ (int(xmax), int(ymax)),
221
+ (0, 255, 0),
222
+ 2
223
+ )
224
+ cv2.putText(
225
+ frame,
226
+ f"{label}: {score:.2f}",
227
+ (int(xmin), int(ymin) - 10),
228
+ cv2.FONT_HERSHEY_SIMPLEX,
229
+ 0.6,
230
+ (0, 255, 0),
231
+ 2
232
+ )
233
+
234
+ out.write(frame)
235
 
236
+ cap.release()
237
+ out.release()
 
 
 
 
 
238
 
239
+ return output_path
 
240
 
 
241
 
242
+ # ---------- UI ----------
 
 
 
 
243
 
244
+ title = """<h1 id="title">License Plate Detection (Image + Video)</h1>"""
 
 
245
 
246
+ description = """
247
+ Detect license plates using YOLOS or DETR.
248
+ Supports:
249
+ - Image URL
250
+ - Image Upload
251
+ - Webcam
252
+ - Video Upload
253
+ """
254
 
255
+ models = [
256
+ "nickmuchi/yolos-small-finetuned-license-plate-detection",
257
+ "nickmuchi/detr-resnet50-license-plate-detection"
258
+ ]
259
 
260
+ css = '''
261
+ h1#title {
262
+ text-align: center;
263
+ }
264
+ '''
265
 
266
+ demo = gr.Blocks(css=css)
 
 
 
267
 
268
+ with demo:
269
+ gr.Markdown(title)
270
+ gr.Markdown(description)
271
 
272
+ options = gr.Dropdown(choices=models, label='Object Detection Model', value=models[0])
273
+ slider_input = gr.Slider(minimum=0.2, maximum=1, value=0.5, step=0.1, label='Prediction Threshold')
 
274
 
275
+ with gr.Tabs():
276
+ with gr.TabItem('Image URL'):
277
+ with gr.Row():
278
+ url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
279
+ original_image = gr.Image(shape=(750, 750))
280
+ url_input.change(get_original_image, url_input, original_image)
281
+ img_output_from_url = gr.Image(shape=(750, 750))
282
+ url_but = gr.Button('Detect')
283
+
284
+ with gr.TabItem('Image Upload'):
285
+ with gr.Row():
286
+ img_input = gr.Image(type='pil', shape=(750, 750))
287
+ img_output_from_upload = gr.Image(shape=(750, 750))
288
+ img_but = gr.Button('Detect')
289
+
290
+ with gr.TabItem('WebCam'):
291
+ with gr.Row():
292
+ web_input = gr.Image(source='webcam', type='pil', shape=(750, 750), streaming=True)
293
+ img_output_from_webcam = gr.Image(shape=(750, 750))
294
+ cam_but = gr.Button('Detect')
295
+
296
+ with gr.TabItem('Video Upload'):
297
+ with gr.Row():
298
+ video_input = gr.Video(label="Upload Video")
299
+ video_output = gr.Video(label="Detected Video")
300
+ vid_but = gr.Button('Detect Video')
301
+
302
+ url_but.click(
303
+ detect_objects_image,
304
+ inputs=[options, url_input, img_input, web_input, slider_input],
305
+ outputs=[img_output_from_url],
306
+ queue=True
307
+ )
308
 
309
+ img_but.click(
310
+ detect_objects_image,
311
+ inputs=[options, url_input, img_input, web_input, slider_input],
312
+ outputs=[img_output_from_upload],
313
+ queue=True
314
+ )
315
 
316
+ cam_but.click(
317
+ detect_objects_image,
318
+ inputs=[options, url_input, img_input, web_input, slider_input],
319
+ outputs=[img_output_from_webcam],
320
+ queue=True
321
+ )
322
 
323
+ vid_but.click(
324
+ detect_objects_video,
325
+ inputs=[options, video_input, slider_input],
326
+ outputs=[video_output],
327
+ queue=True
328
+ )
329
 
 
 
 
 
 
 
330
 
331
+ demo.launch(debug=True, enable_queue=True)