| import io |
| import os |
| import cv2 |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import requests |
| import torch |
| import numpy as np |
| import sqlite3 |
| import pandas as pd |
| from urllib.parse import urlparse |
| from PIL import Image |
| from transformers import YolosImageProcessor, YolosForObjectDetection |
|
|
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
| MODEL_NAME = "nickmuchi/yolos-small-finetuned-license-plate-detection" |
| BASE_AMT = 100 |
|
|
| |
|
|
| def compute_discount(vehicle_type): |
| if vehicle_type == "EV": |
| return BASE_AMT * 0.9, "10% discount applied (EV)" |
| return BASE_AMT, "No discount" |
|
|
| |
|
|
| def is_valid_url(url): |
| try: |
| result = urlparse(url) |
| return all([result.scheme, result.netloc]) |
| except Exception: |
| return False |
|
|
|
|
| def get_original_image(url_input): |
| if url_input and is_valid_url(url_input): |
| image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB") |
| return image |
|
|
|
|
| |
|
|
| conn = sqlite3.connect("vehicles.db", check_same_thread=False) |
| cursor = conn.cursor() |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS vehicles ( |
| plate TEXT, |
| type TEXT, |
| amount REAL, |
| time TEXT |
| ) |
| """) |
| conn.commit() |
|
|
|
|
| |
|
|
| processor = None |
| model = None |
|
|
|
|
| def load_model(): |
| global processor, model |
| if processor is None or model is None: |
| processor = YolosImageProcessor.from_pretrained(MODEL_NAME) |
| model = YolosForObjectDetection.from_pretrained( |
| MODEL_NAME, |
| use_safetensors=True, |
| torch_dtype=torch.float32 |
| ) |
| model.eval() |
| return processor, model |
|
|
|
|
| |
|
|
| def classify_plate_color(plate_img): |
| img = np.array(plate_img) |
| hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
| green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))) |
| yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255))) |
| white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255))) |
|
|
| if green > yellow and green > white: |
| return "EV" |
| elif yellow > green and yellow > white: |
| return "Commercial" |
| else: |
| return "Personal" |
|
|
|
|
| |
|
|
| def get_dashboard(): |
| df = pd.read_sql("SELECT * FROM vehicles", conn) |
| fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
| if len(df) == 0: |
| ax.text(0.5, 0.5, "No vehicles scanned yet", |
| ha="center", va="center", fontsize=12) |
| ax.axis("off") |
| return fig |
|
|
| counts = df["type"].value_counts() |
| counts.plot(kind="bar", ax=ax, color="steelblue") |
|
|
| ax.set_title("Vehicle Classification Dashboard", fontsize=12) |
| ax.set_xlabel("Vehicle Type", fontsize=10) |
| ax.set_ylabel("Count", fontsize=10) |
|
|
| ax.set_xticks(range(len(counts.index))) |
| ax.set_xticklabels(counts.index, rotation=0, ha="center") |
| ax.grid(axis="y", linestyle="--", alpha=0.6) |
|
|
| for i, v in enumerate(counts.values): |
| ax.text(i, v + 0.05, str(v), ha="center", va="bottom", fontsize=10) |
|
|
| plt.tight_layout() |
| return fig |
|
|
|
|
| |
|
|
| def make_prediction(img): |
| processor, model = load_model() |
| inputs = processor(images=img, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| img_size = torch.tensor([tuple(reversed(img.size))]) |
| processed_outputs = processor.post_process_object_detection( |
| outputs, threshold=0.3, target_sizes=img_size |
| ) |
| return processed_outputs[0] |
|
|
|
|
| def fig2img(fig): |
| buf = io.BytesIO() |
| fig.savefig(buf) |
| buf.seek(0) |
| pil_img = Image.open(buf) |
|
|
| basewidth = 750 |
| wpercent = (basewidth / float(pil_img.size[0])) |
| hsize = int((float(pil_img.size[1]) * float(wpercent))) |
| img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS) |
|
|
| plt.close(fig) |
| return img |
|
|
|
|
| |
| def read_plate(plate_img): |
| results = reader.readtext(np.array(plate_img)) |
| if results: |
| return results[0][1] |
| return "UNKNOWN" |
|
|
|
|
| |
|
|
| def visualize_prediction(img, output_dict, threshold=0.5, id2label=None): |
| keep = output_dict["scores"] > threshold |
| boxes = output_dict["boxes"][keep].tolist() |
| scores = output_dict["scores"][keep].tolist() |
| labels = output_dict["labels"][keep].tolist() |
|
|
| if id2label is not None: |
| labels = [id2label[x] for x in labels] |
|
|
| plt.figure(figsize=(20, 20)) |
| plt.imshow(img) |
| ax = plt.gca() |
|
|
| result_lines = [] |
|
|
| for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels): |
| if "plate" in label.lower(): |
| plate_img = img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
|
|
| plate_text = read_plate(plate_img) |
| vehicle_type = classify_plate_color(plate_img) |
| toll, discount_msg = compute_discount(vehicle_type) |
|
|
| cursor.execute( |
| "INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))", |
| (plate_text, vehicle_type, toll) |
| ) |
| conn.commit() |
|
|
| result_lines.append( |
| f"License: {plate_text} | Type: {vehicle_type} | Toll: ₹{int(toll)} | {discount_msg}" |
| ) |
|
|
| ax.add_patch( |
| plt.Rectangle( |
| (xmin, ymin), xmax - xmin, ymax - ymin, |
| fill=False, color="red", linewidth=3 |
| ) |
| ) |
|
|
| ax.text( |
| xmin, ymin - 10, |
| f"{plate_text} | {vehicle_type} | ₹{int(toll)}", |
| fontsize=12, |
| bbox=dict(facecolor="yellow", alpha=0.8) |
| ) |
|
|
| plt.axis("off") |
| final_img = fig2img(plt.gcf()) |
|
|
| if result_lines: |
| result_text = "\n".join(result_lines) |
| else: |
| result_text = "No license plate detected." |
|
|
| return final_img, result_text |
|
|
|
|
| |
|
|
| def detect_objects_image(url_input, image_input, webcam_input, threshold): |
|
|
| if url_input and is_valid_url(url_input): |
| image = get_original_image(url_input) |
| elif image_input is not None: |
| image = image_input |
| elif webcam_input is not None: |
| image = webcam_input |
| else: |
| return None, "No image provided." |
|
|
| processed_outputs = make_prediction(image) |
|
|
| viz_img, result_text = visualize_prediction( |
| image, |
| processed_outputs, |
| threshold, |
| load_model()[1].config.id2label |
| ) |
|
|
| return viz_img, result_text |
|
|
|
|
| |
|
|
| title = "<h1>🚦 Smart Vehicle Classification</h1>" |
| description = """ |
| Detect license plates using YOLOS. |
| Features: |
| - Image URL, Image Upload, Webcam |
| - Vehicle type classification by plate color |
| - EV vehicles get 10% discount on Toll / Parking |
| """ |
|
|
| with gr.Blocks() as demo: |
|
|
| gr.Markdown(title) |
| gr.Markdown(description) |
|
|
| result_box = gr.Textbox(label="Detection Result", lines=5) |
|
|
| with gr.Tabs(): |
|
|
| with gr.TabItem("Image URL"): |
| with gr.Row(): |
| url_input = gr.Textbox(lines=2, label="Enter Image URL") |
| original_image = gr.Image(height=200) |
| url_input.change(get_original_image, url_input, original_image) |
|
|
| img_output_from_url = gr.Image(height=200) |
|
|
| dashboard_output_url = gr.Plot() |
| url_but = gr.Button("Detect") |
|
|
| with gr.TabItem("Image Upload"): |
| with gr.Row(): |
| img_input = gr.Image(type="pil", height=200) |
| img_output_from_upload = gr.Image(height=200) |
|
|
| dashboard_output_upload = gr.Plot() |
| img_but = gr.Button("Detect") |
|
|
| with gr.TabItem("Webcam"): |
| with gr.Row(): |
| web_input = gr.Image( |
| sources=["webcam"], |
| type="pil", |
| height=200, |
| streaming=True |
| ) |
| img_output_from_webcam = gr.Image(height=200) |
|
|
| dashboard_output_webcam = gr.Plot() |
| cam_but = gr.Button("Detect") |
|
|
| slider_input = gr.Slider(0.2, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
|
|
| url_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_url, result_box], |
| queue=True |
| ) |
|
|
| img_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_upload, result_box], |
| queue=True |
| ) |
|
|
| cam_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_webcam, result_box], |
| queue=True |
| ) |
|
|
| url_but.click(get_dashboard, outputs=dashboard_output_url) |
| img_but.click(get_dashboard, outputs=dashboard_output_upload) |
| cam_but.click(get_dashboard, outputs=dashboard_output_webcam) |
|
|
| demo.queue() |
| demo.launch(debug=True, ssr_mode=False) |
|
|