File size: 6,954 Bytes
0e61117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env python3
"""
test_heatmap.py
===============

Test script for the Grad-ECLIP heatmap generation pipeline.

This script allows you to test the heatmap generation functionality either
with images from the artifacts directory (using run IDs) or with any local image.

Usage Examples
--------------
# Using a run ID from artifacts directory:
python test_heatmap.py --run-id d72545f1-xxxx-xxxx-xxxx \
                       --sentence "The men held torches with large flames above their heads" \
                       --out torch_heatmap.png

# Using a direct image path:
python test_heatmap.py --image-path ~/Pictures/painting.jpg \
                       --sentence "A beautiful sunset over the mountains" \
                       --layer-idx -2 \
                       --out sunset_heatmap.png

# Minimal usage with defaults:
python test_heatmap.py --image-path test.jpg --sentence "A cat sitting on a mat"

Arguments
---------
--run-id : str
    Run ID from a previous /presign call (maps to artifacts/<id>.jpg)
    
--image-path : str
    Direct path to an RGB image file
    
--sentence : str
    Text description to generate explanation for (required)
    
--layer-idx : int
    Which vision transformer layer to analyze (default: -1 for last layer)
    
--out : str
    Output PNG filename (default: overlay.png)
    
--alpha : float
    Heatmap overlay opacity, between 0 and 1 (default: 0.45)
    
--colormap : str
    Color scheme for heatmap: 'jet', 'hot', 'viridis', 'plasma' (default: 'jet')
"""

import argparse
import base64
import sys
from pathlib import Path
from typing import Optional

import cv2

# Add parent directory to path for imports (backend/)
sys.path.append(str(Path(__file__).resolve().parent.parent))

from runner.inference import compute_heatmap

# Colormap name to OpenCV constant mapping
COLORMAP_OPTIONS = {
    "jet": cv2.COLORMAP_JET,
    "hot": cv2.COLORMAP_HOT,
    "viridis": cv2.COLORMAP_VIRIDIS,
    "plasma": cv2.COLORMAP_PLASMA,
    "cool": cv2.COLORMAP_COOL,
    "rainbow": cv2.COLORMAP_RAINBOW,
}


def resolve_image_path(run_id: Optional[str], image_path: Optional[str]) -> Path:
    """
    Resolve the image path from either a run ID or direct path.

    Args:
        run_id: Optional run ID that maps to data/artifacts/<id>.jpg
        image_path: Optional direct path to image

    Returns:
        Resolved Path object

    Raises:
        FileNotFoundError: If the resolved path doesn't exist
    """
    if run_id:
        # Get artifacts directory - go up to project root then into data/artifacts
        artifacts_dir = (
            Path(__file__).resolve().parent.parent.parent / "data" / "artifacts"
        )
        resolved_path = artifacts_dir / f"{run_id}.jpg"
    else:
        resolved_path = Path(image_path).expanduser().resolve()

    if not resolved_path.exists():
        raise FileNotFoundError(f"Image not found: {resolved_path}")

    if not resolved_path.is_file():
        raise ValueError(f"Path is not a file: {resolved_path}")

    return resolved_path


def save_heatmap_from_data_url(data_url: str, output_path: Path) -> None:
    """
    Extract and save PNG image from base64 data URL.

    Args:
        data_url: Data URL containing base64-encoded PNG
        output_path: Where to save the decoded image
    """
    if not data_url.startswith("data:image/png;base64,"):
        raise ValueError("Expected PNG data URL")

    # Extract base64 portion
    _, base64_data = data_url.split(",", 1)

    # Decode and save
    image_bytes = base64.b64decode(base64_data)
    output_path.write_bytes(image_bytes)


def main() -> None:
    """Main entry point for the test script."""
    parser = argparse.ArgumentParser(
        description="Test Grad-ECLIP heatmap generation",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )

    # Image source (mutually exclusive)
    image_group = parser.add_mutually_exclusive_group(required=True)
    image_group.add_argument(
        "--run-id",
        help="Run ID from previous /presign call (looks in data/artifacts/<id>.jpg)",
    )
    image_group.add_argument("--image-path", help="Direct path to an RGB image file")

    # Required arguments
    parser.add_argument("--sentence", required=True, help="Text description to explain")

    # Optional arguments
    parser.add_argument(
        "--layer-idx",
        type=int,
        default=-1,
        help="Vision transformer layer index (default: -1 for last layer)",
    )
    parser.add_argument(
        "--out",
        default="overlay.png",
        help="Output PNG filename (default: overlay.png)",
    )
    parser.add_argument(
        "--alpha",
        type=float,
        default=0.45,
        help="Heatmap overlay opacity 0-1 (default: 0.45)",
    )
    parser.add_argument(
        "--colormap",
        choices=list(COLORMAP_OPTIONS.keys()),
        default="jet",
        help="Color scheme for heatmap (default: jet)",
    )
    parser.add_argument(
        "--verbose", action="store_true", help="Print additional debug information"
    )

    args = parser.parse_args()

    # Validate arguments
    if not 0 <= args.alpha <= 1:
        parser.error("--alpha must be between 0 and 1")

    try:
        # Resolve image path
        image_path = resolve_image_path(args.run_id, args.image_path)

        if args.verbose:
            print(f"[DEBUG] Resolved image path: {image_path}")
            print(f"[DEBUG] Layer index: {args.layer_idx}")
            print(f"[DEBUG] Alpha: {args.alpha}")
            print(f"[DEBUG] Colormap: {args.colormap}")

        # Print info
        print("[INFO] Computing heatmap for:")
        print(f"       Image:    {image_path}")
        print(f"       Sentence: '{args.sentence}'")
        print(f"       Output:   {args.out}")

        # Generate heatmap
        print("[INFO] Generating heatmap...")
        data_url = compute_heatmap(
            str(image_path),
            args.sentence,
            layer_idx=args.layer_idx,
            alpha=args.alpha,
            colormap=COLORMAP_OPTIONS[args.colormap],
        )

        # Save output
        output_path = Path(args.out)
        save_heatmap_from_data_url(data_url, output_path)

        print(f"[SUCCESS] Saved heatmap overlay to: {output_path.absolute()}")

        # Print file info if verbose
        if args.verbose:
            file_size = output_path.stat().st_size
            print(f"[DEBUG] Output file size: {file_size:,} bytes")

    except FileNotFoundError as e:
        print(f"[ERROR] {e}", file=sys.stderr)
        sys.exit(1)
    except ValueError as e:
        print(f"[ERROR] Invalid input: {e}", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"[ERROR] Unexpected error: {e}", file=sys.stderr)
        if args.verbose:
            import traceback

            traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()