csmith715 commited on
Commit
4ee1abf
·
1 Parent(s): 81ffce9

Adding Radio button

Browse files
Files changed (2) hide show
  1. app.py +90 -85
  2. weld_tiling.py +1 -1
app.py CHANGED
@@ -1,102 +1,107 @@
1
  import gradio as gr
2
  import numpy as np
3
  from ultralytics import YOLO
4
- # import cv2
5
  from weld_tiling import detect_tiled_softnms, draw_detections
6
 
7
  # Load model once at startup
8
- model = YOLO("best_7-15-25.pt")
9
 
10
  # Class names (must match your training config)
11
- CLASS_NAMES = [
12
- "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer",
13
- "shop_bw", "shop_sw", "Union", "Weld-o-let",
14
- "field_bw", "field_sw", "Insulation"
15
- ]
16
-
17
- # def detect_weld_types(image: np.ndarray) -> tuple[np.ndarray, str]:
18
- # try:
19
- # # I'm not sure how to apply this to the original existing code
20
- # results = detect_tiled_softnms(
21
- # model, image,
22
- # tile_size=1024, overlap=0.23,
23
- # per_tile_conf=0.2, per_tile_iou=0.7,
24
- # softnms_iou=0.6, softnms_method="hard", softnms_sigma=0.5,
25
- # final_conf=0.38, device=None, imgsz=1280
26
- # )
27
- # # results = model(image)
28
- #
29
- # boxes = results[0].boxes
30
- # class_ids = boxes.cls.cpu().numpy().astype(int) if boxes.cls is not None else []
31
- #
32
- # weld_counts = {}
33
- # for cls_id in class_ids:
34
- # if 0 <= cls_id < len(CLASS_NAMES):
35
- # name = CLASS_NAMES[cls_id]
36
- # weld_counts[name] = weld_counts.get(name, 0) + 1
37
- #
38
- # annotated_img = results[0].plot() # BGR
39
- # annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
40
- #
41
- # weld_summary = "\n".join(f"{k}: {v}" for k, v in weld_counts.items()) or "No weld types found."
42
- #
43
- # return annotated_img, str(weld_summary)
44
- #
45
- # except Exception as e:
46
- # print("Error:", e)
47
- # # Ensure fallback types are strictly correct
48
- # return image, "Detection error occurred."
49
-
50
-
51
- # Build Gradio UI
52
-
53
- def detect_weld_types(image: np.ndarray) -> tuple[np.ndarray, str]:
54
- """
55
- Gradio expects/returns RGB numpy arrays. The tiler also works in RGB.
56
- """
57
- try:
58
- # Run tiled inference
59
- out = detect_tiled_softnms(
60
- model, image,
61
- tile_size=1024, overlap=0.23,
62
- per_tile_conf=0.20, per_tile_iou=0.70,
63
- softnms_iou=0.60, softnms_method="hard", softnms_sigma=0.50,
64
- final_conf=0.38,
65
- imgsz=1280, # keep >= tile_size; int or [h, w]
66
- device=None
67
- )
68
-
69
- boxes = out["boxes"] # (N,4) xyxy in full-image pixels
70
- confs = out["conf"] # (N,)
71
- cls_ids = out["cls"] # (N,)
72
- class_names = out["names"] # {id: name}
73
-
74
- # Count per class using model-provided names
75
- counts = {}
76
- for cid in cls_ids:
77
- cname = class_names.get(int(cid), str(int(cid)))
78
- counts[cname] = counts.get(cname, 0) + 1
79
-
80
- # Make an annotated image (RGB in, RGB out)
81
- annotated_img = draw_detections(image.copy(), boxes, confs, cls_ids, class_names)
82
-
83
- # Pretty summary text (sorted by count desc)
84
- if counts:
85
- summary_lines = [f"{k}: {v}" for k, v in sorted(counts.items(), key=lambda kv: -kv[1])]
86
- summary = "\n".join(summary_lines)
87
  else:
88
- summary = "No weld types found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- return annotated_img, summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- except Exception as e:
93
- print("Error in detect_weld_types:", repr(e))
94
- return image, "Detection error occurred."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  app = gr.Interface(
98
- fn=detect_weld_types,
99
- inputs=gr.Image(type="numpy", label="Upload a welding diagram"),
 
 
 
100
  outputs=[
101
  gr.Image(type="numpy", label="Detected Welds"),
102
  gr.Textbox(label="Weld Types Found")
 
1
  import gradio as gr
2
  import numpy as np
3
  from ultralytics import YOLO
4
+ import cv2
5
  from weld_tiling import detect_tiled_softnms, draw_detections
6
 
7
  # Load model once at startup
8
+ MODEL = YOLO("best_7-15-25.pt")
9
 
10
  # Class names (must match your training config)
11
+ # CLASS_NAMES = [
12
+ # "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer",
13
+ # "shop_bw", "shop_sw", "Union", "Weld-o-let",
14
+ # "field_bw", "field_sw", "Insulation"
15
+ # ]
16
+
17
+ class DetectWelds:
18
+ def __init__(self):
19
+ self.model = MODEL
20
+ self.class_names = [
21
+ "Valve", "Butterfly Valve", "Flange", "PRV", "Reducer", "shop_bw",
22
+ "shop_sw", "Union", "Weld-o-let", "field_bw", "field_sw", "Insulation"
23
+ ]
24
+
25
+ def weld_detection(self, input_image: np.ndarray, prediction_type):
26
+ if prediction_type == "tiling":
27
+ annotated_image, welds = self.detect_weld_types(input_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  else:
29
+ annotated_image, welds = self.detect_weld_types_no_tiling(input_image)
30
+ return annotated_image, welds
31
+
32
+ def detect_weld_types_no_tiling(self, image: np.ndarray) -> tuple[np.ndarray, str]:
33
+ try:
34
+ results = self.model(image)
35
+ boxes = results[0].boxes
36
+ class_ids = boxes.cls.cpu().numpy().astype(int) if boxes.cls is not None else []
37
+ weld_counts = {}
38
+ for cls_id in class_ids:
39
+ if 0 <= cls_id < len(self.class_names):
40
+ name = self.class_names[cls_id]
41
+ weld_counts[name] = weld_counts.get(name, 0) + 1
42
+
43
+ annotated_img = results[0].plot() # BGR
44
+ annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
45
+
46
+ weld_summary = "\n".join(f"{k}: {v}" for k, v in weld_counts.items()) or "No weld types found."
47
+
48
+ return annotated_img, str(weld_summary)
49
+
50
+ except Exception as e:
51
+ print("Error:", e)
52
+ # Ensure fallback types are strictly correct
53
+ return image, "Detection error occurred."
54
 
55
+ def detect_weld_types(self, image: np.ndarray) -> tuple[np.ndarray, str]:
56
+ """
57
+ Gradio expects/returns RGB numpy arrays. The tiler also works in RGB.
58
+ """
59
+ try:
60
+ # Run tiled inference
61
+ out = detect_tiled_softnms(
62
+ self.model, image,
63
+ tile_size=512, overlap=0.23,
64
+ per_tile_conf=0.20, per_tile_iou=0.70,
65
+ softnms_iou=0.60, softnms_method="hard", softnms_sigma=0.50,
66
+ final_conf=0.5,
67
+ imgsz=1280, # keep >= tile_size; int or [h, w]
68
+ device=None
69
+ )
70
 
71
+ boxes = out["boxes"] # (N,4) xyxy in full-image pixels
72
+ confs = out["conf"] # (N,)
73
+ cls_ids = out["cls"] # (N,)
74
+ class_names = out["names"] # {id: name}
75
+
76
+ # Count per class using model-provided names
77
+ counts = {}
78
+ for cid in cls_ids:
79
+ cname = class_names.get(int(cid), str(int(cid)))
80
+ counts[cname] = counts.get(cname, 0) + 1
81
+
82
+ # Make an annotated image (RGB in, RGB out)
83
+ annotated_img = draw_detections(image.copy(), boxes, confs, cls_ids, class_names)
84
+
85
+ # Pretty summary text (sorted by count desc)
86
+ if counts:
87
+ summary_lines = [f"{k}: {v}" for k, v in sorted(counts.items(), key=lambda kv: -kv[1])]
88
+ summary = "\n".join(summary_lines)
89
+ else:
90
+ summary = "No weld types found."
91
+
92
+ return annotated_img, summary
93
+
94
+ except Exception as e:
95
+ print("Error in detect_weld_types:", repr(e))
96
+ return image, "Detection error occurred."
97
 
98
 
99
  app = gr.Interface(
100
+ fn=DetectWelds().weld_detection,
101
+ inputs=[
102
+ gr.Image(type="numpy", label="Upload a welding diagram"),
103
+ gr.Radio(['tiling', 'no tiling'], label="Tiling Option", value="tiling")
104
+ ],
105
  outputs=[
106
  gr.Image(type="numpy", label="Detected Welds"),
107
  gr.Textbox(label="Weld Types Found")
weld_tiling.py CHANGED
@@ -173,7 +173,7 @@ def detect_tiled_softnms(
173
  # ✅ Ultralytics now requires imgsz as int or [h,w]
174
  resolved_imgsz = tile_size if imgsz is None else imgsz
175
 
176
- results = model.predict(
177
  source=tile,
178
  conf=per_tile_conf,
179
  iou=per_tile_iou,
 
173
  # ✅ Ultralytics now requires imgsz as int or [h,w]
174
  resolved_imgsz = tile_size if imgsz is None else imgsz
175
 
176
+ results = model(
177
  source=tile,
178
  conf=per_tile_conf,
179
  iou=per_tile_iou,