|
|
|
|
|
|
|
|
import io |
|
|
import os |
|
|
import cv2 |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
import requests |
|
|
import torch |
|
|
import numpy as np |
|
|
from urllib.parse import urlparse |
|
|
from PIL import Image |
|
|
from transformers import YolosImageProcessor, YolosForObjectDetection |
|
|
import easyocr |
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
|
|
|
|
|
processor = None |
|
|
model = None |
|
|
reader = easyocr.Reader(["en"], gpu=False) |
|
|
|
|
|
COLORS = [ |
|
|
[0.000, 0.447, 0.741], |
|
|
[0.850, 0.325, 0.098], |
|
|
[0.929, 0.694, 0.125], |
|
|
[0.494, 0.184, 0.556], |
|
|
[0.466, 0.674, 0.188], |
|
|
[0.301, 0.745, 0.933], |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def is_valid_url(url): |
|
|
try: |
|
|
result = urlparse(url) |
|
|
return all([result.scheme, result.netloc]) |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_original_image(url_input): |
|
|
if url_input and is_valid_url(url_input): |
|
|
image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB") |
|
|
return image |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
global processor, model |
|
|
if processor is None or model is None: |
|
|
processor = YolosImageProcessor.from_pretrained( |
|
|
"nickmuchi/yolos-small-finetuned-license-plate-detection" |
|
|
) |
|
|
model = YolosForObjectDetection.from_pretrained( |
|
|
"nickmuchi/yolos-small-finetuned-license-plate-detection", |
|
|
use_safetensors=True, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
model.eval() |
|
|
return processor, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_plate_color(plate_img): |
|
|
img = np.array(plate_img) |
|
|
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
|
|
|
green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))) |
|
|
yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255))) |
|
|
white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255))) |
|
|
|
|
|
if green > yellow and green > white: |
|
|
return "EV" |
|
|
elif yellow > green and yellow > white: |
|
|
return "Commercial" |
|
|
else: |
|
|
return "Personal" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_plate(plate_img): |
|
|
results = reader.readtext(np.array(plate_img)) |
|
|
if results: |
|
|
return results[0][1] |
|
|
return "UNKNOWN" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_prediction(img): |
|
|
processor, model = load_model() |
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
img_size = torch.tensor([tuple(reversed(img.size))]) |
|
|
processed_outputs = processor.post_process_object_detection( |
|
|
outputs, threshold=0.0, target_sizes=img_size |
|
|
) |
|
|
return processed_outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fig2img(fig): |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf) |
|
|
buf.seek(0) |
|
|
pil_img = Image.open(buf) |
|
|
|
|
|
basewidth = 750 |
|
|
wpercent = basewidth / float(pil_img.size[0]) |
|
|
hsize = int(float(pil_img.size[1]) * float(wpercent)) |
|
|
img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS) |
|
|
|
|
|
plt.close(fig) |
|
|
return img |
|
|
|
|
|
|
|
|
def visualize_prediction(img, output_dict, threshold=0.5, id2label=None): |
|
|
BASE_TOLL = 100 |
|
|
|
|
|
keep = output_dict["scores"] > threshold |
|
|
boxes = output_dict["boxes"][keep].tolist() |
|
|
scores = output_dict["scores"][keep].tolist() |
|
|
labels = output_dict["labels"][keep].tolist() |
|
|
|
|
|
if id2label is not None: |
|
|
labels = [id2label[x] for x in labels] |
|
|
|
|
|
plt.figure(figsize=(20, 20)) |
|
|
plt.imshow(img) |
|
|
ax = plt.gca() |
|
|
colors = COLORS * 100 |
|
|
|
|
|
for score, (xmin, ymin, xmax, ymax), label, color in zip( |
|
|
scores, boxes, labels, colors |
|
|
): |
|
|
if "plate" in label.lower(): |
|
|
crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
|
|
plate_type = classify_plate_color(crop) |
|
|
|
|
|
if plate_type == "EV": |
|
|
discounted_amount = BASE_TOLL * 0.9 |
|
|
price_text = f"EV | ₹{discounted_amount:.0f} (10% off)" |
|
|
else: |
|
|
price_text = f"{plate_type} | ₹{BASE_TOLL}" |
|
|
|
|
|
ax.add_patch( |
|
|
plt.Rectangle( |
|
|
(xmin, ymin), xmax - xmin, ymax - ymin, |
|
|
fill=False, color=color, linewidth=4 |
|
|
) |
|
|
) |
|
|
ax.text( |
|
|
xmin, ymin - 10, |
|
|
f"{price_text} | {score:0.2f}", |
|
|
fontsize=12, |
|
|
bbox=dict(facecolor="yellow", alpha=0.8), |
|
|
) |
|
|
|
|
|
plt.axis("off") |
|
|
return fig2img(plt.gcf()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_objects_image(url_input, image_input, webcam_input, threshold): |
|
|
if url_input and is_valid_url(url_input): |
|
|
image = get_original_image(url_input) |
|
|
elif image_input is not None: |
|
|
image = image_input |
|
|
elif webcam_input is not None: |
|
|
image = webcam_input |
|
|
else: |
|
|
return None |
|
|
|
|
|
processed_outputs = make_prediction(image) |
|
|
viz_img = visualize_prediction( |
|
|
image, processed_outputs, threshold, load_model()[1].config.id2label |
|
|
) |
|
|
|
|
|
return viz_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_objects_video(video_input, threshold): |
|
|
if video_input is None: |
|
|
return None |
|
|
|
|
|
processor, model = load_model() |
|
|
|
|
|
cap = cv2.VideoCapture(video_input) |
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
|
|
output_path = "/tmp/output_detected.mp4" |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
pil_img = Image.fromarray(rgb_frame) |
|
|
|
|
|
processed_outputs = make_prediction(pil_img) |
|
|
|
|
|
keep = processed_outputs["scores"] > threshold |
|
|
boxes = processed_outputs["boxes"][keep].tolist() |
|
|
scores = processed_outputs["scores"][keep].tolist() |
|
|
labels = processed_outputs["labels"][keep].tolist() |
|
|
|
|
|
labels = [model.config.id2label[x] for x in labels] |
|
|
|
|
|
for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels): |
|
|
if "plate" in label.lower(): |
|
|
crop = pil_img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
|
|
plate_type = classify_plate_color(crop) |
|
|
|
|
|
cv2.rectangle( |
|
|
frame, |
|
|
(int(xmin), int(ymin)), |
|
|
(int(xmax), int(ymax)), |
|
|
(0, 255, 0), |
|
|
2, |
|
|
) |
|
|
cv2.putText( |
|
|
frame, |
|
|
f"{plate_type} | {score:.2f}", |
|
|
(int(xmin), int(ymin) - 10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
|
0.6, |
|
|
(0, 255, 0), |
|
|
2, |
|
|
) |
|
|
|
|
|
out.write(frame) |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = """<h1 id="title">Smart Vehicle Classification (Image + Video)</h1>""" |
|
|
|
|
|
description = """ |
|
|
Smart Vehicle Classification system to Promote EV by applying discount on Toll, |
|
|
Tax, parking. |
|
|
Supports:Image URL, Image Upload, Webcam, Video Upload,Vehicle type classification by plate color |
|
|
""" |
|
|
|
|
|
css = """ |
|
|
h1#title { text-align: center; } |
|
|
""" |
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
|
|
with demo: |
|
|
gr.Markdown(title) |
|
|
gr.Markdown(description) |
|
|
|
|
|
slider_input = gr.Slider( |
|
|
minimum=0.2, maximum=1, value=0.5, step=0.1, label="Prediction Threshold" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Image URL"): |
|
|
with gr.Row(): |
|
|
url_input = gr.Textbox(lines=2, label="Enter valid image URL here..") |
|
|
original_image = gr.Image(height=750, width=750) |
|
|
url_input.change(get_original_image, url_input, original_image) |
|
|
img_output_from_url = gr.Image(height=750, width=750) |
|
|
url_but = gr.Button("Detect") |
|
|
|
|
|
with gr.TabItem("Image Upload"): |
|
|
with gr.Row(): |
|
|
img_input = gr.Image(type="pil", height=750, width=750) |
|
|
img_output_from_upload = gr.Image(height=750, width=750) |
|
|
img_but = gr.Button("Detect") |
|
|
|
|
|
with gr.TabItem("WebCam"): |
|
|
with gr.Row(): |
|
|
web_input = gr.Image( |
|
|
sources=["webcam"], type="pil", height=750, width=750, streaming=True |
|
|
) |
|
|
img_output_from_webcam = gr.Image(height=750, width=750) |
|
|
cam_but = gr.Button("Detect") |
|
|
|
|
|
with gr.TabItem("Video Upload"): |
|
|
with gr.Row(): |
|
|
video_input = gr.Video(label="Upload Video") |
|
|
video_output = gr.Video(label="Detected Video") |
|
|
vid_but = gr.Button("Detect Video") |
|
|
|
|
|
url_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_url], |
|
|
queue=True, |
|
|
) |
|
|
|
|
|
img_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_upload], |
|
|
queue=True, |
|
|
) |
|
|
|
|
|
cam_but.click( |
|
|
detect_objects_image, |
|
|
inputs=[url_input, img_input, web_input, slider_input], |
|
|
outputs=[img_output_from_webcam], |
|
|
queue=True, |
|
|
) |
|
|
|
|
|
vid_but.click( |
|
|
detect_objects_video, |
|
|
inputs=[video_input, slider_input], |
|
|
outputs=[video_output], |
|
|
queue=True, |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue() |
|
|
demo.launch(debug=True, ssr_mode=False) |
|
|
|