Anvit25 commited on
Commit
9ff9b70
·
1 Parent(s): f94cf95

Add Image Color Classifier Gradio app

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. api.py +117 -0
  3. app.py +112 -0
  4. classify.py +33 -0
  5. color.py +156 -0
  6. main.py +190 -0
  7. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
api.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+ import cv2 as cv
3
+ import numpy as np
4
+ import json
5
+ from collections import Counter
6
+ from typing import List
7
+
8
+ from fastapi import FastAPI, File, UploadFile, Query, HTTPException
9
+ from pydantic import BaseModel
10
+
11
+ # --- Initialize FastAPI App ---
12
+ app = FastAPI(
13
+ title="Image Color Classifier API",
14
+ description="Upload an image to classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
15
+ )
16
+
17
+ # --- Define the Response Model (for OpenAPI docs and validation) ---
18
+ class ClassificationResponse(BaseModel):
19
+ filename: str
20
+ classification: str
21
+ rustish_ratio: float
22
+ zincish_ratio: float
23
+ top_colors_rgb: List[List[int]]
24
+ top_colors_share: List[float]
25
+
26
+
27
+ # --- Helper Functions (Copied from your main.py) ---
28
+
29
+ # Color space conversions
30
+ def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
31
+ def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
32
+ def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
33
+
34
+ # Dominant color extraction using KMeans
35
+ def dominant_colors_kmeans(bgr, k=3, max_iter=10):
36
+ data = bgr.reshape((-1, 3)).astype(np.float32)
37
+ criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
38
+ flags = cv.KMEANS_PP_CENTERS
39
+ _, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
40
+ centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
41
+ counts = Counter(labels.flatten())
42
+ total = float(len(labels))
43
+
44
+ idx_sorted = [i for i, _ in counts.most_common()]
45
+ palette = []
46
+ for idx in idx_sorted:
47
+ bgr_c = centers_u8[idx].tolist()
48
+ rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
49
+ share = counts[idx] / total
50
+ palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
51
+ return palette
52
+
53
+ # Heuristic calculation for rust/zinc
54
+ def rust_zinc_indicators(bgr, delta=6.0):
55
+ lab = bgr_to_lab(bgr)
56
+ _, a, b = cv.split(lab)
57
+ a_med, b_med = np.median(a), np.median(b)
58
+ a_thr = a_med + delta
59
+ b_thr = b_med + delta
60
+
61
+ rustish = (a.astype(np.float32) > a_thr).mean()
62
+ zincish = (b.astype(np.float32) > b_thr).mean()
63
+ return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
64
+
65
+ # Classification logic
66
+ def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
67
+ if zincish_ratio > zinc_thr:
68
+ return "zinc"
69
+ elif rustish_ratio > rust_thr:
70
+ return "rust"
71
+ else:
72
+ return "normal"
73
+
74
+ # --- API Endpoint ---
75
+
76
+ @app.post("/classify/", response_model=ClassificationResponse)
77
+ async def classify_image(
78
+ file: UploadFile = File(..., description="The image file to classify."),
79
+ k: int = Query(3, description="Number of dominant colors to extract."),
80
+ rust_thr: float = Query(0.01, description="Threshold for 'rust' classification."),
81
+ zinc_thr: float = Query(0.02, description="Threshold for 'zinc' classification."),
82
+ lab_delta: float = Query(6.0, description="Sensitivity for heuristic indicators in Lab color space.")
83
+ ):
84
+ """
85
+ Accepts an image file and returns a classification based on color analysis.
86
+ """
87
+ # 1. Read image bytes from upload
88
+ contents = await file.read()
89
+
90
+ # 2. Convert bytes to a NumPy array and then to an OpenCV image
91
+ nparr = np.frombuffer(contents, np.uint8)
92
+ bgr = cv.imdecode(nparr, cv.IMREAD_COLOR)
93
+
94
+ if bgr is None:
95
+ raise HTTPException(status_code=400, detail="Invalid image file. Could not decode image.")
96
+
97
+ # 3. Perform color analysis and classification
98
+ indicators = rust_zinc_indicators(bgr, delta=lab_delta)
99
+ classification = classify_from_ratios(
100
+ rustish_ratio=indicators["rustish_ratio"],
101
+ zincish_ratio=indicators["zincish_ratio"],
102
+ rust_thr=rust_thr,
103
+ zinc_thr=zinc_thr
104
+ )
105
+ palette = dominant_colors_kmeans(bgr, k=max(1, k))
106
+
107
+ # 4. Format the response
108
+ response_data = {
109
+ "filename": file.filename,
110
+ "classification": classification,
111
+ "rustish_ratio": round(indicators["rustish_ratio"], 4),
112
+ "zincish_ratio": round(indicators["zincish_ratio"], 4),
113
+ "top_colors_rgb": [p["RGB"] for p in palette],
114
+ "top_colors_share": [round(p["share"], 4) for p in palette]
115
+ }
116
+
117
+ return response_data
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gradio_app.py
2
+ import cv2 as cv
3
+ import numpy as np
4
+ from collections import Counter
5
+ from typing import List, Dict, Any
6
+ import gradio as gr
7
+
8
+ # --- Helper Functions ---
9
+
10
+ # Color space conversions
11
+ def bgr_to_rgb(bgr):
12
+ return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
13
+
14
+ def bgr_to_lab(bgr):
15
+ return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
16
+
17
+ # Dominant color extraction using KMeans
18
+ def dominant_colors_kmeans(bgr, k=3, max_iter=10):
19
+ data = bgr.reshape((-1, 3)).astype(np.float32)
20
+ criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
21
+ flags = cv.KMEANS_PP_CENTERS
22
+ _, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
23
+ centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
24
+ counts = Counter(labels.flatten())
25
+ total = float(len(labels))
26
+
27
+ idx_sorted = [i for i, _ in counts.most_common()]
28
+ palette = []
29
+ for idx in idx_sorted:
30
+ bgr_c = centers_u8[idx].tolist()
31
+ rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
32
+ share = counts[idx] / total
33
+ palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
34
+ return palette
35
+
36
+ # Heuristic calculation for rust/zinc
37
+ def rust_zinc_indicators(bgr, delta=6.0):
38
+ lab = bgr_to_lab(bgr)
39
+ _, a, b = cv.split(lab)
40
+ a_med, b_med = np.median(a), np.median(b)
41
+ a_thr = a_med + delta
42
+ b_thr = b_med + delta
43
+
44
+ rustish = (a.astype(np.float32) > a_thr).mean()
45
+ zincish = (b.astype(np.float32) > b_thr).mean()
46
+ return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}
47
+
48
+ # Classification logic
49
+ def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
50
+ if zincish_ratio > zinc_thr:
51
+ return "zinc"
52
+ elif rustish_ratio > rust_thr:
53
+ return "rust"
54
+ else:
55
+ return "normal"
56
+
57
+ # --- Gradio Prediction Function ---
58
+ def classify_image_gradio(
59
+ image: np.ndarray,
60
+ k: int = 3,
61
+ rust_thr: float = 0.01,
62
+ zinc_thr: float = 0.02,
63
+ lab_delta: float = 6.0
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Accepts an image (from Gradio upload) and returns classification and color analysis.
67
+ """
68
+ if image is None:
69
+ return {"error": "No image provided."}
70
+
71
+ # Convert RGB (from Gradio) to BGR for OpenCV
72
+ bgr = cv.cvtColor(image, cv.COLOR_RGB2BGR)
73
+
74
+ # Color analysis
75
+ indicators = rust_zinc_indicators(bgr, delta=lab_delta)
76
+ classification = classify_from_ratios(
77
+ rustish_ratio=indicators["rustish_ratio"],
78
+ zincish_ratio=indicators["zincish_ratio"],
79
+ rust_thr=rust_thr,
80
+ zinc_thr=zinc_thr
81
+ )
82
+ palette = dominant_colors_kmeans(bgr, k=max(1, k))
83
+
84
+ # Format response
85
+ response_data = {
86
+ "classification": classification,
87
+ "rustish_ratio": round(indicators["rustish_ratio"], 4),
88
+ "zincish_ratio": round(indicators["zincish_ratio"], 4),
89
+ "top_colors_rgb": [p["RGB"] for p in palette],
90
+ "top_colors_share": [round(p["share"], 4) for p in palette]
91
+ }
92
+
93
+ return response_data
94
+
95
+ # --- Gradio Interface ---
96
+ iface = gr.Interface(
97
+ fn=classify_image_gradio,
98
+ inputs=[
99
+ gr.Image(type="numpy", label="Upload Image"),
100
+ gr.Slider(1, 10, value=3, label="Number of Dominant Colors (k)"),
101
+ gr.Slider(0.0, 1.0, value=0.01, step=0.01, label="Rust Threshold"),
102
+ gr.Slider(0.0, 1.0, value=0.02, step=0.01, label="Zinc Threshold"),
103
+ gr.Slider(0.0, 20.0, value=6.0, step=0.5, label="Lab Delta")
104
+ ],
105
+ outputs=gr.JSON(label="Classification Result"),
106
+ title="Image Color Classifier",
107
+ description="Upload an image and classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
108
+ )
109
+
110
+ # Launch Gradio app
111
+ if __name__ == "__main__":
112
+ iface.launch()
classify.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, argparse
2
+
3
+ def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
4
+ if zincish_ratio > zinc_thr:
5
+ return "zinc"
6
+ elif rustish_ratio > rust_thr:
7
+ return "rust"
8
+ else:
9
+ return "normal"
10
+
11
+ if __name__ == "__main__":
12
+ ap = argparse.ArgumentParser()
13
+ ap.add_argument("--report", required=True, help="path to *_color_report.json from color.py")
14
+ ap.add_argument("--rust_thr", type=float, default=0.01)
15
+ ap.add_argument("--zinc_thr", type=float, default=0.02)
16
+ args = ap.parse_args()
17
+
18
+ with open(args.report, "r") as f:
19
+ rep = json.load(f)
20
+
21
+ rustish_ratio = rep["heuristics"]["rustish_ratio"]
22
+ zincish_ratio = rep["heuristics"]["zincish_ratio"]
23
+
24
+ classification = classify_from_ratios(rustish_ratio, zincish_ratio,
25
+ rust_thr=args.rust_thr,
26
+ zinc_thr=args.zinc_thr)
27
+
28
+ print({
29
+ "file": rep["input"],
30
+ "rustish_ratio": rustish_ratio,
31
+ "zincish_ratio": zincish_ratio,
32
+ "classification": classification
33
+ })
color.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import numpy as np
3
+ import argparse, os, json
4
+ from collections import Counter
5
+
6
+ def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
7
+ def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
8
+ def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
9
+
10
+ def img_stats(img, space_name):
11
+ # img is uint8, shape HxWxC
12
+ means = img.reshape(-1, img.shape[2]).mean(axis=0)
13
+ stds = img.reshape(-1, img.shape[2]).std(axis=0)
14
+ return {
15
+ "space": space_name,
16
+ "mean": [float(x) for x in means],
17
+ "std": [float(x) for x in stds]
18
+ }
19
+
20
+ def dominant_colors_kmeans(bgr, k=3, max_iter=10, seed=123):
21
+ # reshape to N x 3
22
+ data = bgr.reshape((-1, 3)).astype(np.float32)
23
+
24
+ # kmeans
25
+ criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
26
+ flags = cv.KMEANS_PP_CENTERS
27
+ compactness, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
28
+
29
+ # centers are BGR float; convert to uint8
30
+ centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
31
+ counts = Counter(labels.flatten())
32
+ total = float(len(labels))
33
+
34
+ # sort by frequency desc
35
+ idx_sorted = [i for i,_ in counts.most_common()]
36
+ palette = []
37
+ for idx in idx_sorted:
38
+ bgr_c = centers_u8[idx].tolist()
39
+ rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
40
+ hsv_c = bgr_to_hsv(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
41
+ lab_c = bgr_to_lab(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
42
+ share = counts[idx] / total
43
+ palette.append({
44
+ "share": float(share),
45
+ "BGR": [int(x) for x in bgr_c],
46
+ "RGB": [int(x) for x in rgb_c],
47
+ "HSV": [int(x) for x in hsv_c],
48
+ "Lab": [int(x) for x in lab_c],
49
+ })
50
+ return palette
51
+
52
+ def make_palette_image(palette, width=600, height=80, pad=2):
53
+ # palette: list of dicts with 'share' and 'RGB'
54
+ img = np.zeros((height, width, 3), dtype=np.uint8)
55
+ x = 0
56
+ for p in palette:
57
+ w = max(1, int(p["share"] * width))
58
+ color = tuple(p["RGB"]) # RGB
59
+ # convert to BGR for OpenCV drawing
60
+ bgr = (int(color[2]), int(color[1]), int(color[0]))
61
+ cv.rectangle(img, (x, 0), (min(width-1, x+w-1), height-1), bgr, -1)
62
+ x += w
63
+ # thin separators
64
+ for i in range(1, len(palette)):
65
+ x_sep = int(sum([pp["share"] for pp in palette[:i]]) * width)
66
+ cv.line(img, (x_sep, 0), (x_sep, height-1), (30,30,30), 1)
67
+ return img
68
+
69
+ def rust_zinc_indicators(bgr):
70
+ """Heuristic only, NO detection claims. Gives ratios based on Lab chroma tendencies:
71
+ - 'rustish_ratio': fraction of pixels with a* significantly above median (reddish/brownish)
72
+ - 'zincish_ratio': fraction of pixels with b* significantly above median (yellowish)
73
+ """
74
+ lab = bgr_to_lab(bgr)
75
+ L, a, b = cv.split(lab)
76
+ a_med, b_med = np.median(a), np.median(b)
77
+ a_thr = a_med + 6 # tweak if needed
78
+ b_thr = b_med + 6
79
+
80
+ rustish = (a.astype(np.float32) > a_thr).mean()
81
+ zincish = (b.astype(np.float32) > b_thr).mean()
82
+ return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish),
83
+ "a_median": float(a_med), "b_median": float(b_med),
84
+ "a_thresh": float(a_thr), "b_thresh": float(b_thr)}
85
+
86
+ def main():
87
+ ap = argparse.ArgumentParser()
88
+ ap.add_argument("--img", required=True, help="path to image")
89
+ ap.add_argument("--k", type=int, default=3, help="number of dominant colors")
90
+ ap.add_argument("--resize_max", type=int, default=1200, help="resize longer side to this (0=off)")
91
+ ap.add_argument("--outdir", default="color_out")
92
+ args = ap.parse_args()
93
+
94
+ os.makedirs(args.outdir, exist_ok=True)
95
+
96
+ bgr = cv.imread(args.img, cv.IMREAD_COLOR)
97
+ if bgr is None:
98
+ raise RuntimeError(f"Cannot read image: {args.img}")
99
+
100
+ # Optional resize to speed up
101
+ h, w = bgr.shape[:2]
102
+ if args.resize_max > 0:
103
+ s = max(h, w)
104
+ if s > args.resize_max:
105
+ scale = args.resize_max / float(s)
106
+ bgr = cv.resize(bgr, (int(w*scale), int(h*scale)), interpolation=cv.INTER_AREA)
107
+
108
+ # Color-space stats
109
+ rgb = bgr_to_rgb(bgr)
110
+ hsv = bgr_to_hsv(bgr)
111
+ lab = bgr_to_lab(bgr)
112
+
113
+ stats = [
114
+ img_stats(rgb, "RGB"), # channels: R,G,B (0-255)
115
+ img_stats(hsv, "HSV"), # channels: H(0-179), S(0-255), V(0-255) in OpenCV
116
+ img_stats(lab, "Lab"), # channels: L(0-255), a(0-255), b(0-255) in OpenCV's scaled Lab
117
+ ]
118
+
119
+ # Dominant colors (k-means)
120
+ palette = dominant_colors_kmeans(bgr, k=max(1, args.k))
121
+
122
+ # Heuristic indicators (optional)
123
+ indicators = rust_zinc_indicators(bgr)
124
+
125
+ # Save palette image
126
+ pal_img = make_palette_image(palette)
127
+ base = os.path.splitext(os.path.basename(args.img))[0]
128
+ pal_path = os.path.join(args.outdir, f"{base}_palette.png")
129
+ cv.imwrite(pal_path, pal_img)
130
+
131
+ # Build and save JSON
132
+ report = {
133
+ "input": os.path.basename(args.img),
134
+ "size_hw": [int(bgr.shape[0]), int(bgr.shape[1])],
135
+ "color_stats": stats,
136
+ "dominant_colors": palette, # ordered by share desc
137
+ "heuristics": indicators,
138
+ "palette_image": pal_path
139
+ }
140
+ rep_path = os.path.join(args.outdir, f"{base}_color_report.json")
141
+ with open(rep_path, "w") as f:
142
+ json.dump(report, f, indent=2)
143
+
144
+ # Print a short summary to console
145
+ print(json.dumps({
146
+ "input": report["input"],
147
+ "top_colors_rgb": [p["RGB"] for p in report["dominant_colors"]],
148
+ "top_colors_share": [round(p["share"], 4) for p in report["dominant_colors"]],
149
+ "rustish_ratio": round(report["heuristics"]["rustish_ratio"], 4),
150
+ "zincish_ratio": round(report["heuristics"]["zincish_ratio"], 4),
151
+ "report_path": rep_path,
152
+ "palette_image": pal_path
153
+ }, indent=2))
154
+
155
+ if __name__ == "__main__":
156
+ main()
main.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2 as cv
2
+ import numpy as np
3
+ import argparse, os, json
4
+ from collections import Counter
5
+
6
+ # ---------------- Conversions ----------------
7
+ def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
8
+ def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
9
+ def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)
10
+
11
+ # ---------------- Stats ----------------
12
+ def img_stats(img, space_name):
13
+ # img is uint8, shape HxWxC
14
+ means = img.reshape(-1, img.shape[2]).mean(axis=0)
15
+ stds = img.reshape(-1, img.shape[2]).std(axis=0)
16
+ return {
17
+ "space": space_name,
18
+ "mean": [float(x) for x in means],
19
+ "std": [float(x) for x in stds]
20
+ }
21
+
22
+ # ---------------- Dominant colors ----------------
23
+ def dominant_colors_kmeans(bgr, k=3, max_iter=10):
24
+ data = bgr.reshape((-1, 3)).astype(np.float32)
25
+ criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
26
+ flags = cv.KMEANS_PP_CENTERS
27
+ compactness, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
28
+ centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
29
+ counts = Counter(labels.flatten())
30
+ total = float(len(labels))
31
+
32
+ idx_sorted = [i for i,_ in counts.most_common()]
33
+ palette = []
34
+ for idx in idx_sorted:
35
+ bgr_c = centers_u8[idx].tolist()
36
+ rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
37
+ hsv_c = bgr_to_hsv(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
38
+ lab_c = bgr_to_lab(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
39
+ share = counts[idx] / total
40
+ palette.append({
41
+ "share": float(share),
42
+ "BGR": [int(x) for x in bgr_c],
43
+ "RGB": [int(x) for x in rgb_c],
44
+ "HSV": [int(x) for x in hsv_c],
45
+ "Lab": [int(x) for x in lab_c],
46
+ })
47
+ return palette
48
+
49
+ def make_palette_image(palette, width=600, height=80):
50
+ img = np.zeros((height, width, 3), dtype=np.uint8)
51
+ x = 0
52
+ for p in palette:
53
+ w = max(1, int(p["share"] * width))
54
+ r,g,b = p["RGB"] # stored as RGB
55
+ cv.rectangle(img, (x, 0), (min(width-1, x+w-1), height-1), (b,g,r), -1) # convert to BGR for draw
56
+ x += w
57
+ for i in range(1, len(palette)):
58
+ x_sep = int(sum([pp["share"] for pp in palette[:i]]) * width)
59
+ cv.line(img, (x_sep, 0), (x_sep, height-1), (30,30,30), 1)
60
+ return img
61
+
62
+ # ---------------- Heuristics ----------------
63
+ def rust_zinc_indicators(bgr, delta=6):
64
+ """
65
+ Heuristic only. Uses Lab:
66
+ - rustish_ratio: fraction of pixels with a* > median(a*) + delta
67
+ - zincish_ratio: fraction of pixels with b* > median(b*) + delta
68
+ """
69
+ lab = bgr_to_lab(bgr)
70
+ L, a, b = cv.split(lab)
71
+ a_med, b_med = np.median(a), np.median(b)
72
+ a_thr = a_med + delta
73
+ b_thr = b_med + delta
74
+
75
+ rustish = (a.astype(np.float32) > a_thr).mean()
76
+ zincish = (b.astype(np.float32) > b_thr).mean()
77
+ return {
78
+ "rustish_ratio": float(rustish),
79
+ "zincish_ratio": float(zincish),
80
+ "a_median": float(a_med),
81
+ "b_median": float(b_med),
82
+ "a_thresh": float(a_thr),
83
+ "b_thresh": float(b_thr),
84
+ "delta": float(delta)
85
+ }
86
+
87
+ # ---------------- Classification ----------------
88
+ def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.002, zinc_thr=0.01):
89
+ """
90
+ Your rule:
91
+ - zinc if zincish_ratio > 0.01
92
+ - else rust if rustish_ratio > 0.002
93
+ - else normal
94
+ """
95
+ if zincish_ratio > zinc_thr:
96
+ return "zinc"
97
+ elif rustish_ratio > rust_thr:
98
+ return "rust"
99
+ else:
100
+ return "normal"
101
+
102
+ # ---------------- Main ----------------
103
+ def main():
104
+ ap = argparse.ArgumentParser()
105
+ ap.add_argument("--img", required=True, help="path to image")
106
+ ap.add_argument("--k", type=int, default=3, help="number of dominant colors")
107
+ ap.add_argument("--resize_max", type=int, default=1200, help="resize longer side to this (0=off)")
108
+ ap.add_argument("--outdir", default="color_out")
109
+ # thresholds you defined:
110
+ ap.add_argument("--rust_thr", type=float, default=0.01)
111
+ ap.add_argument("--zinc_thr", type=float, default=0.02)
112
+ # indicator sensitivity (Lab delta)
113
+ ap.add_argument("--lab_delta", type=float, default=6.0)
114
+ args = ap.parse_args()
115
+
116
+ os.makedirs(args.outdir, exist_ok=True)
117
+
118
+ bgr = cv.imread(args.img, cv.IMREAD_COLOR)
119
+ if bgr is None:
120
+ raise RuntimeError(f"Cannot read image: {args.img}")
121
+
122
+ # optional resize
123
+ h, w = bgr.shape[:2]
124
+ if args.resize_max > 0:
125
+ s = max(h, w)
126
+ if s > args.resize_max:
127
+ scale = args.resize_max / float(s)
128
+ bgr = cv.resize(bgr, (int(w*scale), int(h*scale)), interpolation=cv.INTER_AREA)
129
+
130
+ # color stats
131
+ rgb = bgr_to_rgb(bgr)
132
+ hsv = bgr_to_hsv(bgr)
133
+ lab = bgr_to_lab(bgr)
134
+ stats = [
135
+ img_stats(rgb, "RGB"),
136
+ img_stats(hsv, "HSV"),
137
+ img_stats(lab, "Lab"),
138
+ ]
139
+
140
+ # dominant colors
141
+ palette = dominant_colors_kmeans(bgr, k=max(1, args.k))
142
+
143
+ # heuristics
144
+ indicators = rust_zinc_indicators(bgr, delta=args.lab_delta)
145
+
146
+ # classification using your thresholds
147
+ cls = classify_from_ratios(
148
+ rustish_ratio=indicators["rustish_ratio"],
149
+ zincish_ratio=indicators["zincish_ratio"],
150
+ rust_thr=args.rust_thr,
151
+ zinc_thr=args.zinc_thr
152
+ )
153
+
154
+ # save palette image
155
+ base = os.path.splitext(os.path.basename(args.img))[0]
156
+ pal_img = make_palette_image(palette)
157
+ pal_path = os.path.join(args.outdir, f"{base}_palette.png")
158
+ cv.imwrite(pal_path, pal_img)
159
+
160
+ # build + save JSON
161
+ report = {
162
+ "input": os.path.basename(args.img),
163
+ "size_hw": [int(bgr.shape[0]), int(bgr.shape[1])],
164
+ "color_stats": stats,
165
+ "dominant_colors": palette, # ordered by share desc
166
+ "heuristics": indicators,
167
+ "classification": cls,
168
+ "thresholds": {"rust_thr": args.rust_thr, "zinc_thr": args.zinc_thr},
169
+ "palette_image": pal_path
170
+ }
171
+ rep_path = os.path.join(args.outdir, f"{base}_color_report.json")
172
+ with open(rep_path, "w") as f:
173
+ json.dump(report, f, indent=2)
174
+
175
+ # console summary
176
+ print(json.dumps({
177
+ "input": report["input"],
178
+ "classification": cls,
179
+ "rustish_ratio": round(indicators["rustish_ratio"], 4),
180
+ "zincish_ratio": round(indicators["zincish_ratio"], 4),
181
+ "top_colors_rgb": [p["RGB"] for p in palette],
182
+ "top_colors_share": [round(p["share"], 4) for p in palette],
183
+ "report_path": rep_path,
184
+ "palette_image": pal_path
185
+ }, indent=2))
186
+
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ "fastapi[all]"
2
+ opencv-python-headless
3
+ numpy
4
+ gardio