Spaces:
Sleeping
Sleeping
| # api.py | |
| import cv2 as cv | |
| import numpy as np | |
| import json | |
| from collections import Counter | |
| from typing import List | |
| from fastapi import FastAPI, File, UploadFile, Query, HTTPException | |
| from pydantic import BaseModel | |
| # --- Initialize FastAPI App --- | |
| app = FastAPI( | |
| title="Image Color Classifier API", | |
| description="Upload an image to classify it as 'rust', 'zinc', or 'normal' based on color heuristics." | |
| ) | |
| # --- Define the Response Model (for OpenAPI docs and validation) --- | |
| class ClassificationResponse(BaseModel): | |
| filename: str | |
| classification: str | |
| rustish_ratio: float | |
| zincish_ratio: float | |
| top_colors_rgb: List[List[int]] | |
| top_colors_share: List[float] | |
| # --- Helper Functions (Copied from your main.py) --- | |
| # Color space conversions | |
| def bgr_to_rgb(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2RGB) | |
| def bgr_to_hsv(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2HSV) | |
| def bgr_to_lab(bgr): return cv.cvtColor(bgr, cv.COLOR_BGR2LAB) | |
| # Dominant color extraction using KMeans | |
| def dominant_colors_kmeans(bgr, k=3, max_iter=10): | |
| data = bgr.reshape((-1, 3)).astype(np.float32) | |
| criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, max_iter, 1.0) | |
| flags = cv.KMEANS_PP_CENTERS | |
| _, labels, centers = cv.kmeans(data, k, None, criteria, 3, flags) | |
| centers_u8 = np.clip(centers, 0, 255).astype(np.uint8) | |
| counts = Counter(labels.flatten()) | |
| total = float(len(labels)) | |
| idx_sorted = [i for i, _ in counts.most_common()] | |
| palette = [] | |
| for idx in idx_sorted: | |
| bgr_c = centers_u8[idx].tolist() | |
| rgb_c = bgr_to_rgb(np.array([[bgr_c]], dtype=np.uint8)).reshape(-1).tolist() | |
| share = counts[idx] / total | |
| palette.append({"share": float(share), "RGB": [int(x) for x in rgb_c]}) | |
| return palette | |
| # Heuristic calculation for rust/zinc | |
| def rust_zinc_indicators(bgr, delta=6.0): | |
| lab = bgr_to_lab(bgr) | |
| _, a, b = cv.split(lab) | |
| a_med, b_med = np.median(a), np.median(b) | |
| a_thr = a_med + delta | |
| b_thr = b_med + delta | |
| rustish = (a.astype(np.float32) > a_thr).mean() | |
| zincish = (b.astype(np.float32) > b_thr).mean() | |
| return {"rustish_ratio": float(rustish), "zincish_ratio": float(zincish)} | |
| # Classification logic | |
| def classify_from_ratios(rustish_ratio, zincish_ratio, rust_thr=0.01, zinc_thr=0.02): | |
| if zincish_ratio > zinc_thr: | |
| return "zinc" | |
| elif rustish_ratio > rust_thr: | |
| return "rust" | |
| else: | |
| return "normal" | |
| # --- API Endpoint --- | |
| async def classify_image( | |
| file: UploadFile = File(..., description="The image file to classify."), | |
| k: int = Query(3, description="Number of dominant colors to extract."), | |
| rust_thr: float = Query(0.01, description="Threshold for 'rust' classification."), | |
| zinc_thr: float = Query(0.02, description="Threshold for 'zinc' classification."), | |
| lab_delta: float = Query(6.0, description="Sensitivity for heuristic indicators in Lab color space.") | |
| ): | |
| """ | |
| Accepts an image file and returns a classification based on color analysis. | |
| """ | |
| # 1. Read image bytes from upload | |
| contents = await file.read() | |
| # 2. Convert bytes to a NumPy array and then to an OpenCV image | |
| nparr = np.frombuffer(contents, np.uint8) | |
| bgr = cv.imdecode(nparr, cv.IMREAD_COLOR) | |
| if bgr is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file. Could not decode image.") | |
| # 3. Perform color analysis and classification | |
| indicators = rust_zinc_indicators(bgr, delta=lab_delta) | |
| classification = classify_from_ratios( | |
| rustish_ratio=indicators["rustish_ratio"], | |
| zincish_ratio=indicators["zincish_ratio"], | |
| rust_thr=rust_thr, | |
| zinc_thr=zinc_thr | |
| ) | |
| palette = dominant_colors_kmeans(bgr, k=max(1, k)) | |
| # 4. Format the response | |
| response_data = { | |
| "filename": file.filename, | |
| "classification": classification, | |
| "rustish_ratio": round(indicators["rustish_ratio"], 4), | |
| "zincish_ratio": round(indicators["zincish_ratio"], 4), | |
| "top_colors_rgb": [p["RGB"] for p in palette], | |
| "top_colors_share": [round(p["share"], 4) for p in palette] | |
| } | |
| return response_data |