Sarvamangalak commited on
Commit
a0f7c5a
·
verified ·
1 Parent(s): 0bfeae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -128
app.py CHANGED
@@ -1,4 +1,5 @@
1
- # app.py
 
2
  import io
3
  import os
4
  import cv2
@@ -6,28 +7,26 @@ import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import requests
8
  import torch
9
- import pathlib
10
  import numpy as np
11
  from urllib.parse import urlparse
12
- from transformers import AutoImageProcessor, YolosForObjectDetection, DetrForObjectDetection
13
- import sqlite3
14
- import pandas as pd
15
- import matplotlib.pyplot as plt
16
- from PIL import Image, ImageDraw
17
  from transformers import YolosImageProcessor, YolosForObjectDetection
18
  import easyocr
19
- from datetime import datetime
20
-
21
 
22
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
23
 
 
 
 
 
 
24
  COLORS = [
25
  [0.000, 0.447, 0.741],
26
  [0.850, 0.325, 0.098],
27
  [0.929, 0.694, 0.125],
28
  [0.494, 0.184, 0.556],
29
  [0.466, 0.674, 0.188],
30
- [0.301, 0.745, 0.933]
31
  ]
32
 
33
  # ---------------- Utilities ----------------
@@ -44,34 +43,28 @@ def get_original_image(url_input):
44
  if url_input and is_valid_url(url_input):
45
  image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
46
  return image
 
47
 
48
 
 
49
 
50
- # -------------------- Database --------------------
51
- conn = sqlite3.connect("vehicles.db", check_same_thread=False)
52
- cursor = conn.cursor()
53
- cursor.execute("""
54
- CREATE TABLE IF NOT EXISTS vehicles (
55
- plate TEXT,
56
- type TEXT,
57
- time TEXT
58
- )
59
- """)
60
- conn.commit()
61
-
62
- # -------------------- Models --------------------
63
- processor = YolosImageProcessor.from_pretrained(
64
- "nickmuchi/yolos-small-finetuned-license-plate-detection"
65
- )
66
- model = YolosForObjectDetection.from_pretrained(
67
- "nickmuchi/yolos-small-finetuned-license-plate-detection"
68
- )
69
- model.eval()
70
 
71
- reader = easyocr.Reader(['en'], gpu=False)
72
 
 
73
 
74
- # -------------------- Plate Color Classifier --------------------
75
  def classify_plate_color(plate_img):
76
  img = np.array(plate_img)
77
  hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
@@ -87,39 +80,20 @@ def classify_plate_color(plate_img):
87
  else:
88
  return "Personal"
89
 
90
- # -------------------- OCR --------------------
 
 
91
  def read_plate(plate_img):
92
  results = reader.readtext(np.array(plate_img))
93
  if results:
94
  return results[0][1]
95
  return "UNKNOWN"
96
 
97
- # -------------------- Dashboard --------------------
98
- def get_dashboard():
99
- df = pd.read_sql("SELECT * FROM vehicles", conn)
100
-
101
- fig, ax = plt.subplots(figsize=(8, 5))
102
-
103
- if len(df) == 0:
104
- ax.text(0.5, 0.5, "No vehicles scanned yet",
105
- ha="center", va="center", fontsize=10)
106
- ax.axis("off")
107
- return fig
108
-
109
- counts = df["type"].value_counts()
110
- counts.plot(kind="bar", ax=ax)
111
-
112
- ax.set_title("Vehicle Classification Dashboard")
113
- ax.set_xlabel("Vehicle Type")
114
- ax.set_ylabel("Count")
115
- ax.grid(axis="y")
116
-
117
- return fig
118
-
119
 
120
  # ---------------- Core Inference ----------------
121
 
122
- def make_prediction(img, processor, model):
 
123
  inputs = processor(images=img, return_tensors="pt")
124
  with torch.no_grad():
125
  outputs = model(**inputs)
@@ -131,6 +105,8 @@ def make_prediction(img, processor, model):
131
  return processed_outputs[0]
132
 
133
 
 
 
134
  def fig2img(fig):
135
  buf = io.BytesIO()
136
  fig.savefig(buf)
@@ -138,33 +114,17 @@ def fig2img(fig):
138
  pil_img = Image.open(buf)
139
 
140
  basewidth = 750
141
- wpercent = (basewidth / float(pil_img.size[0]))
142
- hsize = int((float(pil_img.size[1]) * float(wpercent)))
143
  img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
144
 
145
  plt.close(fig)
146
  return img
147
 
148
 
149
- def classify_plate_color(plate_img):
150
- img = np.array(plate_img)
151
- hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
152
-
153
- green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255)))
154
- yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255)))
155
- white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255)))
156
-
157
- if green > yellow and green > white:
158
- return "EV"
159
- elif yellow > green and yellow > white:
160
- return "Commercial"
161
- else:
162
- return "Personal"
163
-
164
-
165
- # ---------------- Visualization ----------------
166
-
167
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
 
 
168
  keep = output_dict["scores"] > threshold
169
  boxes = output_dict["boxes"][keep].tolist()
170
  scores = output_dict["scores"][keep].tolist()
@@ -178,10 +138,18 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
178
  ax = plt.gca()
179
  colors = COLORS * 100
180
 
181
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
182
- if label == 'license-plates':
 
 
183
  crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
184
  plate_type = classify_plate_color(crop)
 
 
 
 
 
 
185
 
186
  ax.add_patch(
187
  plt.Rectangle(
@@ -191,9 +159,9 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
191
  )
192
  ax.text(
193
  xmin, ymin - 10,
194
- f"{plate_type} | {score:0.2f}",
195
  fontsize=12,
196
- bbox=dict(facecolor="yellow", alpha=0.8)
197
  )
198
 
199
  plt.axis("off")
@@ -202,9 +170,7 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
202
 
203
  # ---------------- Image Detection ----------------
204
 
205
- def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold):
206
- processor, model = load_model(model_name)
207
-
208
  if url_input and is_valid_url(url_input):
209
  image = get_original_image(url_input)
210
  elif image_input is not None:
@@ -214,24 +180,26 @@ def detect_objects_image(model_name, url_input, image_input, webcam_input, thres
214
  else:
215
  return None
216
 
217
- processed_outputs = make_prediction(image, processor, model)
218
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
 
 
219
 
220
  return viz_img
221
 
222
 
223
  # ---------------- Video Detection ----------------
224
 
225
- def detect_objects_video(model_name, video_input, threshold):
226
  if video_input is None:
227
  return None
228
 
229
- processor, model = load_model(model_name)
230
 
231
  cap = cv2.VideoCapture(video_input)
232
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
233
 
234
- output_path = "/mnt/data/output_detected.mp4"
235
  fps = cap.get(cv2.CAP_PROP_FPS)
236
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
237
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -246,7 +214,7 @@ def detect_objects_video(model_name, video_input, threshold):
246
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
247
  pil_img = Image.fromarray(rgb_frame)
248
 
249
- processed_outputs = make_prediction(pil_img, processor, model)
250
 
251
  keep = processed_outputs["scores"] > threshold
252
  boxes = processed_outputs["boxes"][keep].tolist()
@@ -265,7 +233,7 @@ def detect_objects_video(model_name, video_input, threshold):
265
  (int(xmin), int(ymin)),
266
  (int(xmax), int(ymax)),
267
  (0, 255, 0),
268
- 2
269
  )
270
  cv2.putText(
271
  frame,
@@ -274,7 +242,7 @@ def detect_objects_video(model_name, video_input, threshold):
274
  cv2.FONT_HERSHEY_SIMPLEX,
275
  0.6,
276
  (0, 255, 0),
277
- 2
278
  )
279
 
280
  out.write(frame)
@@ -287,21 +255,17 @@ def detect_objects_video(model_name, video_input, threshold):
287
 
288
  # ---------------- UI ----------------
289
 
290
- title = """<h1 id="title">Smart Vehicle Clssification (Image + Video)</h1>"""
291
 
292
  description = """
293
- Detect license plates using YOLOS or DETR and Vehicle classification.
294
- Supports:Image URL, Image Upload, Webcam, Video Upload
 
295
  """
296
 
297
- #models = [
298
- # "nickmuchi/yolos-small-finetuned-license-plate-detection"
299
- #]
300
- css = '''
301
- h1#title {
302
- text-align: center;
303
- }
304
- '''
305
 
306
  demo = gr.Blocks()
307
 
@@ -309,70 +273,67 @@ with demo:
309
  gr.Markdown(title)
310
  gr.Markdown(description)
311
 
312
- options = gr.Dropdown(choices=models, label='Object Detection Model', value=models[0])
313
- slider_input = gr.Slider(minimum=0.2, maximum=1, value=0.5, step=0.1, label='Prediction Threshold')
 
314
 
315
  with gr.Tabs():
316
- with gr.TabItem('Image URL'):
317
  with gr.Row():
318
- url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
319
  original_image = gr.Image(height=750, width=750)
320
  url_input.change(get_original_image, url_input, original_image)
321
  img_output_from_url = gr.Image(height=750, width=750)
322
- url_but = gr.Button('Detect')
323
 
324
- with gr.TabItem('Image Upload'):
325
  with gr.Row():
326
- img_input = gr.Image(type='pil', height=750, width=750)
327
  img_output_from_upload = gr.Image(height=750, width=750)
328
- img_but = gr.Button('Detect')
329
 
330
- with gr.TabItem('WebCam'):
331
  with gr.Row():
332
  web_input = gr.Image(
333
- sources=["webcam"],
334
- type="pil",
335
- height=750,
336
- width=750,
337
- streaming=True
338
  )
339
  img_output_from_webcam = gr.Image(height=750, width=750)
340
- cam_but = gr.Button('Detect')
341
 
342
- with gr.TabItem('Video Upload'):
343
  with gr.Row():
344
  video_input = gr.Video(label="Upload Video")
345
  video_output = gr.Video(label="Detected Video")
346
- vid_but = gr.Button('Detect Video')
347
 
348
  url_but.click(
349
  detect_objects_image,
350
- inputs=[options, url_input, img_input, web_input, slider_input],
351
  outputs=[img_output_from_url],
352
- queue=True
353
  )
354
 
355
  img_but.click(
356
  detect_objects_image,
357
- inputs=[options, url_input, img_input, web_input, slider_input],
358
  outputs=[img_output_from_upload],
359
- queue=True
360
  )
361
 
362
  cam_but.click(
363
  detect_objects_image,
364
- inputs=[options, url_input, img_input, web_input, slider_input],
365
  outputs=[img_output_from_webcam],
366
- queue=True
367
  )
368
 
369
  vid_but.click(
370
  detect_objects_video,
371
- inputs=[options, video_input, slider_input],
372
  outputs=[video_output],
373
- queue=True
374
  )
375
 
376
 
377
  demo.queue()
378
- demo.launch(debug=True)
 
1
+ # app.py (Clean Final Version for HF Spaces)
2
+
3
  import io
4
  import os
5
  import cv2
 
7
  import matplotlib.pyplot as plt
8
  import requests
9
  import torch
 
10
  import numpy as np
11
  from urllib.parse import urlparse
12
+ from PIL import Image
 
 
 
 
13
  from transformers import YolosImageProcessor, YolosForObjectDetection
14
  import easyocr
 
 
15
 
16
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
17
 
18
+ # ---------------- Globals (lazy loaded) ----------------
19
+ processor = None
20
+ model = None
21
+ reader = easyocr.Reader(["en"], gpu=False)
22
+
23
  COLORS = [
24
  [0.000, 0.447, 0.741],
25
  [0.850, 0.325, 0.098],
26
  [0.929, 0.694, 0.125],
27
  [0.494, 0.184, 0.556],
28
  [0.466, 0.674, 0.188],
29
+ [0.301, 0.745, 0.933],
30
  ]
31
 
32
  # ---------------- Utilities ----------------
 
43
  if url_input and is_valid_url(url_input):
44
  image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
45
  return image
46
+ return None
47
 
48
 
49
+ # ---------------- Model Loader ----------------
50
 
51
+ def load_model():
52
+ global processor, model
53
+ if processor is None or model is None:
54
+ processor = YolosImageProcessor.from_pretrained(
55
+ "nickmuchi/yolos-small-finetuned-license-plate-detection"
56
+ )
57
+ model = YolosForObjectDetection.from_pretrained(
58
+ "nickmuchi/yolos-small-finetuned-license-plate-detection",
59
+ use_safetensors=True,
60
+ torch_dtype=torch.float32,
61
+ )
62
+ model.eval()
63
+ return processor, model
 
 
 
 
 
 
 
64
 
 
65
 
66
+ # ---------------- Plate Color Classifier ----------------
67
 
 
68
  def classify_plate_color(plate_img):
69
  img = np.array(plate_img)
70
  hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
 
80
  else:
81
  return "Personal"
82
 
83
+
84
+ # ---------------- OCR ----------------
85
+
86
  def read_plate(plate_img):
87
  results = reader.readtext(np.array(plate_img))
88
  if results:
89
  return results[0][1]
90
  return "UNKNOWN"
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ---------------- Core Inference ----------------
94
 
95
+ def make_prediction(img):
96
+ processor, model = load_model()
97
  inputs = processor(images=img, return_tensors="pt")
98
  with torch.no_grad():
99
  outputs = model(**inputs)
 
105
  return processed_outputs[0]
106
 
107
 
108
+ # ---------------- Visualization ----------------
109
+
110
  def fig2img(fig):
111
  buf = io.BytesIO()
112
  fig.savefig(buf)
 
114
  pil_img = Image.open(buf)
115
 
116
  basewidth = 750
117
+ wpercent = basewidth / float(pil_img.size[0])
118
+ hsize = int(float(pil_img.size[1]) * float(wpercent))
119
  img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
120
 
121
  plt.close(fig)
122
  return img
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
126
+ BASE_TOLL = 100 # base amount for all vehicles
127
+
128
  keep = output_dict["scores"] > threshold
129
  boxes = output_dict["boxes"][keep].tolist()
130
  scores = output_dict["scores"][keep].tolist()
 
138
  ax = plt.gca()
139
  colors = COLORS * 100
140
 
141
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(
142
+ scores, boxes, labels, colors
143
+ ):
144
+ if "plate" in label.lower():
145
  crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
146
  plate_type = classify_plate_color(crop)
147
+ # Apply 10% discount for EV vehicles
148
+ if plate_type == "EV":
149
+ discounted_amount = BASE_TOLL * 0.9
150
+ price_text = f"EV | ₹{discounted_amount:.0f} (10% off)"
151
+ else:
152
+ price_text = f"{plate_type} | ₹{BASE_TOLL}"
153
 
154
  ax.add_patch(
155
  plt.Rectangle(
 
159
  )
160
  ax.text(
161
  xmin, ymin - 10,
162
+ f"{price_text} | {score:0.2f}",
163
  fontsize=12,
164
+ bbox=dict(facecolor="yellow", alpha=0.8),
165
  )
166
 
167
  plt.axis("off")
 
170
 
171
  # ---------------- Image Detection ----------------
172
 
173
+ def detect_objects_image(url_input, image_input, webcam_input, threshold):
 
 
174
  if url_input and is_valid_url(url_input):
175
  image = get_original_image(url_input)
176
  elif image_input is not None:
 
180
  else:
181
  return None
182
 
183
+ processed_outputs = make_prediction(image)
184
+ viz_img = visualize_prediction(
185
+ image, processed_outputs, threshold, load_model()[1].config.id2label
186
+ )
187
 
188
  return viz_img
189
 
190
 
191
  # ---------------- Video Detection ----------------
192
 
193
+ def detect_objects_video(video_input, threshold):
194
  if video_input is None:
195
  return None
196
 
197
+ processor, model = load_model()
198
 
199
  cap = cv2.VideoCapture(video_input)
200
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
201
 
202
+ output_path = "/tmp/output_detected.mp4"
203
  fps = cap.get(cv2.CAP_PROP_FPS)
204
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
205
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
214
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
215
  pil_img = Image.fromarray(rgb_frame)
216
 
217
+ processed_outputs = make_prediction(pil_img)
218
 
219
  keep = processed_outputs["scores"] > threshold
220
  boxes = processed_outputs["boxes"][keep].tolist()
 
233
  (int(xmin), int(ymin)),
234
  (int(xmax), int(ymax)),
235
  (0, 255, 0),
236
+ 2,
237
  )
238
  cv2.putText(
239
  frame,
 
242
  cv2.FONT_HERSHEY_SIMPLEX,
243
  0.6,
244
  (0, 255, 0),
245
+ 2,
246
  )
247
 
248
  out.write(frame)
 
255
 
256
  # ---------------- UI ----------------
257
 
258
+ title = """<h1 id="title">Smart Vehicle Classification (Image + Video)</h1>"""
259
 
260
  description = """
261
+ Smart Vehicle Classification system to Promote EV by applying discount on Toll,
262
+ Tax, parking.
263
+ Supports:Image URL, Image Upload, Webcam, Video Upload,Vehicle type classification by plate color
264
  """
265
 
266
+ css = """
267
+ h1#title { text-align: center; }
268
+ """
 
 
 
 
 
269
 
270
  demo = gr.Blocks()
271
 
 
273
  gr.Markdown(title)
274
  gr.Markdown(description)
275
 
276
+ slider_input = gr.Slider(
277
+ minimum=0.2, maximum=1, value=0.5, step=0.1, label="Prediction Threshold"
278
+ )
279
 
280
  with gr.Tabs():
281
+ with gr.TabItem("Image URL"):
282
  with gr.Row():
283
+ url_input = gr.Textbox(lines=2, label="Enter valid image URL here..")
284
  original_image = gr.Image(height=750, width=750)
285
  url_input.change(get_original_image, url_input, original_image)
286
  img_output_from_url = gr.Image(height=750, width=750)
287
+ url_but = gr.Button("Detect")
288
 
289
+ with gr.TabItem("Image Upload"):
290
  with gr.Row():
291
+ img_input = gr.Image(type="pil", height=750, width=750)
292
  img_output_from_upload = gr.Image(height=750, width=750)
293
+ img_but = gr.Button("Detect")
294
 
295
+ with gr.TabItem("WebCam"):
296
  with gr.Row():
297
  web_input = gr.Image(
298
+ sources=["webcam"], type="pil", height=750, width=750, streaming=True
 
 
 
 
299
  )
300
  img_output_from_webcam = gr.Image(height=750, width=750)
301
+ cam_but = gr.Button("Detect")
302
 
303
+ with gr.TabItem("Video Upload"):
304
  with gr.Row():
305
  video_input = gr.Video(label="Upload Video")
306
  video_output = gr.Video(label="Detected Video")
307
+ vid_but = gr.Button("Detect Video")
308
 
309
  url_but.click(
310
  detect_objects_image,
311
+ inputs=[url_input, img_input, web_input, slider_input],
312
  outputs=[img_output_from_url],
313
+ queue=True,
314
  )
315
 
316
  img_but.click(
317
  detect_objects_image,
318
+ inputs=[url_input, img_input, web_input, slider_input],
319
  outputs=[img_output_from_upload],
320
+ queue=True,
321
  )
322
 
323
  cam_but.click(
324
  detect_objects_image,
325
+ inputs=[url_input, img_input, web_input, slider_input],
326
  outputs=[img_output_from_webcam],
327
+ queue=True,
328
  )
329
 
330
  vid_but.click(
331
  detect_objects_video,
332
+ inputs=[video_input, slider_input],
333
  outputs=[video_output],
334
+ queue=True,
335
  )
336
 
337
 
338
  demo.queue()
339
+ demo.launch(debug=True, ssr_mode=False)