Anvit25's picture
Add Image Color Classifier Gradio app
9ff9b70
raw
history blame
4.23 kB
# api.py
import cv2 as cv
import numpy as np
import json
from collections import Counter
from typing import List
from fastapi import FastAPI, File, UploadFile, Query, HTTPException
from pydantic import BaseModel
# --- Initialize FastAPI App ---
app = FastAPI(
title="Image Color Classifier API",
description="Upload an image to classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
)
# --- Define the Response Model (for OpenAPI docs and validation) ---
class ClassificationResponse(BaseModel):
filename: str
classification: str
rustish_ratio: float
zincish_ratio: float
top_colors_rgb: List[List[int]]
top_colors_share: List[float]
# --- Helper Functions (Copied from your main.py) ---
# Color space conversions
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
# Dominant color extraction using KMeans
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
data = bgr.reshape((-1, 3)).astype(np.float32)
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
flags = cv.KMEANS_PP_CENTERS
_, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
counts = Counter(labels.flatten())
total = float(len(labels))
idx_sorted = [i for i, _ in counts.most_common()]
palette = []
for idx in idx_sorted:
bgr_c = centers_u8[idx].tolist()
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
share = counts[idx] / total
palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
return palette
# Heuristic calculation for rust/zinc
def rust_zinc_indicators(bgr, delta=6.0):
lab = bgr_to_lab(bgr)
_, a, b = cv.split(lab)
a_med, b_med = np.median(a), np.median(b)
a_thr = a_med + delta
b_thr = b_med + delta
rustish = (a.astype(np.float32) > a_thr).mean()
zincish = (b.astype(np.float32) > b_thr).mean()
return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
# Classification logic
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
if zincish_ratio > zinc_thr:
return "zinc"
elif rustish_ratio > rust_thr:
return "rust"
else:
return "normal"
# --- API Endpoint ---
@app.post("/classify/", response_model=ClassificationResponse)
async def classify_image(
file: UploadFile = File(..., description="The image file to classify."),
k: int = Query(3, description="Number of dominant colors to extract."),
rust_thr: float = Query(0.01, description="Threshold for 'rust' classification."),
zinc_thr: float = Query(0.02, description="Threshold for 'zinc' classification."),
lab_delta: float = Query(6.0, description="Sensitivity for heuristic indicators in Lab color space.")
):
"""
Accepts an image file and returns a classification based on color analysis.
"""
# 1. Read image bytes from upload
contents = await file.read()
# 2. Convert bytes to a NumPy array and then to an OpenCV image
nparr = np.frombuffer(contents, np.uint8)
bgr = cv.imdecode(nparr, cv.IMREAD_COLOR)
if bgr is None:
raise HTTPException(status_code=400, detail="Invalid image file. Could not decode image.")
# 3. Perform color analysis and classification
indicators = rust_zinc_indicators(bgr, delta=lab_delta)
classification = classify_from_ratios(
rustish_ratio=indicators["rustish_ratio"],
zincish_ratio=indicators["zincish_ratio"],
rust_thr=rust_thr,
zinc_thr=zinc_thr
)
palette = dominant_colors_kmeans(bgr, k=max(1, k))
# 4. Format the response
response_data = {
"filename": file.filename,
"classification": classification,
"rustish_ratio": round(indicators["rustish_ratio"], 4),
"zincish_ratio": round(indicators["zincish_ratio"], 4),
"top_colors_rgb": [p["RGB"] for p in palette],
"top_colors_share": [round(p["share"], 4) for p in palette]
}
return response_data