Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,48 +1,85 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from ultralytics import YOLO
|
| 5 |
from sklearn.cluster import KMeans
|
| 6 |
|
| 7 |
-
# load YOLO model on CPU
|
| 8 |
-
model = YOLO("best.pt") # replace with your trained weights
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def dominant_color_bgr(bgr_img, k=3):
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
| 12 |
kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
|
| 13 |
counts = np.bincount(kmeans.labels_)
|
| 14 |
dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
|
| 15 |
-
return tuple(map(int, dominant))
|
|
|
|
| 16 |
|
|
|
|
| 17 |
def predict(image):
|
| 18 |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 19 |
res = model.predict(img, device="cpu", conf=0.25)[0]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
| 23 |
|
| 24 |
out = img.copy()
|
| 25 |
labels = []
|
| 26 |
-
for (x1,y1,x2,y2), c, s in zip(boxes, classes, scores):
|
| 27 |
-
x1,y1,x2,y2 = map(int, (x1,y1,x2,y2))
|
| 28 |
-
cv2.rectangle(out, (x1,y1), (x2,y2), (0,255,0), 2)
|
|
|
|
| 29 |
crop = out[y1:y2, x1:x2]
|
| 30 |
-
if crop.size == 0: continue
|
| 31 |
dom = dominant_color_bgr(crop)
|
| 32 |
color_name = f"RGB{dom}"
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
| 38 |
return out_rgb, labels
|
| 39 |
|
|
|
|
|
|
|
| 40 |
demo = gr.Interface(
|
| 41 |
fn=predict,
|
| 42 |
inputs=gr.Image(type="pil"),
|
| 43 |
outputs=[gr.Image(type="numpy"), gr.JSON()],
|
| 44 |
title="Car Damage + Color Detector (CPU)",
|
| 45 |
-
description="Upload a car photo; model detects damage and estimates color."
|
| 46 |
)
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
from ultralytics import YOLO
|
| 6 |
from sklearn.cluster import KMeans
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# ---- Model Loading Fix for PyTorch 2.6 ----
|
| 10 |
+
def load_yolo_model(weights_path="best.pt"):
|
| 11 |
+
"""
|
| 12 |
+
Loads YOLO model safely in PyTorch >=2.6
|
| 13 |
+
If torch.load fails due to weights_only, force weights_only=False.
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
# Try normal YOLO loading
|
| 17 |
+
model = YOLO(weights_path)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"[WARN] Normal YOLO load failed: {e}")
|
| 20 |
+
print("[INFO] Retrying with torch.load(weights_only=False)...")
|
| 21 |
+
|
| 22 |
+
ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
|
| 23 |
+
model = YOLO()
|
| 24 |
+
model.model = ckpt["model"] # load model state
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Load model (replace best.pt with your trained weights)
|
| 29 |
+
model = load_yolo_model("best.pt")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---- Color Detection Helper ----
|
| 33 |
def dominant_color_bgr(bgr_img, k=3):
|
| 34 |
+
"""Return dominant RGB color of a cropped BGR image."""
|
| 35 |
+
if bgr_img is None or bgr_img.size == 0:
|
| 36 |
+
return (0, 0, 0)
|
| 37 |
+
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
| 38 |
kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
|
| 39 |
counts = np.bincount(kmeans.labels_)
|
| 40 |
dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
|
| 41 |
+
return tuple(map(int, dominant)) # (R, G, B)
|
| 42 |
+
|
| 43 |
|
| 44 |
+
# ---- Inference Function ----
|
| 45 |
def predict(image):
|
| 46 |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 47 |
res = model.predict(img, device="cpu", conf=0.25)[0]
|
| 48 |
+
|
| 49 |
+
boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else []
|
| 50 |
+
classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else []
|
| 51 |
+
scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else []
|
| 52 |
|
| 53 |
out = img.copy()
|
| 54 |
labels = []
|
| 55 |
+
for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores):
|
| 56 |
+
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
|
| 57 |
+
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 58 |
+
|
| 59 |
crop = out[y1:y2, x1:x2]
|
|
|
|
| 60 |
dom = dominant_color_bgr(crop)
|
| 61 |
color_name = f"RGB{dom}"
|
| 62 |
+
|
| 63 |
+
cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10),
|
| 64 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
| 65 |
+
|
| 66 |
+
labels.append({
|
| 67 |
+
"bbox": [x1, y1, x2, y2],
|
| 68 |
+
"score": float(s),
|
| 69 |
+
"color": color_name
|
| 70 |
+
})
|
| 71 |
|
| 72 |
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
| 73 |
return out_rgb, labels
|
| 74 |
|
| 75 |
+
|
| 76 |
+
# ---- Gradio App ----
|
| 77 |
demo = gr.Interface(
|
| 78 |
fn=predict,
|
| 79 |
inputs=gr.Image(type="pil"),
|
| 80 |
outputs=[gr.Image(type="numpy"), gr.JSON()],
|
| 81 |
title="Car Damage + Color Detector (CPU)",
|
| 82 |
+
description="Upload a car photo; model detects damage and estimates dominant color."
|
| 83 |
)
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|