File size: 4,230 Bytes
9ff9b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# api.py
import cv2 as cv
import numpy as np
import json
from collections import Counter
from typing import List

from fastapi import FastAPI, File, UploadFile, Query, HTTPException
from pydantic import BaseModel

# --- Initialize FastAPI App ---
app = FastAPI(
    title="Image Color Classifier API",
    description="Upload an image to classify it as 'rust', 'zinc', or 'normal' based on color heuristics."
)

# --- Define the Response Model (for OpenAPI docs and validation) ---
class ClassificationResponse(BaseModel):
    filename: str
    classification: str
    rustish_ratio: float
    zincish_ratio: float
    top_colors_rgb: List[List[int]]
    top_colors_share: List[float]


# --- Helper Functions (Copied from your main.py) ---

# Color space conversions
def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV)
def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB)

# Dominant color extraction using KMeans
def dominant_colors_kmeans(bgr, k=3, max_iter=10):
    data = bgr.reshape((-1, 3)).astype(np.float32)
    criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0)
    flags = cv.KMEANS_PP_CENTERS
    _, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags)
    centers_u8 = np.clip(centers, 0, 255).astype(np.uint8)
    counts = Counter(labels.flatten())
    total = float(len(labels))

    idx_sorted = [i for i, _ in counts.most_common()]
    palette = []
    for idx in idx_sorted:
        bgr_c = centers_u8[idx].tolist()
        rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist()
        share = counts[idx] / total
        palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]})
    return palette

# Heuristic calculation for rust/zinc
def rust_zinc_indicators(bgr, delta=6.0):
    lab = bgr_to_lab(bgr)
    _, a, b = cv.split(lab)
    a_med, b_med = np.median(a), np.median(b)
    a_thr = a_med + delta
    b_thr = b_med + delta

    rustish = (a.astype(np.float32) > a_thr).mean()
    zincish = (b.astype(np.float32) > b_thr).mean()
    return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)}

# Classification logic
def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02):
    if zincish_ratio > zinc_thr:
        return "zinc"
    elif rustish_ratio > rust_thr:
        return "rust"
    else:
        return "normal"

# --- API Endpoint ---

@app.post("/classify/", response_model=ClassificationResponse)
async def classify_image(
    file: UploadFile = File(..., description="The image file to classify."),
    k: int = Query(3, description="Number of dominant colors to extract."),
    rust_thr: float = Query(0.01, description="Threshold for 'rust' classification."),
    zinc_thr: float = Query(0.02, description="Threshold for 'zinc' classification."),
    lab_delta: float = Query(6.0, description="Sensitivity for heuristic indicators in Lab color space.")
):
    """
    Accepts an image file and returns a classification based on color analysis.
    """
    # 1. Read image bytes from upload
    contents = await file.read()
    
    # 2. Convert bytes to a NumPy array and then to an OpenCV image
    nparr = np.frombuffer(contents, np.uint8)
    bgr = cv.imdecode(nparr, cv.IMREAD_COLOR)

    if bgr is None:
        raise HTTPException(status_code=400, detail="Invalid image file. Could not decode image.")

    # 3. Perform color analysis and classification
    indicators = rust_zinc_indicators(bgr, delta=lab_delta)
    classification = classify_from_ratios(
        rustish_ratio=indicators["rustish_ratio"],
        zincish_ratio=indicators["zincish_ratio"],
        rust_thr=rust_thr,
        zinc_thr=zinc_thr
    )
    palette = dominant_colors_kmeans(bgr, k=max(1, k))

    # 4. Format the response
    response_data = {
        "filename": file.filename,
        "classification": classification,
        "rustish_ratio": round(indicators["rustish_ratio"], 4),
        "zincish_ratio": round(indicators["zincish_ratio"], 4),
        "top_colors_rgb": [p["RGB"] for p in palette],
        "top_colors_share": [round(p["share"], 4) for p in palette]
    }

    return response_data