Sarvamangalak's picture
Rename app.py to app_video.py
fa051fe verified
# app.py (Clean Final Version for HF Spaces)
import io
import os
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import requests
import torch
import numpy as np
from urllib.parse import urlparse
from PIL import Image
from transformers import YolosImageProcessor, YolosForObjectDetection
import easyocr
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# ---------------- Globals (lazy loaded) ----------------
processor = None
model = None
reader = easyocr.Reader(["en"], gpu=False)
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
]
# ---------------- Utilities ----------------
def is_valid_url(url):
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except Exception:
return False
def get_original_image(url_input):
if url_input and is_valid_url(url_input):
image = Image.open(requests.get(url_input, stream=True).raw).convert("RGB")
return image
return None
# ---------------- Model Loader ----------------
def load_model():
global processor, model
if processor is None or model is None:
processor = YolosImageProcessor.from_pretrained(
"nickmuchi/yolos-small-finetuned-license-plate-detection"
)
model = YolosForObjectDetection.from_pretrained(
"nickmuchi/yolos-small-finetuned-license-plate-detection",
use_safetensors=True,
torch_dtype=torch.float32,
)
model.eval()
return processor, model
# ---------------- Plate Color Classifier ----------------
def classify_plate_color(plate_img):
img = np.array(plate_img)
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
green = np.sum(cv2.inRange(hsv, (35, 40, 40), (85, 255, 255)))
yellow = np.sum(cv2.inRange(hsv, (15, 50, 50), (35, 255, 255)))
white = np.sum(cv2.inRange(hsv, (0, 0, 200), (180, 30, 255)))
if green > yellow and green > white:
return "EV"
elif yellow > green and yellow > white:
return "Commercial"
else:
return "Personal"
# ---------------- OCR ----------------
def read_plate(plate_img):
results = reader.readtext(np.array(plate_img))
if results:
return results[0][1]
return "UNKNOWN"
# ---------------- Core Inference ----------------
def make_prediction(img):
processor, model = load_model()
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
img_size = torch.tensor([tuple(reversed(img.size))])
processed_outputs = processor.post_process_object_detection(
outputs, threshold=0.0, target_sizes=img_size
)
return processed_outputs[0]
# ---------------- Visualization ----------------
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
pil_img = Image.open(buf)
basewidth = 750
wpercent = basewidth / float(pil_img.size[0])
hsize = int(float(pil_img.size[1]) * float(wpercent))
img = pil_img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
plt.close(fig)
return img
def visualize_prediction(img, output_dict, threshold=0.5, id2label=None):
BASE_TOLL = 100 # base amount for all vehicles
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
plt.figure(figsize=(20, 20))
plt.imshow(img)
ax = plt.gca()
colors = COLORS * 100
for score, (xmin, ymin, xmax, ymax), label, color in zip(
scores, boxes, labels, colors
):
if "plate" in label.lower():
crop = img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
plate_type = classify_plate_color(crop)
# Apply 10% discount for EV vehicles
if plate_type == "EV":
discounted_amount = BASE_TOLL * 0.9
price_text = f"EV | ₹{discounted_amount:.0f} (10% off)"
else:
price_text = f"{plate_type} | ₹{BASE_TOLL}"
ax.add_patch(
plt.Rectangle(
(xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=color, linewidth=4
)
)
ax.text(
xmin, ymin - 10,
f"{price_text} | {score:0.2f}",
fontsize=12,
bbox=dict(facecolor="yellow", alpha=0.8),
)
plt.axis("off")
return fig2img(plt.gcf())
# ---------------- Image Detection ----------------
def detect_objects_image(url_input, image_input, webcam_input, threshold):
if url_input and is_valid_url(url_input):
image = get_original_image(url_input)
elif image_input is not None:
image = image_input
elif webcam_input is not None:
image = webcam_input
else:
return None
processed_outputs = make_prediction(image)
viz_img = visualize_prediction(
image, processed_outputs, threshold, load_model()[1].config.id2label
)
return viz_img
# ---------------- Video Detection ----------------
def detect_objects_video(video_input, threshold):
if video_input is None:
return None
processor, model = load_model()
cap = cv2.VideoCapture(video_input)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
output_path = "/tmp/output_detected.mp4"
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(rgb_frame)
processed_outputs = make_prediction(pil_img)
keep = processed_outputs["scores"] > threshold
boxes = processed_outputs["boxes"][keep].tolist()
scores = processed_outputs["scores"][keep].tolist()
labels = processed_outputs["labels"][keep].tolist()
labels = [model.config.id2label[x] for x in labels]
for score, (xmin, ymin, xmax, ymax), label in zip(scores, boxes, labels):
if "plate" in label.lower():
crop = pil_img.crop((int(xmin), int(ymin), int(xmax), int(ymax)))
plate_type = classify_plate_color(crop)
cv2.rectangle(
frame,
(int(xmin), int(ymin)),
(int(xmax), int(ymax)),
(0, 255, 0),
2,
)
cv2.putText(
frame,
f"{plate_type} | {score:.2f}",
(int(xmin), int(ymin) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 0),
2,
)
out.write(frame)
cap.release()
out.release()
return output_path
# ---------------- UI ----------------
title = """<h1 id="title">Smart Vehicle Classification (Image + Video)</h1>"""
description = """
Smart Vehicle Classification system to Promote EV by applying discount on Toll,
Tax, parking.
Supports:Image URL, Image Upload, Webcam, Video Upload,Vehicle type classification by plate color
"""
css = """
h1#title { text-align: center; }
"""
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
slider_input = gr.Slider(
minimum=0.2, maximum=1, value=0.5, step=0.1, label="Prediction Threshold"
)
with gr.Tabs():
with gr.TabItem("Image URL"):
with gr.Row():
url_input = gr.Textbox(lines=2, label="Enter valid image URL here..")
original_image = gr.Image(height=750, width=750)
url_input.change(get_original_image, url_input, original_image)
img_output_from_url = gr.Image(height=750, width=750)
url_but = gr.Button("Detect")
with gr.TabItem("Image Upload"):
with gr.Row():
img_input = gr.Image(type="pil", height=750, width=750)
img_output_from_upload = gr.Image(height=750, width=750)
img_but = gr.Button("Detect")
with gr.TabItem("WebCam"):
with gr.Row():
web_input = gr.Image(
sources=["webcam"], type="pil", height=750, width=750, streaming=True
)
img_output_from_webcam = gr.Image(height=750, width=750)
cam_but = gr.Button("Detect")
with gr.TabItem("Video Upload"):
with gr.Row():
video_input = gr.Video(label="Upload Video")
video_output = gr.Video(label="Detected Video")
vid_but = gr.Button("Detect Video")
url_but.click(
detect_objects_image,
inputs=[url_input, img_input, web_input, slider_input],
outputs=[img_output_from_url],
queue=True,
)
img_but.click(
detect_objects_image,
inputs=[url_input, img_input, web_input, slider_input],
outputs=[img_output_from_upload],
queue=True,
)
cam_but.click(
detect_objects_image,
inputs=[url_input, img_input, web_input, slider_input],
outputs=[img_output_from_webcam],
queue=True,
)
vid_but.click(
detect_objects_video,
inputs=[video_input, slider_input],
outputs=[video_output],
queue=True,
)
demo.queue()
demo.launch(debug=True, ssr_mode=False)