Spaces:
Sleeping
Sleeping
Add Image Color Classifier Gradio app
Browse files- .gitignore +1 -0
- api.py +117 -0
- app.py +112 -0
- classify.py +33 -0
- color.py +156 -0
- main.py +190 -0
- requirements.txt +4 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
api.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# api.py
|
| 2 |
+
import cv2 as cv
|
| 3 |
+
import numpy as np
|
| 4 |
+
import json
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from fastapi import FastAPI, File, UploadFile, Query, HTTPException
|
| 9 |
+
from pydantic import BaseModel
|
| 10 |
+
|
| 11 |
+
# --- Initialize FastAPI App ---
|
| 12 |
+
app = FastAPI(
|
| 13 |
+
title="Image Color Classifier API",
|
| 14 |
+
description="Upload an image to classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# --- Define the Response Model (for OpenAPI docs and validation) ---
|
| 18 |
+
class ClassificationResponse(BaseModel):
|
| 19 |
+
filename: str
|
| 20 |
+
classification: str
|
| 21 |
+
rustish_ratio: float
|
| 22 |
+
zincish_ratio: float
|
| 23 |
+
top_colors_rgb: List[List[int]]
|
| 24 |
+
top_colors_share: List[float]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# --- Helper Functions (Copied from your main.py) ---
|
| 28 |
+
|
| 29 |
+
# Color space conversions
|
| 30 |
+
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
|
| 31 |
+
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
|
| 32 |
+
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
|
| 33 |
+
|
| 34 |
+
# Dominant color extraction using KMeans
|
| 35 |
+
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
|
| 36 |
+
data = bgr.reshape((-1, 3)).astype(np.float32)
|
| 37 |
+
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
|
| 38 |
+
flags = cv.KMEANS_PP_CENTERS
|
| 39 |
+
_, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
|
| 40 |
+
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
|
| 41 |
+
counts = Counter(labels.flatten())
|
| 42 |
+
total = float(len(labels))
|
| 43 |
+
|
| 44 |
+
idx_sorted = [i for i, _ in counts.most_common()]
|
| 45 |
+
palette = []
|
| 46 |
+
for idx in idx_sorted:
|
| 47 |
+
bgr_c = centers_u8[idx].tolist()
|
| 48 |
+
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 49 |
+
share = counts[idx] / total
|
| 50 |
+
palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
|
| 51 |
+
return palette
|
| 52 |
+
|
| 53 |
+
# Heuristic calculation for rust/zinc
|
| 54 |
+
def rust_zinc_indicators(bgr, delta=6.0):
|
| 55 |
+
lab = bgr_to_lab(bgr)
|
| 56 |
+
_, a, b = cv.split(lab)
|
| 57 |
+
a_med, b_med = np.median(a), np.median(b)
|
| 58 |
+
a_thr = a_med + delta
|
| 59 |
+
b_thr = b_med + delta
|
| 60 |
+
|
| 61 |
+
rustish = (a.astype(np.float32) > a_thr).mean()
|
| 62 |
+
zincish = (b.astype(np.float32) > b_thr).mean()
|
| 63 |
+
return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
|
| 64 |
+
|
| 65 |
+
# Classification logic
|
| 66 |
+
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
|
| 67 |
+
if zincish_ratio > zinc_thr:
|
| 68 |
+
return "zinc"
|
| 69 |
+
elif rustish_ratio > rust_thr:
|
| 70 |
+
return "rust"
|
| 71 |
+
else:
|
| 72 |
+
return "normal"
|
| 73 |
+
|
| 74 |
+
# --- API Endpoint ---
|
| 75 |
+
|
| 76 |
+
@app.post("/classify/", response_model=ClassificationResponse)
|
| 77 |
+
async def classify_image(
|
| 78 |
+
file: UploadFile = File(..., description="The image file to classify."),
|
| 79 |
+
k: int = Query(3, description="Number of dominant colors to extract."),
|
| 80 |
+
rust_thr: float = Query(0.01, description="Threshold for 'rust' classification."),
|
| 81 |
+
zinc_thr: float = Query(0.02, description="Threshold for 'zinc' classification."),
|
| 82 |
+
lab_delta: float = Query(6.0, description="Sensitivity for heuristic indicators in Lab color space.")
|
| 83 |
+
):
|
| 84 |
+
"""
|
| 85 |
+
Accepts an image file and returns a classification based on color analysis.
|
| 86 |
+
"""
|
| 87 |
+
# 1. Read image bytes from upload
|
| 88 |
+
contents = await file.read()
|
| 89 |
+
|
| 90 |
+
# 2. Convert bytes to a NumPy array and then to an OpenCV image
|
| 91 |
+
nparr = np.frombuffer(contents, np.uint8)
|
| 92 |
+
bgr = cv.imdecode(nparr, cv.IMREAD_COLOR)
|
| 93 |
+
|
| 94 |
+
if bgr is None:
|
| 95 |
+
raise HTTPException(status_code=400, detail="Invalid image file. Could not decode image.")
|
| 96 |
+
|
| 97 |
+
# 3. Perform color analysis and classification
|
| 98 |
+
indicators = rust_zinc_indicators(bgr, delta=lab_delta)
|
| 99 |
+
classification = classify_from_ratios(
|
| 100 |
+
rustish_ratio=indicators["rustish_ratio"],
|
| 101 |
+
zincish_ratio=indicators["zincish_ratio"],
|
| 102 |
+
rust_thr=rust_thr,
|
| 103 |
+
zinc_thr=zinc_thr
|
| 104 |
+
)
|
| 105 |
+
palette = dominant_colors_kmeans(bgr, k=max(1, k))
|
| 106 |
+
|
| 107 |
+
# 4. Format the response
|
| 108 |
+
response_data = {
|
| 109 |
+
"filename": file.filename,
|
| 110 |
+
"classification": classification,
|
| 111 |
+
"rustish_ratio": round(indicators["rustish_ratio"], 4),
|
| 112 |
+
"zincish_ratio": round(indicators["zincish_ratio"], 4),
|
| 113 |
+
"top_colors_rgb": [p["RGB"] for p in palette],
|
| 114 |
+
"top_colors_share": [round(p["share"], 4) for p in palette]
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
return response_data
|
app.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# gradio_app.py
|
| 2 |
+
import cv2 as cv
|
| 3 |
+
import numpy as np
|
| 4 |
+
from collections import Counter
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
# --- Helper Functions ---
|
| 9 |
+
|
| 10 |
+
# Color space conversions
|
| 11 |
+
def bgr_to_rgb(bgr):
|
| 12 |
+
return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
|
| 13 |
+
|
| 14 |
+
def bgr_to_lab(bgr):
|
| 15 |
+
return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
|
| 16 |
+
|
| 17 |
+
# Dominant color extraction using KMeans
|
| 18 |
+
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
|
| 19 |
+
data = bgr.reshape((-1, 3)).astype(np.float32)
|
| 20 |
+
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
|
| 21 |
+
flags = cv.KMEANS_PP_CENTERS
|
| 22 |
+
_, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
|
| 23 |
+
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
|
| 24 |
+
counts = Counter(labels.flatten())
|
| 25 |
+
total = float(len(labels))
|
| 26 |
+
|
| 27 |
+
idx_sorted = [i for i, _ in counts.most_common()]
|
| 28 |
+
palette = []
|
| 29 |
+
for idx in idx_sorted:
|
| 30 |
+
bgr_c = centers_u8[idx].tolist()
|
| 31 |
+
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 32 |
+
share = counts[idx] / total
|
| 33 |
+
palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
|
| 34 |
+
return palette
|
| 35 |
+
|
| 36 |
+
# Heuristic calculation for rust/zinc
|
| 37 |
+
def rust_zinc_indicators(bgr, delta=6.0):
|
| 38 |
+
lab = bgr_to_lab(bgr)
|
| 39 |
+
_, a, b = cv.split(lab)
|
| 40 |
+
a_med, b_med = np.median(a), np.median(b)
|
| 41 |
+
a_thr = a_med + delta
|
| 42 |
+
b_thr = b_med + delta
|
| 43 |
+
|
| 44 |
+
rustish = (a.astype(np.float32) > a_thr).mean()
|
| 45 |
+
zincish = (b.astype(np.float32) > b_thr).mean()
|
| 46 |
+
return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
|
| 47 |
+
|
| 48 |
+
# Classification logic
|
| 49 |
+
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
|
| 50 |
+
if zincish_ratio > zinc_thr:
|
| 51 |
+
return "zinc"
|
| 52 |
+
elif rustish_ratio > rust_thr:
|
| 53 |
+
return "rust"
|
| 54 |
+
else:
|
| 55 |
+
return "normal"
|
| 56 |
+
|
| 57 |
+
# --- Gradio Prediction Function ---
|
| 58 |
+
def classify_image_gradio(
|
| 59 |
+
image: np.ndarray,
|
| 60 |
+
k: int = 3,
|
| 61 |
+
rust_thr: float = 0.01,
|
| 62 |
+
zinc_thr: float = 0.02,
|
| 63 |
+
lab_delta: float = 6.0
|
| 64 |
+
) -> Dict[str, Any]:
|
| 65 |
+
"""
|
| 66 |
+
Accepts an image (from Gradio upload) and returns classification and color analysis.
|
| 67 |
+
"""
|
| 68 |
+
if image is None:
|
| 69 |
+
return {"error": "No image provided."}
|
| 70 |
+
|
| 71 |
+
# Convert RGB (from Gradio) to BGR for OpenCV
|
| 72 |
+
bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
|
| 73 |
+
|
| 74 |
+
# Color analysis
|
| 75 |
+
indicators = rust_zinc_indicators(bgr, delta=lab_delta)
|
| 76 |
+
classification = classify_from_ratios(
|
| 77 |
+
rustish_ratio=indicators["rustish_ratio"],
|
| 78 |
+
zincish_ratio=indicators["zincish_ratio"],
|
| 79 |
+
rust_thr=rust_thr,
|
| 80 |
+
zinc_thr=zinc_thr
|
| 81 |
+
)
|
| 82 |
+
palette = dominant_colors_kmeans(bgr, k=max(1, k))
|
| 83 |
+
|
| 84 |
+
# Format response
|
| 85 |
+
response_data = {
|
| 86 |
+
"classification": classification,
|
| 87 |
+
"rustish_ratio": round(indicators["rustish_ratio"], 4),
|
| 88 |
+
"zincish_ratio": round(indicators["zincish_ratio"], 4),
|
| 89 |
+
"top_colors_rgb": [p["RGB"] for p in palette],
|
| 90 |
+
"top_colors_share": [round(p["share"], 4) for p in palette]
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return response_data
|
| 94 |
+
|
| 95 |
+
# --- Gradio Interface ---
|
| 96 |
+
iface = gr.Interface(
|
| 97 |
+
fn=classify_image_gradio,
|
| 98 |
+
inputs=[
|
| 99 |
+
gr.Image(type="numpy", label="Upload Image"),
|
| 100 |
+
gr.Slider(1, 10, value=3, label="Number of Dominant Colors (k)"),
|
| 101 |
+
gr.Slider(0.0, 1.0, value=0.01, step=0.01, label="Rust Threshold"),
|
| 102 |
+
gr.Slider(0.0, 1.0, value=0.02, step=0.01, label="Zinc Threshold"),
|
| 103 |
+
gr.Slider(0.0, 20.0, value=6.0, step=0.5, label="Lab Delta")
|
| 104 |
+
],
|
| 105 |
+
outputs=gr.JSON(label="Classification Result"),
|
| 106 |
+
title="Image Color Classifier",
|
| 107 |
+
description="Upload an image and classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Launch Gradio app
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
iface.launch()
|
classify.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, argparse
|
| 2 |
+
|
| 3 |
+
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
|
| 4 |
+
if zincish_ratio > zinc_thr:
|
| 5 |
+
return "zinc"
|
| 6 |
+
elif rustish_ratio > rust_thr:
|
| 7 |
+
return "rust"
|
| 8 |
+
else:
|
| 9 |
+
return "normal"
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
ap = argparse.ArgumentParser()
|
| 13 |
+
ap.add_argument("--report", required=True, help="path to *_color_report.json from color.py")
|
| 14 |
+
ap.add_argument("--rust_thr", type=float, default=0.01)
|
| 15 |
+
ap.add_argument("--zinc_thr", type=float, default=0.02)
|
| 16 |
+
args = ap.parse_args()
|
| 17 |
+
|
| 18 |
+
with open(args.report, "r") as f:
|
| 19 |
+
rep = json.load(f)
|
| 20 |
+
|
| 21 |
+
rustish_ratio = rep["heuristics"]["rustish_ratio"]
|
| 22 |
+
zincish_ratio = rep["heuristics"]["zincish_ratio"]
|
| 23 |
+
|
| 24 |
+
classification = classify_from_ratios(rustish_ratio, zincish_ratio,
|
| 25 |
+
rust_thr=args.rust_thr,
|
| 26 |
+
zinc_thr=args.zinc_thr)
|
| 27 |
+
|
| 28 |
+
print({
|
| 29 |
+
"file": rep["input"],
|
| 30 |
+
"rustish_ratio": rustish_ratio,
|
| 31 |
+
"zincish_ratio": zincish_ratio,
|
| 32 |
+
"classification": classification
|
| 33 |
+
})
|
color.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2 as cv
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse, os, json
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
|
| 7 |
+
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
|
| 8 |
+
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
|
| 9 |
+
|
| 10 |
+
def img_stats(img, space_name):
|
| 11 |
+
# img is uint8, shape HxWxC
|
| 12 |
+
means = img.reshape(-1, img.shape[2]).mean(axis=0)
|
| 13 |
+
stds = img.reshape(-1, img.shape[2]).std(axis=0)
|
| 14 |
+
return {
|
| 15 |
+
"space": space_name,
|
| 16 |
+
"mean": [float(x) for x in means],
|
| 17 |
+
"std": [float(x) for x in stds]
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
def dominant_colors_kmeans(bgr, k=3, max_iter=10, seed=123):
|
| 21 |
+
# reshape to N x 3
|
| 22 |
+
data = bgr.reshape((-1, 3)).astype(np.float32)
|
| 23 |
+
|
| 24 |
+
# kmeans
|
| 25 |
+
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
|
| 26 |
+
flags = cv.KMEANS_PP_CENTERS
|
| 27 |
+
compactness, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
|
| 28 |
+
|
| 29 |
+
# centers are BGR float; convert to uint8
|
| 30 |
+
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
|
| 31 |
+
counts = Counter(labels.flatten())
|
| 32 |
+
total = float(len(labels))
|
| 33 |
+
|
| 34 |
+
# sort by frequency desc
|
| 35 |
+
idx_sorted = [i for i,_ in counts.most_common()]
|
| 36 |
+
palette = []
|
| 37 |
+
for idx in idx_sorted:
|
| 38 |
+
bgr_c = centers_u8[idx].tolist()
|
| 39 |
+
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 40 |
+
hsv_c = bgr_to_hsv(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 41 |
+
lab_c = bgr_to_lab(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 42 |
+
share = counts[idx] / total
|
| 43 |
+
palette.append({
|
| 44 |
+
"share": float(share),
|
| 45 |
+
"BGR": [int(x) for x in bgr_c],
|
| 46 |
+
"RGB": [int(x) for x in rgb_c],
|
| 47 |
+
"HSV": [int(x) for x in hsv_c],
|
| 48 |
+
"Lab": [int(x) for x in lab_c],
|
| 49 |
+
})
|
| 50 |
+
return palette
|
| 51 |
+
|
| 52 |
+
def make_palette_image(palette, width=600, height=80, pad=2):
|
| 53 |
+
# palette: list of dicts with 'share' and 'RGB'
|
| 54 |
+
img = np.zeros((height, width, 3), dtype=np.uint8)
|
| 55 |
+
x = 0
|
| 56 |
+
for p in palette:
|
| 57 |
+
w = max(1, int(p["share"] * width))
|
| 58 |
+
color = tuple(p["RGB"]) # RGB
|
| 59 |
+
# convert to BGR for OpenCV drawing
|
| 60 |
+
bgr = (int(color[2]), int(color[1]), int(color[0]))
|
| 61 |
+
cv.rectangle(img, (x, 0), (min(width-1, x+w-1), height-1), bgr, -1)
|
| 62 |
+
x += w
|
| 63 |
+
# thin separators
|
| 64 |
+
for i in range(1, len(palette)):
|
| 65 |
+
x_sep = int(sum([pp["share"] for pp in palette[:i]]) * width)
|
| 66 |
+
cv.line(img, (x_sep, 0), (x_sep, height-1), (30,30,30), 1)
|
| 67 |
+
return img
|
| 68 |
+
|
| 69 |
+
def rust_zinc_indicators(bgr):
|
| 70 |
+
"""Heuristic only, NO detection claims. Gives ratios based on Lab chroma tendencies:
|
| 71 |
+
- 'rustish_ratio': fraction of pixels with a* significantly above median (reddish/brownish)
|
| 72 |
+
- 'zincish_ratio': fraction of pixels with b* significantly above median (yellowish)
|
| 73 |
+
"""
|
| 74 |
+
lab = bgr_to_lab(bgr)
|
| 75 |
+
L, a, b = cv.split(lab)
|
| 76 |
+
a_med, b_med = np.median(a), np.median(b)
|
| 77 |
+
a_thr = a_med + 6 # tweak if needed
|
| 78 |
+
b_thr = b_med + 6
|
| 79 |
+
|
| 80 |
+
rustish = (a.astype(np.float32) > a_thr).mean()
|
| 81 |
+
zincish = (b.astype(np.float32) > b_thr).mean()
|
| 82 |
+
return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish),
|
| 83 |
+
"a_median": float(a_med), "b_median": float(b_med),
|
| 84 |
+
"a_thresh": float(a_thr), "b_thresh": float(b_thr)}
|
| 85 |
+
|
| 86 |
+
def main():
|
| 87 |
+
ap = argparse.ArgumentParser()
|
| 88 |
+
ap.add_argument("--img", required=True, help="path to image")
|
| 89 |
+
ap.add_argument("--k", type=int, default=3, help="number of dominant colors")
|
| 90 |
+
ap.add_argument("--resize_max", type=int, default=1200, help="resize longer side to this (0=off)")
|
| 91 |
+
ap.add_argument("--outdir", default="color_out")
|
| 92 |
+
args = ap.parse_args()
|
| 93 |
+
|
| 94 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
bgr = cv.imread(args.img, cv.IMREAD_COLOR)
|
| 97 |
+
if bgr is None:
|
| 98 |
+
raise RuntimeError(f"Cannot read image: {args.img}")
|
| 99 |
+
|
| 100 |
+
# Optional resize to speed up
|
| 101 |
+
h, w = bgr.shape[:2]
|
| 102 |
+
if args.resize_max > 0:
|
| 103 |
+
s = max(h, w)
|
| 104 |
+
if s > args.resize_max:
|
| 105 |
+
scale = args.resize_max / float(s)
|
| 106 |
+
bgr = cv.resize(bgr, (int(w*scale), int(h*scale)), interpolation=cv.INTER_AREA)
|
| 107 |
+
|
| 108 |
+
# Color-space stats
|
| 109 |
+
rgb = bgr_to_rgb(bgr)
|
| 110 |
+
hsv = bgr_to_hsv(bgr)
|
| 111 |
+
lab = bgr_to_lab(bgr)
|
| 112 |
+
|
| 113 |
+
stats = [
|
| 114 |
+
img_stats(rgb, "RGB"), # channels: R,G,B (0-255)
|
| 115 |
+
img_stats(hsv, "HSV"), # channels: H(0-179), S(0-255), V(0-255) in OpenCV
|
| 116 |
+
img_stats(lab, "Lab"), # channels: L(0-255), a(0-255), b(0-255) in OpenCV's scaled Lab
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
# Dominant colors (k-means)
|
| 120 |
+
palette = dominant_colors_kmeans(bgr, k=max(1, args.k))
|
| 121 |
+
|
| 122 |
+
# Heuristic indicators (optional)
|
| 123 |
+
indicators = rust_zinc_indicators(bgr)
|
| 124 |
+
|
| 125 |
+
# Save palette image
|
| 126 |
+
pal_img = make_palette_image(palette)
|
| 127 |
+
base = os.path.splitext(os.path.basename(args.img))[0]
|
| 128 |
+
pal_path = os.path.join(args.outdir, f"{base}_palette.png")
|
| 129 |
+
cv.imwrite(pal_path, pal_img)
|
| 130 |
+
|
| 131 |
+
# Build and save JSON
|
| 132 |
+
report = {
|
| 133 |
+
"input": os.path.basename(args.img),
|
| 134 |
+
"size_hw": [int(bgr.shape[0]), int(bgr.shape[1])],
|
| 135 |
+
"color_stats": stats,
|
| 136 |
+
"dominant_colors": palette, # ordered by share desc
|
| 137 |
+
"heuristics": indicators,
|
| 138 |
+
"palette_image": pal_path
|
| 139 |
+
}
|
| 140 |
+
rep_path = os.path.join(args.outdir, f"{base}_color_report.json")
|
| 141 |
+
with open(rep_path, "w") as f:
|
| 142 |
+
json.dump(report, f, indent=2)
|
| 143 |
+
|
| 144 |
+
# Print a short summary to console
|
| 145 |
+
print(json.dumps({
|
| 146 |
+
"input": report["input"],
|
| 147 |
+
"top_colors_rgb": [p["RGB"] for p in report["dominant_colors"]],
|
| 148 |
+
"top_colors_share": [round(p["share"], 4) for p in report["dominant_colors"]],
|
| 149 |
+
"rustish_ratio": round(report["heuristics"]["rustish_ratio"], 4),
|
| 150 |
+
"zincish_ratio": round(report["heuristics"]["zincish_ratio"], 4),
|
| 151 |
+
"report_path": rep_path,
|
| 152 |
+
"palette_image": pal_path
|
| 153 |
+
}, indent=2))
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2 as cv
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse, os, json
|
| 4 |
+
from collections import Counter
|
| 5 |
+
|
| 6 |
+
# ---------------- Conversions ----------------
|
| 7 |
+
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
|
| 8 |
+
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
|
| 9 |
+
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
|
| 10 |
+
|
| 11 |
+
# ---------------- Stats ----------------
|
| 12 |
+
def img_stats(img, space_name):
|
| 13 |
+
# img is uint8, shape HxWxC
|
| 14 |
+
means = img.reshape(-1, img.shape[2]).mean(axis=0)
|
| 15 |
+
stds = img.reshape(-1, img.shape[2]).std(axis=0)
|
| 16 |
+
return {
|
| 17 |
+
"space": space_name,
|
| 18 |
+
"mean": [float(x) for x in means],
|
| 19 |
+
"std": [float(x) for x in stds]
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# ---------------- Dominant colors ----------------
|
| 23 |
+
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
|
| 24 |
+
data = bgr.reshape((-1, 3)).astype(np.float32)
|
| 25 |
+
criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
|
| 26 |
+
flags = cv.KMEANS_PP_CENTERS
|
| 27 |
+
compactness, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
|
| 28 |
+
centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
|
| 29 |
+
counts = Counter(labels.flatten())
|
| 30 |
+
total = float(len(labels))
|
| 31 |
+
|
| 32 |
+
idx_sorted = [i for i,_ in counts.most_common()]
|
| 33 |
+
palette = []
|
| 34 |
+
for idx in idx_sorted:
|
| 35 |
+
bgr_c = centers_u8[idx].tolist()
|
| 36 |
+
rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 37 |
+
hsv_c = bgr_to_hsv(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 38 |
+
lab_c = bgr_to_lab(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
|
| 39 |
+
share = counts[idx] / total
|
| 40 |
+
palette.append({
|
| 41 |
+
"share": float(share),
|
| 42 |
+
"BGR": [int(x) for x in bgr_c],
|
| 43 |
+
"RGB": [int(x) for x in rgb_c],
|
| 44 |
+
"HSV": [int(x) for x in hsv_c],
|
| 45 |
+
"Lab": [int(x) for x in lab_c],
|
| 46 |
+
})
|
| 47 |
+
return palette
|
| 48 |
+
|
| 49 |
+
def make_palette_image(palette, width=600, height=80):
|
| 50 |
+
img = np.zeros((height, width, 3), dtype=np.uint8)
|
| 51 |
+
x = 0
|
| 52 |
+
for p in palette:
|
| 53 |
+
w = max(1, int(p["share"] * width))
|
| 54 |
+
r,g,b = p["RGB"] # stored as RGB
|
| 55 |
+
cv.rectangle(img, (x, 0), (min(width-1, x+w-1), height-1), (b,g,r), -1) # convert to BGR for draw
|
| 56 |
+
x += w
|
| 57 |
+
for i in range(1, len(palette)):
|
| 58 |
+
x_sep = int(sum([pp["share"] for pp in palette[:i]]) * width)
|
| 59 |
+
cv.line(img, (x_sep, 0), (x_sep, height-1), (30,30,30), 1)
|
| 60 |
+
return img
|
| 61 |
+
|
| 62 |
+
# ---------------- Heuristics ----------------
|
| 63 |
+
def rust_zinc_indicators(bgr, delta=6):
|
| 64 |
+
"""
|
| 65 |
+
Heuristic only. Uses Lab:
|
| 66 |
+
- rustish_ratio: fraction of pixels with a* > median(a*) + delta
|
| 67 |
+
- zincish_ratio: fraction of pixels with b* > median(b*) + delta
|
| 68 |
+
"""
|
| 69 |
+
lab = bgr_to_lab(bgr)
|
| 70 |
+
L, a, b = cv.split(lab)
|
| 71 |
+
a_med, b_med = np.median(a), np.median(b)
|
| 72 |
+
a_thr = a_med + delta
|
| 73 |
+
b_thr = b_med + delta
|
| 74 |
+
|
| 75 |
+
rustish = (a.astype(np.float32) > a_thr).mean()
|
| 76 |
+
zincish = (b.astype(np.float32) > b_thr).mean()
|
| 77 |
+
return {
|
| 78 |
+
"rustish_ratio": float(rustish),
|
| 79 |
+
"zincish_ratio": float(zincish),
|
| 80 |
+
"a_median": float(a_med),
|
| 81 |
+
"b_median": float(b_med),
|
| 82 |
+
"a_thresh": float(a_thr),
|
| 83 |
+
"b_thresh": float(b_thr),
|
| 84 |
+
"delta": float(delta)
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# ---------------- Classification ----------------
|
| 88 |
+
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.002, zinc_thr=0.01):
|
| 89 |
+
"""
|
| 90 |
+
Your rule:
|
| 91 |
+
- zinc if zincish_ratio > 0.01
|
| 92 |
+
- else rust if rustish_ratio > 0.002
|
| 93 |
+
- else normal
|
| 94 |
+
"""
|
| 95 |
+
if zincish_ratio > zinc_thr:
|
| 96 |
+
return "zinc"
|
| 97 |
+
elif rustish_ratio > rust_thr:
|
| 98 |
+
return "rust"
|
| 99 |
+
else:
|
| 100 |
+
return "normal"
|
| 101 |
+
|
| 102 |
+
# ---------------- Main ----------------
|
| 103 |
+
def main():
|
| 104 |
+
ap = argparse.ArgumentParser()
|
| 105 |
+
ap.add_argument("--img", required=True, help="path to image")
|
| 106 |
+
ap.add_argument("--k", type=int, default=3, help="number of dominant colors")
|
| 107 |
+
ap.add_argument("--resize_max", type=int, default=1200, help="resize longer side to this (0=off)")
|
| 108 |
+
ap.add_argument("--outdir", default="color_out")
|
| 109 |
+
# thresholds you defined:
|
| 110 |
+
ap.add_argument("--rust_thr", type=float, default=0.01)
|
| 111 |
+
ap.add_argument("--zinc_thr", type=float, default=0.02)
|
| 112 |
+
# indicator sensitivity (Lab delta)
|
| 113 |
+
ap.add_argument("--lab_delta", type=float, default=6.0)
|
| 114 |
+
args = ap.parse_args()
|
| 115 |
+
|
| 116 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 117 |
+
|
| 118 |
+
bgr = cv.imread(args.img, cv.IMREAD_COLOR)
|
| 119 |
+
if bgr is None:
|
| 120 |
+
raise RuntimeError(f"Cannot read image: {args.img}")
|
| 121 |
+
|
| 122 |
+
# optional resize
|
| 123 |
+
h, w = bgr.shape[:2]
|
| 124 |
+
if args.resize_max > 0:
|
| 125 |
+
s = max(h, w)
|
| 126 |
+
if s > args.resize_max:
|
| 127 |
+
scale = args.resize_max / float(s)
|
| 128 |
+
bgr = cv.resize(bgr, (int(w*scale), int(h*scale)), interpolation=cv.INTER_AREA)
|
| 129 |
+
|
| 130 |
+
# color stats
|
| 131 |
+
rgb = bgr_to_rgb(bgr)
|
| 132 |
+
hsv = bgr_to_hsv(bgr)
|
| 133 |
+
lab = bgr_to_lab(bgr)
|
| 134 |
+
stats = [
|
| 135 |
+
img_stats(rgb, "RGB"),
|
| 136 |
+
img_stats(hsv, "HSV"),
|
| 137 |
+
img_stats(lab, "Lab"),
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
# dominant colors
|
| 141 |
+
palette = dominant_colors_kmeans(bgr, k=max(1, args.k))
|
| 142 |
+
|
| 143 |
+
# heuristics
|
| 144 |
+
indicators = rust_zinc_indicators(bgr, delta=args.lab_delta)
|
| 145 |
+
|
| 146 |
+
# classification using your thresholds
|
| 147 |
+
cls = classify_from_ratios(
|
| 148 |
+
rustish_ratio=indicators["rustish_ratio"],
|
| 149 |
+
zincish_ratio=indicators["zincish_ratio"],
|
| 150 |
+
rust_thr=args.rust_thr,
|
| 151 |
+
zinc_thr=args.zinc_thr
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# save palette image
|
| 155 |
+
base = os.path.splitext(os.path.basename(args.img))[0]
|
| 156 |
+
pal_img = make_palette_image(palette)
|
| 157 |
+
pal_path = os.path.join(args.outdir, f"{base}_palette.png")
|
| 158 |
+
cv.imwrite(pal_path, pal_img)
|
| 159 |
+
|
| 160 |
+
# build + save JSON
|
| 161 |
+
report = {
|
| 162 |
+
"input": os.path.basename(args.img),
|
| 163 |
+
"size_hw": [int(bgr.shape[0]), int(bgr.shape[1])],
|
| 164 |
+
"color_stats": stats,
|
| 165 |
+
"dominant_colors": palette, # ordered by share desc
|
| 166 |
+
"heuristics": indicators,
|
| 167 |
+
"classification": cls,
|
| 168 |
+
"thresholds": {"rust_thr": args.rust_thr, "zinc_thr": args.zinc_thr},
|
| 169 |
+
"palette_image": pal_path
|
| 170 |
+
}
|
| 171 |
+
rep_path = os.path.join(args.outdir, f"{base}_color_report.json")
|
| 172 |
+
with open(rep_path, "w") as f:
|
| 173 |
+
json.dump(report, f, indent=2)
|
| 174 |
+
|
| 175 |
+
# console summary
|
| 176 |
+
print(json.dumps({
|
| 177 |
+
"input": report["input"],
|
| 178 |
+
"classification": cls,
|
| 179 |
+
"rustish_ratio": round(indicators["rustish_ratio"], 4),
|
| 180 |
+
"zincish_ratio": round(indicators["zincish_ratio"], 4),
|
| 181 |
+
"top_colors_rgb": [p["RGB"] for p in palette],
|
| 182 |
+
"top_colors_share": [round(p["share"], 4) for p in palette],
|
| 183 |
+
"report_path": rep_path,
|
| 184 |
+
"palette_image": pal_path
|
| 185 |
+
}, indent=2))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"fastapi[all]"
|
| 2 |
+
opencv-python-headless
|
| 3 |
+
numpy
|
| 4 |
+
gardio
|