Sarvamangalak commited on
Commit
e95a235
·
verified ·
1 Parent(s): 03f5e41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -96
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import io
3
  import os
4
  import cv2
@@ -6,7 +5,6 @@ import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import requests
8
  import torch
9
- import pathlib
10
  import numpy as np
11
  import sqlite3
12
  import pandas as pd
@@ -16,23 +14,17 @@ from transformers import YolosImageProcessor, YolosForObjectDetection
16
 
17
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
18
 
19
- base_amt = 100
20
- #--- calculate discount
 
 
 
21
  def compute_discount(vehicle_type):
22
  if vehicle_type == "EV":
23
- return base_amt * 0.9, "10% discount applied (EV)"
24
- return toll_parking_amt, "No discount"
25
- #------------
26
- COLORS = [
27
- [0.000, 0.447, 0.741],
28
- [0.850, 0.325, 0.098],
29
- [0.929, 0.694, 0.125],
30
- [0.494, 0.184, 0.556],
31
- [0.466, 0.674, 0.188],
32
- [0.301, 0.745, 0.933]
33
- ]
34
-
35
- # ---------------- Utilities ----------------
36
 
37
  def is_valid_url(url):
38
  try:
@@ -49,6 +41,7 @@ def get_original_image(url_input):
49
 
50
 
51
  # -------------------- Database --------------------
 
52
  conn = sqlite3.connect("vehicles.db", check_same_thread=False)
53
  cursor = conn.cursor()
54
  cursor.execute("""
@@ -61,7 +54,9 @@ CREATE TABLE IF NOT EXISTS vehicles (
61
  """)
62
  conn.commit()
63
 
 
64
  # -------------------- Lazy Model --------------------
 
65
  processor = None
66
  model = None
67
 
@@ -69,11 +64,9 @@ model = None
69
  def load_model():
70
  global processor, model
71
  if processor is None or model is None:
72
- processor = YolosImageProcessor.from_pretrained(
73
- "nickmuchi/yolos-small-finetuned-license-plate-detection"
74
- )
75
  model = YolosForObjectDetection.from_pretrained(
76
- "nickmuchi/yolos-small-finetuned-license-plate-detection",
77
  use_safetensors=True,
78
  torch_dtype=torch.float32
79
  )
@@ -99,11 +92,10 @@ def classify_plate_color(plate_img):
99
  return "Personal"
100
 
101
 
102
- # ---------------- Dashboard ----------------
103
 
104
  def get_dashboard():
105
  df = pd.read_sql("SELECT * FROM vehicles", conn)
106
-
107
  fig, ax = plt.subplots(figsize=(7, 5))
108
 
109
  if len(df) == 0:
@@ -113,21 +105,16 @@ def get_dashboard():
113
  return fig
114
 
115
  counts = df["type"].value_counts()
116
-
117
- # Use bar chart instead of line for categorical data
118
  counts.plot(kind="bar", ax=ax, color="steelblue")
119
 
120
  ax.set_title("Vehicle Classification Dashboard", fontsize=12)
121
  ax.set_xlabel("Vehicle Type", fontsize=10)
122
  ax.set_ylabel("Count", fontsize=10)
123
 
124
- # Ensure labels are fully visible
125
  ax.set_xticks(range(len(counts.index)))
126
  ax.set_xticklabels(counts.index, rotation=0, ha="center")
127
-
128
  ax.grid(axis="y", linestyle="--", alpha=0.6)
129
 
130
- # Add value labels on top of bars
131
  for i, v in enumerate(counts.values):
132
  ax.text(i, v + 0.05, str(v), ha="center", va="bottom", fontsize=10)
133
 
@@ -135,7 +122,7 @@ def get_dashboard():
135
  return fig
136
 
137
 
138
- # ---------------- Core Inference ----------------
139
 
140
  def make_prediction(img):
141
  processor, model = load_model()
@@ -145,7 +132,7 @@ def make_prediction(img):
145
 
146
  img_size = torch.tensor([tuple(reversed(img.size))])
147
  processed_outputs = processor.post_process_object_detection(
148
- outputs, threshold=0.0, target_sizes=img_size
149
  )
150
  return processed_outputs[0]
151
 
@@ -165,7 +152,14 @@ def fig2img(fig):
165
  return img
166
 
167
 
168
- # ---------------- Visualization ----------------
 
 
 
 
 
 
 
169
 
170
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
171
  keep = output_dict["scores"] > threshold
@@ -179,11 +173,10 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
179
  plt.figure(figsize=(20, 20))
180
  plt.imshow(img)
181
  ax = plt.gca()
182
- colors = COLORS * 100
183
 
184
  result_lines = []
185
 
186
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
187
  if "plate" in label.lower():
188
  crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
189
 
@@ -191,6 +184,12 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
191
  vehicle_type = classify_plate_color(crop)
192
  toll, discount_msg = compute_discount(vehicle_type)
193
 
 
 
 
 
 
 
194
  result_lines.append(
195
  f"License: {plate_text} | Type: {vehicle_type} | Toll: ₹{int(toll)} | {discount_msg}"
196
  )
@@ -198,7 +197,7 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
198
  ax.add_patch(
199
  plt.Rectangle(
200
  (xmin, ymin), xmax - xmin, ymax - ymin,
201
- fill=False, color=color, linewidth=4
202
  )
203
  )
204
 
@@ -219,10 +218,10 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
219
 
220
  return final_img, result_text
221
 
222
- # ---------------- Image Detection ----------------
223
 
224
- def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold):
225
- processor, model = load_model(model_name)
 
226
 
227
  if url_input and is_valid_url(url_input):
228
  image = get_original_image(url_input)
@@ -233,73 +232,58 @@ def detect_objects_image(model_name, url_input, image_input, webcam_input, thres
233
  else:
234
  return None, "No image provided."
235
 
236
- processed_outputs = make_prediction(image, processor, model)
237
 
238
  viz_img, result_text = visualize_prediction(
239
- image, processed_outputs, threshold, model.config.id2label
 
 
 
240
  )
241
 
242
  return viz_img, result_text
243
 
244
 
245
- # ---------------- UI ----------------
246
-
247
- title = """<h1 id="title">Smart Vehicle classification</h1>"""
248
 
 
249
  description = """
250
  Detect license plates using YOLOS.
251
  Features:
252
- - Image URL, Image Upload, Webcam,Vehicle type classification by plate color
253
- - EV vehicles get 10% discount on Tolls, Tax, parking
 
254
  """
255
- result_box = gr.Textbox(
256
- label="Detection Result",
257
- lines=5,
258
- interactive=False
259
- )
260
- demo = gr.Blocks()
261
-
262
- with demo:
263
- debug=False,
264
- share=False,
265
- ssr_mode=False
266
- gr.Markdown(title)
267
- gr.Markdown(description)
268
- options = gr.Dropdown(
269
- choices=model,
270
- label="Object Detection Model",
271
- value=model[0]
272
- )
273
-
274
- url_input = gr.Textbox(label="Image URL")
275
- img_input = gr.Image(type="pil", label="Upload Image")
276
- web_input = gr.Image(source="webcam", type="pil", label="Webcam Input")
277
- slider_input = gr.Slider(0, 1, value=0.5, step=0.05, label="Confidence Threshold")
278
 
279
- img_output_from_url = gr.Image(label="Detection Output")
280
 
281
- detect_btn = gr.Button("Detect")
 
282
 
283
- slider_input = gr.Slider(minimum=0.2, maximum=1, value=0.5, step=0.1, label='Prediction Threshold')
284
 
285
  with gr.Tabs():
286
- with gr.TabItem('Image URL'):
 
287
  with gr.Row():
288
- url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
289
  original_image = gr.Image(height=200)
290
  url_input.change(get_original_image, url_input, original_image)
 
291
  img_output_from_url = gr.Image(height=200)
 
292
  dashboard_output_url = gr.Plot()
293
- url_but = gr.Button('Detect')
294
 
295
- with gr.TabItem('Image Upload'):
296
  with gr.Row():
297
- img_input = gr.Image(type='pil', height=200)
298
  img_output_from_upload = gr.Image(height=200)
 
299
  dashboard_output_upload = gr.Plot()
300
- img_but = gr.Button('Detect')
301
 
302
- with gr.TabItem('WebCam'):
303
  with gr.Row():
304
  web_input = gr.Image(
305
  sources=["webcam"],
@@ -308,42 +292,36 @@ with demo:
308
  streaming=True
309
  )
310
  img_output_from_webcam = gr.Image(height=200)
 
311
  dashboard_output_webcam = gr.Plot()
312
- cam_but = gr.Button('Detect')
 
 
313
 
314
  url_but.click(
315
  detect_objects_image,
316
- inputs=[options, url_input, img_input, web_input, slider_input],
317
- outputs=[img_output_from_url],
318
  queue=True
319
  )
320
 
321
  img_but.click(
322
  detect_objects_image,
323
- inputs=[options, url_input, img_input, web_input, slider_input],
324
- outputs=[img_output_from_upload],
325
  queue=True
326
  )
327
 
328
  cam_but.click(
329
  detect_objects_image,
330
- inputs=[options, url_input, img_input, web_input, slider_input],
331
- outputs=[img_output_from_webcam],
332
  queue=True
333
  )
334
 
335
- # vid_but.click(
336
- # detect_objects_video,
337
- # inputs=[video_input, slider_input],
338
- # outputs=[video_output],
339
- # queue=True
340
- # )
341
 
342
  demo.queue()
343
- import asyncio
344
-
345
- try:
346
- asyncio.get_running_loop()
347
- except RuntimeError:
348
- asyncio.set_event_loop(asyncio.new_event_loop())
349
  demo.launch(debug=True, ssr_mode=False)
 
 
1
  import io
2
  import os
3
  import cv2
 
5
  import matplotlib.pyplot as plt
6
  import requests
7
  import torch
 
8
  import numpy as np
9
  import sqlite3
10
  import pandas as pd
 
14
 
15
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
16
 
17
+ MODEL_NAME = "nickmuchi/yolos-small-finetuned-license-plate-detection"
18
+ BASE_AMT = 100
19
+
20
+ # -------------------- Discount --------------------
21
+
22
  def compute_discount(vehicle_type):
23
  if vehicle_type == "EV":
24
+ return BASE_AMT * 0.9, "10% discount applied (EV)"
25
+ return BASE_AMT, "No discount"
26
+
27
+ # -------------------- Utilities --------------------
 
 
 
 
 
 
 
 
 
28
 
29
  def is_valid_url(url):
30
  try:
 
41
 
42
 
43
  # -------------------- Database --------------------
44
+
45
  conn = sqlite3.connect("vehicles.db", check_same_thread=False)
46
  cursor = conn.cursor()
47
  cursor.execute("""
 
54
  """)
55
  conn.commit()
56
 
57
+
58
  # -------------------- Lazy Model --------------------
59
+
60
  processor = None
61
  model = None
62
 
 
64
  def load_model():
65
  global processor, model
66
  if processor is None or model is None:
67
+ processor = YolosImageProcessor.from_pretrained(MODEL_NAME)
 
 
68
  model = YolosForObjectDetection.from_pretrained(
69
+ MODEL_NAME,
70
  use_safetensors=True,
71
  torch_dtype=torch.float32
72
  )
 
92
  return "Personal"
93
 
94
 
95
+ # -------------------- Dashboard --------------------
96
 
97
  def get_dashboard():
98
  df = pd.read_sql("SELECT * FROM vehicles", conn)
 
99
  fig, ax = plt.subplots(figsize=(7, 5))
100
 
101
  if len(df) == 0:
 
105
  return fig
106
 
107
  counts = df["type"].value_counts()
 
 
108
  counts.plot(kind="bar", ax=ax, color="steelblue")
109
 
110
  ax.set_title("Vehicle Classification Dashboard", fontsize=12)
111
  ax.set_xlabel("Vehicle Type", fontsize=10)
112
  ax.set_ylabel("Count", fontsize=10)
113
 
 
114
  ax.set_xticks(range(len(counts.index)))
115
  ax.set_xticklabels(counts.index, rotation=0, ha="center")
 
116
  ax.grid(axis="y", linestyle="--", alpha=0.6)
117
 
 
118
  for i, v in enumerate(counts.values):
119
  ax.text(i, v + 0.05, str(v), ha="center", va="bottom", fontsize=10)
120
 
 
122
  return fig
123
 
124
 
125
+ # -------------------- YOLOS Inference --------------------
126
 
127
  def make_prediction(img):
128
  processor, model = load_model()
 
132
 
133
  img_size = torch.tensor([tuple(reversed(img.size))])
134
  processed_outputs = processor.post_process_object_detection(
135
+ outputs, threshold=0.3, target_sizes=img_size
136
  )
137
  return processed_outputs[0]
138
 
 
152
  return img
153
 
154
 
155
+ # -------------------- OCR Stub --------------------
156
+
157
+ def read_plate(crop):
158
+ # Placeholder OCR logic
159
+ return "KA01AB1234"
160
+
161
+
162
+ # -------------------- Visualization --------------------
163
 
164
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
165
  keep = output_dict["scores"] > threshold
 
173
  plt.figure(figsize=(20, 20))
174
  plt.imshow(img)
175
  ax = plt.gca()
 
176
 
177
  result_lines = []
178
 
179
+ for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels):
180
  if "plate" in label.lower():
181
  crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
182
 
 
184
  vehicle_type = classify_plate_color(crop)
185
  toll, discount_msg = compute_discount(vehicle_type)
186
 
187
+ cursor.execute(
188
+ "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
189
+ (plate_text, vehicle_type, toll)
190
+ )
191
+ conn.commit()
192
+
193
  result_lines.append(
194
  f"License: {plate_text} | Type: {vehicle_type} | Toll: ₹{int(toll)} | {discount_msg}"
195
  )
 
197
  ax.add_patch(
198
  plt.Rectangle(
199
  (xmin, ymin), xmax - xmin, ymax - ymin,
200
+ fill=False, color="red", linewidth=3
201
  )
202
  )
203
 
 
218
 
219
  return final_img, result_text
220
 
 
221
 
222
+ # -------------------- Gradio Callback --------------------
223
+
224
+ def detect_objects_image(url_input, image_input, webcam_input, threshold):
225
 
226
  if url_input and is_valid_url(url_input):
227
  image = get_original_image(url_input)
 
232
  else:
233
  return None, "No image provided."
234
 
235
+ processed_outputs = make_prediction(image)
236
 
237
  viz_img, result_text = visualize_prediction(
238
+ image,
239
+ processed_outputs,
240
+ threshold,
241
+ load_model()[1].config.id2label
242
  )
243
 
244
  return viz_img, result_text
245
 
246
 
247
+ # -------------------- UI --------------------
 
 
248
 
249
+ title = "<h1>🚦 Smart Vehicle Classification</h1>"
250
  description = """
251
  Detect license plates using YOLOS.
252
  Features:
253
+ - Image URL, Image Upload, Webcam
254
+ - Vehicle type classification by plate color
255
+ - EV vehicles get 10% discount on Toll / Parking
256
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ with gr.Blocks() as demo:
259
 
260
+ gr.Markdown(title)
261
+ gr.Markdown(description)
262
 
263
+ result_box = gr.Textbox(label="Detection Result", lines=5)
264
 
265
  with gr.Tabs():
266
+
267
+ with gr.TabItem("Image URL"):
268
  with gr.Row():
269
+ url_input = gr.Textbox(lines=2, label="Enter Image URL")
270
  original_image = gr.Image(height=200)
271
  url_input.change(get_original_image, url_input, original_image)
272
+
273
  img_output_from_url = gr.Image(height=200)
274
+
275
  dashboard_output_url = gr.Plot()
276
+ url_but = gr.Button("Detect")
277
 
278
+ with gr.TabItem("Image Upload"):
279
  with gr.Row():
280
+ img_input = gr.Image(type="pil", height=200)
281
  img_output_from_upload = gr.Image(height=200)
282
+
283
  dashboard_output_upload = gr.Plot()
284
+ img_but = gr.Button("Detect")
285
 
286
+ with gr.TabItem("Webcam"):
287
  with gr.Row():
288
  web_input = gr.Image(
289
  sources=["webcam"],
 
292
  streaming=True
293
  )
294
  img_output_from_webcam = gr.Image(height=200)
295
+
296
  dashboard_output_webcam = gr.Plot()
297
+ cam_but = gr.Button("Detect")
298
+
299
+ slider_input = gr.Slider(0.2, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
300
 
301
  url_but.click(
302
  detect_objects_image,
303
+ inputs=[url_input, img_input, web_input, slider_input],
304
+ outputs=[img_output_from_url, result_box],
305
  queue=True
306
  )
307
 
308
  img_but.click(
309
  detect_objects_image,
310
+ inputs=[url_input, img_input, web_input, slider_input],
311
+ outputs=[img_output_from_upload, result_box],
312
  queue=True
313
  )
314
 
315
  cam_but.click(
316
  detect_objects_image,
317
+ inputs=[url_input, img_input, web_input, slider_input],
318
+ outputs=[img_output_from_webcam, result_box],
319
  queue=True
320
  )
321
 
322
+ url_but.click(get_dashboard, outputs=dashboard_output_url)
323
+ img_but.click(get_dashboard, outputs=dashboard_output_upload)
324
+ cam_but.click(get_dashboard, outputs=dashboard_output_webcam)
 
 
 
325
 
326
  demo.queue()
 
 
 
 
 
 
327
  demo.launch(debug=True, ssr_mode=False)