Sarvamangalak commited on
Commit
9917a7b
·
verified ·
1 Parent(s): 0fd01f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -67
app.py CHANGED
@@ -1,13 +1,14 @@
1
- # app_with_video.py
2
  import io
3
  import os
4
  import cv2
5
- import numpy as np
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
8
- import requests, validators
9
  import torch
10
  import pathlib
 
 
11
  from PIL import Image
12
  from transformers import AutoImageProcessor, YolosForObjectDetection, DetrForObjectDetection
13
 
@@ -22,12 +23,45 @@ COLORS = [
22
  [0.301, 0.745, 0.933]
23
  ]
24
 
25
- # ---------- Core Inference ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def make_prediction(img, processor, model):
28
  inputs = processor(images=img, return_tensors="pt")
29
  with torch.no_grad():
30
  outputs = model(**inputs)
 
31
  img_size = torch.tensor([tuple(reversed(img.size))])
32
  processed_outputs = processor.post_process_object_detection(
33
  outputs, threshold=0.0, target_sizes=img_size
@@ -40,36 +74,42 @@ def fig2img(fig):
40
  fig.savefig(buf)
41
  buf.seek(0)
42
  pil_img = Image.open(buf)
 
43
  basewidth = 750
44
  wpercent = (basewidth / float(pil_img.size[0]))
45
  hsize = int((float(pil_img.size[1]) * float(wpercent)))
46
  img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
 
47
  plt.close(fig)
48
  return img
49
 
50
 
 
 
51
  def classify_plate_color(crop_img):
52
- # Convert PIL to OpenCV BGR
53
  img = cv2.cvtColor(np.array(crop_img), cv2.COLOR_RGB2BGR)
54
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
55
  h, s, v = cv2.split(hsv)
 
56
  avg_h, avg_s, avg_v = np.mean(h), np.mean(s), np.mean(v)
57
 
58
- # Heuristic thresholds (India-style plates)
59
  if avg_v < 80:
60
  return "Black Plate (Commercial)"
61
  if avg_s < 40 and avg_v > 180:
62
  return "White Plate (Private)"
63
  if 15 < avg_h < 35 and avg_s > 80:
64
  return "Yellow Plate (Commercial)"
65
- if avg_h > 80 and avg_h < 130:
66
  return "Blue Plate (Diplomatic)"
67
- if avg_h > 35 and avg_h < 85:
68
  return "Green Plate (Electric)"
69
 
70
  return "Unknown Plate"
71
 
72
 
 
 
73
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
74
  keep = output_dict["scores"] > threshold
75
  boxes = output_dict["boxes"][keep].tolist()
@@ -104,66 +144,14 @@ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
104
 
105
  plt.axis("off")
106
  return fig2img(plt.gcf())
107
- keep = output_dict["scores"] > threshold
108
- boxes = output_dict["boxes"][keep].tolist()
109
- scores = output_dict["scores"][keep].tolist()
110
- labels = output_dict["labels"][keep].tolist()
111
-
112
- if id2label is not None:
113
- labels = [id2label[x] for x in labels]
114
-
115
- plt.figure(figsize=(20, 20))
116
- plt.imshow(img)
117
- ax = plt.gca()
118
- colors = COLORS * 100
119
-
120
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
121
- if label == 'license-plates':
122
- ax.add_patch(
123
- plt.Rectangle(
124
- (xmin, ymin), xmax - xmin, ymax - ymin,
125
- fill=False, color=color, linewidth=4
126
- )
127
- )
128
- ax.text(
129
- xmin, ymin,
130
- f"{label}: {score:0.2f}",
131
- fontsize=12,
132
- bbox=dict(facecolor="yellow", alpha=0.8)
133
- )
134
-
135
- plt.axis("off")
136
- return fig2img(plt.gcf())
137
-
138
 
139
- # ---------- Utilities ----------
140
-
141
- def get_original_image(url_input):
142
- if validators.url(url_input):
143
- image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
144
- return image
145
-
146
-
147
- def load_model(model_name):
148
- processor = AutoImageProcessor.from_pretrained(model_name)
149
-
150
- if "yolos" in model_name:
151
- model = YolosForObjectDetection.from_pretrained(model_name)
152
- elif "detr" in model_name:
153
- model = DetrForObjectDetection.from_pretrained(model_name)
154
- else:
155
- raise ValueError("Unsupported model")
156
-
157
- model.eval()
158
- return processor, model
159
 
160
-
161
- # ---------- Image Detection ----------
162
 
163
  def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold):
164
  processor, model = load_model(model_name)
165
 
166
- if validators.url(url_input):
167
  image = get_original_image(url_input)
168
  elif image_input is not None:
169
  image = image_input
@@ -178,7 +166,7 @@ def detect_objects_image(model_name, url_input, image_input, webcam_input, thres
178
  return viz_img
179
 
180
 
181
- # ---------- Video Detection ----------
182
 
183
  def detect_objects_video(model_name, video_input, threshold):
184
  if video_input is None:
@@ -215,6 +203,9 @@ def detect_objects_video(model_name, video_input, threshold):
215
 
216
  for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels):
217
  if label == 'license-plates':
 
 
 
218
  cv2.rectangle(
219
  frame,
220
  (int(xmin), int(ymin)),
@@ -224,7 +215,7 @@ def detect_objects_video(model_name, video_input, threshold):
224
  )
225
  cv2.putText(
226
  frame,
227
- f"{label}: {score:.2f}",
228
  (int(xmin), int(ymin) - 10),
229
  cv2.FONT_HERSHEY_SIMPLEX,
230
  0.6,
@@ -240,7 +231,7 @@ def detect_objects_video(model_name, video_input, threshold):
240
  return output_path
241
 
242
 
243
- # ---------- UI ----------
244
 
245
  title = """<h1 id="title">License Plate Detection (Image + Video)</h1>"""
246
 
@@ -251,6 +242,7 @@ Supports:
251
  - Image Upload
252
  - Webcam
253
  - Video Upload
 
254
  """
255
 
256
  models = [
@@ -264,7 +256,7 @@ h1#title {
264
  }
265
  '''
266
 
267
- demo = gr.Blocks(css=css)
268
 
269
  with demo:
270
  gr.Markdown(title)
@@ -277,7 +269,7 @@ with demo:
277
  with gr.TabItem('Image URL'):
278
  with gr.Row():
279
  url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
280
- original_image = gr.Image(shape=(750, 750))
281
  url_input.change(get_original_image, url_input, original_image)
282
  img_output_from_url = gr.Image(shape=(750, 750))
283
  url_but = gr.Button('Detect')
 
1
+ # app.py (FINAL CLEAN VERSION)
2
  import io
3
  import os
4
  import cv2
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
+ import requests
8
  import torch
9
  import pathlib
10
+ import numpy as np
11
+ from urllib.parse import urlparse
12
  from PIL import Image
13
  from transformers import AutoImageProcessor, YolosForObjectDetection, DetrForObjectDetection
14
 
 
23
  [0.301, 0.745, 0.933]
24
  ]
25
 
26
+ # ---------------- Utilities ----------------
27
+
28
+ def is_valid_url(url):
29
+ try:
30
+ result = urlparse(url)
31
+ return all([result.scheme, result.netloc])
32
+ except Exception:
33
+ return False
34
+
35
+
36
+ def get_original_image(url_input):
37
+ if url_input and is_valid_url(url_input):
38
+ image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
39
+ return image
40
+
41
+
42
+ # ---------------- Model Loading ----------------
43
+
44
+ def load_model(model_name):
45
+ processor = AutoImageProcessor.from_pretrained(model_name)
46
+
47
+ if "yolos" in model_name:
48
+ model = YolosForObjectDetection.from_pretrained(model_name)
49
+ elif "detr" in model_name:
50
+ model = DetrForObjectDetection.from_pretrained(model_name)
51
+ else:
52
+ raise ValueError("Unsupported model")
53
+
54
+ model.eval()
55
+ return processor, model
56
+
57
+
58
+ # ---------------- Core Inference ----------------
59
 
60
  def make_prediction(img, processor, model):
61
  inputs = processor(images=img, return_tensors="pt")
62
  with torch.no_grad():
63
  outputs = model(**inputs)
64
+
65
  img_size = torch.tensor([tuple(reversed(img.size))])
66
  processed_outputs = processor.post_process_object_detection(
67
  outputs, threshold=0.0, target_sizes=img_size
 
74
  fig.savefig(buf)
75
  buf.seek(0)
76
  pil_img = Image.open(buf)
77
+
78
  basewidth = 750
79
  wpercent = (basewidth / float(pil_img.size[0]))
80
  hsize = int((float(pil_img.size[1]) * float(wpercent)))
81
  img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
82
+
83
  plt.close(fig)
84
  return img
85
 
86
 
87
+ # ---------------- Plate Color Classification ----------------
88
+
89
  def classify_plate_color(crop_img):
 
90
  img = cv2.cvtColor(np.array(crop_img), cv2.COLOR_RGB2BGR)
91
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
92
  h, s, v = cv2.split(hsv)
93
+
94
  avg_h, avg_s, avg_v = np.mean(h), np.mean(s), np.mean(v)
95
 
96
+ # Heuristic thresholds (India-style)
97
  if avg_v < 80:
98
  return "Black Plate (Commercial)"
99
  if avg_s < 40 and avg_v > 180:
100
  return "White Plate (Private)"
101
  if 15 < avg_h < 35 and avg_s > 80:
102
  return "Yellow Plate (Commercial)"
103
+ if 80 < avg_h < 130:
104
  return "Blue Plate (Diplomatic)"
105
+ if 35 < avg_h < 85:
106
  return "Green Plate (Electric)"
107
 
108
  return "Unknown Plate"
109
 
110
 
111
+ # ---------------- Visualization ----------------
112
+
113
  def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
114
  keep = output_dict["scores"] > threshold
115
  boxes = output_dict["boxes"][keep].tolist()
 
144
 
145
  plt.axis("off")
146
  return fig2img(plt.gcf())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # ---------------- Image Detection ----------------
 
150
 
151
  def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold):
152
  processor, model = load_model(model_name)
153
 
154
+ if url_input and is_valid_url(url_input):
155
  image = get_original_image(url_input)
156
  elif image_input is not None:
157
  image = image_input
 
166
  return viz_img
167
 
168
 
169
+ # ---------------- Video Detection ----------------
170
 
171
  def detect_objects_video(model_name, video_input, threshold):
172
  if video_input is None:
 
203
 
204
  for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels):
205
  if label == 'license-plates':
206
+ crop = pil_img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
207
+ plate_type = classify_plate_color(crop)
208
+
209
  cv2.rectangle(
210
  frame,
211
  (int(xmin), int(ymin)),
 
215
  )
216
  cv2.putText(
217
  frame,
218
+ f"{plate_type} | {score:.2f}",
219
  (int(xmin), int(ymin) - 10),
220
  cv2.FONT_HERSHEY_SIMPLEX,
221
  0.6,
 
231
  return output_path
232
 
233
 
234
+ # ---------------- UI ----------------
235
 
236
  title = """<h1 id="title">License Plate Detection (Image + Video)</h1>"""
237
 
 
242
  - Image Upload
243
  - Webcam
244
  - Video Upload
245
+ - Vehicle type classification by plate color
246
  """
247
 
248
  models = [
 
256
  }
257
  '''
258
 
259
+ demo = gr.Blocks()
260
 
261
  with demo:
262
  gr.Markdown(title)
 
269
  with gr.TabItem('Image URL'):
270
  with gr.Row():
271
  url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
272
+ original_image = gr.Image(height=750, width=750)
273
  url_input.change(get_original_image, url_input, original_image)
274
  img_output_from_url = gr.Image(shape=(750, 750))
275
  url_but = gr.Button('Detect')