Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import onnxruntime as ort | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import time, os | |
| import pyttsx3 # for optional voice alerts | |
| # --------------------------- | |
| # CONFIG | |
| # --------------------------- | |
| MODEL_PATH = "model.onnx" | |
| INPUT_SIZE = (640, 640) | |
| CONF_THRESHOLD_DEFAULT = 0.35 | |
| # Initialize voice engine | |
| engine = pyttsx3.init() | |
| engine.setProperty("rate", 180) | |
| # Load model | |
| print(f"Loading ONNX model from: {MODEL_PATH}") | |
| sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"]) | |
| print("\nONNX Model Inputs:") | |
| for i, inp in enumerate(sess.get_inputs()): | |
| print(f" Input[{i}] name={inp.name}, shape={inp.shape}, dtype={inp.type}") | |
| print("\nONNX Model Outputs:") | |
| for i, out in enumerate(sess.get_outputs()): | |
| print(f" Output[{i}] name={out.name}, shape={out.shape}, dtype={out.type}") | |
| # Preprocess | |
| def preprocess_frame(frame_np): | |
| img = Image.fromarray(frame_np.astype("uint8"), "RGB") | |
| img_resized = img.resize(INPUT_SIZE) | |
| arr = np.array(img_resized).astype(np.float32) / 255.0 | |
| arr = np.transpose(arr, (2, 0, 1))[np.newaxis, ...] # NCHW | |
| return arr | |
| # Postprocess | |
| def postprocess_outputs(outputs, orig_w, orig_h, conf_thresh=0.35): | |
| outs = [np.array(o) for o in outputs] | |
| cand = outs[0] | |
| if cand.ndim == 3 and cand.shape[0] == 1: | |
| cand = cand[0] | |
| detections = [] | |
| if cand.ndim == 2 and cand.shape[1] >= 6: | |
| for row in cand: | |
| cx, cy, w, h = row[0], row[1], row[2], row[3] | |
| obj_conf = float(row[4]) | |
| class_probs = row[5:] | |
| best_idx = int(np.argmax(class_probs)) if class_probs.size > 0 else 0 | |
| cls_conf = float(class_probs[best_idx]) if class_probs.size > 0 else 1.0 | |
| score = obj_conf * cls_conf | |
| if score < conf_thresh: | |
| continue | |
| if max(cx, cy, w, h) <= 1.0: | |
| x1 = (cx - w / 2) * orig_w | |
| y1 = (cy - h / 2) * orig_h | |
| x2 = (cx + w / 2) * orig_w | |
| y2 = (cy + h / 2) * orig_h | |
| else: | |
| x1, y1, x2, y2 = cx - w/2, cy - h/2, cx + w/2, cy + h/2 | |
| detections.append({"box": [x1, y1, x2, y2], "score": score, "class": best_idx}) | |
| return detections | |
| # Draw boxes | |
| def draw_boxes_on_image(pil_img, detections): | |
| img = pil_img.convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| font = ImageFont.load_default() | |
| for d in detections: | |
| x1, y1, x2, y2 = d["box"] | |
| label = f"Class {d['class']}" | |
| txt = f"{label} {d['score']:.2f}" | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=3) | |
| draw.text((x1, max(0, y1 - 12)), txt, fill="red", font=font) | |
| return img | |
| # Voice alert | |
| last_spoken = "" | |
| def speak_alert(detections): | |
| global last_spoken | |
| if not detections: | |
| return | |
| labels_detected = [f"class {d['class']}" for d in detections] | |
| msg = ", ".join(set(labels_detected)) | |
| if msg != last_spoken: | |
| engine.say(f"Detected: {msg}") | |
| engine.runAndWait() | |
| last_spoken = msg | |
| # Main function | |
| def predict_live(frame, conf_threshold): | |
| if frame is None: | |
| return None, "No frame" | |
| orig_h, orig_w = frame.shape[:2] | |
| input_tensor = preprocess_frame(frame) | |
| input_name = sess.get_inputs()[0].name | |
| outputs = sess.run(None, {input_name: input_tensor}) | |
| detections = postprocess_outputs(outputs, orig_w, orig_h, conf_thresh=conf_threshold) | |
| pil_img = Image.fromarray(frame.astype("uint8"), "RGB") | |
| out_img = draw_boxes_on_image(pil_img, detections) | |
| speak_alert(detections) | |
| debug_txt = ( | |
| f"Model: {os.path.basename(MODEL_PATH)}\n" | |
| f"Detections: {len(detections)}" | |
| ) | |
| return out_img, debug_txt | |
| # Gradio interface with webcam + slider | |
| iface = gr.Interface( | |
| fn=predict_live, | |
| inputs=[ | |
| gr.Image(sources=["webcam"], type="numpy", label="Live Camera"), | |
| gr.Slider(0.05, 0.9, value=CONF_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold") | |
| ], | |
| outputs=[gr.Image(type="pil"), gr.Textbox(lines=4)], | |
| live=True, | |
| title="ONNX Live Camera Detection", | |
| description="Continuous live detection with bounding boxes + voice alerts" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |