File size: 2,754 Bytes
7855e61
4f3bee8
 
26e6ab2
4f3bee8
 
7855e61
24e4dba
26e6ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f3bee8
26e6ab2
 
 
 
4f3bee8
 
 
26e6ab2
 
24e4dba
26e6ab2
4f3bee8
 
 
26e6ab2
 
 
 
24e4dba
4f3bee8
 
26e6ab2
 
 
 
4f3bee8
 
 
26e6ab2
 
 
 
 
 
 
 
 
4f3bee8
 
 
7855e61
26e6ab2
 
24e4dba
4f3bee8
24e4dba
4f3bee8
 
26e6ab2
24e4dba
7855e61
24e4dba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from sklearn.cluster import KMeans


# ---- Model Loading Fix for PyTorch 2.6 ----
def load_yolo_model(weights_path="best.pt"):
    """
    Loads YOLO model safely in PyTorch >=2.6
    If torch.load fails due to weights_only, force weights_only=False.
    """
    try:
        # Try normal YOLO loading
        model = YOLO(weights_path)
    except Exception as e:
        print(f"[WARN] Normal YOLO load failed: {e}")
        print("[INFO] Retrying with torch.load(weights_only=False)...")

        ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
        model = YOLO()
        model.model = ckpt["model"]  # load model state
    return model


# Load model (replace best.pt with your trained weights)
model = load_yolo_model("best.pt")


# ---- Color Detection Helper ----
def dominant_color_bgr(bgr_img, k=3):
    """Return dominant RGB color of a cropped BGR image."""
    if bgr_img is None or bgr_img.size == 0:
        return (0, 0, 0)
    img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
    kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
    counts = np.bincount(kmeans.labels_)
    dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
    return tuple(map(int, dominant))  # (R, G, B)


# ---- Inference Function ----
def predict(image):
    img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    res = model.predict(img, device="cpu", conf=0.25)[0]

    boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else []
    classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else []
    scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else []

    out = img.copy()
    labels = []
    for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores):
        x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
        cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)

        crop = out[y1:y2, x1:x2]
        dom = dominant_color_bgr(crop)
        color_name = f"RGB{dom}"

        cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

        labels.append({
            "bbox": [x1, y1, x2, y2],
            "score": float(s),
            "color": color_name
        })

    out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
    return out_rgb, labels


# ---- Gradio App ----
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="numpy"), gr.JSON()],
    title="Car Damage + Color Detector (CPU)",
    description="Upload a car photo; model detects damage and estimates dominant color."
)

if __name__ == "__main__":
    demo.launch()