monesh2212 commited on
Commit
a836cf6
·
verified ·
1 Parent(s): 8fff0c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -82
app.py CHANGED
@@ -1,82 +1,189 @@
1
- # app.py
2
- import gradio as gr
3
- import onnxruntime as ort
4
- import numpy as np
5
- from PIL import Image, ImageDraw, ImageFont
6
- import cv2
7
- import io
8
-
9
- # Load labels
10
- with open("labels.txt", "r") as f:
11
- LABELS = [x.strip() for x in f.readlines()]
12
-
13
- # Load ONNX model
14
- sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
15
-
16
- # Helper: simple preprocess - adapt to your model's needs
17
- def preprocess(pil_img, input_size=(640,640)):
18
- img = pil_img.convert("RGB")
19
- img = img.resize(input_size)
20
- np_img = np.array(img).astype(np.float32) / 255.0 # normalize if model expected 0-1
21
- # Change to channels-first if model expects (1,C,H,W)
22
- np_img = np.transpose(np_img, (2,0,1))[np.newaxis, ...]
23
- return np_img
24
-
25
- # Helper: very basic NMS and postprocess - adapt as necessary
26
- def postprocess(outputs, orig_w, orig_h, conf_threshold=0.3, iou_threshold=0.45):
27
- # This section depends on your model outputs.
28
- # Example: suppose outputs[0] -> [N, 6] with (x1,y1,x2,y2,score,class)
29
- preds = outputs[0]
30
- boxes = []
31
- for row in preds:
32
- x1,y1,x2,y2,score,cls = row
33
- if score < conf_threshold:
34
- continue
35
- # scale coords back to original image size if your model used 640x640
36
- boxes.append({
37
- "box": [x1*orig_w, y1*orig_h, x2*orig_w, y2*orig_h],
38
- "score": float(score),
39
- "class": int(cls)
40
- })
41
- # (Optional) Apply NMS here if model doesn't already do it
42
- return boxes
43
-
44
- def draw_boxes(pil_img, boxes):
45
- img = pil_img.convert("RGB")
46
- draw = ImageDraw.Draw(img)
47
- for b in boxes:
48
- x1,y1,x2,y2 = b["box"]
49
- label = LABELS[b["class"]] if 0 <= b["class"] < len(LABELS) else str(b["class"])
50
- draw.rectangle([x1,y1,x2,y2], outline="red", width=3)
51
- draw.text((x1, y1-10), f"{label} {b['score']:.2f}", fill="red")
52
- return img
53
-
54
- def predict(image):
55
- if image is None:
56
- return None, "No image"
57
- pil = Image.fromarray(image.astype('uint8')) if isinstance(image, np.ndarray) else Image.open(io.BytesIO(image.read()))
58
- orig_w, orig_h = pil.size
59
- input_tensor = preprocess(pil) # adapt input_size if needed
60
-
61
- # Run ONNX
62
- input_name = sess.get_inputs()[0].name
63
- outputs = sess.run(None, {input_name: input_tensor})
64
- # Postprocess according to your model's output structure
65
- boxes = postprocess(outputs, orig_w, orig_h)
66
- out_img = draw_boxes(pil, boxes)
67
- txt = "\n".join([f"{LABELS[b['class']]}: {b['score']:.2f}" for b in boxes]) if boxes else "No detections"
68
- return out_img, txt
69
-
70
- # Gradio UI
71
- title = "ONNX Demo"
72
- desc = "Upload an image or use webcam. Adapt preprocessing/postprocessing per your model."
73
-
74
- iface = gr.Interface(fn=predict,
75
- inputs=gr.Image(source="upload", tool="editor", type="numpy"),
76
- outputs=[gr.Image(type="pil"), gr.Textbox()],
77
- title=title,
78
- description=desc,
79
- examples=None)
80
-
81
- if __name__ == "__main__":
82
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import time, os
7
+
8
+ # ---------------------------
9
+ # CONFIG
10
+ # ---------------------------
11
+ MODEL_PATH = "model.onnx" # Ensure model.onnx is in repo
12
+ LABELS_PATH = "labels.txt" # Optional: one label per line
13
+ CONF_THRESHOLD = 0.35
14
+ PREVIEW_INPUT_SIZE = (640, 640) # Change if model expects different input size
15
+
16
+ # ---------------------------
17
+ # LOAD LABELS
18
+ # ---------------------------
19
+ if os.path.exists(LABELS_PATH):
20
+ with open(LABELS_PATH, "r") as f:
21
+ LABELS = [l.strip() for l in f.readlines() if l.strip()]
22
+ else:
23
+ LABELS = None
24
+
25
+ # ---------------------------
26
+ # LOAD MODEL
27
+ # ---------------------------
28
+ print(f"Loading ONNX model from: {MODEL_PATH}")
29
+ sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
30
+
31
+ print("\nONNX Model Inputs:")
32
+ for i, inp in enumerate(sess.get_inputs()):
33
+ print(f" input[{i}] name={inp.name}, shape={inp.shape}, dtype={inp.type}")
34
+
35
+ print("\nONNX Model Outputs:")
36
+ for i, out in enumerate(sess.get_outputs()):
37
+ print(f" output[{i}] name={out.name}, shape={out.shape}, dtype={out.type}")
38
+
39
+ # ---------------------------
40
+ # PREPROCESS FUNCTION
41
+ # ---------------------------
42
+ def preprocess_frame(frame_np, input_size=PREVIEW_INPUT_SIZE):
43
+ img = Image.fromarray(frame_np.astype("uint8"), "RGB")
44
+ img_resized = img.resize(input_size)
45
+ arr = np.array(img_resized).astype(np.float32) / 255.0 # normalize 0..1
46
+ arr = np.transpose(arr, (2, 0, 1))[np.newaxis, ...] # to NCHW
47
+ return arr
48
+
49
+ # ---------------------------
50
+ # POSTPROCESS FUNCTION (FIXED)
51
+ # ---------------------------
52
+ def postprocess_outputs(outputs, orig_w, orig_h, conf_thresh=0.35, debug=False):
53
+ outs = [o if isinstance(o, np.ndarray) else np.array(o) for o in outputs]
54
+ if len(outs) == 0:
55
+ return []
56
+
57
+ cand = None
58
+ for o in outs:
59
+ if o.ndim >= 2 and o.shape[-1] >= 4:
60
+ cand = o
61
+ break
62
+ if cand is None:
63
+ cand = outs[0]
64
+
65
+ if cand.ndim == 3 and cand.shape[0] == 1:
66
+ cand = cand[0]
67
+
68
+ detections = []
69
+
70
+ if debug:
71
+ print("Raw chosen output shape:", cand.shape)
72
+ try:
73
+ print("Sample rows:", cand.reshape(-1, cand.shape[-1])[:5])
74
+ except Exception:
75
+ pass
76
+
77
+ # Case 1: Nx6
78
+ if cand.ndim == 2 and cand.shape[1] == 6:
79
+ for r in cand:
80
+ x1, y1, x2, y2, score, cls = r
81
+ if score < conf_thresh:
82
+ continue
83
+ if max(x1, y1, x2, y2) <= 1.0:
84
+ x1, y1, x2, y2 = x1*orig_w, y1*orig_h, x2*orig_w, y2*orig_h
85
+ detections.append({"box": [x1, y1, x2, y2], "score": float(score), "class": int(cls)})
86
+ return detections
87
+
88
+ # Case 2: YOLO-style Nx(5+num_classes)
89
+ if cand.ndim == 2 and cand.shape[1] >= 6:
90
+ for r in cand:
91
+ cx, cy, w, h = r[0], r[1], r[2], r[3]
92
+ obj_conf = float(r[4])
93
+ class_probs = r[5:]
94
+ best_idx = int(np.argmax(class_probs)) if class_probs.size > 0 else 0
95
+ cls_conf = float(class_probs[best_idx]) if class_probs.size > 0 else 1.0
96
+ score = obj_conf * cls_conf
97
+ if score < conf_thresh:
98
+ continue
99
+ if max(cx, cy, w, h) <= 1.0:
100
+ x1 = (cx - w/2) * orig_w
101
+ y1 = (cy - h/2) * orig_h
102
+ x2 = (cx + w/2) * orig_w
103
+ y2 = (cy + h/2) * orig_h
104
+ else:
105
+ x1, y1, x2, y2 = cx - w/2, cy - h/2, cx + w/2, cy + h/2
106
+ detections.append({"box": [x1, y1, x2, y2], "score": score, "class": best_idx})
107
+ return detections
108
+
109
+ # Case 3: Separate outputs (boxes, scores, labels)
110
+ if len(outs) >= 3:
111
+ boxes_arr = next((o for o in outs if o.ndim == 2 and o.shape[1] == 4), None)
112
+ scores_arr = next((o for o in outs if o.ndim <= 2 and o.size == boxes_arr.shape[0]), None) if boxes_arr is not None else None
113
+ labels_arr = next((o for o in outs if o.ndim <= 2 and o.size == boxes_arr.shape[0]), None) if boxes_arr is not None else None
114
+ if boxes_arr is not None:
115
+ for i, bx in enumerate(boxes_arr):
116
+ score = float(scores_arr[i]) if scores_arr is not None else 1.0
117
+ if score < conf_thresh:
118
+ continue
119
+ if max(bx) <= 1.0:
120
+ x1, y1, x2, y2 = bx[0]*orig_w, bx[1]*orig_h, bx[2]*orig_w, bx[3]*orig_h
121
+ else:
122
+ x1, y1, x2, y2 = bx
123
+ detections.append({"box": [x1, y1, x2, y2], "score": score, "class": int(labels_arr[i]) if labels_arr is not None else 0})
124
+ return detections
125
+
126
+ if debug:
127
+ print("Could not parse model outputs automatically.")
128
+ return detections
129
+
130
+ # ---------------------------
131
+ # DRAW BOXES ON IMAGE
132
+ # ---------------------------
133
+ def draw_boxes_on_image(pil_img, detections):
134
+ img = pil_img.convert("RGB")
135
+ draw = ImageDraw.Draw(img)
136
+ font = ImageFont.load_default()
137
+ for d in detections:
138
+ x1, y1, x2, y2 = d["box"]
139
+ label = str(d["class"])
140
+ if LABELS and 0 <= d["class"] < len(LABELS):
141
+ label = LABELS[d["class"]]
142
+ txt = f"{label} {d['score']:.2f}"
143
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
144
+ draw.text((x1, max(0, y1 - 12)), txt, fill="red", font=font)
145
+ return img
146
+
147
+ # ---------------------------
148
+ # MAIN PREDICT FUNCTION
149
+ # ---------------------------
150
+ def predict_live(frame):
151
+ if frame is None:
152
+ return None, "No frame"
153
+ t0 = time.time()
154
+ orig_h, orig_w = frame.shape[0], frame.shape[1]
155
+ input_tensor = preprocess_frame(frame, PREVIEW_INPUT_SIZE)
156
+ input_name = sess.get_inputs()[0].name
157
+ try:
158
+ outputs = sess.run(None, {input_name: input_tensor})
159
+ except Exception as e:
160
+ return None, f"ONNX runtime error: {e}"
161
+
162
+ detections = postprocess_outputs(outputs, orig_w, orig_h, conf_thresh=CONF_THRESHOLD, debug=True)
163
+ pil_img = Image.fromarray(frame.astype("uint8"), "RGB")
164
+ out_img = draw_boxes_on_image(pil_img, detections)
165
+
166
+ t1 = time.time()
167
+ debug_txt = (
168
+ f"Model: {os.path.basename(MODEL_PATH)}\n"
169
+ f"Input shape: {sess.get_inputs()[0].shape}\n"
170
+ f"Output(s): {[o.shape for o in sess.get_outputs()]}\n"
171
+ f"Detections: {len(detections)}\n"
172
+ f"Inference time: {(t1 - t0)*1000:.1f} ms"
173
+ )
174
+ return out_img, debug_txt
175
+
176
+ # ---------------------------
177
+ # GRADIO INTERFACE
178
+ # ---------------------------
179
+ iface = gr.Interface(
180
+ fn=predict_live,
181
+ inputs=gr.Image(source="webcam", type="numpy"),
182
+ outputs=[gr.Image(type="pil"), gr.Textbox(lines=6)],
183
+ live=True,
184
+ title="ONNX Live Detection",
185
+ description="Real-time detection using your ONNX model. Adjust CONF_THRESHOLD or input size if needed."
186
+ )
187
+
188
+ if __name__ == "__main__":
189
+ iface.launch(server_name="0.0.0.0", server_port=7860)