File size: 6,954 Bytes
0e61117 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
#!/usr/bin/env python3
"""
test_heatmap.py
===============
Test script for the Grad-ECLIP heatmap generation pipeline.
This script allows you to test the heatmap generation functionality either
with images from the artifacts directory (using run IDs) or with any local image.
Usage Examples
--------------
# Using a run ID from artifacts directory:
python test_heatmap.py --run-id d72545f1-xxxx-xxxx-xxxx \
--sentence "The men held torches with large flames above their heads" \
--out torch_heatmap.png
# Using a direct image path:
python test_heatmap.py --image-path ~/Pictures/painting.jpg \
--sentence "A beautiful sunset over the mountains" \
--layer-idx -2 \
--out sunset_heatmap.png
# Minimal usage with defaults:
python test_heatmap.py --image-path test.jpg --sentence "A cat sitting on a mat"
Arguments
---------
--run-id : str
Run ID from a previous /presign call (maps to artifacts/<id>.jpg)
--image-path : str
Direct path to an RGB image file
--sentence : str
Text description to generate explanation for (required)
--layer-idx : int
Which vision transformer layer to analyze (default: -1 for last layer)
--out : str
Output PNG filename (default: overlay.png)
--alpha : float
Heatmap overlay opacity, between 0 and 1 (default: 0.45)
--colormap : str
Color scheme for heatmap: 'jet', 'hot', 'viridis', 'plasma' (default: 'jet')
"""
import argparse
import base64
import sys
from pathlib import Path
from typing import Optional
import cv2
# Add parent directory to path for imports (backend/)
sys.path.append(str(Path(__file__).resolve().parent.parent))
from runner.inference import compute_heatmap
# Colormap name to OpenCV constant mapping
COLORMAP_OPTIONS = {
"jet": cv2.COLORMAP_JET,
"hot": cv2.COLORMAP_HOT,
"viridis": cv2.COLORMAP_VIRIDIS,
"plasma": cv2.COLORMAP_PLASMA,
"cool": cv2.COLORMAP_COOL,
"rainbow": cv2.COLORMAP_RAINBOW,
}
def resolve_image_path(run_id: Optional[str], image_path: Optional[str]) -> Path:
"""
Resolve the image path from either a run ID or direct path.
Args:
run_id: Optional run ID that maps to data/artifacts/<id>.jpg
image_path: Optional direct path to image
Returns:
Resolved Path object
Raises:
FileNotFoundError: If the resolved path doesn't exist
"""
if run_id:
# Get artifacts directory - go up to project root then into data/artifacts
artifacts_dir = (
Path(__file__).resolve().parent.parent.parent / "data" / "artifacts"
)
resolved_path = artifacts_dir / f"{run_id}.jpg"
else:
resolved_path = Path(image_path).expanduser().resolve()
if not resolved_path.exists():
raise FileNotFoundError(f"Image not found: {resolved_path}")
if not resolved_path.is_file():
raise ValueError(f"Path is not a file: {resolved_path}")
return resolved_path
def save_heatmap_from_data_url(data_url: str, output_path: Path) -> None:
"""
Extract and save PNG image from base64 data URL.
Args:
data_url: Data URL containing base64-encoded PNG
output_path: Where to save the decoded image
"""
if not data_url.startswith("data:image/png;base64,"):
raise ValueError("Expected PNG data URL")
# Extract base64 portion
_, base64_data = data_url.split(",", 1)
# Decode and save
image_bytes = base64.b64decode(base64_data)
output_path.write_bytes(image_bytes)
def main() -> None:
"""Main entry point for the test script."""
parser = argparse.ArgumentParser(
description="Test Grad-ECLIP heatmap generation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
# Image source (mutually exclusive)
image_group = parser.add_mutually_exclusive_group(required=True)
image_group.add_argument(
"--run-id",
help="Run ID from previous /presign call (looks in data/artifacts/<id>.jpg)",
)
image_group.add_argument("--image-path", help="Direct path to an RGB image file")
# Required arguments
parser.add_argument("--sentence", required=True, help="Text description to explain")
# Optional arguments
parser.add_argument(
"--layer-idx",
type=int,
default=-1,
help="Vision transformer layer index (default: -1 for last layer)",
)
parser.add_argument(
"--out",
default="overlay.png",
help="Output PNG filename (default: overlay.png)",
)
parser.add_argument(
"--alpha",
type=float,
default=0.45,
help="Heatmap overlay opacity 0-1 (default: 0.45)",
)
parser.add_argument(
"--colormap",
choices=list(COLORMAP_OPTIONS.keys()),
default="jet",
help="Color scheme for heatmap (default: jet)",
)
parser.add_argument(
"--verbose", action="store_true", help="Print additional debug information"
)
args = parser.parse_args()
# Validate arguments
if not 0 <= args.alpha <= 1:
parser.error("--alpha must be between 0 and 1")
try:
# Resolve image path
image_path = resolve_image_path(args.run_id, args.image_path)
if args.verbose:
print(f"[DEBUG] Resolved image path: {image_path}")
print(f"[DEBUG] Layer index: {args.layer_idx}")
print(f"[DEBUG] Alpha: {args.alpha}")
print(f"[DEBUG] Colormap: {args.colormap}")
# Print info
print("[INFO] Computing heatmap for:")
print(f" Image: {image_path}")
print(f" Sentence: '{args.sentence}'")
print(f" Output: {args.out}")
# Generate heatmap
print("[INFO] Generating heatmap...")
data_url = compute_heatmap(
str(image_path),
args.sentence,
layer_idx=args.layer_idx,
alpha=args.alpha,
colormap=COLORMAP_OPTIONS[args.colormap],
)
# Save output
output_path = Path(args.out)
save_heatmap_from_data_url(data_url, output_path)
print(f"[SUCCESS] Saved heatmap overlay to: {output_path.absolute()}")
# Print file info if verbose
if args.verbose:
file_size = output_path.stat().st_size
print(f"[DEBUG] Output file size: {file_size:,} bytes")
except FileNotFoundError as e:
print(f"[ERROR] {e}", file=sys.stderr)
sys.exit(1)
except ValueError as e:
print(f"[ERROR] Invalid input: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"[ERROR] Unexpected error: {e}", file=sys.stderr)
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
|