viswanani commited on
Commit
26e6ab2
·
verified ·
1 Parent(s): 43e80e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -15
app.py CHANGED
@@ -1,48 +1,85 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
 
4
  from ultralytics import YOLO
5
  from sklearn.cluster import KMeans
6
 
7
- # load YOLO model on CPU
8
- model = YOLO("best.pt") # replace with your trained weights
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def dominant_color_bgr(bgr_img, k=3):
11
- img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1,3)
 
 
 
12
  kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
13
  counts = np.bincount(kmeans.labels_)
14
  dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
15
- return tuple(map(int, dominant))
 
16
 
 
17
  def predict(image):
18
  img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
19
  res = model.predict(img, device="cpu", conf=0.25)[0]
20
- boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, 'boxes') else []
21
- classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, 'boxes') else []
22
- scores = res.boxes.conf.cpu().numpy() if hasattr(res, 'boxes') else []
 
23
 
24
  out = img.copy()
25
  labels = []
26
- for (x1,y1,x2,y2), c, s in zip(boxes, classes, scores):
27
- x1,y1,x2,y2 = map(int, (x1,y1,x2,y2))
28
- cv2.rectangle(out, (x1,y1), (x2,y2), (0,255,0), 2)
 
29
  crop = out[y1:y2, x1:x2]
30
- if crop.size == 0: continue
31
  dom = dominant_color_bgr(crop)
32
  color_name = f"RGB{dom}"
33
- cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1-10),
34
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
35
- labels.append({"bbox":[x1,y1,x2,y2],"score":float(s),"color":color_name})
 
 
 
 
 
 
36
 
37
  out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
38
  return out_rgb, labels
39
 
 
 
40
  demo = gr.Interface(
41
  fn=predict,
42
  inputs=gr.Image(type="pil"),
43
  outputs=[gr.Image(type="numpy"), gr.JSON()],
44
  title="Car Damage + Color Detector (CPU)",
45
- description="Upload a car photo; model detects damage and estimates color."
46
  )
47
 
48
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ import torch
5
  from ultralytics import YOLO
6
  from sklearn.cluster import KMeans
7
 
 
 
8
 
9
+ # ---- Model Loading Fix for PyTorch 2.6 ----
10
+ def load_yolo_model(weights_path="best.pt"):
11
+ """
12
+ Loads YOLO model safely in PyTorch >=2.6
13
+ If torch.load fails due to weights_only, force weights_only=False.
14
+ """
15
+ try:
16
+ # Try normal YOLO loading
17
+ model = YOLO(weights_path)
18
+ except Exception as e:
19
+ print(f"[WARN] Normal YOLO load failed: {e}")
20
+ print("[INFO] Retrying with torch.load(weights_only=False)...")
21
+
22
+ ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
23
+ model = YOLO()
24
+ model.model = ckpt["model"] # load model state
25
+ return model
26
+
27
+
28
+ # Load model (replace best.pt with your trained weights)
29
+ model = load_yolo_model("best.pt")
30
+
31
+
32
+ # ---- Color Detection Helper ----
33
  def dominant_color_bgr(bgr_img, k=3):
34
+ """Return dominant RGB color of a cropped BGR image."""
35
+ if bgr_img is None or bgr_img.size == 0:
36
+ return (0, 0, 0)
37
+ img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
38
  kmeans = KMeans(n_clusters=k, random_state=0).fit(img)
39
  counts = np.bincount(kmeans.labels_)
40
  dominant = kmeans.cluster_centers_[np.argmax(counts)].astype(int)
41
+ return tuple(map(int, dominant)) # (R, G, B)
42
+
43
 
44
+ # ---- Inference Function ----
45
  def predict(image):
46
  img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
47
  res = model.predict(img, device="cpu", conf=0.25)[0]
48
+
49
+ boxes = res.boxes.xyxy.cpu().numpy() if hasattr(res, "boxes") else []
50
+ classes = res.boxes.cls.cpu().numpy().astype(int) if hasattr(res, "boxes") else []
51
+ scores = res.boxes.conf.cpu().numpy() if hasattr(res, "boxes") else []
52
 
53
  out = img.copy()
54
  labels = []
55
+ for (x1, y1, x2, y2), c, s in zip(boxes, classes, scores):
56
+ x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
57
+ cv2.rectangle(out, (x1, y1), (x2, y2), (0, 255, 0), 2)
58
+
59
  crop = out[y1:y2, x1:x2]
 
60
  dom = dominant_color_bgr(crop)
61
  color_name = f"RGB{dom}"
62
+
63
+ cv2.putText(out, f"{color_name} {s:.2f}", (x1, y1 - 10),
64
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
65
+
66
+ labels.append({
67
+ "bbox": [x1, y1, x2, y2],
68
+ "score": float(s),
69
+ "color": color_name
70
+ })
71
 
72
  out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
73
  return out_rgb, labels
74
 
75
+
76
+ # ---- Gradio App ----
77
  demo = gr.Interface(
78
  fn=predict,
79
  inputs=gr.Image(type="pil"),
80
  outputs=[gr.Image(type="numpy"), gr.JSON()],
81
  title="Car Damage + Color Detector (CPU)",
82
+ description="Upload a car photo; model detects damage and estimates dominant color."
83
  )
84
 
85
  if __name__ == "__main__":