cv_project_2 / streamlit_app.py
1javid's picture
Upload 4 files
b6066e7 verified
import io
import json
import cv2
import numpy as np
import pandas as pd
import streamlit as st
from depth_estimation import (
depth_to_heatmap,
load_midas,
midas_depth,
sgbm_depth,
)
from object_distance import (
compute_evaluation_metrics,
draw_detections,
estimate_distances,
estimate_focal_length,
load_yolo,
run_yolo,
)
st.set_page_config(page_title="CV Task Playground", layout="wide")
MIDAS_MODELS = ["MiDaS_small", "DPT_Hybrid", "DPT_Large", "MiDaS"]
YOLO_MODELS = ["yolov5n", "yolov5s", "yolov5m", "yolov5l", "yolov5x"]
@st.cache_resource(show_spinner=False)
def get_midas_bundle(model_type: str):
return load_midas(model_type)
@st.cache_resource(show_spinner=False)
def get_yolo_model(model_name: str, conf_thresh: float, iou_thresh: float):
return load_yolo(model_name, conf_thresh=conf_thresh, iou_thresh=iou_thresh)
def decode_uploaded_image(uploaded_file) -> np.ndarray:
data = np.frombuffer(uploaded_file.read(), dtype=np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR)
if img is None:
raise ValueError("Could not decode the uploaded image.")
return img
def bgr_to_rgb(img: np.ndarray) -> np.ndarray:
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def image_download_bytes(img: np.ndarray) -> bytes:
ok, encoded = cv2.imencode(".png", img)
if not ok:
raise ValueError("Could not encode image for download.")
return encoded.tobytes()
def detections_to_dataframe(detections: list[dict]) -> pd.DataFrame:
rows = []
for det in sorted(detections, key=lambda d: d["distance"] if d.get("distance") is not None else 1e9):
rows.append({
"label": det["label"],
"confidence": round(det["conf"], 4),
"pixel_height": det.get("pixel_height"),
"known_height_m": det.get("known_height_m"),
"bbox_depth_median": det.get("bbox_depth_median"),
"dist_pinhole_m": det.get("dist_pinhole"),
"dist_midas_m": det.get("dist_midas"),
"final_distance_m": det.get("distance"),
"method": det.get("method"),
})
return pd.DataFrame(rows)
st.title("Computer Vision Task Playground")
st.write("Upload an image, switch between the two tasks, and tune the main hyperparameters interactively.")
with st.sidebar:
st.header("Controls")
task = st.radio("Task", ["Depth Estimation", "Object Distance"], index=0)
uploaded_file = st.file_uploader(
"Upload an image",
type=["png", "jpg", "jpeg", "bmp", "webp"],
)
if uploaded_file is None:
st.info("Upload an image to begin.")
st.stop()
try:
img = decode_uploaded_image(uploaded_file)
except Exception as exc:
st.error(str(exc))
st.stop()
left_col, right_col = st.columns([1, 1])
with left_col:
st.subheader("Uploaded Image")
st.image(bgr_to_rgb(img), use_container_width=True)
if task == "Depth Estimation":
with st.sidebar:
st.subheader("Depth Parameters")
baseline_shift_pct = st.slider("Stereo baseline shift (%)", 1, 12, 3) / 100.0
block_size = st.slider("SGBM block size", 3, 15, 7, step=2)
uniqueness_ratio = st.slider("SGBM uniqueness ratio", 1, 25, 10)
speckle_window_size = st.slider("SGBM speckle window", 0, 200, 100)
speckle_range = st.slider("SGBM speckle range", 0, 10, 2)
midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0)
run_depth = st.button("Run Depth Estimation", type="primary")
if run_depth:
with st.spinner("Running depth estimation..."):
try:
depth_cl, left_img, right_img = sgbm_depth(
img,
baseline_shift_pct=baseline_shift_pct,
block_size=block_size,
uniqueness_ratio=uniqueness_ratio,
speckle_window_size=speckle_window_size,
speckle_range=speckle_range,
)
midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type)
depth_ml = midas_depth(img, midas_model, midas_transform, midas_device)
classical_heatmap = depth_to_heatmap(depth_cl)
midas_heatmap = depth_to_heatmap(depth_ml)
except Exception as exc:
st.error(f"Depth estimation failed: {exc}")
st.stop()
with right_col:
st.subheader("Run Summary")
st.json({
"midas_model": midas_model_type,
"baseline_shift_pct": baseline_shift_pct,
"block_size": block_size,
"uniqueness_ratio": uniqueness_ratio,
"speckle_window_size": speckle_window_size,
"speckle_range": speckle_range,
"classical_mean_depth": float(depth_cl.mean()),
"midas_mean_depth": float(depth_ml.mean()),
})
c1, c2 = st.columns(2)
with c1:
st.subheader("Classical Stereo Pair")
st.image(bgr_to_rgb(left_img), caption="Left view", use_container_width=True)
st.image(bgr_to_rgb(right_img), caption="Synthetic right view", use_container_width=True)
with c2:
st.subheader("Depth Heatmaps")
st.image(bgr_to_rgb(classical_heatmap), caption="Classical SGBM", use_container_width=True)
st.image(bgr_to_rgb(midas_heatmap), caption=f"MiDaS ({midas_model_type})", use_container_width=True)
dl1, dl2 = st.columns(2)
with dl1:
st.download_button(
"Download classical heatmap",
data=image_download_bytes(classical_heatmap),
file_name="classical_heatmap.png",
mime="image/png",
)
with dl2:
st.download_button(
"Download MiDaS heatmap",
data=image_download_bytes(midas_heatmap),
file_name="midas_heatmap.png",
mime="image/png",
)
else:
with st.sidebar:
st.subheader("Detection Parameters")
yolo_model_name = st.selectbox("YOLO model", YOLO_MODELS, index=1)
conf_thresh = st.slider("Confidence threshold", 0.05, 0.95, 0.35, step=0.05)
iou_thresh = st.slider("NMS IoU threshold", 0.10, 0.95, 0.45, step=0.05)
midas_model_type = st.selectbox("MiDaS model", MIDAS_MODELS, index=0)
focal_mode = st.radio("Focal length mode", ["Estimate from FOV", "Manual pixels"], index=0)
if focal_mode == "Estimate from FOV":
fov_deg = st.slider("Horizontal FOV (deg)", 30, 120, 60)
focal_length = estimate_focal_length(img.shape[1], fov_deg=fov_deg)
else:
focal_length = st.number_input("Focal length (px)", min_value=50.0, value=800.0, step=10.0)
depth_inner_ratio = st.slider("Depth sampling inner box", 0.10, 1.00, 0.60, step=0.05)
min_depth_value = st.slider("Minimum valid MiDaS depth", 0.0, 0.2, 0.02, step=0.01)
blend_weight_pinhole = st.slider("Blend weight: pinhole", 0.0, 1.0, 0.55, step=0.05)
run_detection = st.button("Run Object Distance", type="primary")
if run_detection:
with st.spinner("Running detection and distance estimation..."):
try:
yolo_model = get_yolo_model(yolo_model_name, conf_thresh, iou_thresh)
yolo_model.conf = conf_thresh
yolo_model.iou = iou_thresh
detections = run_yolo(yolo_model, img, conf_thresh=conf_thresh)
if not detections:
st.warning("No objects detected with the current settings.")
st.stop()
midas_model, midas_transform, midas_device = get_midas_bundle(midas_model_type)
depth_map = midas_depth(img, midas_model, midas_transform, midas_device)
detections, eval_context = estimate_distances(
detections,
depth_map,
focal_length=float(focal_length),
inner_ratio=depth_inner_ratio,
min_depth_value=min_depth_value,
blend_weight_pinhole=blend_weight_pinhole,
)
metrics = compute_evaluation_metrics(detections, float(focal_length), eval_context)
annotated = draw_detections(img, detections)
depth_heatmap = depth_to_heatmap(depth_map)
det_df = detections_to_dataframe(detections)
except Exception as exc:
st.error(f"Object-distance pipeline failed: {exc}")
st.stop()
with right_col:
st.subheader("Run Summary")
st.json({
"yolo_model": yolo_model_name,
"midas_model": midas_model_type,
"focal_length_px": float(focal_length),
"confidence_threshold": conf_thresh,
"iou_threshold": iou_thresh,
"depth_inner_ratio": depth_inner_ratio,
"min_depth_value": min_depth_value,
"blend_weight_pinhole": blend_weight_pinhole,
"detections": len(detections),
})
c1, c2 = st.columns(2)
with c1:
st.subheader("Annotated Output")
st.image(bgr_to_rgb(annotated), use_container_width=True)
with c2:
st.subheader("MiDaS Depth")
st.image(bgr_to_rgb(depth_heatmap), use_container_width=True)
st.subheader("Detected Objects")
st.dataframe(det_df, use_container_width=True)
st.subheader("Evaluation Metrics")
st.json(metrics)
csv_bytes = det_df.to_csv(index=False).encode("utf-8")
metrics_bytes = json.dumps(metrics, indent=2).encode("utf-8")
d1, d2, d3 = st.columns(3)
with d1:
st.download_button(
"Download annotated image",
data=image_download_bytes(annotated),
file_name="detections_with_distance.png",
mime="image/png",
)
with d2:
st.download_button(
"Download detections CSV",
data=csv_bytes,
file_name="detection_distances.csv",
mime="text/csv",
)
with d3:
st.download_button(
"Download metrics JSON",
data=metrics_bytes,
file_name="metrics.json",
mime="application/json",
)