Spaces:
Runtime error
Runtime error
File size: 2,754 Bytes
7855e61 4f3bee8 26e6ab2 4f3bee8 7855e61 24e4dba 26e6ab2 4f3bee8 26e6ab2 4f3bee8 26e6ab2 24e4dba 26e6ab2 4f3bee8 26e6ab2 24e4dba 4f3bee8 26e6ab2 4f3bee8 26e6ab2 4f3bee8 7855e61 26e6ab2 24e4dba 4f3bee8 24e4dba 4f3bee8 26e6ab2 24e4dba 7855e61 24e4dba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from sklearn.cluster import KMeans
# ---- Model Loading Fix for PyTorch 2.6 ----
def load_yolo_model(weights_path="best.pt"):
"""
Loads YOLO model safely in PyTorch >=2.6
If torch.load fails due to weights_only, force weights_only=False.
"""
try:
# Try normal YOLO loading
model = YOLO(weights_path)
except Exception as e:
print(f"[WARN] Normal YOLO load failed: {e}")
print("[INFO] Retrying with torch.load(weights_only=False)...")
ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
model = YOLO()
model.model = ckpt["model"] # load model state
return model
# Load model (replace best.pt with your trained weights)
model = load_yolo_model("best.pt")
# ---- Color Detection Helper ----
def dominant_color_bgr(bgr_img, k=3):
"""Return dominant RGB color of a cropped BGR image."""
if bgr_img is None or bgr_img.size == 0:
return (0, 0, 0)
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
counts = np.bincount(kmeans.labels_)
dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
return tuple(map(int, dominant)) # (R, G, B)
# ---- Inference Function ----
def predict(image):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
res = model.predict(img, device="cpu", conf=0.25)[0]
boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else []
classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else []
scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else []
out = img.copy()
labels = []
for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores):
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)
crop = out[y1:y2, x1:x2]
dom = dominant_color_bgr(crop)
color_name = f"RGB{dom}"
cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
labels.append({
"bbox": [x1, y1, x2, y2],
"score": float(s),
"color": color_name
})
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
return out_rgb, labels
# ---- Gradio App ----
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="numpy"), gr.JSON()],
title="Car Damage + Color Detector (CPU)",
description="Upload a car photo; model detects damage and estimates dominant color."
)
if __name__ == "__main__":
demo.launch()
|