Sarvamangalak's picture
Update app.py
d9bea70 verified
raw
history blame
7.44 kB
import io
import cv2
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import numpy as np
import sqlite3
import pandas as pd
import pytesseract
from PIL import Image
from transformers import YolosImageProcessor, YolosForObjectDetection
# ---------------- CONFIG ----------------
MODEL_NAME = "nickmuchi/yolos-small-finetuned-license-plate-detection"
BASE_AMT = 100
# ---------------- DATABASE ----------------
conn = sqlite3.connect("vehicles.db", check_same_thread=False)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS vehicles (
plate TEXT,
type TEXT,
amount REAL,
time TEXT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS feedback (
result TEXT,
feedback TEXT
)
""")
conn.commit()
# ---------------- MODEL (Lazy Load) ----------------
processor = None
model = None
def load_model():
global processor, model
if processor is None:
processor = YolosImageProcessor.from_pretrained(MODEL_NAME)
model = YolosForObjectDetection.from_pretrained(MODEL_NAME)
model.eval()
return processor, model
# ---------------- LOGIC ----------------
def compute_discount(vehicle_type):
if vehicle_type == "EV":
return BASE_AMT * 0.9
return BASE_AMT
def classify_plate_color(plate_img):
hsv = cv2.cvtColor(plate_img, cv2.COLOR_BGR2HSV)
# Masks for colors
masks = {}
# White
masks["white"] = cv2.inRange(hsv, np.array([0, 0, 180]), np.array([180, 60, 255]))
# Yellow
masks["yellow"] = cv2.inRange(hsv, np.array([15, 80, 80]), np.array([40, 255, 255]))
# Green
masks["green"] = cv2.inRange(hsv, np.array([35, 50, 50]), np.array([85, 255, 255]))
# Red
masks["red1"] = cv2.inRange(hsv, np.array([0, 70, 50]), np.array([10, 255, 255]))
masks["red2"] = cv2.inRange(hsv, np.array([170, 70, 50]), np.array([180, 255, 255]))
masks["red"] = masks["red1"] + masks["red2"]
# Blue
masks["blue"] = cv2.inRange(hsv, np.array([90, 50, 50]), np.array([130, 255, 255]))
# Count pixels
color_counts = {color: np.sum(mask) for color, mask in masks.items()}
dominant_color = max(color_counts, key=color_counts.get)
# Classification logic
if dominant_color == "white":
return "Private Vehicle"
elif dominant_color == "yellow":
return "Commercial Vehicle"
elif dominant_color == "green":
return "Electric Vehicle (EV)"
elif dominant_color == "red":
return "Temporary Registration Vehicle"
elif dominant_color == "blue":
return "Diplomatic Vehicle"
else:
return "Unknown Vehicle Type"
def read_plate(plate_img):
try:
gray = cv2.cvtColor(np.array(plate_img), cv2.COLOR_RGB2GRAY)
gray = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY)[1]
text = pytesseract.image_to_string(
gray,
config="--psm 7 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
return text.strip() if text.strip() else "UNKNOWN"
except:
return "UNKNOWN"
def make_prediction(img):
processor, model = load_model()
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
img_size = torch.tensor([img.size[::-1]])
results = processor.post_process_object_detection(
outputs, threshold=0.3, target_sizes=img_size
)
return results[0], model.config.id2label
# ---------------- VISUALIZATION ----------------
def visualize(img, output, id2label, threshold):
try:
keep = output["scores"] > threshold
boxes = output["boxes"][keep]
labels = output["labels"][keep]
fig, ax = plt.subplots(figsize=(6,6))
ax.imshow(img)
results_text = []
for box, label in zip(boxes, labels):
label_name = id2label[label.item()].lower()
if "plate" not in label_name:
continue
x1,y1,x2,y2 = map(int, box.tolist())
plate_img = img.crop((x1,y1,x2,y2))
plate = read_plate(plate_img)
vtype = classify_plate_color(plate_img)
toll = compute_discount(vtype)
cursor.execute(
"INSERT INTO vehicles VALUES (?, ?, ?, datetime('now'))",
(plate, vtype, toll)
)
conn.commit()
results_text.append(f"{plate} | {vtype} | ₹{int(toll)}")
ax.add_patch(
plt.Rectangle((x1,y1), x2-x1, y2-y1,
fill=False, color="red", linewidth=2)
)
ax.text(x1, y1-5, f"{plate} ({vtype})",
color="yellow", fontsize=8)
ax.axis("off")
if not results_text:
return fig, "No plate detected"
return fig, "\n".join(results_text)
except Exception as e:
return None, f"Error: {str(e)}"
# ---------------- DASHBOARD ----------------
def get_dashboard():
df = pd.read_sql("SELECT * FROM vehicles", conn)
fig, ax = plt.subplots()
if df.empty:
ax.text(0.5,0.5,"No data yet",ha="center")
ax.axis("off")
return fig
df["type"].value_counts().plot(kind="bar", ax=ax)
ax.set_title("Vehicle Types")
return fig
# ---------------- FEEDBACK ----------------
def submit_feedback(result_text, feedback_choice):
if not result_text:
return "No result available."
cursor.execute(
"INSERT INTO feedback VALUES (?, ?)",
(result_text, feedback_choice)
)
conn.commit()
return "Feedback recorded!"
def show_accuracy():
df = pd.read_sql("SELECT * FROM feedback", conn)
if df.empty:
return "No feedback yet."
correct = len(df[df["feedback"] == "Correct"])
total = len(df)
accuracy = (correct / total) * 100
return f"Accuracy (User Feedback Based): {accuracy:.2f}%"
# ---------------- CALLBACK ----------------
def detect_image(img, threshold):
if img is None:
return None, "No image provided"
output, id2label = make_prediction(img)
return visualize(img, output, id2label, threshold)
# ---------------- UI ----------------
with gr.Blocks() as demo:
gr.Markdown("## Smart Vehicle Classification System")
slider = gr.Slider(0.3, 1.0, 0.5, label="Confidence Threshold")
with gr.Row():
img_input = gr.Image(type="pil")
img_output = gr.Plot()
result_box = gr.Textbox(label="Detection Result", lines=4)
detect_btn = gr.Button("Detect")
detect_btn.click(
detect_image,
inputs=[img_input, slider],
outputs=[img_output, result_box]
)
gr.Markdown("### Feedback")
feedback_radio = gr.Radio(["Correct", "Incorrect"], label="Prediction correct?")
feedback_btn = gr.Button("Submit Feedback")
feedback_msg = gr.Textbox(label="Feedback Status")
feedback_btn.click(
submit_feedback,
inputs=[result_box, feedback_radio],
outputs=feedback_msg
)
gr.Markdown("### Model Accuracy")
accuracy_btn = gr.Button("Show Accuracy")
accuracy_box = gr.Textbox(label="Accuracy")
accuracy_btn.click(show_accuracy, outputs=accuracy_box)
gr.Markdown("### Dashboard")
dashboard_plot = gr.Plot()
refresh_btn = gr.Button("Refresh Dashboard")
refresh_btn.click(get_dashboard, outputs=dashboard_plot)
demo.launch()