Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from depth_estimation import ( | |
| depth_to_heatmap, | |
| load_midas, | |
| midas_depth, | |
| sgbm_depth, | |
| ) | |
| from object_distance import ( | |
| compute_evaluation_metrics, | |
| draw_detections, | |
| estimate_distances, | |
| estimate_focal_length, | |
| load_yolo, | |
| run_yolo, | |
| ) | |
| st.set_page_config(page_title="CV Task Playground", layout="wide") | |
| MIDAS_MODELS = ["MiDaS_small", "DPT_Hybrid", "DPT_Large", "MiDaS"] | |
| YOLO_MODELS = ["yolov5n", "yolov5s", "yolov5m", "yolov5l", "yolov5x"] | |
| def get_midas_bundle(model_type: str): | |
| return load_midas(model_type) | |
| def get_yolo_model(model_name: str, conf_thresh: float, iou_thresh: float): | |
| return load_yolo(model_name, conf_thresh=conf_thresh, iou_thresh=iou_thresh) | |
| def decode_uploaded_image(uploaded_file) -> np.ndarray: | |
| data = np.frombuffer(uploaded_file.read(), dtype=np.uint8) | |
| img = cv2.imdecode(data, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise ValueError("Could not decode the uploaded image.") | |
| return img | |
| def bgr_to_rgb(img: np.ndarray) -> np.ndarray: | |
| return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| def image_download_bytes(img: np.ndarray) -> bytes: | |
| ok, encoded = cv2.imencode(".png", img) | |
| if not ok: | |
| raise ValueError("Could not encode image for download.") | |
| return encoded.tobytes() | |
| def detections_to_dataframe(detections: list[dict]) -> pd.DataFrame: | |
| rows = [] | |
| for det in sorted(detections, key=lambda d: d["distance"] if d.get("distance") is not None else 1e9): | |
| rows.append({ | |
| "label": det["label"], | |
| "confidence": round(det["conf"], 4), | |
| "pixel_height": det.get("pixel_height"), | |
| "known_height_m": det.get("known_height_m"), | |
| "bbox_depth_median": det.get("bbox_depth_median"), | |
| "dist_pinhole_m": det.get("dist_pinhole"), | |
| "dist_midas_m": det.get("dist_midas"), | |
| "final_distance_m": det.get("distance"), | |
| "method": det.get("method"), | |
| }) | |
| return pd.DataFrame(rows) | |
| st.title("Computer Vision Task Playground") | |
| st.write("Upload an image, switch between the two tasks, and tune the main hyperparameters interactively.") | |
| with st.sidebar: | |
| st.header("Controls") | |
| task = st.radio("Task", ["Depth Estimation", "Object Distance"], index=0) | |
| uploaded_file = st.file_uploader( | |
| "Upload an image", | |
| type=["png", "jpg", "jpeg", "bmp", "webp"], | |
| ) | |
| if uploaded_file is None: | |
| st.info("Upload an image to begin.") | |
| st.stop() | |
| try: | |
| img = decode_uploaded_image(uploaded_file) | |
| except Exception as exc: | |
| st.error(str(exc)) | |
| st.stop() | |
| left_col, right_col = st.columns([1, 1]) | |
| with left_col: | |
| st.subheader("Uploaded Image") | |
| st.image(bgr_to_rgb(img), use_container_width=True) | |
| if task == "Depth Estimation": | |
| with st.sidebar: | |
| st.subheader("Depth Parameters") | |
| baseline_shift_pct = st.slider("Stereo baseline shift (%)", 1, 12, 3) / 100.0 | |
| block_size = st.slider("SGBM block size", 3, 15, 7, step=2) | |
| uniqueness_ratio = st.slider("SGBM uniqueness ratio", 1, 25, 10) | |
| speckle_window_size = st.slider("SGBM speckle window", 0, 200, 100) | |
| speckle_range = st.slider("SGBM speckle range", 0, 10, 2) | |
| midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0) | |
| run_depth = st.button("Run Depth Estimation", type="primary") | |
| if run_depth: | |
| with st.spinner("Running depth estimation..."): | |
| try: | |
| depth_cl, left_img, right_img = sgbm_depth( | |
| img, | |
| baseline_shift_pct=baseline_shift_pct, | |
| block_size=block_size, | |
| uniqueness_ratio=uniqueness_ratio, | |
| speckle_window_size=speckle_window_size, | |
| speckle_range=speckle_range, | |
| ) | |
| midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type) | |
| depth_ml = midas_depth(img, midas_model, midas_transform, midas_device) | |
| classical_heatmap = depth_to_heatmap(depth_cl) | |
| midas_heatmap = depth_to_heatmap(depth_ml) | |
| except Exception as exc: | |
| st.error(f"Depth estimation failed: {exc}") | |
| st.stop() | |
| with right_col: | |
| st.subheader("Run Summary") | |
| st.json({ | |
| "midas_model": midas_model_type, | |
| "baseline_shift_pct": baseline_shift_pct, | |
| "block_size": block_size, | |
| "uniqueness_ratio": uniqueness_ratio, | |
| "speckle_window_size": speckle_window_size, | |
| "speckle_range": speckle_range, | |
| "classical_mean_depth": float(depth_cl.mean()), | |
| "midas_mean_depth": float(depth_ml.mean()), | |
| }) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.subheader("Classical Stereo Pair") | |
| st.image(bgr_to_rgb(left_img), caption="Left view", use_container_width=True) | |
| st.image(bgr_to_rgb(right_img), caption="Synthetic right view", use_container_width=True) | |
| with c2: | |
| st.subheader("Depth Heatmaps") | |
| st.image(bgr_to_rgb(classical_heatmap), caption="Classical SGBM", use_container_width=True) | |
| st.image(bgr_to_rgb(midas_heatmap), caption=f"MiDaS ({midas_model_type})", use_container_width=True) | |
| dl1, dl2 = st.columns(2) | |
| with dl1: | |
| st.download_button( | |
| "Download classical heatmap", | |
| data=image_download_bytes(classical_heatmap), | |
| file_name="classical_heatmap.png", | |
| mime="image/png", | |
| ) | |
| with dl2: | |
| st.download_button( | |
| "Download MiDaS heatmap", | |
| data=image_download_bytes(midas_heatmap), | |
| file_name="midas_heatmap.png", | |
| mime="image/png", | |
| ) | |
| else: | |
| with st.sidebar: | |
| st.subheader("Detection Parameters") | |
| yolo_model_name = st.selectbox("YOLO model", YOLO_MODELS, index=1) | |
| conf_thresh = st.slider("Confidence threshold", 0.05, 0.95, 0.35, step=0.05) | |
| iou_thresh = st.slider("NMS IoU threshold", 0.10, 0.95, 0.45, step=0.05) | |
| midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0) | |
| focal_mode = st.radio("Focal length mode", ["Estimate from FOV", "Manual pixels"], index=0) | |
| if focal_mode == "Estimate from FOV": | |
| fov_deg = st.slider("Horizontal FOV (deg)", 30, 120, 60) | |
| focal_length = estimate_focal_length(img.shape[1], fov_deg=fov_deg) | |
| else: | |
| focal_length = st.number_input("Focal length (px)", min_value=50.0, value=800.0, step=10.0) | |
| depth_inner_ratio = st.slider("Depth sampling inner box", 0.10, 1.00, 0.60, step=0.05) | |
| min_depth_value = st.slider("Minimum valid MiDaS depth", 0.0, 0.2, 0.02, step=0.01) | |
| blend_weight_pinhole = st.slider("Blend weight: pinhole", 0.0, 1.0, 0.55, step=0.05) | |
| run_detection = st.button("Run Object Distance", type="primary") | |
| if run_detection: | |
| with st.spinner("Running detection and distance estimation..."): | |
| try: | |
| yolo_model = get_yolo_model(yolo_model_name, conf_thresh, iou_thresh) | |
| yolo_model.conf = conf_thresh | |
| yolo_model.iou = iou_thresh | |
| detections = run_yolo(yolo_model, img, conf_thresh=conf_thresh) | |
| if not detections: | |
| st.warning("No objects detected with the current settings.") | |
| st.stop() | |
| midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type) | |
| depth_map = midas_depth(img, midas_model, midas_transform, midas_device) | |
| detections, eval_context = estimate_distances( | |
| detections, | |
| depth_map, | |
| focal_length=float(focal_length), | |
| inner_ratio=depth_inner_ratio, | |
| min_depth_value=min_depth_value, | |
| blend_weight_pinhole=blend_weight_pinhole, | |
| ) | |
| metrics = compute_evaluation_metrics(detections, float(focal_length), eval_context) | |
| annotated = draw_detections(img, detections) | |
| depth_heatmap = depth_to_heatmap(depth_map) | |
| det_df = detections_to_dataframe(detections) | |
| except Exception as exc: | |
| st.error(f"Object-distance pipeline failed: {exc}") | |
| st.stop() | |
| with right_col: | |
| st.subheader("Run Summary") | |
| st.json({ | |
| "yolo_model": yolo_model_name, | |
| "midas_model": midas_model_type, | |
| "focal_length_px": float(focal_length), | |
| "confidence_threshold": conf_thresh, | |
| "iou_threshold": iou_thresh, | |
| "depth_inner_ratio": depth_inner_ratio, | |
| "min_depth_value": min_depth_value, | |
| "blend_weight_pinhole": blend_weight_pinhole, | |
| "detections": len(detections), | |
| }) | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| st.subheader("Annotated Output") | |
| st.image(bgr_to_rgb(annotated), use_container_width=True) | |
| with c2: | |
| st.subheader("MiDaS Depth") | |
| st.image(bgr_to_rgb(depth_heatmap), use_container_width=True) | |
| st.subheader("Detected Objects") | |
| st.dataframe(det_df, use_container_width=True) | |
| st.subheader("Evaluation Metrics") | |
| st.json(metrics) | |
| csv_bytes = det_df.to_csv(index=False).encode("utf-8") | |
| metrics_bytes = json.dumps(metrics, indent=2).encode("utf-8") | |
| d1, d2, d3 = st.columns(3) | |
| with d1: | |
| st.download_button( | |
| "Download annotated image", | |
| data=image_download_bytes(annotated), | |
| file_name="detections_with_distance.png", | |
| mime="image/png", | |
| ) | |
| with d2: | |
| st.download_button( | |
| "Download detections CSV", | |
| data=csv_bytes, | |
| file_name="detection_distances.csv", | |
| mime="text/csv", | |
| ) | |
| with d3: | |
| st.download_button( | |
| "Download metrics JSON", | |
| data=metrics_bytes, | |
| file_name="metrics.json", | |
| mime="application/json", | |
| ) | |