Anvit25's picture
Add Image Color Classifier Gradio app
9ff9b70
raw
history blame
3.73 kB
# gradio_app.py
import cv2 as cv
import numpy as np
from collections import Counter
from typing import List, Dict, Any
import gradio as gr
# --- Helper Functions ---
# Color space conversions
def bgr_to_rgb(bgr):
return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
def bgr_to_lab(bgr):
return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
# Dominant color extraction using KMeans
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
data = bgr.reshape((-1, 3)).astype(np.float32)
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
flags = cv.KMEANS_PP_CENTERS
_, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
counts = Counter(labels.flatten())
total = float(len(labels))
idx_sorted = [i for i, _ in counts.most_common()]
palette = []
for idx in idx_sorted:
bgr_c = centers_u8[idx].tolist()
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
share = counts[idx] / total
palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
return palette
# Heuristic calculation for rust/zinc
def rust_zinc_indicators(bgr, delta=6.0):
lab = bgr_to_lab(bgr)
_, a, b = cv.split(lab)
a_med, b_med = np.median(a), np.median(b)
a_thr = a_med + delta
b_thr = b_med + delta
rustish = (a.astype(np.float32) > a_thr).mean()
zincish = (b.astype(np.float32) > b_thr).mean()
return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
# Classification logic
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
if zincish_ratio > zinc_thr:
return "zinc"
elif rustish_ratio > rust_thr:
return "rust"
else:
return "normal"
# --- Gradio Prediction Function ---
def classify_image_gradio(
image: np.ndarray,
k: int = 3,
rust_thr: float = 0.01,
zinc_thr: float = 0.02,
lab_delta: float = 6.0
) -> Dict[str, Any]:
"""
Accepts an image (from Gradio upload) and returns classification and color analysis.
"""
if image is None:
return {"error": "No image provided."}
# Convert RGB (from Gradio) to BGR for OpenCV
bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
# Color analysis
indicators = rust_zinc_indicators(bgr, delta=lab_delta)
classification = classify_from_ratios(
rustish_ratio=indicators["rustish_ratio"],
zincish_ratio=indicators["zincish_ratio"],
rust_thr=rust_thr,
zinc_thr=zinc_thr
)
palette = dominant_colors_kmeans(bgr, k=max(1, k))
# Format response
response_data = {
"classification": classification,
"rustish_ratio": round(indicators["rustish_ratio"], 4),
"zincish_ratio": round(indicators["zincish_ratio"], 4),
"top_colors_rgb": [p["RGB"] for p in palette],
"top_colors_share": [round(p["share"], 4) for p in palette]
}
return response_data
# --- Gradio Interface ---
iface = gr.Interface(
fn=classify_image_gradio,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(1, 10, value=3, label="Number of Dominant Colors (k)"),
gr.Slider(0.0, 1.0, value=0.01, step=0.01, label="Rust Threshold"),
gr.Slider(0.0, 1.0, value=0.02, step=0.01, label="Zinc Threshold"),
gr.Slider(0.0, 20.0, value=6.0, step=0.5, label="Lab Delta")
],
outputs=gr.JSON(label="Classification Result"),
title="Image Color Classifier",
description="Upload an image and classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
)
# Launch Gradio app
if __name__ == "__main__":
iface.launch()