import gradio as gr import cv2 import numpy as np import torch from ultralytics import YOLO from sklearn.cluster import KMeans # ---- Model Loading Fix for PyTorch 2.6 ---- def load_yolo_model(weights_path="best.pt"): """ Loads YOLO model safely in PyTorch >=2.6 If torch.load fails due to weights_only, force weights_only=False. """ try: # Try normal YOLO loading model = YOLO(weights_path) except Exception as e: print(f"[WARN] Normal YOLO load failed: {e}") print("[INFO] Retrying with torch.load(weights_only=False)...") ckpt = torch.load(weights_path, map_location="cpu", weights_only=False) model = YOLO() model.model = ckpt["model"] # load model state return model # Load model (replace best.pt with your trained weights) model = load_yolo_model("best.pt") # ---- Color Detection Helper ---- def dominant_color_bgr(bgr_img, k=3): """Return dominant RGB color of a cropped BGR image.""" if bgr_img is None or bgr_img.size == 0: return (0, 0, 0) img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3) kmeans = KMeans(n_clusters=k, random_state=0).fit(img) counts = np.bincount(kmeans.labels_) dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int) return tuple(map(int, dominant)) # (R, G, B) # ---- Inference Function ---- def predict(image): img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) res = model.predict(img, device="cpu", conf=0.25)[0] boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else [] classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else [] scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else [] out = img.copy() labels = [] for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores): x1, y1, x2, y2 = map(int, (x1, y1, x2, y2)) cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2) crop = out[y1:y2, x1:x2] dom = dominant_color_bgr(crop) color_name = f"RGB{dom}" cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) labels.append({ "bbox": [x1, y1, x2, y2], "score": float(s), "color": color_name }) out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) return out_rgb, labels # ---- Gradio App ---- demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="numpy"), gr.JSON()], title="Car Damage + Color Detector (CPU)", description="Upload a car photo; model detects damage and estimates dominant color." ) if __name__ == "__main__": demo.launch()