| |
| import io |
| import os |
| import cv2 |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import requests |
| import torch |
| import pathlib |
| import numpy as np |
| import sqlite3 |
| import pandas as pd |
| from urllib.parse import urlparse |
| from PIL import Image |
| from transformers import YolosImageProcessor, YolosForObjectDetection |
|
|
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
| base_amt = 100 |
| |
| def compute_discount(vehicle_type): |
| if vehicle_type == "EV": |
| return base_amt * 0.9, "10% discount applied (EV)" |
| return toll_parking_amt, "No discount" |
| |
| COLORS = [ |
| [0.000, 0.447, 0.741], |
| [0.850, 0.325, 0.098], |
| [0.929, 0.694, 0.125], |
| [0.494, 0.184, 0.556], |
| [0.466, 0.674, 0.188], |
| [0.301, 0.745, 0.933] |
| ] |
|
|
| |
|
|
| def is_valid_url(url): |
| try: |
| result = urlparse(url) |
| return all([result.scheme, result.netloc]) |
| except Exception: |
| return False |
|
|
|
|
| def get_original_image(url_input): |
| if url_input and is_valid_url(url_input): |
| image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB") |
| return image |
|
|
|
|
| |
| conn = sqlite3.connect("vehicles.db", check_same_thread=False) |
| cursor = conn.cursor() |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS vehicles ( |
| plate TEXT, |
| type TEXT, |
| amount REAL, |
| time TEXT |
| ) |
| """) |
| conn.commit() |
|
|
| |
| processor = None |
| model = None |
|
|
|
|
| def load_model(): |
| global processor, model |
| if processor is None or model is None: |
| processor = YolosImageProcessor.from_pretrained( |
| "nickmuchi/yolos-small-finetuned-license-plate-detection" |
| ) |
| model = YolosForObjectDetection.from_pretrained( |
| "nickmuchi/yolos-small-finetuned-license-plate-detection", |
| use_safetensors=True, |
| torch_dtype=torch.float32 |
| ) |
| model.eval() |
| return processor, model |
|
|
|
|
| |
|
|
| def classify_plate_color(plate_img): |
| img = np.array(plate_img) |
| hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
| green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))) |
| yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255))) |
| white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255))) |
|
|
| if green > yellow and green > white: |
| return "EV" |
| elif yellow > green and yellow > white: |
| return "Commercial" |
| else: |
| return "Personal" |
|
|
|
|
| |
|
|
| def get_dashboard(): |
| df = pd.read_sql("SELECT * FROM vehicles", conn) |
|
|
| fig, ax = plt.subplots(figsize=(7, 5)) |
|
|
| if len(df) == 0: |
| ax.text(0.5, 0.5, "No vehicles scanned yet", |
| ha="center", va="center", fontsize=12) |
| ax.axis("off") |
| return fig |
|
|
| counts = df["type"].value_counts() |
|
|
| |
| counts.plot(kind="bar", ax=ax, color="steelblue") |
|
|
| ax.set_title("Vehicle Classification Dashboard", fontsize=12) |
| ax.set_xlabel("Vehicle Type", fontsize=10) |
| ax.set_ylabel("Count", fontsize=10) |
|
|
| |
| ax.set_xticks(range(len(counts.index))) |
| ax.set_xticklabels(counts.index, rotation=0, ha="center") |
|
|
| ax.grid(axis="y", linestyle="--", alpha=0.6) |
|
|
| |
| for i, v in enumerate(counts.values): |
| ax.text(i, v + 0.05, str(v), ha="center", va="bottom", fontsize=10) |
|
|
| plt.tight_layout() |
| return fig |
|
|
|
|
| |
|
|
| def make_prediction(img): |
| processor, model = load_model() |
| inputs = processor(images=img, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| img_size = torch.tensor([tuple(reversed(img.size))]) |
| processed_outputs = processor.post_process_object_detection( |
| outputs, threshold=0.0, target_sizes=img_size |
| ) |
| return processed_outputs[0] |
|
|
|
|
| def fig2img(fig): |
| buf = io.BytesIO() |
| fig.savefig(buf) |
| buf.seek(0) |
| pil_img = Image.open(buf) |
|
|
| basewidth = 750 |
| wpercent = (basewidth / float(pil_img.size[0])) |
| hsize = int((float(pil_img.size[1]) * float(wpercent))) |
| img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS) |
|
|
| plt.close(fig) |
| return img |
|
|
|
|
| |
|
|
| def visualize_prediction(img, output_dict, threshold=0.5, id2label=None): |
| keep = output_dict["scores"] > threshold |
| boxes = output_dict["boxes"][keep].tolist() |
| scores = output_dict["scores"][keep].tolist() |
| labels = output_dict["labels"][keep].tolist() |
|
|
| if id2label is not None: |
| labels = [id2label[x] for x in labels] |
|
|
| plt.figure(figsize=(20, 20)) |
| plt.imshow(img) |
| ax = plt.gca() |
| colors = COLORS * 100 |
|
|
| result_lines = [] |
|
|
| for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors): |
| if "plate" in label.lower(): |
| crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
|
|
| plate_text = read_plate(crop) |
| vehicle_type = classify_plate_color(crop) |
| toll, discount_msg = compute_discount(vehicle_type) |
|
|
| result_lines.append( |
| f"License: {plate_text} | Type: {vehicle_type} | Toll: ₹{int(toll)} | {discount_msg}" |
| ) |
|
|
| ax.add_patch( |
| plt.Rectangle( |
| (xmin, ymin), xmax - xmin, ymax - ymin, |
| fill=False, color=color, linewidth=4 |
| ) |
| ) |
|
|
| ax.text( |
| xmin, ymin - 10, |
| f"{plate_text} | {vehicle_type} | ₹{int(toll)}", |
| fontsize=12, |
| bbox=dict(facecolor="yellow", alpha=0.8) |
| ) |
|
|
| plt.axis("off") |
| final_img = fig2img(plt.gcf()) |
|
|
| if result_lines: |
| result_text = "\n".join(result_lines) |
| else: |
| result_text = "No license plate detected." |
|
|
| return final_img, result_text |
|
|
| |
|
|
| def detect_objects_image(model_name, url_input, image_input, webcam_input, threshold): |
| processor, model = load_model(model_name) |
|
|
| if url_input and is_valid_url(url_input): |
| image = get_original_image(url_input) |
| elif image_input is not None: |
| image = image_input |
| elif webcam_input is not None: |
| image = webcam_input |
| else: |
| return None, "No image provided." |
|
|
| processed_outputs = make_prediction(image, processor, model) |
|
|
| viz_img, result_text = visualize_prediction( |
| image, processed_outputs, threshold, model.config.id2label |
| ) |
|
|
| return viz_img, result_text |
|
|
|
|
| |
|
|
| title = """<h1 id="title">Smart Vehicle classification</h1>""" |
|
|
| description = """ |
| Detect license plates using YOLOS. |
| Features: |
| - Image URL, Image Upload, Webcam,Vehicle type classification by plate color |
| - EV vehicles get 10% discount on Tolls, Tax, parking |
| """ |
| result_box = gr.Textbox( |
| label="Detection Result", |
| lines=5, |
| interactive=False |
| ) |
| demo = gr.Blocks() |
| |
| with demo: |
| debug=False, |
| share=False, |
| ssr_mode=False |
| gr.Markdown(title) |
| gr.Markdown(description) |
| options = gr.Dropdown( |
| choices=model, |
| label="Object Detection Model", |
| value=model[0] |
| ) |
|
|
| url_input = gr.Textbox(label="Image URL") |
| img_input = gr.Image(type="pil", label="Upload Image") |
| web_input = gr.Image(source="webcam", type="pil", label="Webcam Input") |
| slider_input = gr.Slider(0, 1, value=0.5, step=0.05, label="Confidence Threshold") |
|
|
| img_output_from_url = gr.Image(label="Detection Output") |
|
|
| detect_btn = gr.Button("Detect") |
|
|
| slider_input = gr.Slider(minimum=0.2, maximum=1, value=0.5, step=0.1, label='Prediction Threshold') |
|
|
| with gr.Tabs(): |
| with gr.TabItem('Image URL'): |
| with gr.Row(): |
| url_input = gr.Textbox(lines=2, label='Enter valid image URL here..') |
| original_image = gr.Image(height=200) |
| url_input.change(get_original_image, url_input, original_image) |
| img_output_from_url = gr.Image(height=200) |
| dashboard_output_url = gr.Plot() |
| url_but = gr.Button('Detect') |
|
|
| with gr.TabItem('Image Upload'): |
| with gr.Row(): |
| img_input = gr.Image(type='pil', height=200) |
| img_output_from_upload = gr.Image(height=200) |
| dashboard_output_upload = gr.Plot() |
| img_but = gr.Button('Detect') |
|
|
| with gr.TabItem('WebCam'): |
| with gr.Row(): |
| web_input = gr.Image( |
| sources=["webcam"], |
| type="pil", |
| height=200, |
| streaming=True |
| ) |
| img_output_from_webcam = gr.Image(height=200) |
| dashboard_output_webcam = gr.Plot() |
| cam_but = gr.Button('Detect') |
|
|
| url_but.click( |
| detect_objects_image, |
| inputs=[options, url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_url], |
| queue=True |
| ) |
|
|
| img_but.click( |
| detect_objects_image, |
| inputs=[options, url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_upload], |
| queue=True |
| ) |
|
|
| cam_but.click( |
| detect_objects_image, |
| inputs=[options, url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_webcam], |
| queue=True |
| ) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| demo.queue() |
| import asyncio |
|
|
| try: |
| asyncio.get_running_loop() |
| except RuntimeError: |
| asyncio.set_event_loop(asyncio.new_event_loop()) |
| demo.launch(debug=True, ssr_mode=False) |
|
|