Sarvamangalak commited on
Commit
3720108
·
verified ·
1 Parent(s): fa051fe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -0
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import os
4
+ import cv2
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import requests
8
+ import torch
9
+ import pathlib
10
+ import numpy as np
11
+ import sqlite3
12
+ import pandas as pd
13
+ from urllib.parse import urlparse
14
+ from PIL import Image
15
+ from transformers import YolosImageProcessor, YolosForObjectDetection
16
+
17
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
18
+
19
+ BASE_TOLL = 100
20
+
21
+ COLORS = [
22
+ [0.000, 0.447, 0.741],
23
+ [0.850, 0.325, 0.098],
24
+ [0.929, 0.694, 0.125],
25
+ [0.494, 0.184, 0.556],
26
+ [0.466, 0.674, 0.188],
27
+ [0.301, 0.745, 0.933]
28
+ ]
29
+
30
+ # ---------------- Utilities ----------------
31
+
32
+ def is_valid_url(url):
33
+ try:
34
+ result = urlparse(url)
35
+ return all([result.scheme, result.netloc])
36
+ except Exception:
37
+ return False
38
+
39
+
40
+ def get_original_image(url_input):
41
+ if url_input and is_valid_url(url_input):
42
+ image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
43
+ return image
44
+
45
+
46
+ # -------------------- Database --------------------
47
+ conn = sqlite3.connect("vehicles.db", check_same_thread=False)
48
+ cursor = conn.cursor()
49
+ cursor.execute("""
50
+ CREATE TABLE IF NOT EXISTS vehicles (
51
+ plate TEXT,
52
+ type TEXT,
53
+ amount REAL,
54
+ time TEXT
55
+ )
56
+ """)
57
+ conn.commit()
58
+
59
+ # -------------------- Lazy Model --------------------
60
+ processor = None
61
+ model = None
62
+
63
+
64
+ def load_model():
65
+ global processor, model
66
+ if processor is None or model is None:
67
+ processor = YolosImageProcessor.from_pretrained(
68
+ "nickmuchi/yolos-small-finetuned-license-plate-detection"
69
+ )
70
+ model = YolosForObjectDetection.from_pretrained(
71
+ "nickmuchi/yolos-small-finetuned-license-plate-detection",
72
+ use_safetensors=True,
73
+ torch_dtype=torch.float32
74
+ )
75
+ model.eval()
76
+ return processor, model
77
+
78
+
79
+ # -------------------- Plate Color Classifier --------------------
80
+
81
+ def classify_plate_color(plate_img):
82
+ img = np.array(plate_img)
83
+ hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
84
+
85
+ green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255)))
86
+ yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255)))
87
+ white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255)))
88
+
89
+ if green > yellow and green > white:
90
+ return "EV"
91
+ elif yellow > green and yellow > white:
92
+ return "Commercial"
93
+ else:
94
+ return "Personal"
95
+
96
+
97
+ # ---------------- Dashboard ----------------
98
+
99
+ def get_dashboard():
100
+ df = pd.read_sql("SELECT * FROM vehicles", conn)
101
+
102
+ fig, ax = plt.subplots(figsize=(6, 4))
103
+
104
+ if len(df) == 0:
105
+ ax.text(0.5, 0.5, "No vehicles scanned yet",
106
+ ha="center", va="center", fontsize=12)
107
+ ax.axis("off")
108
+ return fig
109
+
110
+ counts = df["type"].value_counts()
111
+ counts.plot(kind="bar", ax=ax)
112
+
113
+ ax.set_title("Vehicle Classification Dashboard")
114
+ ax.set_xlabel("Vehicle Type")
115
+ ax.set_ylabel("Count")
116
+ ax.grid(axis="y")
117
+
118
+ return fig
119
+
120
+
121
+ # ---------------- Core Inference ----------------
122
+
123
+ def make_prediction(img):
124
+ processor, model = load_model()
125
+ inputs = processor(images=img, return_tensors="pt")
126
+ with torch.no_grad():
127
+ outputs = model(**inputs)
128
+
129
+ img_size = torch.tensor([tuple(reversed(img.size))])
130
+ processed_outputs = processor.post_process_object_detection(
131
+ outputs, threshold=0.0, target_sizes=img_size
132
+ )
133
+ return processed_outputs[0]
134
+
135
+
136
+ def fig2img(fig):
137
+ buf = io.BytesIO()
138
+ fig.savefig(buf)
139
+ buf.seek(0)
140
+ pil_img = Image.open(buf)
141
+
142
+ basewidth = 750
143
+ wpercent = (basewidth / float(pil_img.size[0]))
144
+ hsize = int((float(pil_img.size[1]) * float(wpercent)))
145
+ img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
146
+
147
+ plt.close(fig)
148
+ return img
149
+
150
+
151
+ # ---------------- Visualization ----------------
152
+
153
+ def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
154
+ keep = output_dict["scores"] > threshold
155
+ boxes = output_dict["boxes"][keep].tolist()
156
+ scores = output_dict["scores"][keep].tolist()
157
+ labels = output_dict["labels"][keep].tolist()
158
+
159
+ if id2label is not None:
160
+ labels = [id2label[x] for x in labels]
161
+
162
+ plt.figure(figsize=(20, 20))
163
+ plt.imshow(img)
164
+ ax = plt.gca()
165
+ colors = COLORS * 100
166
+
167
+ for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
168
+ if "plate" in label.lower():
169
+ crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
170
+ plate_type = classify_plate_color(crop)
171
+
172
+ if plate_type == "EV":
173
+ amount = BASE_TOLL * 0.9
174
+ price_text = f"EV | ₹{amount:.0f} (10% off)"
175
+ else:
176
+ amount = BASE_TOLL
177
+ price_text = f"{plate_type} | ₹{amount:.0f}"
178
+
179
+ cursor.execute(
180
+ "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
181
+ ("UNKNOWN", plate_type, amount)
182
+ )
183
+ conn.commit()
184
+
185
+ ax.add_patch(
186
+ plt.Rectangle(
187
+ (xmin, ymin), xmax - xmin, ymax - ymin,
188
+ fill=False, color=color, linewidth=4
189
+ )
190
+ )
191
+ ax.text(
192
+ xmin, ymin - 10,
193
+ f"{price_text} | {score:0.2f}",
194
+ fontsize=12,
195
+ bbox=dict(facecolor="yellow", alpha=0.8)
196
+ )
197
+
198
+ plt.axis("off")
199
+ return fig2img(plt.gcf())
200
+
201
+
202
+ # ---------------- Image Detection ----------------
203
+
204
+ def detect_objects_image(url_input, image_input, webcam_input, threshold):
205
+ if url_input and is_valid_url(url_input):
206
+ image = get_original_image(url_input)
207
+ elif image_input is not None:
208
+ image = image_input
209
+ elif webcam_input is not None:
210
+ image = webcam_input
211
+ else:
212
+ return None, None
213
+
214
+ processed_outputs = make_prediction(image)
215
+ viz_img = visualize_prediction(image, processed_outputs, threshold, load_model()[1].config.id2label)
216
+ dashboard_fig = get_dashboard()
217
+
218
+ return viz_img, dashboard_fig
219
+
220
+
221
+ # ---------------- UI ----------------
222
+
223
+ title = """<h1 id="title">License Plate Detection + Toll Billing</h1>"""
224
+
225
+ description = """
226
+ Detect license plates using YOLOS.
227
+ Features:
228
+ - Image URL
229
+ - Image Upload
230
+ - Webcam
231
+ - Vehicle type classification by plate color
232
+ - EV vehicles get 10% discount
233
+ - Billing dashboard
234
+ """
235
+
236
+ demo = gr.Blocks()
237
+
238
+ with demo:
239
+ gr.Markdown(title)
240
+ gr.Markdown(description)
241
+
242
+ slider_input = gr.Slider(minimum=0.2, maximum=1, value=0.5, step=0.1, label='Prediction Threshold')
243
+
244
+ with gr.Tabs():
245
+ with gr.TabItem('Image URL'):
246
+ with gr.Row():
247
+ url_input = gr.Textbox(lines=2, label='Enter valid image URL here..')
248
+ original_image = gr.Image(height=400)
249
+ url_input.change(get_original_image, url_input, original_image)
250
+ img_output_from_url = gr.Image(height=400)
251
+ dashboard_output_url = gr.Plot()
252
+ url_but = gr.Button('Detect')
253
+
254
+ with gr.TabItem('Image Upload'):
255
+ with gr.Row():
256
+ img_input = gr.Image(type='pil', height=400)
257
+ img_output_from_upload = gr.Image(height=400)
258
+ dashboard_output_upload = gr.Plot()
259
+ img_but = gr.Button('Detect')
260
+
261
+ with gr.TabItem('WebCam'):
262
+ with gr.Row():
263
+ web_input = gr.Image(
264
+ sources=["webcam"],
265
+ type="pil",
266
+ height=400,
267
+ streaming=True
268
+ )
269
+ img_output_from_webcam = gr.Image(height=400)
270
+ dashboard_output_webcam = gr.Plot()
271
+ cam_but = gr.Button('Detect')
272
+
273
+ url_but.click(
274
+ detect_objects_image,
275
+ inputs=[url_input, img_input, web_input, slider_input],
276
+ outputs=[img_output_from_url, dashboard_output_url]
277
+ )
278
+
279
+ img_but.click(
280
+ detect_objects_image,
281
+ inputs=[url_input, img_input, web_input, slider_input],
282
+ outputs=[img_output_from_upload, dashboard_output_upload]
283
+ )
284
+
285
+ cam_but.click(
286
+ detect_objects_image,
287
+ inputs=[url_input, img_input, web_input, slider_input],
288
+ outputs=[img_output_from_webcam, dashboard_output_webcam]
289
+ )
290
+
291
+
292
+ demo.queue()
293
+ demo.launch(debug=True, ssr_mode=False)