Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from ultralytics import YOLO | |
| from sklearn.cluster import KMeans | |
| # ---- Model Loading Fix for PyTorch 2.6 ---- | |
| def load_yolo_model(weights_path="best.pt"): | |
| """ | |
| Loads YOLO model safely in PyTorch >=2.6 | |
| If torch.load fails due to weights_only, force weights_only=False. | |
| """ | |
| try: | |
| # Try normal YOLO loading | |
| model = YOLO(weights_path) | |
| except Exception as e: | |
| print(f"[WARN] Normal YOLO load failed: {e}") | |
| print("[INFO] Retrying with torch.load(weights_only=False)...") | |
| ckpt = torch.load(weights_path, map_location="cpu", weights_only=False) | |
| model = YOLO() | |
| model.model = ckpt["model"] # load model state | |
| return model | |
| # Load model (replace best.pt with your trained weights) | |
| model = load_yolo_model("best.pt") | |
| # ---- Color Detection Helper ---- | |
| def dominant_color_bgr(bgr_img, k=3): | |
| """Return dominant RGB color of a cropped BGR image.""" | |
| if bgr_img is None or bgr_img.size == 0: | |
| return (0, 0, 0) | |
| img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3) | |
| kmeans = KMeans(n_clusters=k, random_state=0).fit(img) | |
| counts = np.bincount(kmeans.labels_) | |
| dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int) | |
| return tuple(map(int, dominant)) # (R, G, B) | |
| # ---- Inference Function ---- | |
| def predict(image): | |
| img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| res = model.predict(img, device="cpu", conf=0.25)[0] | |
| boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else [] | |
| classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else [] | |
| scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else [] | |
| out = img.copy() | |
| labels = [] | |
| for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores): | |
| x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) | |
| cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| crop = out[y1:y2, x1:x2] | |
| dom = dominant_color_bgr(crop) | |
| color_name = f"RGB{dom}" | |
| cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
| labels.append({ | |
| "bbox": [x1, y1, x2, y2], | |
| "score": float(s), | |
| "color": color_name | |
| }) | |
| out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) | |
| return out_rgb, labels | |
| # ---- Gradio App ---- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[gr.Image(type="numpy"), gr.JSON()], | |
| title="Car Damage + Color Detector (CPU)", | |
| description="Upload a car photo; model detects damage and estimates dominant color." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |