File size: 7,735 Bytes
0e61117 4ac1f80 0e61117 4ac1f80 0e61117 4ac1f80 0e61117 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
#!/usr/bin/env python3
"""
test_patch_inference.py
======================
Test script for the patch-based region inference functionality.
Tests both whole-image inference and grid-cell specific inference.
Usage:
python test_patch_inference.py
This script demonstrates:
1. Whole-image inference (traditional CLIP similarity)
2. Grid-cell specific inference (click on region → relevant sentences)
3. Comparison between different regions
"""
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
# Add parent directory to path for imports (backend/)
sys.path.append(str(Path(__file__).resolve().parent.parent))
from runner.inference import run_inference
def print_results(results: List[Dict[str, Any]], title: str, max_display: int = 5):
"""Pretty print inference results."""
print(f"\n{title}")
print("=" * len(title))
for i, result in enumerate(results[:max_display], 1):
print(f"\n{i}. {result['english_original'][:100]}...")
print(f" Score: {result['score']:.4f}")
print(f" Work: {result['work']}")
print(f" ID: {result['id']}")
if len(results) > max_display:
print(f"\n... and {len(results) - max_display} more results")
def test_grid_visualization(grid_size=(7, 7)):
"""Show ASCII grid with coordinates for reference."""
print("\nGrid Reference (row, col):")
print("=" * 50)
rows, cols = grid_size
for r in range(rows):
row_str = ""
for c in range(cols):
row_str += f"({r},{c}) "
print(row_str)
print()
def main():
"""Run comprehensive patch inference tests."""
# Configuration - Updated to use new directory structure
project_root = Path(__file__).resolve().parent.parent.parent
IMAGE_PATH = str(
project_root
/ "frontend"
/ "images"
/ "examples"
/ "Giotto_-_Scrovegni_-_-31-_-_Kiss_of_Judas.jpg"
)
GRID_SIZE = (7, 7) # 7x7 grid matches ViT-B/32 patch grid
TOP_K = 25 # Number of results to return
print("Patch Inference Test Suite")
print("=" * 60)
print(f"Image: {IMAGE_PATH}")
print(f"Grid Size: {GRID_SIZE[0]}x{GRID_SIZE[1]} (matches ViT-B/32 patch grid)")
print(f"Top K: {TOP_K}")
# Check if image exists
if not Path(IMAGE_PATH).exists():
print(f"\nError: Image not found at {IMAGE_PATH}")
print("Please update IMAGE_PATH to point to a valid test image.")
# Test 1: Whole-image inference (baseline)
print("\n\nTest 1: Whole-Image Inference")
print("-" * 60)
whole_image_results = run_inference(
image_path=IMAGE_PATH, cell=None, top_k=TOP_K # None means whole image
)
print_results(whole_image_results, "Whole Image Results")
# Show grid reference
test_grid_visualization(GRID_SIZE)
# Test 2: Center region inference
print("\n\nTest 2: Center Region Inference")
print("-" * 60)
center_row, center_col = GRID_SIZE[0] // 2, GRID_SIZE[1] // 2
print(f"Testing center cell: ({center_row}, {center_col})")
center_results = run_inference(
image_path=IMAGE_PATH,
cell=(center_row, center_col),
grid_size=GRID_SIZE,
top_k=TOP_K,
)
print_results(center_results, f"Center Cell ({center_row},{center_col}) Results")
# Test 3: Corner regions comparison
print("\n\nTest 3: Corner Regions Comparison")
print("-" * 60)
corners = {
"Top-Left": (0, 0),
"Top-Right": (0, GRID_SIZE[1] - 1),
"Bottom-Left": (GRID_SIZE[0] - 1, 0),
"Bottom-Right": (GRID_SIZE[0] - 1, GRID_SIZE[1] - 1),
}
for corner_name, (row, col) in corners.items():
print(f"\n{corner_name} Corner ({row}, {col}):")
corner_results = run_inference(
image_path=IMAGE_PATH,
cell=(row, col),
grid_size=GRID_SIZE,
top_k=3, # Show fewer results for comparison
)
for i, result in enumerate(corner_results[:3], 1):
print(f" {i}. {result['english_original'][:60]}...")
print(f" Score: {result['score']:.4f}")
# Test 4: Specific region of interest
print("\n\nTest 4: Custom Region Test")
print("-" * 60)
# You can modify these coordinates to test specific regions
custom_row, custom_col = 2, 4 # Example: row 2, column 4
print(f"Testing custom cell: ({custom_row}, {custom_col})")
custom_results = run_inference(
image_path=IMAGE_PATH,
cell=(custom_row, custom_col),
grid_size=GRID_SIZE,
top_k=TOP_K,
)
print_results(custom_results, f"Custom Cell ({custom_row},{custom_col}) Results")
# Test 5: Compare scores between whole image and regions
print("\n\nTest 5: Score Comparison Analysis")
print("-" * 60)
# Get the top sentence from whole image
if whole_image_results:
top_sentence_id = whole_image_results[0]["id"]
top_sentence_text = whole_image_results[0]["english_original"]
print(f"Top whole-image sentence: {top_sentence_text[:80]}...")
print(f"Whole-image score: {whole_image_results[0]['score']:.4f}")
# Check this sentence's score in different regions
print("\nScore for this sentence in different regions:")
test_cells = [
("Center", (center_row, center_col)),
("Top-Left", (0, 0)),
("Bottom-Right", (GRID_SIZE[0] - 1, GRID_SIZE[1] - 1)),
]
for region_name, (row, col) in test_cells:
region_results = run_inference(
image_path=IMAGE_PATH, cell=(row, col), grid_size=GRID_SIZE, top_k=TOP_K
)
# Find the score for our target sentence
region_score = None
region_rank = None
for rank, result in enumerate(region_results, 1):
if result["id"] == top_sentence_id:
region_score = result["score"]
region_rank = rank
break
if region_score:
print(
f" {region_name} ({row},{col}): score={region_score:.4f}, rank={region_rank}"
)
else:
print(f" {region_name} ({row},{col}): Not in top {TOP_K}")
# Summary statistics
print("\n\nSummary")
print("=" * 60)
print("✓ Whole-image inference tested")
print("✓ Region-specific inference tested")
print("✓ Multiple grid cells compared")
print("\nThe patch inference system is working correctly!")
# Optional: Save results for further analysis
save_results = input("\nSave detailed results to JSON? (y/n): ").lower() == "y"
if save_results:
results_data = {
"image_path": IMAGE_PATH,
"grid_size": GRID_SIZE,
"whole_image": whole_image_results,
"center_cell": {
"position": [center_row, center_col],
"results": center_results,
},
"corners": {
name: {
"position": list(pos),
"results": run_inference(
IMAGE_PATH, cell=pos, grid_size=GRID_SIZE, top_k=5
),
}
for name, pos in corners.items()
},
}
# Save to data/outputs directory
outputs_dir = project_root / "runner" / "tests" / "test-outputs"
outputs_dir.mkdir(exist_ok=True)
output_path = outputs_dir / "patch_inference_test_results.json"
with open(output_path, "w") as f:
json.dump(results_data, f, indent=2)
print(f"\nResults saved to: {output_path}")
if __name__ == "__main__":
main()
|