AutoTeckIQ / app.py
viswanani's picture
Update app.py
26e6ab2 verified
import gradio as gr
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from sklearn.cluster import KMeans
# ---- Model Loading Fix for PyTorch 2.6 ----
def load_yolo_model(weights_path="best.pt"):
"""
Loads YOLO model safely in PyTorch >=2.6
If torch.load fails due to weights_only, force weights_only=False.
"""
try:
# Try normal YOLO loading
model = YOLO(weights_path)
except Exception as e:
print(f"[WARN] Normal YOLO load failed: {e}")
print("[INFO] Retrying with torch.load(weights_only=False)...")
ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
model = YOLO()
model.model = ckpt["model"] # load model state
return model
# Load model (replace best.pt with your trained weights)
model = load_yolo_model("best.pt")
# ---- Color Detection Helper ----
def dominant_color_bgr(bgr_img, k=3):
"""Return dominant RGB color of a cropped BGR image."""
if bgr_img is None or bgr_img.size == 0:
return (0, 0, 0)
img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
counts = np.bincount(kmeans.labels_)
dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
return tuple(map(int, dominant)) # (R, G, B)
# ---- Inference Function ----
def predict(image):
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
res = model.predict(img, device="cpu", conf=0.25)[0]
boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else []
classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else []
scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else []
out = img.copy()
labels = []
for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores):
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)
crop = out[y1:y2, x1:x2]
dom = dominant_color_bgr(crop)
color_name = f"RGB{dom}"
cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
labels.append({
"bbox": [x1, y1, x2, y2],
"score": float(s),
"color": color_name
})
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
return out_rgb, labels
# ---- Gradio App ----
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="numpy"), gr.JSON()],
title="Car Damage + Color Detector (CPU)",
description="Upload a car photo; model detects damage and estimates dominant color."
)
if __name__ == "__main__":
demo.launch()