Sarvamangalak commited on
Commit
2fdc11a
·
verified ·
1 Parent(s): 7b7476e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -52
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import io
2
- import os
3
  import cv2
4
  import gradio as gr
5
  import matplotlib
@@ -45,10 +44,7 @@ def load_model():
45
  global processor, model
46
  if processor is None or model is None:
47
  processor = YolosImageProcessor.from_pretrained(MODEL_NAME)
48
- model = YolosForObjectDetection.from_pretrained(
49
- MODEL_NAME,
50
- torch_dtype=torch.float32
51
- )
52
  model.eval()
53
  return processor, model
54
 
@@ -62,14 +58,15 @@ def is_valid_url(url):
62
  return False
63
 
64
  def get_original_image(url):
65
- return Image.open(requests.get(url, stream=True).raw).convert("RGB")
 
66
 
67
  # -------------------- DISCOUNT LOGIC --------------------
68
 
69
  def compute_discount(vehicle_type):
70
  if vehicle_type == "EV":
71
- return BASE_AMT * 0.9, "10% EV discount applied"
72
- return BASE_AMT, "No discount"
73
 
74
  # -------------------- PLATE COLOR CLASSIFICATION --------------------
75
 
@@ -86,7 +83,7 @@ def classify_plate_color(plate_img):
86
  return "Commercial"
87
  return "Personal"
88
 
89
- # -------------------- OCR (LIGHTWEIGHT) --------------------
90
 
91
  def read_plate(plate_img):
92
  gray = cv2.cvtColor(np.array(plate_img), cv2.COLOR_RGB2GRAY)
@@ -103,38 +100,40 @@ def read_plate(plate_img):
103
  def make_prediction(img):
104
  processor, model = load_model()
105
  inputs = processor(images=img, return_tensors="pt")
 
106
  with torch.no_grad():
107
  outputs = model(**inputs)
108
 
109
- img_size = torch.tensor([tuple(reversed(img.size))])
110
  results = processor.post_process_object_detection(
111
  outputs, threshold=0.3, target_sizes=img_size
112
  )
113
- return results[0]
 
114
 
115
  # -------------------- VISUALIZATION --------------------
116
 
117
  def fig_to_img(fig):
118
  buf = io.BytesIO()
119
- fig.savefig(buf)
120
  buf.seek(0)
121
  img = Image.open(buf)
122
  plt.close(fig)
123
  return img
124
 
125
- def visualize(img, output, threshold):
126
  keep = output["scores"] > threshold
127
  boxes = output["boxes"][keep]
128
  labels = output["labels"][keep]
129
 
130
- plt.figure(figsize=(10,10))
131
- plt.imshow(img)
132
- ax = plt.gca()
133
 
134
- results = []
135
 
136
  for box, label in zip(boxes, labels):
137
- if "plate" not in load_model()[1].config.id2label[label.item()].lower():
 
138
  continue
139
 
140
  x1,y1,x2,y2 = map(int, box.tolist())
@@ -142,7 +141,7 @@ def visualize(img, output, threshold):
142
 
143
  plate = read_plate(plate_img)
144
  vtype = classify_plate_color(plate_img)
145
- toll, msg = compute_discount(vtype)
146
 
147
  cursor.execute(
148
  "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
@@ -150,15 +149,21 @@ def visualize(img, output, threshold):
150
  )
151
  conn.commit()
152
 
153
- results.append(f"{plate} | {vtype} | ₹{int(toll)}")
154
 
155
  ax.add_patch(
156
- plt.Rectangle((x1,y1), x2-x1, y2-y1, fill=False, color="red", linewidth=2)
 
157
  )
158
- ax.text(x1, y1-5, f"{plate} ({vtype})", color="yellow")
 
 
 
159
 
160
- plt.axis("off")
161
- return fig_to_img(plt.gcf()), "\n".join(results) if results else "No plate detected"
 
 
162
 
163
  # -------------------- DASHBOARD --------------------
164
 
@@ -177,48 +182,57 @@ def get_dashboard():
177
 
178
  # -------------------- MAIN CALLBACK --------------------
179
 
180
- def detect(url, img, cam, threshold):
181
- if url and is_valid_url(url):
182
- image = get_original_image(url)
183
- elif img is not None:
184
- image = img
185
- elif cam is not None:
186
- image = cam
187
- else:
188
- return None, "No input"
189
 
190
- output = make_prediction(image)
191
- return visualize(image, output, threshold)
 
 
 
192
 
193
  # -------------------- UI --------------------
194
 
195
  with gr.Blocks() as demo:
196
- gr.Markdown("## 🚦 Smart Vehicle Classification System")
197
 
 
198
  result_box = gr.Textbox(label="Result", lines=4)
199
- slider = gr.Slider(0.3,1.0,0.5,label="Confidence Threshold")
200
 
201
  with gr.Tabs():
 
202
  with gr.Tab("Image URL"):
203
- url = gr.Textbox(label="Image URL")
204
- out1 = gr.Image()
205
- btn1 = gr.Button("Detect")
 
 
 
206
 
207
  with gr.Tab("Upload"):
208
- img = gr.Image(type="pil")
209
- out2 = gr.Image()
210
- btn2 = gr.Button("Detect")
 
 
 
211
 
212
  with gr.Tab("Webcam"):
213
- cam = gr.Image(source="webcam", type="pil")
214
- out3 = gr.Image()
215
- btn3 = gr.Button("Detect")
216
-
217
- btn1.click(detect, [url, img, cam, slider], [out1, result_box])
218
- btn2.click(detect, [url, img, cam, slider], [out2, result_box])
219
- btn3.click(detect, [url, img, cam, slider], [out3, result_box])
220
-
221
- gr.Markdown("### 📊 Dashboard")
 
 
222
  gr.Plot(get_dashboard)
223
 
224
  demo.launch()
 
1
  import io
 
2
  import cv2
3
  import gradio as gr
4
  import matplotlib
 
44
  global processor, model
45
  if processor is None or model is None:
46
  processor = YolosImageProcessor.from_pretrained(MODEL_NAME)
47
+ model = YolosForObjectDetection.from_pretrained(MODEL_NAME)
 
 
 
48
  model.eval()
49
  return processor, model
50
 
 
58
  return False
59
 
60
  def get_original_image(url):
61
+ response = requests.get(url, stream=True)
62
+ return Image.open(response.raw).convert("RGB")
63
 
64
  # -------------------- DISCOUNT LOGIC --------------------
65
 
66
  def compute_discount(vehicle_type):
67
  if vehicle_type == "EV":
68
+ return BASE_AMT * 0.9
69
+ return BASE_AMT
70
 
71
  # -------------------- PLATE COLOR CLASSIFICATION --------------------
72
 
 
83
  return "Commercial"
84
  return "Personal"
85
 
86
+ # -------------------- OCR --------------------
87
 
88
  def read_plate(plate_img):
89
  gray = cv2.cvtColor(np.array(plate_img), cv2.COLOR_RGB2GRAY)
 
100
  def make_prediction(img):
101
  processor, model = load_model()
102
  inputs = processor(images=img, return_tensors="pt")
103
+
104
  with torch.no_grad():
105
  outputs = model(**inputs)
106
 
107
+ img_size = torch.tensor([img.size[::-1]])
108
  results = processor.post_process_object_detection(
109
  outputs, threshold=0.3, target_sizes=img_size
110
  )
111
+
112
+ return results[0], model.config.id2label
113
 
114
  # -------------------- VISUALIZATION --------------------
115
 
116
  def fig_to_img(fig):
117
  buf = io.BytesIO()
118
+ fig.savefig(buf, bbox_inches="tight")
119
  buf.seek(0)
120
  img = Image.open(buf)
121
  plt.close(fig)
122
  return img
123
 
124
+ def visualize(img, output, id2label, threshold):
125
  keep = output["scores"] > threshold
126
  boxes = output["boxes"][keep]
127
  labels = output["labels"][keep]
128
 
129
+ fig, ax = plt.subplots(figsize=(8,8))
130
+ ax.imshow(img)
 
131
 
132
+ results_text = []
133
 
134
  for box, label in zip(boxes, labels):
135
+ label_name = id2label[label.item()].lower()
136
+ if "plate" not in label_name:
137
  continue
138
 
139
  x1,y1,x2,y2 = map(int, box.tolist())
 
141
 
142
  plate = read_plate(plate_img)
143
  vtype = classify_plate_color(plate_img)
144
+ toll = compute_discount(vtype)
145
 
146
  cursor.execute(
147
  "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
 
149
  )
150
  conn.commit()
151
 
152
+ results_text.append(f"{plate} | {vtype} | ₹{int(toll)}")
153
 
154
  ax.add_patch(
155
+ plt.Rectangle((x1,y1), x2-x1, y2-y1,
156
+ fill=False, color="red", linewidth=2)
157
  )
158
+ ax.text(x1, y1-5, f"{plate} ({vtype})",
159
+ color="yellow", fontsize=10)
160
+
161
+ ax.axis("off")
162
 
163
+ if not results_text:
164
+ return fig_to_img(fig), "No plate detected"
165
+
166
+ return fig_to_img(fig), "\n".join(results_text)
167
 
168
  # -------------------- DASHBOARD --------------------
169
 
 
182
 
183
  # -------------------- MAIN CALLBACK --------------------
184
 
185
+ def detect_from_url(url, threshold):
186
+ if not url or not is_valid_url(url):
187
+ return None, "Invalid URL"
188
+ img = get_original_image(url)
189
+ output, id2label = make_prediction(img)
190
+ return visualize(img, output, id2label, threshold)
 
 
 
191
 
192
+ def detect_from_image(img, threshold):
193
+ if img is None:
194
+ return None, "No image provided"
195
+ output, id2label = make_prediction(img)
196
+ return visualize(img, output, id2label, threshold)
197
 
198
  # -------------------- UI --------------------
199
 
200
  with gr.Blocks() as demo:
201
+ gr.Markdown("## Smart Vehicle Classification System")
202
 
203
+ slider = gr.Slider(0.3, 1.0, 0.5, label="Confidence Threshold")
204
  result_box = gr.Textbox(label="Result", lines=4)
 
205
 
206
  with gr.Tabs():
207
+
208
  with gr.Tab("Image URL"):
209
+ url_input = gr.Textbox(label="Image URL")
210
+ url_output = gr.Image()
211
+ url_btn = gr.Button("Detect")
212
+ url_btn.click(detect_from_url,
213
+ inputs=[url_input, slider],
214
+ outputs=[url_output, result_box])
215
 
216
  with gr.Tab("Upload"):
217
+ img_input = gr.Image(type="pil")
218
+ img_output = gr.Image()
219
+ img_btn = gr.Button("Detect")
220
+ img_btn.click(detect_from_image,
221
+ inputs=[img_input, slider],
222
+ outputs=[img_output, result_box])
223
 
224
  with gr.Tab("Webcam"):
225
+ cam_input = gr.Image(
226
+ sources=["webcam"], # using web camera
227
+ type="pil"
228
+ )
229
+ cam_output = gr.Image()
230
+ cam_btn = gr.Button("Detect")
231
+ cam_btn.click(detect_from_image,
232
+ inputs=[cam_input, slider],
233
+ outputs=[cam_output, result_box])
234
+
235
+ gr.Markdown("###Dashboard")
236
  gr.Plot(get_dashboard)
237
 
238
  demo.launch()