Sarvamangalak commited on
Commit
7b7476e
·
verified ·
1 Parent(s): 8cc070d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import cv2
4
+ import gradio as gr
5
+ import matplotlib
6
+ matplotlib.use("Agg")
7
+
8
+ import matplotlib.pyplot as plt
9
+ import requests
10
+ import torch
11
+ import numpy as np
12
+ import sqlite3
13
+ import pandas as pd
14
+ import pytesseract
15
+
16
+ from urllib.parse import urlparse
17
+ from PIL import Image
18
+ from transformers import YolosImageProcessor, YolosForObjectDetection
19
+
20
+ # -------------------- CONFIG --------------------
21
+
22
+ MODEL_NAME = "nickmuchi/yolos-small-finetuned-license-plate-detection"
23
+ BASE_AMT = 100
24
+
25
+ # -------------------- DATABASE --------------------
26
+
27
+ conn = sqlite3.connect("vehicles.db", check_same_thread=False)
28
+ cursor = conn.cursor()
29
+ cursor.execute("""
30
+ CREATE TABLE IF NOT EXISTS vehicles (
31
+ plate TEXT,
32
+ type TEXT,
33
+ amount REAL,
34
+ time TEXT
35
+ )
36
+ """)
37
+ conn.commit()
38
+
39
+ # -------------------- MODEL (Lazy Load) --------------------
40
+
41
+ processor = None
42
+ model = None
43
+
44
+ def load_model():
45
+ global processor, model
46
+ if processor is None or model is None:
47
+ processor = YolosImageProcessor.from_pretrained(MODEL_NAME)
48
+ model = YolosForObjectDetection.from_pretrained(
49
+ MODEL_NAME,
50
+ torch_dtype=torch.float32
51
+ )
52
+ model.eval()
53
+ return processor, model
54
+
55
+ # -------------------- UTILITIES --------------------
56
+
57
+ def is_valid_url(url):
58
+ try:
59
+ r = urlparse(url)
60
+ return all([r.scheme, r.netloc])
61
+ except:
62
+ return False
63
+
64
+ def get_original_image(url):
65
+ return Image.open(requests.get(url, stream=True).raw).convert("RGB")
66
+
67
+ # -------------------- DISCOUNT LOGIC --------------------
68
+
69
+ def compute_discount(vehicle_type):
70
+ if vehicle_type == "EV":
71
+ return BASE_AMT * 0.9, "10% EV discount applied"
72
+ return BASE_AMT, "No discount"
73
+
74
+ # -------------------- PLATE COLOR CLASSIFICATION --------------------
75
+
76
+ def classify_plate_color(plate_img):
77
+ img = np.array(plate_img)
78
+ hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
79
+
80
+ green = np.sum(cv2.inRange(hsv, (35,40,40), (85,255,255)))
81
+ yellow = np.sum(cv2.inRange(hsv, (15,50,50), (35,255,255)))
82
+
83
+ if green > yellow:
84
+ return "EV"
85
+ elif yellow > green:
86
+ return "Commercial"
87
+ return "Personal"
88
+
89
+ # -------------------- OCR (LIGHTWEIGHT) --------------------
90
+
91
+ def read_plate(plate_img):
92
+ gray = cv2.cvtColor(np.array(plate_img), cv2.COLOR_RGB2GRAY)
93
+ gray = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY)[1]
94
+
95
+ text = pytesseract.image_to_string(
96
+ gray,
97
+ config="--psm 7 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
98
+ )
99
+ return text.strip() if text.strip() else "UNKNOWN"
100
+
101
+ # -------------------- YOLOS INFERENCE --------------------
102
+
103
+ def make_prediction(img):
104
+ processor, model = load_model()
105
+ inputs = processor(images=img, return_tensors="pt")
106
+ with torch.no_grad():
107
+ outputs = model(**inputs)
108
+
109
+ img_size = torch.tensor([tuple(reversed(img.size))])
110
+ results = processor.post_process_object_detection(
111
+ outputs, threshold=0.3, target_sizes=img_size
112
+ )
113
+ return results[0]
114
+
115
+ # -------------------- VISUALIZATION --------------------
116
+
117
+ def fig_to_img(fig):
118
+ buf = io.BytesIO()
119
+ fig.savefig(buf)
120
+ buf.seek(0)
121
+ img = Image.open(buf)
122
+ plt.close(fig)
123
+ return img
124
+
125
+ def visualize(img, output, threshold):
126
+ keep = output["scores"] > threshold
127
+ boxes = output["boxes"][keep]
128
+ labels = output["labels"][keep]
129
+
130
+ plt.figure(figsize=(10,10))
131
+ plt.imshow(img)
132
+ ax = plt.gca()
133
+
134
+ results = []
135
+
136
+ for box, label in zip(boxes, labels):
137
+ if "plate" not in load_model()[1].config.id2label[label.item()].lower():
138
+ continue
139
+
140
+ x1,y1,x2,y2 = map(int, box.tolist())
141
+ plate_img = img.crop((x1,y1,x2,y2))
142
+
143
+ plate = read_plate(plate_img)
144
+ vtype = classify_plate_color(plate_img)
145
+ toll, msg = compute_discount(vtype)
146
+
147
+ cursor.execute(
148
+ "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
149
+ (plate, vtype, toll)
150
+ )
151
+ conn.commit()
152
+
153
+ results.append(f"{plate} | {vtype} | ₹{int(toll)}")
154
+
155
+ ax.add_patch(
156
+ plt.Rectangle((x1,y1), x2-x1, y2-y1, fill=False, color="red", linewidth=2)
157
+ )
158
+ ax.text(x1, y1-5, f"{plate} ({vtype})", color="yellow")
159
+
160
+ plt.axis("off")
161
+ return fig_to_img(plt.gcf()), "\n".join(results) if results else "No plate detected"
162
+
163
+ # -------------------- DASHBOARD --------------------
164
+
165
+ def get_dashboard():
166
+ df = pd.read_sql("SELECT * FROM vehicles", conn)
167
+ fig, ax = plt.subplots()
168
+
169
+ if df.empty:
170
+ ax.text(0.5,0.5,"No data yet",ha="center")
171
+ ax.axis("off")
172
+ return fig
173
+
174
+ df["type"].value_counts().plot(kind="bar", ax=ax)
175
+ ax.set_title("Vehicle Types")
176
+ return fig
177
+
178
+ # -------------------- MAIN CALLBACK --------------------
179
+
180
+ def detect(url, img, cam, threshold):
181
+ if url and is_valid_url(url):
182
+ image = get_original_image(url)
183
+ elif img is not None:
184
+ image = img
185
+ elif cam is not None:
186
+ image = cam
187
+ else:
188
+ return None, "No input"
189
+
190
+ output = make_prediction(image)
191
+ return visualize(image, output, threshold)
192
+
193
+ # -------------------- UI --------------------
194
+
195
+ with gr.Blocks() as demo:
196
+ gr.Markdown("## 🚦 Smart Vehicle Classification System")
197
+
198
+ result_box = gr.Textbox(label="Result", lines=4)
199
+ slider = gr.Slider(0.3,1.0,0.5,label="Confidence Threshold")
200
+
201
+ with gr.Tabs():
202
+ with gr.Tab("Image URL"):
203
+ url = gr.Textbox(label="Image URL")
204
+ out1 = gr.Image()
205
+ btn1 = gr.Button("Detect")
206
+
207
+ with gr.Tab("Upload"):
208
+ img = gr.Image(type="pil")
209
+ out2 = gr.Image()
210
+ btn2 = gr.Button("Detect")
211
+
212
+ with gr.Tab("Webcam"):
213
+ cam = gr.Image(source="webcam", type="pil")
214
+ out3 = gr.Image()
215
+ btn3 = gr.Button("Detect")
216
+
217
+ btn1.click(detect, [url, img, cam, slider], [out1, result_box])
218
+ btn2.click(detect, [url, img, cam, slider], [out2, result_box])
219
+ btn3.click(detect, [url, img, cam, slider], [out3, result_box])
220
+
221
+ gr.Markdown("### 📊 Dashboard")
222
+ gr.Plot(get_dashboard)
223
+
224
+ demo.launch()