import io import json import cv2 import numpy as np import pandas as pd import streamlit as st from depth_estimation import ( depth_to_heatmap, load_midas, midas_depth, sgbm_depth, ) from object_distance import ( compute_evaluation_metrics, draw_detections, estimate_distances, estimate_focal_length, load_yolo, run_yolo, ) st.set_page_config(page_title="CV Task Playground", layout="wide") MIDAS_MODELS = ["MiDaS_small", "DPT_Hybrid", "DPT_Large", "MiDaS"] YOLO_MODELS = ["yolov5n", "yolov5s", "yolov5m", "yolov5l", "yolov5x"] @st.cache_resource(show_spinner=False) def get_midas_bundle(model_type: str): return load_midas(model_type) @st.cache_resource(show_spinner=False) def get_yolo_model(model_name: str, conf_thresh: float, iou_thresh: float): return load_yolo(model_name, conf_thresh=conf_thresh, iou_thresh=iou_thresh) def decode_uploaded_image(uploaded_file) -> np.ndarray: data = np.frombuffer(uploaded_file.read(), dtype=np.uint8) img = cv2.imdecode(data, cv2.IMREAD_COLOR) if img is None: raise ValueError("Could not decode the uploaded image.") return img def bgr_to_rgb(img: np.ndarray) -> np.ndarray: return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) def image_download_bytes(img: np.ndarray) -> bytes: ok, encoded = cv2.imencode(".png", img) if not ok: raise ValueError("Could not encode image for download.") return encoded.tobytes() def detections_to_dataframe(detections: list[dict]) -> pd.DataFrame: rows = [] for det in sorted(detections, key=lambda d: d["distance"] if d.get("distance") is not None else 1e9): rows.append({ "label": det["label"], "confidence": round(det["conf"], 4), "pixel_height": det.get("pixel_height"), "known_height_m": det.get("known_height_m"), "bbox_depth_median": det.get("bbox_depth_median"), "dist_pinhole_m": det.get("dist_pinhole"), "dist_midas_m": det.get("dist_midas"), "final_distance_m": det.get("distance"), "method": det.get("method"), }) return pd.DataFrame(rows) st.title("Computer Vision Task Playground") st.write("Upload an image, switch between the two tasks, and tune the main hyperparameters interactively.") with st.sidebar: st.header("Controls") task = st.radio("Task", ["Depth Estimation", "Object Distance"], index=0) uploaded_file = st.file_uploader( "Upload an image", type=["png", "jpg", "jpeg", "bmp", "webp"], ) if uploaded_file is None: st.info("Upload an image to begin.") st.stop() try: img = decode_uploaded_image(uploaded_file) except Exception as exc: st.error(str(exc)) st.stop() left_col, right_col = st.columns([1, 1]) with left_col: st.subheader("Uploaded Image") st.image(bgr_to_rgb(img), use_container_width=True) if task == "Depth Estimation": with st.sidebar: st.subheader("Depth Parameters") baseline_shift_pct = st.slider("Stereo baseline shift (%)", 1, 12, 3) / 100.0 block_size = st.slider("SGBM block size", 3, 15, 7, step=2) uniqueness_ratio = st.slider("SGBM uniqueness ratio", 1, 25, 10) speckle_window_size = st.slider("SGBM speckle window", 0, 200, 100) speckle_range = st.slider("SGBM speckle range", 0, 10, 2) midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0) run_depth = st.button("Run Depth Estimation", type="primary") if run_depth: with st.spinner("Running depth estimation..."): try: depth_cl, left_img, right_img = sgbm_depth( img, baseline_shift_pct=baseline_shift_pct, block_size=block_size, uniqueness_ratio=uniqueness_ratio, speckle_window_size=speckle_window_size, speckle_range=speckle_range, ) midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type) depth_ml = midas_depth(img, midas_model, midas_transform, midas_device) classical_heatmap = depth_to_heatmap(depth_cl) midas_heatmap = depth_to_heatmap(depth_ml) except Exception as exc: st.error(f"Depth estimation failed: {exc}") st.stop() with right_col: st.subheader("Run Summary") st.json({ "midas_model": midas_model_type, "baseline_shift_pct": baseline_shift_pct, "block_size": block_size, "uniqueness_ratio": uniqueness_ratio, "speckle_window_size": speckle_window_size, "speckle_range": speckle_range, "classical_mean_depth": float(depth_cl.mean()), "midas_mean_depth": float(depth_ml.mean()), }) c1, c2 = st.columns(2) with c1: st.subheader("Classical Stereo Pair") st.image(bgr_to_rgb(left_img), caption="Left view", use_container_width=True) st.image(bgr_to_rgb(right_img), caption="Synthetic right view", use_container_width=True) with c2: st.subheader("Depth Heatmaps") st.image(bgr_to_rgb(classical_heatmap), caption="Classical SGBM", use_container_width=True) st.image(bgr_to_rgb(midas_heatmap), caption=f"MiDaS ({midas_model_type})", use_container_width=True) dl1, dl2 = st.columns(2) with dl1: st.download_button( "Download classical heatmap", data=image_download_bytes(classical_heatmap), file_name="classical_heatmap.png", mime="image/png", ) with dl2: st.download_button( "Download MiDaS heatmap", data=image_download_bytes(midas_heatmap), file_name="midas_heatmap.png", mime="image/png", ) else: with st.sidebar: st.subheader("Detection Parameters") yolo_model_name = st.selectbox("YOLO model", YOLO_MODELS, index=1) conf_thresh = st.slider("Confidence threshold", 0.05, 0.95, 0.35, step=0.05) iou_thresh = st.slider("NMS IoU threshold", 0.10, 0.95, 0.45, step=0.05) midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0) focal_mode = st.radio("Focal length mode", ["Estimate from FOV", "Manual pixels"], index=0) if focal_mode == "Estimate from FOV": fov_deg = st.slider("Horizontal FOV (deg)", 30, 120, 60) focal_length = estimate_focal_length(img.shape[1], fov_deg=fov_deg) else: focal_length = st.number_input("Focal length (px)", min_value=50.0, value=800.0, step=10.0) depth_inner_ratio = st.slider("Depth sampling inner box", 0.10, 1.00, 0.60, step=0.05) min_depth_value = st.slider("Minimum valid MiDaS depth", 0.0, 0.2, 0.02, step=0.01) blend_weight_pinhole = st.slider("Blend weight: pinhole", 0.0, 1.0, 0.55, step=0.05) run_detection = st.button("Run Object Distance", type="primary") if run_detection: with st.spinner("Running detection and distance estimation..."): try: yolo_model = get_yolo_model(yolo_model_name, conf_thresh, iou_thresh) yolo_model.conf = conf_thresh yolo_model.iou = iou_thresh detections = run_yolo(yolo_model, img, conf_thresh=conf_thresh) if not detections: st.warning("No objects detected with the current settings.") st.stop() midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type) depth_map = midas_depth(img, midas_model, midas_transform, midas_device) detections, eval_context = estimate_distances( detections, depth_map, focal_length=float(focal_length), inner_ratio=depth_inner_ratio, min_depth_value=min_depth_value, blend_weight_pinhole=blend_weight_pinhole, ) metrics = compute_evaluation_metrics(detections, float(focal_length), eval_context) annotated = draw_detections(img, detections) depth_heatmap = depth_to_heatmap(depth_map) det_df = detections_to_dataframe(detections) except Exception as exc: st.error(f"Object-distance pipeline failed: {exc}") st.stop() with right_col: st.subheader("Run Summary") st.json({ "yolo_model": yolo_model_name, "midas_model": midas_model_type, "focal_length_px": float(focal_length), "confidence_threshold": conf_thresh, "iou_threshold": iou_thresh, "depth_inner_ratio": depth_inner_ratio, "min_depth_value": min_depth_value, "blend_weight_pinhole": blend_weight_pinhole, "detections": len(detections), }) c1, c2 = st.columns(2) with c1: st.subheader("Annotated Output") st.image(bgr_to_rgb(annotated), use_container_width=True) with c2: st.subheader("MiDaS Depth") st.image(bgr_to_rgb(depth_heatmap), use_container_width=True) st.subheader("Detected Objects") st.dataframe(det_df, use_container_width=True) st.subheader("Evaluation Metrics") st.json(metrics) csv_bytes = det_df.to_csv(index=False).encode("utf-8") metrics_bytes = json.dumps(metrics, indent=2).encode("utf-8") d1, d2, d3 = st.columns(3) with d1: st.download_button( "Download annotated image", data=image_download_bytes(annotated), file_name="detections_with_distance.png", mime="image/png", ) with d2: st.download_button( "Download detections CSV", data=csv_bytes, file_name="detection_distances.csv", mime="text/csv", ) with d3: st.download_button( "Download metrics JSON", data=metrics_bytes, file_name="metrics.json", mime="application/json", )