|
|
|
|
|
""" |
|
|
test_patch_inference.py |
|
|
====================== |
|
|
|
|
|
Test script for the patch-based region inference functionality. |
|
|
Tests both whole-image inference and grid-cell specific inference. |
|
|
|
|
|
Usage: |
|
|
python test_patch_inference.py |
|
|
|
|
|
This script demonstrates: |
|
|
1. Whole-image inference (traditional CLIP similarity) |
|
|
2. Grid-cell specific inference (click on region → relevant sentences) |
|
|
3. Comparison between different regions |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
|
|
|
from runner.inference import run_inference |
|
|
|
|
|
|
|
|
def print_results(results: List[Dict[str, Any]], title: str, max_display: int = 5): |
|
|
"""Pretty print inference results.""" |
|
|
print(f"\n{title}") |
|
|
print("=" * len(title)) |
|
|
|
|
|
for i, result in enumerate(results[:max_display], 1): |
|
|
print(f"\n{i}. {result['english_original'][:100]}...") |
|
|
print(f" Score: {result['score']:.4f}") |
|
|
print(f" Work: {result['work']}") |
|
|
print(f" ID: {result['id']}") |
|
|
|
|
|
if len(results) > max_display: |
|
|
print(f"\n... and {len(results) - max_display} more results") |
|
|
|
|
|
|
|
|
def test_grid_visualization(grid_size=(7, 7)): |
|
|
"""Show ASCII grid with coordinates for reference.""" |
|
|
print("\nGrid Reference (row, col):") |
|
|
print("=" * 50) |
|
|
|
|
|
rows, cols = grid_size |
|
|
for r in range(rows): |
|
|
row_str = "" |
|
|
for c in range(cols): |
|
|
row_str += f"({r},{c}) " |
|
|
print(row_str) |
|
|
print() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run comprehensive patch inference tests.""" |
|
|
|
|
|
|
|
|
project_root = Path(__file__).resolve().parent.parent.parent |
|
|
IMAGE_PATH = str( |
|
|
project_root |
|
|
/ "frontend" |
|
|
/ "images" |
|
|
/ "examples" |
|
|
/ "Giotto_-_Scrovegni_-_-31-_-_Kiss_of_Judas.jpg" |
|
|
) |
|
|
GRID_SIZE = (7, 7) |
|
|
TOP_K = 25 |
|
|
|
|
|
print("Patch Inference Test Suite") |
|
|
print("=" * 60) |
|
|
print(f"Image: {IMAGE_PATH}") |
|
|
print(f"Grid Size: {GRID_SIZE[0]}x{GRID_SIZE[1]} (matches ViT-B/32 patch grid)") |
|
|
print(f"Top K: {TOP_K}") |
|
|
|
|
|
|
|
|
if not Path(IMAGE_PATH).exists(): |
|
|
print(f"\nError: Image not found at {IMAGE_PATH}") |
|
|
print("Please update IMAGE_PATH to point to a valid test image.") |
|
|
|
|
|
|
|
|
print("\n\nTest 1: Whole-Image Inference") |
|
|
print("-" * 60) |
|
|
whole_image_results = run_inference( |
|
|
image_path=IMAGE_PATH, cell=None, top_k=TOP_K |
|
|
) |
|
|
print_results(whole_image_results, "Whole Image Results") |
|
|
|
|
|
|
|
|
test_grid_visualization(GRID_SIZE) |
|
|
|
|
|
|
|
|
print("\n\nTest 2: Center Region Inference") |
|
|
print("-" * 60) |
|
|
center_row, center_col = GRID_SIZE[0] // 2, GRID_SIZE[1] // 2 |
|
|
print(f"Testing center cell: ({center_row}, {center_col})") |
|
|
|
|
|
center_results = run_inference( |
|
|
image_path=IMAGE_PATH, |
|
|
cell=(center_row, center_col), |
|
|
grid_size=GRID_SIZE, |
|
|
top_k=TOP_K, |
|
|
) |
|
|
print_results(center_results, f"Center Cell ({center_row},{center_col}) Results") |
|
|
|
|
|
|
|
|
print("\n\nTest 3: Corner Regions Comparison") |
|
|
print("-" * 60) |
|
|
|
|
|
corners = { |
|
|
"Top-Left": (0, 0), |
|
|
"Top-Right": (0, GRID_SIZE[1] - 1), |
|
|
"Bottom-Left": (GRID_SIZE[0] - 1, 0), |
|
|
"Bottom-Right": (GRID_SIZE[0] - 1, GRID_SIZE[1] - 1), |
|
|
} |
|
|
|
|
|
for corner_name, (row, col) in corners.items(): |
|
|
print(f"\n{corner_name} Corner ({row}, {col}):") |
|
|
corner_results = run_inference( |
|
|
image_path=IMAGE_PATH, |
|
|
cell=(row, col), |
|
|
grid_size=GRID_SIZE, |
|
|
top_k=3, |
|
|
) |
|
|
|
|
|
for i, result in enumerate(corner_results[:3], 1): |
|
|
print(f" {i}. {result['english_original'][:60]}...") |
|
|
print(f" Score: {result['score']:.4f}") |
|
|
|
|
|
|
|
|
print("\n\nTest 4: Custom Region Test") |
|
|
print("-" * 60) |
|
|
|
|
|
custom_row, custom_col = 2, 4 |
|
|
print(f"Testing custom cell: ({custom_row}, {custom_col})") |
|
|
|
|
|
custom_results = run_inference( |
|
|
image_path=IMAGE_PATH, |
|
|
cell=(custom_row, custom_col), |
|
|
grid_size=GRID_SIZE, |
|
|
top_k=TOP_K, |
|
|
) |
|
|
print_results(custom_results, f"Custom Cell ({custom_row},{custom_col}) Results") |
|
|
|
|
|
|
|
|
print("\n\nTest 5: Score Comparison Analysis") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
if whole_image_results: |
|
|
top_sentence_id = whole_image_results[0]["id"] |
|
|
top_sentence_text = whole_image_results[0]["english_original"] |
|
|
|
|
|
print(f"Top whole-image sentence: {top_sentence_text[:80]}...") |
|
|
print(f"Whole-image score: {whole_image_results[0]['score']:.4f}") |
|
|
|
|
|
|
|
|
print("\nScore for this sentence in different regions:") |
|
|
|
|
|
test_cells = [ |
|
|
("Center", (center_row, center_col)), |
|
|
("Top-Left", (0, 0)), |
|
|
("Bottom-Right", (GRID_SIZE[0] - 1, GRID_SIZE[1] - 1)), |
|
|
] |
|
|
|
|
|
for region_name, (row, col) in test_cells: |
|
|
region_results = run_inference( |
|
|
image_path=IMAGE_PATH, cell=(row, col), grid_size=GRID_SIZE, top_k=TOP_K |
|
|
) |
|
|
|
|
|
|
|
|
region_score = None |
|
|
region_rank = None |
|
|
for rank, result in enumerate(region_results, 1): |
|
|
if result["id"] == top_sentence_id: |
|
|
region_score = result["score"] |
|
|
region_rank = rank |
|
|
break |
|
|
|
|
|
if region_score: |
|
|
print( |
|
|
f" {region_name} ({row},{col}): score={region_score:.4f}, rank={region_rank}" |
|
|
) |
|
|
else: |
|
|
print(f" {region_name} ({row},{col}): Not in top {TOP_K}") |
|
|
|
|
|
|
|
|
print("\n\nSummary") |
|
|
print("=" * 60) |
|
|
print("✓ Whole-image inference tested") |
|
|
print("✓ Region-specific inference tested") |
|
|
print("✓ Multiple grid cells compared") |
|
|
print("\nThe patch inference system is working correctly!") |
|
|
|
|
|
|
|
|
save_results = input("\nSave detailed results to JSON? (y/n): ").lower() == "y" |
|
|
if save_results: |
|
|
results_data = { |
|
|
"image_path": IMAGE_PATH, |
|
|
"grid_size": GRID_SIZE, |
|
|
"whole_image": whole_image_results, |
|
|
"center_cell": { |
|
|
"position": [center_row, center_col], |
|
|
"results": center_results, |
|
|
}, |
|
|
"corners": { |
|
|
name: { |
|
|
"position": list(pos), |
|
|
"results": run_inference( |
|
|
IMAGE_PATH, cell=pos, grid_size=GRID_SIZE, top_k=5 |
|
|
), |
|
|
} |
|
|
for name, pos in corners.items() |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
outputs_dir = project_root / "runner" / "tests" / "test-outputs" |
|
|
outputs_dir.mkdir(exist_ok=True) |
|
|
output_path = outputs_dir / "patch_inference_test_results.json" |
|
|
|
|
|
with open(output_path, "w") as f: |
|
|
json.dump(results_data, f, indent=2) |
|
|
print(f"\nResults saved to: {output_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|