Spaces:
Sleeping
Sleeping
| """ | |
| Deterministic ablation: Stage 1 (VLM only) vs Stage 2 (VLM + depth context) | |
| vs Stage 3 (VLM + depth context + YOLOv8n). | |
| All metrics are reference-free — no human-written descriptions are required. | |
| Ground truth is derived from the depth context preamble itself (which contains | |
| the exact spatial facts injected into the VLM) and from a fixed spatial | |
| vocabulary. | |
| Metrics computed | |
| ---------------- | |
| Spatial Term Density (STD) | |
| Count of directional/distance terms per 100 words. Stage 2 and Stage 3 | |
| descriptions should contain substantially more spatial language than Stage 1. | |
| Preamble BERTScore (P/R/F1) | |
| BERTScore [Zhang et al., ICLR 2020] computed between each stage's | |
| description and the depth context preamble that was prepended to that | |
| stage's prompt. Stage 3 is scored against its OWN preamble (which | |
| includes per-object YOLO measurements absent from Stage 2's preamble). | |
| Stage 1 (no preamble) descriptions are scored against the Stage 2 preamble | |
| as a null baseline. | |
| Typical usage (library):: | |
| from src.evaluation.bertscore_ablation import run_ablation | |
| rows, summary = run_ablation(image_paths) | |
| CLI usage:: | |
| python -m src.evaluation.bertscore_ablation \\ | |
| --images data/test_images/ \\ | |
| --output outputs/results/bertscore_ablation.csv | |
| """ | |
| import argparse | |
| import csv | |
| import re | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| from typing import Sequence | |
| import numpy as np | |
| import torch | |
| from bert_score import score as compute_bertscore | |
| from PIL import Image | |
| from ..config import RESULTS_DIR | |
| from ..pipeline import Pipeline | |
| # --------------------------------------------------------------------------- | |
| # Spatial vocabulary for Spatial Term Density (STD) | |
| # --------------------------------------------------------------------------- | |
| _DIRECTION_TERMS: frozenset[str] = frozenset({ | |
| "left", "right", "centre", "center", "ahead", "behind", | |
| "front", "back", "beside", "between", | |
| }) | |
| _DISTANCE_TERMS: frozenset[str] = frozenset({ | |
| "cm", "metre", "metres", "meter", "meters", "m", | |
| "near", "nearby", "close", "far", "away", "approximately", "about", | |
| }) | |
| _SPATIAL_TERMS: frozenset[str] = _DIRECTION_TERMS | _DISTANCE_TERMS | |
| def _spatial_term_density(text: str) -> float: | |
| """Count spatial vocabulary terms per 100 words. | |
| Args: | |
| text: Raw description string. | |
| Returns: | |
| Spatial term density (float, ≥ 0). | |
| """ | |
| words = re.findall(r"[a-zA-Z]+", text.lower()) | |
| if not words: | |
| return 0.0 | |
| hits = sum(1 for w in words if w in _SPATIAL_TERMS) | |
| return hits / len(words) * 100.0 | |
| # --------------------------------------------------------------------------- | |
| # CSV schema | |
| # --------------------------------------------------------------------------- | |
| _FIELDNAMES = [ | |
| "image", | |
| # Stage 1 | |
| "s1_total_s", | |
| "s1_spatial_density", | |
| # Stage 2 | |
| "s2_depth_s", "s2_vlm_s", "s2_total_s", | |
| "s2_spatial_density", | |
| "spatial_uplift", # s2_density - s1_density | |
| # Preamble BERTScore (Stage 2 description vs injected preamble) | |
| "preamble_P", "preamble_R", "preamble_F1", | |
| # Null baseline BERTScore (Stage 1 description vs same preamble) | |
| "baseline_P", "baseline_R", "baseline_F1", | |
| # Delta F1 (preamble faithfulness gain from adding depth context) | |
| "delta_faith_F1", | |
| # Stage 3 (VLM + depth + YOLO) | |
| "s3_total_s", "s3_depth_s", "s3_detect_s", "s3_vlm_s", | |
| "s3_spatial_density", | |
| "s3_num_objects", | |
| # Preamble BERTScore Stage 3 (vs its own YOLO-enriched preamble) | |
| "preamble_F1_s3", | |
| # Faithfulness deltas for Stage 3 | |
| "delta_faith_F1_s3_vs_s1", # S3 F1 − S1 F1 | |
| "delta_faith_F1_s3_vs_s2", # S3 F1 − S2 F1 (headline S2→S3 increment) | |
| # Raw text | |
| "s1_description", | |
| "s2_description", | |
| "s3_description", | |
| "depth_context", | |
| ] | |
| # Sentinel timing dicts used when a stage errors out on one image. | |
| _TIMING1_ZERO: dict[str, float] = {"vlm_s": 0.0, "total_s": 0.0, "vram_mb": 0.0} | |
| _TIMING2_ZERO: dict[str, float] = { | |
| "depth_s": 0.0, "vlm_s": 0.0, "total_s": 0.0, "vram_mb": 0.0, | |
| } | |
| _TIMING3_ZERO: dict[str, float] = { | |
| "depth_s": 0.0, "yolo_s": 0.0, "vlm_s": 0.0, | |
| "total_s": 0.0, "vram_mb": 0.0, "n_detections": 0.0, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Main function | |
| # --------------------------------------------------------------------------- | |
| def run_ablation( | |
| image_paths: Sequence[str | Path], | |
| output_csv: Path | None = None, | |
| bert_device: str = "cpu", | |
| skip_stage3: bool = False, | |
| ) -> tuple[list[dict], dict]: | |
| """Run deterministic Stage 1 / Stage 2 / Stage 3 ablation over a set of images. | |
| No human-written references are needed. Ground truth is derived from: | |
| - The depth context preamble injected into each stage (Preamble BERTScore) | |
| - A fixed spatial vocabulary (Spatial Term Density) | |
| Descriptions are collected from all images first, then BERTScore is | |
| computed in a single batched call (more efficient than N individual calls). | |
| Failed images are included in the CSV with zero scores so they don't | |
| silently skew averages. | |
| Args: | |
| image_paths: Ordered list of image file paths. | |
| output_csv: Destination CSV file. Defaults to | |
| ``RESULTS_DIR/bertscore_ablation.csv``. | |
| bert_device: Device for BERTScore model (``"cpu"`` keeps it off the | |
| GPU that the pipeline models occupy). | |
| skip_stage3: If True, Stage 3 columns are filled with zeros and | |
| ``run_stage3`` is never called. Use when YOLO is | |
| unavailable. | |
| Returns: | |
| rows: List of per-image result dicts (same columns as CSV). | |
| summary: Dict with mean/std for each numeric column. | |
| """ | |
| if output_csv is None: | |
| output_csv = RESULTS_DIR / "bertscore_ablation.csv" | |
| output_csv = Path(output_csv) | |
| output_csv.parent.mkdir(parents=True, exist_ok=True) | |
| n = len(image_paths) | |
| pipeline = Pipeline() | |
| # ── Step 1: run all stages on every image ───────────────────────────────── | |
| s1_descriptions: list[str] = [] | |
| s2_descriptions: list[str] = [] | |
| s3_descriptions: list[str] = [] | |
| s2_preambles: list[str] = [] | |
| s3_preambles: list[str] = [] | |
| s1_timings: list[dict[str, float]] = [] | |
| s2_timings: list[dict[str, float]] = [] | |
| s3_timings: list[dict[str, float]] = [] | |
| errors: list[str] = [""] * n | |
| for i, img_path in enumerate(image_paths): | |
| img_path = Path(img_path) | |
| print(f"[{i + 1}/{n}] {img_path.name}", flush=True) | |
| try: | |
| frame_rgb = np.array(Image.open(img_path).convert("RGB")) | |
| except Exception: | |
| msg = f"LOAD_ERROR: {traceback.format_exc(limit=1).strip()}" | |
| print(f" WARNING: {msg}") | |
| errors[i] = msg | |
| s1_descriptions.append("") | |
| s2_descriptions.append("") | |
| s3_descriptions.append("") | |
| s2_preambles.append("") | |
| s3_preambles.append("") | |
| s1_timings.append(_TIMING1_ZERO.copy()) | |
| s2_timings.append(_TIMING2_ZERO.copy()) | |
| s3_timings.append(_TIMING3_ZERO.copy()) | |
| continue | |
| # Stage 1 | |
| try: | |
| desc1, t1 = pipeline.run_stage1(frame_rgb) | |
| print(f" S1 {t1['total_s']:.2f}s | {desc1[:80]}...") | |
| except Exception: | |
| msg = f"STAGE1_ERROR: {traceback.format_exc(limit=1).strip()}" | |
| print(f" WARNING: {msg}") | |
| errors[i] = msg | |
| desc1, t1 = "", _TIMING1_ZERO.copy() | |
| # Stage 2 — also captures the preamble (depth context) | |
| try: | |
| desc2, ctx2, t2 = pipeline.run_stage2(frame_rgb) | |
| print(f" S2 {t2['total_s']:.2f}s | {desc2[:80]}...") | |
| except Exception: | |
| msg2 = f"STAGE2_ERROR: {traceback.format_exc(limit=1).strip()}" | |
| print(f" WARNING: {msg2}") | |
| if not errors[i]: | |
| errors[i] = msg2 | |
| desc2, ctx2, t2 = "", "", _TIMING2_ZERO.copy() | |
| # Stage 3 — own preamble includes per-object YOLO measurements | |
| if skip_stage3: | |
| desc3, ctx3, t3 = "", "", _TIMING3_ZERO.copy() | |
| else: | |
| try: | |
| desc3, ctx3, t3 = pipeline.run_stage3(frame_rgb) | |
| n_det = int(t3.get("n_detections", 0)) | |
| print( | |
| f" S3 {t3['total_s']:.2f}s | " | |
| f"objects={n_det} | {desc3[:80]}..." | |
| ) | |
| except Exception: | |
| msg3 = f"STAGE3_ERROR: {traceback.format_exc(limit=1).strip()}" | |
| print(f" WARNING: {msg3}") | |
| if not errors[i]: | |
| errors[i] = msg3 | |
| desc3, ctx3, t3 = "", "", _TIMING3_ZERO.copy() | |
| s1_descriptions.append(desc1) | |
| s2_descriptions.append(desc2) | |
| s3_descriptions.append(desc3) | |
| s2_preambles.append(ctx2) | |
| s3_preambles.append(ctx3) | |
| s1_timings.append(t1) | |
| s2_timings.append(t2) | |
| s3_timings.append(t3) | |
| # ── Step 2: Spatial Term Density (no model required) ───────────────────── | |
| s1_densities = [_spatial_term_density(d) for d in s1_descriptions] | |
| s2_densities = [_spatial_term_density(d) for d in s2_descriptions] | |
| s3_densities = [_spatial_term_density(d) for d in s3_descriptions] | |
| # Free GPU memory before loading BERTScore model so there is headroom | |
| # even when the pipeline models are still in scope. | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ── Step 3: Preamble BERTScore (one batched call per comparison) ────────── | |
| # Stage 1 and Stage 2 are both scored against the Stage 2 preamble. | |
| # Stage 3 is scored against its own (YOLO-enriched) preamble. | |
| preamble_refs_s2 = [p if p else " " for p in s2_preambles] | |
| preamble_refs_s3 = [p if p else " " for p in s3_preambles] | |
| cands1 = [d if d else " " for d in s1_descriptions] | |
| cands2 = [d if d else " " for d in s2_descriptions] | |
| cands3 = [d if d else " " for d in s3_descriptions] | |
| print("\nComputing Preamble BERTScore (Stage 1 baseline)...", flush=True) | |
| bP1_t, bR1_t, bF1_1_t = compute_bertscore( | |
| cands1, preamble_refs_s2, lang="en", device=bert_device, verbose=False | |
| ) | |
| print("Computing Preamble BERTScore (Stage 2)...", flush=True) | |
| bP2_t, bR2_t, bF1_2_t = compute_bertscore( | |
| cands2, preamble_refs_s2, lang="en", device=bert_device, verbose=False | |
| ) | |
| if not skip_stage3: | |
| print("Computing Preamble BERTScore (Stage 3 vs own preamble)...", flush=True) | |
| _, _, bF1_3_t = compute_bertscore( | |
| cands3, preamble_refs_s3, lang="en", device=bert_device, verbose=False | |
| ) | |
| bF1_3 = bF1_3_t.tolist() | |
| else: | |
| bF1_3 = [0.0] * n | |
| bP1 = bP1_t.tolist(); bR1 = bR1_t.tolist(); bF1_1 = bF1_1_t.tolist() | |
| bP2 = bP2_t.tolist(); bR2 = bR2_t.tolist(); bF1_2 = bF1_2_t.tolist() | |
| # ── Step 4: assemble per-image rows ─────────────────────────────────────── | |
| rows: list[dict] = [] | |
| for i, img_path in enumerate(image_paths): | |
| t1, t2, t3 = s1_timings[i], s2_timings[i], s3_timings[i] | |
| row: dict = { | |
| "image": Path(img_path).name, | |
| # Stage 1 | |
| "s1_total_s": round(t1["total_s"], 3), | |
| "s1_spatial_density": round(s1_densities[i], 2), | |
| # Stage 2 | |
| "s2_depth_s": round(t2.get("depth_s", 0.0), 3), | |
| "s2_vlm_s": round(t2.get("vlm_s", 0.0), 3), | |
| "s2_total_s": round(t2["total_s"], 3), | |
| "s2_spatial_density": round(s2_densities[i], 2), | |
| "spatial_uplift": round(s2_densities[i] - s1_densities[i], 2), | |
| # Preamble BERTScore — Stage 2 | |
| "preamble_P": round(bP2[i], 4), | |
| "preamble_R": round(bR2[i], 4), | |
| "preamble_F1": round(bF1_2[i], 4), | |
| # Preamble BERTScore — Stage 1 null baseline | |
| "baseline_P": round(bP1[i], 4), | |
| "baseline_R": round(bR1[i], 4), | |
| "baseline_F1": round(bF1_1[i], 4), | |
| # Faithfulness gain S1 → S2 | |
| "delta_faith_F1": round(bF1_2[i] - bF1_1[i], 4), | |
| # Stage 3 | |
| "s3_total_s": round(t3["total_s"], 3), | |
| "s3_depth_s": round(t3.get("depth_s", 0.0), 3), | |
| "s3_detect_s": round(t3.get("yolo_s", 0.0), 3), | |
| "s3_vlm_s": round(t3.get("vlm_s", 0.0), 3), | |
| "s3_spatial_density": round(s3_densities[i], 2), | |
| "s3_num_objects": int(t3.get("n_detections", 0)), | |
| # Preamble BERTScore — Stage 3 vs own preamble | |
| "preamble_F1_s3": round(bF1_3[i], 4), | |
| # Faithfulness deltas | |
| "delta_faith_F1_s3_vs_s1": round(bF1_3[i] - bF1_1[i], 4), | |
| "delta_faith_F1_s3_vs_s2": round(bF1_3[i] - bF1_2[i], 4), | |
| # Raw text | |
| "s1_description": s1_descriptions[i] or errors[i], | |
| "s2_description": s2_descriptions[i] or errors[i], | |
| "s3_description": s3_descriptions[i] or errors[i], | |
| "depth_context": s2_preambles[i], | |
| } | |
| rows.append(row) | |
| # ── Step 5: compute summary statistics ──────────────────────────────────── | |
| numeric_cols = [ | |
| "s1_total_s", "s1_spatial_density", | |
| "s2_depth_s", "s2_vlm_s", "s2_total_s", | |
| "s2_spatial_density", "spatial_uplift", | |
| "preamble_P", "preamble_R", "preamble_F1", | |
| "baseline_P", "baseline_R", "baseline_F1", | |
| "delta_faith_F1", | |
| "s3_total_s", "s3_depth_s", "s3_detect_s", "s3_vlm_s", | |
| "s3_spatial_density", "s3_num_objects", | |
| "preamble_F1_s3", | |
| "delta_faith_F1_s3_vs_s1", "delta_faith_F1_s3_vs_s2", | |
| ] | |
| summary: dict = {} | |
| for col in numeric_cols: | |
| vals = np.array([r[col] for r in rows], dtype=np.float64) | |
| summary[f"{col}_mean"] = round(float(vals.mean()), 4) | |
| summary[f"{col}_std"] = round(float(vals.std()), 4) | |
| text_cols = {"s1_description": "", "s2_description": "", | |
| "s3_description": "", "depth_context": ""} | |
| mean_row: dict = {"image": "MEAN"} | { | |
| col: summary[f"{col}_mean"] for col in numeric_cols | |
| } | text_cols | |
| std_row: dict = {"image": "STD"} | { | |
| col: summary[f"{col}_std"] for col in numeric_cols | |
| } | text_cols | |
| # ── Step 6: write CSV ───────────────────────────────────────────────────── | |
| with open(output_csv, "w", newline="", encoding="utf-8") as fh: | |
| writer = csv.DictWriter(fh, fieldnames=_FIELDNAMES) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| writer.writerow({}) # blank separator before summary | |
| writer.writerow(mean_row) | |
| writer.writerow(std_row) | |
| # ── Step 7: print summary ───────────────────────────────────────────────── | |
| _print_summary(summary, n, output_csv, skip_stage3=skip_stage3) | |
| return rows, summary | |
| # --------------------------------------------------------------------------- | |
| # Pretty-print helper | |
| # --------------------------------------------------------------------------- | |
| def _print_summary( | |
| summary: dict, | |
| n: int, | |
| output_csv: Path, | |
| skip_stage3: bool = False, | |
| ) -> None: | |
| """Print a human-readable summary table to stdout.""" | |
| s3_label = "Stage 3" if not skip_stage3 else "Stage 3*" | |
| sep = "-" * 72 | |
| print(f"\n{sep}") | |
| print(f" Depth-Aware Ablation ({n} images) — reference-free metrics") | |
| if skip_stage3: | |
| print(" * Stage 3 skipped (--skip-stage3)") | |
| print(sep) | |
| print( | |
| f" {'Metric':<28} {'Stage 1':>10} {'Stage 2':>10} {s3_label:>10}" | |
| ) | |
| print(sep) | |
| # Spatial Term Density | |
| s1_std = summary["s1_spatial_density_mean"] | |
| s2_std = summary["s2_spatial_density_mean"] | |
| s3_std = summary["s3_spatial_density_mean"] | |
| print( | |
| f" {'Spatial Term Density':<28} " | |
| f"{s1_std:>10.2f} {s2_std:>10.2f} {s3_std:>10.2f} " | |
| f"(S2 +{summary['spatial_uplift_mean']:.2f})" | |
| ) | |
| print(sep) | |
| # Preamble BERTScore | |
| print(f" Preamble BERTScore (vs depth context preamble)") | |
| for metric, b_key, p_key, p3_key in ( | |
| ("P", "baseline_P", "preamble_P", None), | |
| ("R", "baseline_R", "preamble_R", None), | |
| ("F1", "baseline_F1", "preamble_F1", "preamble_F1_s3"), | |
| ): | |
| bm = summary[f"{b_key}_mean"] | |
| bsd = summary[f"{b_key}_std"] | |
| pm = summary[f"{p_key}_mean"] | |
| psd = summary[f"{p_key}_std"] | |
| if p3_key: | |
| p3m = summary[f"{p3_key}_mean"] | |
| p3sd = summary[f"{p3_key}_std"] | |
| print( | |
| f" {metric:<28} " | |
| f"{bm:.4f}±{bsd:.4f} " | |
| f"{pm:.4f}±{psd:.4f} " | |
| f"{p3m:.4f}±{p3sd:.4f}" | |
| ) | |
| else: | |
| print( | |
| f" {metric:<28} " | |
| f"{bm:.4f}±{bsd:.4f} " | |
| f"{pm:.4f}±{psd:.4f} " | |
| f"{'—':>14}" | |
| ) | |
| print(sep) | |
| # Faithfulness deltas | |
| df12 = summary["delta_faith_F1_mean"] | |
| df12s = summary["delta_faith_F1_std"] | |
| df31 = summary["delta_faith_F1_s3_vs_s1_mean"] | |
| df31s = summary["delta_faith_F1_s3_vs_s1_std"] | |
| df32 = summary["delta_faith_F1_s3_vs_s2_mean"] | |
| df32s = summary["delta_faith_F1_s3_vs_s2_std"] | |
| print( | |
| f" {'Faith. gain S1→S2 (ΔF1)':<28} " | |
| f"{'':>10} {df12:>+.4f}±{df12s:.4f}" | |
| ) | |
| if not skip_stage3: | |
| print( | |
| f" {'Faith. gain S1→S3 (ΔF1)':<28} " | |
| f"{'':>10} {'':>10} {df31:>+.4f}±{df31s:.4f}" | |
| ) | |
| print( | |
| f" {'Faith. gain S2→S3 (ΔF1)':<28} " | |
| f"{'':>10} {'':>10} {df32:>+.4f}±{df32s:.4f} ← headline" | |
| ) | |
| print(sep) | |
| # Latency | |
| print(f" Latency S1 (mean) {summary['s1_total_s_mean']:>10.2f}s") | |
| print(f" Latency S2 (mean) {summary['s2_total_s_mean']:>10.2f}s") | |
| print(f" of which depth {summary['s2_depth_s_mean']:>10.2f}s") | |
| print(f" of which VLM {summary['s2_vlm_s_mean']:>10.2f}s") | |
| if not skip_stage3: | |
| print(f" Latency S3 (mean) {summary['s3_total_s_mean']:>10.2f}s") | |
| print(f" of which depth {summary['s3_depth_s_mean']:>10.2f}s") | |
| print(f" of which detect {summary['s3_detect_s_mean']:>10.2f}s") | |
| print(f" of which VLM {summary['s3_vlm_s_mean']:>10.2f}s") | |
| print( | |
| f" Objects/image (mean){summary['s3_num_objects_mean']:>10.2f}" | |
| ) | |
| print(sep) | |
| print(f" Results written to: {output_csv}") | |
| print(sep) | |
| # --------------------------------------------------------------------------- | |
| # CLI entry point | |
| # --------------------------------------------------------------------------- | |
| def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description=( | |
| "Deterministic ablation: Stage 1 vs Stage 2 vs Stage 3. " | |
| "No human references required." | |
| ) | |
| ) | |
| p.add_argument( | |
| "--images", | |
| required=True, | |
| help=( | |
| "Directory of images (sorted alphabetically) OR a text file " | |
| "listing one image path per line." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--output", | |
| default=None, | |
| help="Destination CSV (default: outputs/results/bertscore_ablation.csv).", | |
| ) | |
| p.add_argument( | |
| "--bert-device", | |
| default="cpu", | |
| help="Device for BERTScore model: 'cpu' or 'cuda' (default: cpu).", | |
| ) | |
| p.add_argument( | |
| "--skip-stage3", | |
| action="store_true", | |
| default=False, | |
| help=( | |
| "Skip Stage 3 (YOLOv8n + depth + VLM). " | |
| "Use when ultralytics/YOLO is not installed. " | |
| "Stage 3 columns will be filled with zeros." | |
| ), | |
| ) | |
| return p.parse_args(argv) | |
| def _load_image_paths(images_arg: str) -> list[Path]: | |
| """Return sorted image paths from a directory or a text-file manifest.""" | |
| p = Path(images_arg) | |
| if p.is_dir(): | |
| exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} | |
| paths = sorted(f for f in p.iterdir() if f.suffix.lower() in exts) | |
| if not paths: | |
| raise FileNotFoundError(f"No images found in directory: {p}") | |
| return paths | |
| # Treat as a manifest file | |
| lines = p.read_text(encoding="utf-8").splitlines() | |
| return [Path(line.strip()) for line in lines if line.strip()] | |
| def main(argv: list[str] | None = None) -> None: | |
| """CLI entry point.""" | |
| args = _parse_args(argv) | |
| image_paths = _load_image_paths(args.images) | |
| stage_tag = " (Stage 3 skipped)" if args.skip_stage3 else "" | |
| print( | |
| f"Running ablation on {len(image_paths)} images " | |
| f"(reference-free){stage_tag}..." | |
| ) | |
| run_ablation( | |
| image_paths=image_paths, | |
| output_csv=Path(args.output) if args.output else None, | |
| bert_device=args.bert_device, | |
| skip_stage3=args.skip_stage3, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |