cv_project_2 / depth_estimation.py
1javid's picture
Upload 3 files
f0640c4 verified
"""
Subtask 1 – Depth Estimation
1. Classical method : SGBM Stereo Matching on a synthesised stereo pair
2. ML-based method : Actual MiDaS (MiDaS_small) via torch.hub
3. Both rendered as heatmaps (hot colours = close, cold colours = far)
Usage:
python depth_estimation.py <image_path> [output_dir]
Example:
python depth_estimation.py street.jpg output/
"""
import sys
import os
import builtins
import csv
import cv2
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import torch
# ═══════════════════════════════════════════════════════════
# 0. LOAD IMAGE (real image required)
# ═══════════════════════════════════════════════════════════
def load_image(path: str) -> np.ndarray:
if not path or not os.path.exists(path):
sys.exit(
f"ERROR: Image not found: '{path}'\n"
"Usage: python depth_estimation.py <image_path>\n"
"Example: python depth_estimation.py street.jpg"
)
img = cv2.imread(path)
if img is None:
sys.exit(f"ERROR: Could not read image: '{path}'")
print(f"Loaded: {path} {img.shape[1]}x{img.shape[0]} ({img.shape[2]} channels)")
return img
# ═══════════════════════════════════════════════════════════
# 1. CLASSICAL METHOD – SGBM STEREO MATCHING
# ═══════════════════════════════════════════════════════════
def synthesise_stereo_pair(
img: np.ndarray,
baseline_shift_pct: float = 0.03
) -> tuple:
"""
Simulate a stereo pair from a monocular image.
A per-pixel disparity seed is estimated from two monocular cues:
- Focus sharpness (Laplacian magnitude): sharp regions β†’ close
- Vertical position (perspective geometry): lower in frame β†’ close
That seed drives a horizontal warp to produce the right view,
mimicking a camera shifted by `baseline_shift_pct * width` pixels.
This is the same bootstrap step used in single-image SfM pipelines.
"""
h, w = img.shape[:2]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Sharpness cue
lap = cv2.Laplacian(gray.astype(np.float32), cv2.CV_32F)
sharpness = gaussian_filter(np.abs(lap), sigma=5)
sharpness = sharpness / (sharpness.max() + 1e-6)
# Vertical prior
vert = np.linspace(0, 1, h)[:, None] * np.ones((h, w))
# Combine and smooth
closeness = 0.5 * sharpness + 0.5 * vert
closeness = gaussian_filter(closeness.astype(np.float32), sigma=10)
closeness = (closeness - closeness.min()) / (closeness.max() - closeness.min() + 1e-6)
max_shift = int(w * baseline_shift_pct)
disp_seed = (closeness * max_shift).astype(np.float32)
# Warp: right image looks slightly to the left
map_x = np.tile(np.arange(w, dtype=np.float32), (h, 1)) - disp_seed
map_y = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w))
right = cv2.remap(img, map_x, map_y, cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
return img.copy(), right, max_shift
def sgbm_depth(
img: np.ndarray,
baseline_shift_pct: float = 0.03,
block_size: int = 7,
uniqueness_ratio: int = 10,
speckle_window_size: int = 100,
speckle_range: int = 2
) -> tuple:
"""
Semi-Global Block Matching (HirschmΓΌller 2008).
SGBM minimises a global energy function across multiple 1-D scanline
paths (8 directions in SGBM_3WAY mode), combining a per-pixel data
cost (census transform) with smoothness penalties P1/P2 that penalise
disparity discontinuities.
Returns:
depth_norm – normalised closeness map [0, 1], 1 = close
left_img – left view of stereo pair
right_img – right view of stereo pair
"""
left_img, right_img, max_shift = synthesise_stereo_pair(
img, baseline_shift_pct=baseline_shift_pct
)
left_g = cv2.cvtColor(left_img, cv2.COLOR_BGR2GRAY)
right_g = cv2.cvtColor(right_img, cv2.COLOR_BGR2GRAY)
num_disp = max(16, ((max_shift // 16) + 1) * 16) # must be multiple of 16
block = max(3, int(block_size))
if block % 2 == 0:
block += 1
matcher = cv2.StereoSGBM_create(
minDisparity = 0,
numDisparities = num_disp,
blockSize = block,
P1 = 8 * 3 * block ** 2, # small-discontinuity penalty
P2 = 32 * 3 * block ** 2, # large-discontinuity penalty
disp12MaxDiff = 5,
uniquenessRatio = uniqueness_ratio,
speckleWindowSize = speckle_window_size,
speckleRange = speckle_range,
mode = cv2.STEREO_SGBM_MODE_SGBM_3WAY
)
disp = matcher.compute(left_g, right_g).astype(np.float32) / 16.0
disp = np.maximum(disp, 0)
# Edge-preserving smoothing (bilateral keeps object boundaries clean)
disp = cv2.bilateralFilter(disp, d=9, sigmaColor=75, sigmaSpace=75)
# Normalise to [0, 1]: high disparity = close = 1
d = (disp - disp.min()) / (disp.max() - disp.min() + 1e-6)
# Guided filter refinement β€” sharpens depth edges using the colour image
d_8u = (d * 255).clip(0, 255).astype(np.uint8)
d = cv2.ximgproc.guidedFilter(
guide=left_g, src=d_8u, radius=8, eps=200, dDepth=cv2.CV_32F)
d = np.clip(d / (d.max() + 1e-6), 0, 1)
return d, left_img, right_img
# ═══════════════════════════════════════════════════════════
# 2. ML-BASED METHOD – Actual MiDaS (MiDaS_small)
# ═══════════════════════════════════════════════════════════
def load_midas(model_type: str = "MiDaS_small"):
"""
Load MiDaS from torch.hub (intel-isl/MiDaS).
Available model_type values (largest β†’ smallest / slowest β†’ fastest):
"DPT_Large" – DPT-L (ViT-L backbone, best quality)
"DPT_Hybrid" – DPT-H (ViT-H + ResNet50, good balance)
"MiDaS" – MiDaS v2.1 large (ResNet-101)
"MiDaS_small" – MiDaS v2.1 small (EfficientNet-Lite, fast) ← default
Weights are cached in ~/.cache/torch/hub/ after the first download.
"""
print(f"[ MiDaS ] Loading model '{model_type}' from torch.hub ...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
# Hugging Face / Gradio deployments are non-interactive. Some MiDaS variants
# (notably MiDaS_small) may trigger a *secondary* torch.hub download from
# `rwightman/gen-efficientnet-pytorch` without `trust_repo=True`, which would
# prompt for confirmation and crash with EOFError.
#
# We handle this in two layers:
# 1) Pre-trust the dependency repo (best-effort).
# 2) During the actual MiDaS load, temporarily auto-answer any trust prompt.
if model_type == "MiDaS_small":
try:
torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=True,
trust_repo=True,
)
except Exception:
pass
_orig_input = builtins.input
try:
builtins.input = lambda *_args, **_kwargs: "y"
model = torch.hub.load("intel-isl/MiDaS", model_type, trust_repo=True)
model.to(device).eval()
finally:
builtins.input = _orig_input
transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
transform = (transforms.small_transform
if model_type == "MiDaS_small"
else transforms.dpt_transform)
n_params = sum(p.numel() for p in model.parameters())
print(f" Model loaded ({n_params:,} parameters)")
return model, transform, device
def midas_depth(
img: np.ndarray,
model,
transform,
device: torch.device
) -> np.ndarray:
"""
Run MiDaS inference on a BGR image.
MiDaS predicts *inverse* relative depth (disparity-like): larger values
correspond to closer surfaces. We normalise to [0, 1] so 1 = close.
Pipeline:
BGR image
β†’ RGB conversion
β†’ MiDaS transform (resize to 256x256 + ImageNet normalisation)
β†’ EfficientNet encoder (feature extraction)
β†’ decoder + skip connections
β†’ bilinear upsample to original resolution
β†’ normalise to [0, 1]
Returns:
depth_norm – closeness map [0, 1] at original image resolution
"""
h, w = img.shape[:2]
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Preprocess: resize + normalise
input_batch = transform(img_rgb).to(device)
with torch.no_grad():
prediction = model(input_batch)
# Upsample back to original resolution
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=(h, w),
mode="bilinear",
align_corners=False,
).squeeze()
depth = prediction.cpu().numpy()
# MiDaS output is inverse depth β€” higher value means closer.
# Normalise to [0, 1].
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
return depth.astype(np.float32)
# ═══════════════════════════════════════════════════════════
# 3. VISUALISATION
# ═══════════════════════════════════════════════════════════
def depth_to_heatmap(depth: np.ndarray) -> np.ndarray:
"""depth [0,1] where 1=close β†’ turbo BGR heatmap image."""
cmap = plt.get_cmap("turbo")
rgb = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
return cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
def compute_depth_metrics(img: np.ndarray, depth_cl: np.ndarray, depth_ml: np.ndarray) -> dict:
"""
Internal diagnostics only (no ground truth).
Produces simple summary + cross-method agreement metrics.
"""
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
def grad_mag(x: np.ndarray) -> np.ndarray:
gx = cv2.Sobel(x, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(x, cv2.CV_32F, 0, 1, ksize=3)
return np.sqrt(gx * gx + gy * gy)
def safe_corr(a: np.ndarray, b: np.ndarray) -> float | None:
a = a.reshape(-1)
b = b.reshape(-1)
if a.size == 0:
return None
a = a.astype(np.float64)
b = b.astype(np.float64)
a -= a.mean()
b -= b.mean()
denom = (np.sqrt((a * a).sum()) * np.sqrt((b * b).sum())) + 1e-12
return float((a * b).sum() / denom)
# Basic stats
metrics = {
"classical_mean": float(depth_cl.mean()),
"classical_std": float(depth_cl.std()),
"midas_mean": float(depth_ml.mean()),
"midas_std": float(depth_ml.std()),
}
# Cross-method agreement
metrics["cross_pearson"] = safe_corr(depth_cl, depth_ml)
# Edge alignment (depth edges should line up with image edges)
img_edges = grad_mag(gray)
metrics["edge_align_classical"] = safe_corr(grad_mag(depth_cl), img_edges)
metrics["edge_align_midas"] = safe_corr(grad_mag(depth_ml), img_edges)
return metrics
def depth_metrics_table(metrics: dict) -> list[list[str]]:
"""Small table (only key metrics). Returns rows: [metric, value]."""
def fmt(v):
if v is None:
return "N/A"
if isinstance(v, float):
return f"{v:.4f}"
return str(v)
keys = [
("classical_mean", "classical_mean"),
("classical_std", "classical_std"),
("midas_mean", "midas_mean"),
("midas_std", "midas_std"),
("cross_pearson", "cross_pearson"),
("edge_align_classical", "edge_align_classical"),
("edge_align_midas", "edge_align_midas"),
]
return [[label, fmt(metrics.get(k))] for label, k in keys]
def save_depth_evaluation(out_dir: str, metrics: dict) -> str:
eval_dir = os.path.join(out_dir, "evaluation")
os.makedirs(eval_dir, exist_ok=True)
table_path = os.path.join(eval_dir, "metrics_table.csv")
with open(table_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["metric", "value"])
writer.writerows(depth_metrics_table(metrics))
print(f"Saved -> {table_path}")
return table_path
def visualise_results(
img: np.ndarray,
depth_cl: np.ndarray,
depth_ml: np.ndarray,
out_path: str = "output/depth_estimation_subtask1.png"
) -> None:
"""
Compose a 3-column figure:
Col 1 – Original image
Col 2 – Classical SGBM heatmap + scan-line profiles
Col 3 – MiDaS heatmap + scan-line profiles
"""
h, w = img.shape[:2]
ncols = 3
fig = plt.figure(figsize=(ncols * 5.6, 11), dpi=130)
fig.patch.set_facecolor("#1a1a2e")
titles = [
"Original Image",
"Classical Depth\n(SGBM Stereo Matching)",
"ML-Based Depth\n(MiDaS_small β€” actual model)",
]
depths = [None, depth_cl, depth_ml]
ax_top = [fig.add_subplot(2, ncols, c + 1) for c in range(ncols)]
ax_bot = [fig.add_subplot(2, ncols, ncols + c + 1) for c in range(ncols)]
# ── Top row: images / heatmaps ──
for ax, title, d in zip(ax_top, titles, depths):
ax.set_title(title, color="white", fontsize=10, fontweight="bold", pad=8)
ax.axis("off")
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if d is None:
ax.imshow(rgb)
else:
cmap_arr = plt.get_cmap("turbo")(d)[:, :, :3]
blended = rgb.astype(np.float32) / 255 * 0.22 + cmap_arr * 0.78
ax.imshow(blended)
sm = plt.cm.ScalarMappable(cmap="turbo",
norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
cb = plt.colorbar(sm, ax=ax, fraction=0.03, pad=0.02)
cb.set_label("Near -> Far", color="white", fontsize=7)
cb.set_ticks([0, 0.5, 1])
cb.set_ticklabels(["Far", "Mid", "Near"], color="white", fontsize=7)
cb.ax.yaxis.set_tick_params(color="white")
# ── Scan lines on heatmap panels ──
scan_ys = [int(h * f) for f in [0.25, 0.50, 0.75]]
scan_colors = ["#ff6b6b", "#ffd93d", "#6bcb77"]
for ax in ax_top[1:]:
for sy, sc in zip(scan_ys, scan_colors):
ax.axhline(sy, color=sc, linewidth=1.2, alpha=0.75)
# ── Bottom row: depth profile plots ──
x = np.arange(w)
method_maps = [depth_cl, depth_ml]
method_names = ["Classical (SGBM)", "MiDaS (actual)"]
ls = ["-", "--"]
for col, ax in enumerate(ax_bot):
ax.set_facecolor("#16213e")
for sp in ["top", "right"]: ax.spines[sp].set_visible(False)
for sp in ["bottom", "left"]: ax.spines[sp].set_color("#555")
ax.tick_params(colors="#888", labelsize=7)
ax.set_xlim(0, w - 1)
ax.set_ylim(-0.05, 1.05)
ax.set_xlabel("Pixel x", color="#aaa", fontsize=8)
ax.set_ylabel("Closeness (1 = near)", color="#aaa", fontsize=8)
if col == 0:
# Compare both methods at the middle scan line
ax.set_title("Method comparison β€” middle scan line",
color="white", fontsize=9, pad=6)
sy = scan_ys[1]
for mp, nm, l in zip(method_maps, method_names, ls):
ax.plot(x, mp[sy, :], linestyle=l, linewidth=1.6, label=nm)
ax.legend(fontsize=8, framealpha=0.25, labelcolor="white")
else:
# Per-method: three scan lines
mp = method_maps[col - 1]
nm = method_names[col - 1]
ax.set_title(f"{nm} β€” scan-line profiles",
color="white", fontsize=9, pad=6)
for sy, sc in zip(scan_ys, scan_colors):
ax.plot(x, mp[sy, :], color=sc, linewidth=1.4,
label=f"y = {sy}")
ax.legend(fontsize=7, framealpha=0.25, labelcolor="white")
# ── Colour scale strip ──
ax_s = fig.add_axes([0.05, 0.01, 0.90, 0.022])
ax_s.imshow(np.linspace(0, 1, 512).reshape(1, -1),
aspect="auto", cmap="turbo")
ax_s.set_yticks([])
ax_s.set_xticks([0, 170, 341, 511])
ax_s.set_xticklabels(
["Far (cold / blue)", "Mid-far", "Mid-close", "Close (hot / red)"],
color="white", fontsize=8
)
plt.suptitle(
"Subtask 1 β€” Classical (SGBM) vs ML-Based (MiDaS) Depth Estimation\n"
"Heatmap: red/hot = close blue/cold = far",
color="white", fontsize=13, fontweight="bold", y=1.003
)
plt.tight_layout(rect=[0, 0.05, 1, 1])
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
plt.savefig(out_path, dpi=130, bbox_inches="tight",
facecolor=fig.get_facecolor())
plt.close(fig)
print(f"Saved -> {out_path}")
# ═══════════════════════════════════════════════════════════
# 4. MAIN
# ═══════════════════════════════════════════════════════════
def main() -> None:
if len(sys.argv) < 2:
sys.exit(
"Usage: python depth_estimation.py <image_path> [output_dir]\n"
"Example: python depth_estimation.py street.jpg output/"
)
image_path = sys.argv[1]
out_dir = sys.argv[2] if len(sys.argv) > 2 else "output"
# ── Load image ──
img = load_image(image_path)
# ── Classical: SGBM ──
print("\n[ Classical ] Running SGBM stereo matching ...")
depth_cl, left_img, right_img = sgbm_depth(img)
print(f" Done. depth in [0,1] mean={depth_cl.mean():.3f}")
# ── ML: actual MiDaS ──
print("\n[ MiDaS ] Loading and running MiDaS_small ...")
midas_model, midas_transform, device = load_midas("MiDaS_small")
depth_ml = midas_depth(img, midas_model, midas_transform, device)
print(f" Done. depth in [0,1] mean={depth_ml.mean():.3f}")
# ── Save outputs ──
os.makedirs(out_dir, exist_ok=True)
cv2.imwrite(os.path.join(out_dir, "classical_heatmap.png"),
depth_to_heatmap(depth_cl))
cv2.imwrite(os.path.join(out_dir, "midas_heatmap.png"),
depth_to_heatmap(depth_ml))
cv2.imwrite(os.path.join(out_dir, "stereo_left.png"), left_img)
cv2.imwrite(os.path.join(out_dir, "stereo_right.png"), right_img)
print("\n[ Visualise ] Compositing final figure ...")
visualise_results(
img, depth_cl, depth_ml,
out_path=os.path.join(out_dir, "depth_estimation_subtask1.png")
)
print("\n[ Eval ] Writing evaluation table ...")
metrics = compute_depth_metrics(img, depth_cl, depth_ml)
save_depth_evaluation(out_dir, metrics)
print(f"\nDone. Outputs written to: {out_dir}/")
if __name__ == "__main__":
main()