| |
|
|
| import io |
| import os |
| import cv2 |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| import requests |
| import torch |
| import numpy as np |
| from urllib.parse import urlparse |
| from PIL import Image |
| from transformers import YolosImageProcessor, YolosForObjectDetection |
| import easyocr |
|
|
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
| |
| processor = None |
| model = None |
| reader = easyocr.Reader(["en"], gpu=False) |
|
|
| COLORS = [ |
| [0.000, 0.447, 0.741], |
| [0.850, 0.325, 0.098], |
| [0.929, 0.694, 0.125], |
| [0.494, 0.184, 0.556], |
| [0.466, 0.674, 0.188], |
| [0.301, 0.745, 0.933], |
| ] |
|
|
| |
|
|
| def is_valid_url(url): |
| try: |
| result = urlparse(url) |
| return all([result.scheme, result.netloc]) |
| except Exception: |
| return False |
|
|
|
|
| def get_original_image(url_input): |
| if url_input and is_valid_url(url_input): |
| image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB") |
| return image |
| return None |
|
|
|
|
| |
|
|
| def load_model(): |
| global processor, model |
| if processor is None or model is None: |
| processor = YolosImageProcessor.from_pretrained( |
| "nickmuchi/yolos-small-finetuned-license-plate-detection" |
| ) |
| model = YolosForObjectDetection.from_pretrained( |
| "nickmuchi/yolos-small-finetuned-license-plate-detection", |
| use_safetensors=True, |
| torch_dtype=torch.float32, |
| ) |
| model.eval() |
| return processor, model |
|
|
|
|
| |
|
|
| def classify_plate_color(plate_img): |
| img = np.array(plate_img) |
| hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
|
| green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255))) |
| yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255))) |
| white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255))) |
|
|
| if green > yellow and green > white: |
| return "EV" |
| elif yellow > green and yellow > white: |
| return "Commercial" |
| else: |
| return "Personal" |
|
|
|
|
| |
|
|
| def read_plate(plate_img): |
| results = reader.readtext(np.array(plate_img)) |
| if results: |
| return results[0][1] |
| return "UNKNOWN" |
|
|
|
|
| |
|
|
| def make_prediction(img): |
| processor, model = load_model() |
| inputs = processor(images=img, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| img_size = torch.tensor([tuple(reversed(img.size))]) |
| processed_outputs = processor.post_process_object_detection( |
| outputs, threshold=0.0, target_sizes=img_size |
| ) |
| return processed_outputs[0] |
|
|
|
|
| |
|
|
| def fig2img(fig): |
| buf = io.BytesIO() |
| fig.savefig(buf) |
| buf.seek(0) |
| pil_img = Image.open(buf) |
|
|
| basewidth = 750 |
| wpercent = basewidth / float(pil_img.size[0]) |
| hsize = int(float(pil_img.size[1]) * float(wpercent)) |
| img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS) |
|
|
| plt.close(fig) |
| return img |
|
|
|
|
| def visualize_prediction(img, output_dict, threshold=0.5, id2label=None): |
| BASE_TOLL = 100 |
|
|
| keep = output_dict["scores"] > threshold |
| boxes = output_dict["boxes"][keep].tolist() |
| scores = output_dict["scores"][keep].tolist() |
| labels = output_dict["labels"][keep].tolist() |
|
|
| if id2label is not None: |
| labels = [id2label[x] for x in labels] |
|
|
| plt.figure(figsize=(20, 20)) |
| plt.imshow(img) |
| ax = plt.gca() |
| colors = COLORS * 100 |
|
|
| for score, (xmin, ymin, xmax, ymax), label, color in zip( |
| scores, boxes, labels, colors |
| ): |
| if "plate" in label.lower(): |
| crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
| plate_type = classify_plate_color(crop) |
| |
| if plate_type == "EV": |
| discounted_amount = BASE_TOLL * 0.9 |
| price_text = f"EV | ₹{discounted_amount:.0f} (10% off)" |
| else: |
| price_text = f"{plate_type} | ₹{BASE_TOLL}" |
|
|
| ax.add_patch( |
| plt.Rectangle( |
| (xmin, ymin), xmax - xmin, ymax - ymin, |
| fill=False, color=color, linewidth=4 |
| ) |
| ) |
| ax.text( |
| xmin, ymin - 10, |
| f"{price_text} | {score:0.2f}", |
| fontsize=12, |
| bbox=dict(facecolor="yellow", alpha=0.8), |
| ) |
|
|
| plt.axis("off") |
| return fig2img(plt.gcf()) |
|
|
|
|
| |
|
|
| def detect_objects_image(url_input, image_input, webcam_input, threshold): |
| if url_input and is_valid_url(url_input): |
| image = get_original_image(url_input) |
| elif image_input is not None: |
| image = image_input |
| elif webcam_input is not None: |
| image = webcam_input |
| else: |
| return None |
|
|
| processed_outputs = make_prediction(image) |
| viz_img = visualize_prediction( |
| image, processed_outputs, threshold, load_model()[1].config.id2label |
| ) |
|
|
| return viz_img |
|
|
|
|
| |
|
|
| def detect_objects_video(video_input, threshold): |
| if video_input is None: |
| return None |
|
|
| processor, model = load_model() |
|
|
| cap = cv2.VideoCapture(video_input) |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
| output_path = "/tmp/output_detected.mp4" |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| pil_img = Image.fromarray(rgb_frame) |
|
|
| processed_outputs = make_prediction(pil_img) |
|
|
| keep = processed_outputs["scores"] > threshold |
| boxes = processed_outputs["boxes"][keep].tolist() |
| scores = processed_outputs["scores"][keep].tolist() |
| labels = processed_outputs["labels"][keep].tolist() |
|
|
| labels = [model.config.id2label[x] for x in labels] |
|
|
| for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels): |
| if "plate" in label.lower(): |
| crop = pil_img.crop((int(xmin), int(ymin), int(xmax), int(ymax))) |
| plate_type = classify_plate_color(crop) |
|
|
| cv2.rectangle( |
| frame, |
| (int(xmin), int(ymin)), |
| (int(xmax), int(ymax)), |
| (0, 255, 0), |
| 2, |
| ) |
| cv2.putText( |
| frame, |
| f"{plate_type} | {score:.2f}", |
| (int(xmin), int(ymin) - 10), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.6, |
| (0, 255, 0), |
| 2, |
| ) |
|
|
| out.write(frame) |
|
|
| cap.release() |
| out.release() |
|
|
| return output_path |
|
|
|
|
| |
|
|
| title = """<h1 id="title">Smart Vehicle Classification (Image + Video)</h1>""" |
|
|
| description = """ |
| Smart Vehicle Classification system to Promote EV by applying discount on Toll, |
| Tax, parking. |
| Supports:Image URL, Image Upload, Webcam, Video Upload,Vehicle type classification by plate color |
| """ |
|
|
| css = """ |
| h1#title { text-align: center; } |
| """ |
|
|
| demo = gr.Blocks() |
|
|
| with demo: |
| gr.Markdown(title) |
| gr.Markdown(description) |
|
|
| slider_input = gr.Slider( |
| minimum=0.2, maximum=1, value=0.5, step=0.1, label="Prediction Threshold" |
| ) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Image URL"): |
| with gr.Row(): |
| url_input = gr.Textbox(lines=2, label="Enter valid image URL here..") |
| original_image = gr.Image(height=750, width=750) |
| url_input.change(get_original_image, url_input, original_image) |
| img_output_from_url = gr.Image(height=750, width=750) |
| url_but = gr.Button("Detect") |
|
|
| with gr.TabItem("Image Upload"): |
| with gr.Row(): |
| img_input = gr.Image(type="pil", height=750, width=750) |
| img_output_from_upload = gr.Image(height=750, width=750) |
| img_but = gr.Button("Detect") |
|
|
| with gr.TabItem("WebCam"): |
| with gr.Row(): |
| web_input = gr.Image( |
| sources=["webcam"], type="pil", height=750, width=750, streaming=True |
| ) |
| img_output_from_webcam = gr.Image(height=750, width=750) |
| cam_but = gr.Button("Detect") |
|
|
| with gr.TabItem("Video Upload"): |
| with gr.Row(): |
| video_input = gr.Video(label="Upload Video") |
| video_output = gr.Video(label="Detected Video") |
| vid_but = gr.Button("Detect Video") |
|
|
| url_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_url], |
| queue=True, |
| ) |
|
|
| img_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_upload], |
| queue=True, |
| ) |
|
|
| cam_but.click( |
| detect_objects_image, |
| inputs=[url_input, img_input, web_input, slider_input], |
| outputs=[img_output_from_webcam], |
| queue=True, |
| ) |
|
|
| vid_but.click( |
| detect_objects_video, |
| inputs=[video_input, slider_input], |
| outputs=[video_output], |
| queue=True, |
| ) |
|
|
|
|
| demo.queue() |
| demo.launch(debug=True, ssr_mode=False) |
|
|