|
|
import io |
|
|
import os |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
import requests |
|
|
import torch |
|
|
import numpy as np |
|
|
import sqlite3 |
|
|
import pandas as pd |
|
|
from urllib.parse import urlparse |
|
|
from PIL import Image |
|
|
from transformers import YolosImageProcessor, YolosForObjectDetection |
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
|
|
MODEL_NAME = "nickmuchi/yolos-small-finetuned-license-plate-detection" |
|
|
BASE_AMT = 100 |
|
|
|
|
|
|
|
|
|
|
|
def compute_discount(vehicle_type): |
|
|
if vehicle_type == "EV": |
|
|
return BASE_AMT * 0.9, "10% discount applied (EV)" |
|
|
return BASE_AMT, "No discount" |
|
|
|
|
|
|
|
|
|
|
|
def is_valid_url(url): |
|
|
try: |
|
|
result = urlparse(url) |
|
|
return all([result.scheme, result.netloc]) |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_original_image(url_input): |
|
|
if url_input and is_valid_url(url_input): |
|
|
image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB") |
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect("vehicles.db", check_same_thread=False) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS vehicles ( |
|
|
plate TEXT, |
|
|
type TEXT, |
|
|
amount REAL, |
|
|
time TEXT |
|
|
) |
|
|
""") |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = None |
|
|
model = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
global processor, model |
|
|
if processor is None or model is None: |
|
|
processor = YolosImageProcessor.from_pretrained(MODEL_NAME) |
|
|
model = YolosForObjectDetection.from_pretrained( |
|
|
MODEL_NAME, |
|
|
use_safetensors=True, |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
model.eval() |
|
|
return processor, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_plate_color(plate_img): |
|
|
img = np.array(plate_img) |
|
|
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
|
|
|
green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))) |
|
|
yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255))) |
|
|
white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255))) |
|
|
|
|
|
if green > yellow and green > white: |
|
|
return "EV" |
|
|
elif yellow > green and yellow > white: |
|
|
return "Commercial" |
|
|
else: |
|
|
return "Personal" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dashboard(): |
|
|
df = pd.read_sql("SELECT * FROM vehicles", conn) |
|
|
fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
|
|
|
if len(df) == 0: |
|
|
ax.text(0.5, 0.5, "No vehicles scanned yet", |
|
|
ha="center", va="center", fontsize=12) |
|
|
ax.axis("off") |
|
|
return fig |
|
|
|
|
|
counts = df["type"].value_counts() |
|
|
counts.plot(kind="bar", ax=ax, color="steelblue") |
|
|
|
|
|
ax.set_title("Vehicle Classification Dashboard", fontsize=12) |
|
|
ax.set_xlabel("Vehicle Type", fontsize=10) |
|
|
ax.set_ylabel("Count", fontsize=10) |
|
|
|
|
|
ax.set_xticks(range(len(counts.index))) |
|
|
ax.set_xticklabels(counts.index, rotation=0, ha="center") |
|
|
ax.grid(axis="y", linestyle="--", alpha=0.6) |
|
|
|
|
|
for i, v in enumerate(counts.values): |
|
|
ax.text(i, v + 0.05, str(v), ha="center", va="bottom", fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_prediction(img): |
|
|
processor, model = load_model() |
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
img_size = torch.tensor([tuple(reversed(img.size))]) |
|
|
processed_outputs = processor.post_process_object_detection( |
|
|
outputs, threshold=0.3, target_sizes=img_size |
|
|
) |
|
|
return processed_outputs[0] |
|
|
|
|
|
|
|
|
def fig2img(fig): |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf) |
|
|
buf.seek(0) |
|
|
pil_img = Image.open(buf) |
|
|
|
|
|
basewidth = 750 |
|
|
wpercent = (basewidth / float(pil_img.size[0])) |
|
|
hsize = int((float(pil_img.size[1]) * float(wpercent))) |
|
|
img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS) |
|
|
|
|
|
plt.close(fig) |
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
def read_plate(plate_img): |
|
|
results = reader.readtext(np.array(plate_img)) |
|
|
if results: |
|
|
return results[0][1] |
|
|
return "UNKNOWN" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_prediction(img, output_dict, threshold=0.5, id2label=None): |
|
|
keep = output_dict["scores"] > threshold |
|
|
boxes = output_dict["boxes"][keep].tolist() |
|
|
scores = output_dict["scores"][keep].tolist() |
|
|
labels = output_dict["labels"][keep].tolist() |
|
|
|
|
|
if id2label is not None: |
|
|
labels = [id2label[x] for x in labels] |
|
|
|
|
|
plt.figure(figsize=(20, 20)) |
|
|
plt.imshow(img) |
|
|
ax = plt.gca() |
|
|
|
|
|
result_lines = [] |
|
|
|
|
|
for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels): |
|
|
if "plate" in label.lower(): |
|
|
plate_img = img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
|
|
|
|
|
plate_text = read_plate(plate_img) |
|
|
vehicle_type = classify_plate_color(plate_img) |
|
|
toll, discount_msg = compute_discount(vehicle_type) |
|
|
|
|
|
cursor.execute( |
|
|
"INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))", |
|
|
(plate_text, vehicle_type, toll) |
|
|
) |
|
|
conn.commit() |
|
|
|
|
|
result_lines.append( |
|
|
f"License: {plate_text} | Type: {vehicle_type} | Toll: ₹{int(toll)} | {discount_msg}" |
|
|
) |
|
|
|
|
|
ax.add_patch( |
|
|
plt.Rectangle( |
|
|
(xmin, ymin), xmax - xmin, ymax - ymin, |
|
|
fill=False, color="red", linewidth=3 |
|
|
) |
|
|
) |
|
|
|
|
|
ax.text( |
|
|
xmin, ymin - 10, |
|
|
f"{plate_text} | {vehicle_type} | ₹{int(toll)}", |
|
|
fontsize=12, |
|
|
bbox=dict(facecolor="yellow", alpha=0.8) |
|
|
) |
|
|
|
|
|
plt.axis("off") |
|
|
final_img = fig2img(plt.gcf()) |
|
|
|
|
|
if result_lines: |
|
|
result_text = "\n".join(result_lines) |
|
|
else: |
|
|
result_text = "No license plate detected." |
|
|
|
|
|
return final_img, result_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_objects_image(url_input, image_input, webcam_input, threshold): |
|
|
|
|
|
if url_input and is_valid_url(url_input): |
|
|
image = get_original_image(url_input) |
|
|
elif image_input is not None: |
|
|
image = image_input |
|
|
elif webcam_input is not None: |
|
|
image = webcam_input |
|
|
else: |
|
|
return None, "No image provided." |
|
|
|
|
|
processed_outputs = make_prediction(image) |
|
|
|
|
|
viz_img, result_text = visualize_prediction( |
|
|
image, |
|
|
processed_outputs, |
|
|
threshold, |
|
|
load_model()[1].config.id2label |
|
|
) |
|
|
|
|
|
return viz_img, result_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "<h1>🚦 Smart Vehicle Classification</h1>" |
|
|
description = """ |
|
|
Detect license plates using YOLOS. |
|
|
Features: |
|
|
- Image URL, Image Upload, Webcam |
|
|
- Vehicle type classification by plate color |
|
|
- EV vehicles get 10% discount on Toll / Parking |
|
|
""" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
gr.Markdown(title) |
|
|
gr.Markdown(description) |
|
|
|
|
|
result_box = gr.Textbox(label="Detection Result", lines=5) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("Image URL"): |
|
|
with gr.Row(): |
|
|
url_input = gr.Textbox(lines=2, label="Enter Image URL") |
|
|
original_image = gr.Image(height=200) |
|
|
url_input.change(get_original_image, url_input, original_image) |
|
|
|
|
|
img_output_from_url = gr.Image(height=200) |
|
|
|
|
|
dashboard_output_url = gr.Plot() |
|
|
url_but = gr.Button("Detect") |
|
|
|
|
|
with gr.TabItem("Image Upload"): |
|
|
with gr.Row(): |
|
|
img_input = gr.Image(type="pil", height=200) |
|
|
img_output_from_upload = gr.Image(height=200) |
|
|
|
|
|
dashboard_output_upload = gr.Plot() |
|
|
img_but = gr.Button("Detect") |
|
|
|
|
|
with gr.TabItem("Webcam"): |
|
|
with gr.Row(): |
|
|
web_input = gr.Image( |
|
|
sources=["webcam"], |
|
|
type="pil", |
|
|
height=200, |
|
|
streaming=True |
|
|
) |
|
|
img_output_from_webcam = gr.Image(height=200) |
|
|
|
|
|
dashboard_output_webcam = gr.Plot() |
|
|
cam_but = gr.Button("Detect") |
|
|
|
|
|
slider_input = gr.Slider(0.2, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
|
|
|
|
|
url_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_url, result_box], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
img_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_upload, result_box], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
cam_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_webcam, result_box], |
|
|
queue=True |
|
|
) |
|
|
|
|
|
url_but.click(get_dashboard, outputs=dashboard_output_url) |
|
|
img_but.click(get_dashboard, outputs=dashboard_output_upload) |
|
|
cam_but.click(get_dashboard, outputs=dashboard_output_webcam) |
|
|
|
|
|
demo.queue() |
|
|
demo.launch(debug=True, ssr_mode=False) |
|
|
|