File size: 6,614 Bytes
9ff9b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import cv2 as cv
import numpy as np
import argparse, os, json
from collections import Counter

# ---------------- Conversions ----------------
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)

# ---------------- Stats ----------------
def img_stats(img, space_name):
    # img is uint8, shape HxWxC
    means = img.reshape(-1, img.shape[2]).mean(axis=0)
    stds  = img.reshape(-1, img.shape[2]).std(axis=0)
    return {
        "space": space_name,
        "mean": [float(x) for x in means],
        "std":  [float(x) for x in stds]
    }

# ---------------- Dominant colors ----------------
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
    data = bgr.reshape((-1, 3)).astype(np.float32)
    criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
    flags = cv.KMEANS_PP_CENTERS
    compactness, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
    centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
    counts = Counter(labels.flatten())
    total = float(len(labels))

    idx_sorted = [i for i,_ in counts.most_common()]
    palette = []
    for idx in idx_sorted:
        bgr_c = centers_u8[idx].tolist()
        rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
        hsv_c = bgr_to_hsv(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
        lab_c = bgr_to_lab(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
        share = counts[idx] / total
        palette.append({
            "share": float(share),
            "BGR": [int(x) for x in bgr_c],
            "RGB": [int(x) for x in rgb_c],
            "HSV": [int(x) for x in hsv_c],
            "Lab": [int(x) for x in lab_c],
        })
    return palette

def make_palette_image(palette, width=600, height=80):
    img = np.zeros((height, width, 3), dtype=np.uint8)
    x = 0
    for p in palette:
        w = max(1, int(p["share"] * width))
        r,g,b = p["RGB"]  # stored as RGB
        cv.rectangle(img, (x, 0), (min(width-1, x+w-1), height-1), (b,g,r), -1)  # convert to BGR for draw
        x += w
    for i in range(1, len(palette)):
        x_sep = int(sum([pp["share"] for pp in palette[:i]]) * width)
        cv.line(img, (x_sep, 0), (x_sep, height-1), (30,30,30), 1)
    return img

# ---------------- Heuristics ----------------
def rust_zinc_indicators(bgr, delta=6):
    """
    Heuristic only. Uses Lab:
    - rustish_ratio: fraction of pixels with a* > median(a*) + delta
    - zincish_ratio: fraction of pixels with b* > median(b*) + delta
    """
    lab = bgr_to_lab(bgr)
    L, a, b = cv.split(lab)
    a_med, b_med = np.median(a), np.median(b)
    a_thr = a_med + delta
    b_thr = b_med + delta

    rustish = (a.astype(np.float32) > a_thr).mean()
    zincish = (b.astype(np.float32) > b_thr).mean()
    return {
        "rustish_ratio": float(rustish),
        "zincish_ratio": float(zincish),
        "a_median": float(a_med),
        "b_median": float(b_med),
        "a_thresh": float(a_thr),
        "b_thresh": float(b_thr),
        "delta": float(delta)
    }

# ---------------- Classification ----------------
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.002, zinc_thr=0.01):
    """
    Your rule:
    - zinc if zincish_ratio > 0.01
    - else rust if rustish_ratio > 0.002
    - else normal
    """
    if zincish_ratio > zinc_thr:
        return "zinc"
    elif rustish_ratio > rust_thr:
        return "rust"
    else:
        return "normal"

# ---------------- Main ----------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--img", required=True, help="path to image")
    ap.add_argument("--k", type=int, default=3, help="number of dominant colors")
    ap.add_argument("--resize_max", type=int, default=1200, help="resize longer side to this (0=off)")
    ap.add_argument("--outdir", default="color_out")
    # thresholds you defined:
    ap.add_argument("--rust_thr", type=float, default=0.01)
    ap.add_argument("--zinc_thr", type=float, default=0.02)
    # indicator sensitivity (Lab delta)
    ap.add_argument("--lab_delta", type=float, default=6.0)
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    bgr = cv.imread(args.img, cv.IMREAD_COLOR)
    if bgr is None:
        raise RuntimeError(f"Cannot read image: {args.img}")

    # optional resize
    h, w = bgr.shape[:2]
    if args.resize_max > 0:
        s = max(h, w)
        if s > args.resize_max:
            scale = args.resize_max / float(s)
            bgr = cv.resize(bgr, (int(w*scale), int(h*scale)), interpolation=cv.INTER_AREA)

    # color stats
    rgb = bgr_to_rgb(bgr)
    hsv = bgr_to_hsv(bgr)
    lab = bgr_to_lab(bgr)
    stats = [
        img_stats(rgb, "RGB"),
        img_stats(hsv, "HSV"),
        img_stats(lab, "Lab"),
    ]

    # dominant colors
    palette = dominant_colors_kmeans(bgr, k=max(1, args.k))

    # heuristics
    indicators = rust_zinc_indicators(bgr, delta=args.lab_delta)

    # classification using your thresholds
    cls = classify_from_ratios(
        rustish_ratio=indicators["rustish_ratio"],
        zincish_ratio=indicators["zincish_ratio"],
        rust_thr=args.rust_thr,
        zinc_thr=args.zinc_thr
    )

    # save palette image
    base = os.path.splitext(os.path.basename(args.img))[0]
    pal_img = make_palette_image(palette)
    pal_path = os.path.join(args.outdir, f"{base}_palette.png")
    cv.imwrite(pal_path, pal_img)

    # build + save JSON
    report = {
        "input": os.path.basename(args.img),
        "size_hw": [int(bgr.shape[0]), int(bgr.shape[1])],
        "color_stats": stats,
        "dominant_colors": palette,    # ordered by share desc
        "heuristics": indicators,
        "classification": cls,
        "thresholds": {"rust_thr": args.rust_thr, "zinc_thr": args.zinc_thr},
        "palette_image": pal_path
    }
    rep_path = os.path.join(args.outdir, f"{base}_color_report.json")
    with open(rep_path, "w") as f:
        json.dump(report, f, indent=2)

    # console summary
    print(json.dumps({
        "input": report["input"],
        "classification": cls,
        "rustish_ratio": round(indicators["rustish_ratio"], 4),
        "zincish_ratio": round(indicators["zincish_ratio"], 4),
        "top_colors_rgb": [p["RGB"] for p in palette],
        "top_colors_share": [round(p["share"], 4) for p in palette],
        "report_path": rep_path,
        "palette_image": pal_path
    }, indent=2))



if __name__ == "__main__":
    main()