snipshot-backend / test.py
Famanias
Deploy to Hugging Face
0f6f6c1
"""
Test script for snipshot_engine — the new lean translation pipeline.
Tests:
1. Import check — verify all modules load
2. Config construction — verify pydantic config works
3. In-process translate — run SnipshotTranslator directly (no server)
4. Server /health — hit the FastAPI health endpoint
5. Server /translate/raw — send an image and get translated PNG back
6. Server /translate — send an image and get Supabase URL back
Usage:
# Test 1-3 only (no server needed):
python test.py --local
# Test 1-6 (start the server first):
python main.py
python test.py
# Specify a different image or server URL:
python test.py --image 155.jpg --url http://localhost:9000
"""
import argparse
import asyncio
import io
import json
import sys
import time
from pathlib import Path
async def _run_inpaint_preview(translator, pil_image):
"""Run pipeline up to inpainting and return PIL image + region count."""
import numpy as np
from snipshot_engine.config import Inpainter
from snipshot_engine.utils import (
load_image,
dump_image,
is_valuable_text,
sort_regions,
LANGUAGE_ORIENTATION_PRESETS,
)
from snipshot_engine.ocr import dispatch as dispatch_ocr
from snipshot_engine.textline_merge import dispatch as dispatch_textline_merge
from snipshot_engine.mask_refinement import dispatch as dispatch_mask_refinement
from snipshot_engine.inpainting import dispatch as dispatch_inpainting
cfg = translator.config
img_rgb, img_alpha = load_image(pil_image)
textlines, mask_raw, _ = await translator._detector.infer(
img_rgb,
cfg.detector.detection_size,
cfg.detector.text_threshold,
cfg.detector.box_threshold,
cfg.detector.unclip_ratio,
verbose=False,
)
if not textlines:
return dump_image(pil_image, img_rgb, img_alpha), 0
textlines = await dispatch_ocr(cfg.ocr.ocr, img_rgb, textlines, cfg.ocr, translator.device, verbose=False)
textlines = [t for t in textlines if t.text.strip()]
if not textlines:
return dump_image(pil_image, img_rgb, img_alpha), 0
text_regions = await dispatch_textline_merge(textlines, img_rgb.shape[1], img_rgb.shape[0], verbose=False)
text_regions = [
r for r in text_regions
if len(r.text) >= cfg.ocr.min_text_length and is_valuable_text(r.text)
]
if not text_regions:
return dump_image(pil_image, img_rgb, img_alpha), 0
text_regions = sort_regions(text_regions, img_rgb.shape[1], img_rgb.shape[0])
target_lang = cfg.translator.target_lang
for region in text_regions:
region.target_lang = target_lang
preset = LANGUAGE_ORIENTATION_PRESETS.get(target_lang)
if preset:
region._direction = preset
mask = await dispatch_mask_refinement(
text_regions,
img_rgb,
mask_raw if mask_raw is not None else np.zeros(img_rgb.shape[:2], dtype=np.uint8),
method="fit_text",
dilation_offset=cfg.mask_dilation_offset,
ignore_bubble=cfg.ocr.ignore_bubble,
kernel_size=cfg.kernel_size,
)
if cfg.inpainter.inpainter != Inpainter.none:
img_inpainted = await dispatch_inpainting(
cfg.inpainter.inpainter,
img_rgb,
mask,
cfg.inpainter,
cfg.inpainter.inpainting_size,
translator.device,
verbose=False,
)
else:
img_inpainted = img_rgb.copy()
return dump_image(pil_image, img_inpainted, img_alpha), len(text_regions)
def _infer_dynamic_detector_params(width: int, height: int) -> dict:
"""Infer detector settings from image dimensions for local test sweeps."""
long_side = max(width, height)
mp = (width * height) / 1_000_000.0
# Round to nearest 64 and clamp to practical range.
detection_size = int(round(long_side / 64.0) * 64)
detection_size = max(1024, min(3072, detection_size))
if mp < 1.0:
box_threshold = 0.55
elif mp < 2.5:
box_threshold = 0.60
elif mp < 4.0:
box_threshold = 0.65
else:
box_threshold = 0.70
return {
"detection_size": detection_size,
"box_threshold": box_threshold,
}
def _collect_image_paths(image_arg: str):
"""Collect input images from a file path or directory."""
p = Path(image_arg)
exts = {".png", ".jpg", ".jpeg", ".webp"}
if p.is_file():
return [str(p)]
if p.is_dir():
out = []
for fp in sorted(p.iterdir()):
if fp.suffix.lower() not in exts:
continue
stem = fp.stem.lower()
if stem.endswith("_translated") or stem.endswith("_inpainted"):
continue
out.append(str(fp))
return out
return [image_arg]
# ---------------------------------------------------------------------------
# 1. Import check
# ---------------------------------------------------------------------------
def test_imports():
print("\n[1] Import check...")
failures = []
try:
from snipshot_engine import Config, SnipshotTranslator
print(" OK snipshot_engine (Config, SnipshotTranslator)")
except Exception as e:
failures.append(f"snipshot_engine: {e}")
modules = [
("snipshot_engine.config", None),
("snipshot_engine.utils", ["TextBlock", "Quadrilateral", "ModelWrapper", "load_image", "dump_image"]),
("snipshot_engine.detection", ["DefaultDetector"]),
("snipshot_engine.ocr", ["prepare", "dispatch", "unload"]),
("snipshot_engine.textline_merge", ["dispatch"]),
("snipshot_engine.translation", ["prepare", "dispatch"]),
("snipshot_engine.mask_refinement", ["dispatch"]),
("snipshot_engine.inpainting", ["prepare", "dispatch", "unload"]),
("snipshot_engine.rendering", ["dispatch"]),
]
for mod_name, symbols in modules:
try:
mod = __import__(mod_name, fromlist=symbols or ["__name__"])
if symbols:
missing = [s for s in symbols if not hasattr(mod, s)]
if missing:
failures.append(f"{mod_name}: missing {missing}")
else:
print(f" OK {mod_name} ({', '.join(symbols)})")
else:
print(f" OK {mod_name}")
except Exception as e:
failures.append(f"{mod_name}: {e}")
if failures:
for f in failures:
print(f" FAIL {f}")
return False
print(" All imports passed.")
return True
# ---------------------------------------------------------------------------
# 2. Config construction
# ---------------------------------------------------------------------------
def test_config():
print("\n[2] Config construction...")
from snipshot_engine import Config
# Default config
cfg = Config()
print(f" Default detector: {cfg.detector.detector.value}")
print(f" Default OCR: {cfg.ocr.ocr.value}")
print(f" Default translator: {cfg.translator.translator.value}")
print(f" Default inpainter: {cfg.inpainter.inpainter.value}")
print(f" Default renderer: {cfg.render.renderer.value}")
# Custom config (like the frontend would send)
custom = Config(**{
"detector": {"detection_size": 1024, "box_threshold": 0.5},
"translator": {"target_lang": "CHS"},
"inpainter": {"inpainter": "none"},
"render": {"direction": "horizontal"},
})
assert custom.detector.detection_size == 1024
assert custom.translator.target_lang == "CHS"
assert custom.inpainter.inpainter.value == "none"
assert custom.render.direction.value == "horizontal"
print(" Custom config overrides work.")
print(" Config test passed.")
return True
# ---------------------------------------------------------------------------
# 3. In-process translate (local, no server)
# ---------------------------------------------------------------------------
async def _process_one_local_image(
translator,
image_path: str,
save_inpainted: bool,
inpaint_only: bool,
auto_detector: bool,
):
print(f"\n[3] In-process translate ({image_path})...")
from PIL import Image
try:
img = Image.open(image_path).convert("RGB")
print(f" Loaded {image_path} ({img.size[0]}x{img.size[1]})")
except FileNotFoundError:
print(f" SKIP {image_path} not found")
return None
print(" Running translation pipeline...")
if auto_detector:
dyn = _infer_dynamic_detector_params(img.size[0], img.size[1])
translator.config.detector.detection_size = dyn["detection_size"]
translator.config.detector.box_threshold = dyn["box_threshold"]
print(
f" Dynamic detector: size={dyn['detection_size']} "
f"box_threshold={dyn['box_threshold']:.2f}"
)
if save_inpainted or inpaint_only:
t_preview = time.time()
inpainted_img, region_count = await _run_inpaint_preview(translator, img)
inpaint_out = image_path.rsplit(".", 1)[0] + "_inpainted.png"
inpainted_img.save(inpaint_out)
print(
f" Saved {inpaint_out} ({inpainted_img.size[0]}x{inpainted_img.size[1]}) "
f"from {region_count} regions in {time.time() - t_preview:.1f}s"
)
if inpaint_only:
print(" Inpaint-only mode complete (translation/render skipped).")
return True
t0 = time.time()
try:
result = await translator.translate(img)
except Exception as exc:
exc_str = str(exc)
if "model" in exc_str.lower() and ("not found" in exc_str.lower() or "404" in exc_str.lower()):
print(f" WARN Groq model not found — update GROQ_MODEL in .env")
print(f" {exc_str[:200]}")
print(" (Detection + OCR + merge stages all passed before this point)")
return None
raise
elapsed = time.time() - t0
out_path = image_path.rsplit(".", 1)[0] + "_translated.png"
result.save(out_path)
print(f" Saved {out_path} ({result.size[0]}x{result.size[1]}) in {elapsed:.1f}s")
print(" Local translate passed.")
return True
async def test_local_translate(
image_paths,
save_inpainted: bool = False,
inpaint_only: bool = False,
target_lang: str = "ENG",
detection_size: int = 1536,
box_threshold: float = 0.7,
inpainting_size: int = 2048,
auto_detector: bool = False,
):
from snipshot_engine import SnipshotTranslator, Config
config = Config(**{
"translator": {"target_lang": target_lang},
"inpainter": {"inpainter": "lama_large", "inpainting_size": inpainting_size},
"detector": {"detection_size": detection_size, "box_threshold": box_threshold},
})
translator = SnipshotTranslator(config, device="cpu")
print("\n Loading models once for all local image tests...")
t0 = time.time()
await translator.load_models()
print(f" Models loaded in {time.time() - t0:.1f}s")
successes = 0
total = len(image_paths)
t_all = time.time()
for image_path in image_paths:
ok = await _process_one_local_image(
translator,
image_path,
save_inpainted,
inpaint_only,
auto_detector,
)
if ok:
successes += 1
print(f"\n Local summary: {successes}/{total} succeeded in {time.time() - t_all:.1f}s")
return successes == total
# ---------------------------------------------------------------------------
# 4-6. Server endpoint tests
# ---------------------------------------------------------------------------
async def test_server(api_url: str, image_path: str):
import httpx
print(f"\n Server: {api_url}")
async with httpx.AsyncClient(timeout=180.0) as client:
# 4. Health
print("\n[4] GET /health...")
try:
resp = await client.get(f"{api_url}/health")
if resp.status_code == 200:
print(f" OK {resp.json()}")
else:
print(f" FAIL status {resp.status_code}: {resp.text}")
return False
except httpx.ConnectError:
print(f" FAIL Cannot connect to {api_url}")
print(" Is the server running? Start it with:")
print(" python main.py")
return False
# Load image bytes
try:
with open(image_path, "rb") as f:
image_bytes = f.read()
print(f" Using {image_path} ({len(image_bytes)} bytes)")
except FileNotFoundError:
from PIL import Image as PILImage
print(f" {image_path} not found — using blank 200x200 test image")
buf = io.BytesIO()
PILImage.new("RGB", (200, 200), "white").save(buf, format="PNG")
image_bytes = buf.getvalue()
config_json = json.dumps({
"translator": {"target_lang": "ENG"},
"detector": {"detection_size": 1536, "box_threshold": 0.7},
})
# 5. /translate/raw
# print("\n[5] POST /translate/raw...")
# resp = await client.post(
# f"{api_url}/translate/raw",
# files={"image": ("test.jpg", image_bytes, "image/jpeg")},
# data={"config": config_json},
# )
# if resp.status_code == 200 and "image/png" in resp.headers.get("content-type", ""):
# out_path = "test_snipshot_raw.png"
# with open(out_path, "wb") as f:
# f.write(resp.content)
# print(f" OK Got PNG ({len(resp.content)} bytes) → {out_path}")
# else:
# print(f" FAIL status {resp.status_code}: {resp.text[:200]}")
# 6. /translate (Supabase upload)
# print("\n[6] POST /translate (Supabase upload)...")
# resp = await client.post(
# f"{api_url}/translate",
# files={"image": ("test.jpg", image_bytes, "image/jpeg")},
# data={"config": config_json},
# )
# if resp.status_code == 200:
# result = resp.json()
# print(f" OK success={result.get('success')}")
# print(f" URL: {result.get('image_url', 'N/A')}")
# elif resp.status_code == 502:
# print(f" WARN Supabase not configured (expected if no env vars)")
# print(f" {resp.json().get('detail', resp.text[:200])}")
# else:
# print(f" FAIL status {resp.status_code}: {resp.text[:200]}")
print("\n Server tests done.")
return True
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Test snipshot_engine")
parser.add_argument("--local", action="store_true", help="Run only local tests (1-3), skip server tests")
parser.add_argument("--save-inpainted", action="store_true", help="Also save an inpaint-only preview image (*_inpainted.png)")
parser.add_argument("--inpaint-only", action="store_true", help="Run only up to inpainting and save *_inpainted.png (skip translation/render)")
parser.add_argument("--image", default="test-image-easy1.jpg", help="Path to a test image or directory of images")
parser.add_argument("--url", default="http://localhost:8001", help="Server URL for endpoint tests (default: http://localhost:8001)")
parser.add_argument("--target-lang", default="ENG", help="Target language for local translate (default: ENG)")
parser.add_argument("--detection-size", type=int, default=1536, help="Detector size for local translate (default: 1536)")
parser.add_argument("--box-threshold", type=float, default=0.7, help="Detector box threshold for local translate (default: 0.7)")
parser.add_argument("--inpainting-size", type=int, default=2048, help="Inpainting size for local translate (default: 2048)")
parser.add_argument("--auto-detector", action="store_true", help="Auto-tune detector size/threshold per image dimensions")
args = parser.parse_args()
if args.inpaint_only:
args.local = True
print("=" * 60)
print(" SnipShot Engine Test Suite")
print("=" * 60)
# Always run import + config tests
ok = test_imports()
if not ok:
print("\nImport check failed — aborting.")
sys.exit(1)
test_config()
image_paths = _collect_image_paths(args.image)
if not image_paths:
print(f"\nNo input images found for: {args.image}")
sys.exit(1)
if len(image_paths) > 1:
print(f"\n Collected {len(image_paths)} images from {args.image}")
# Local in-process translate
asyncio.run(
test_local_translate(
image_paths,
save_inpainted=args.save_inpainted or args.inpaint_only,
inpaint_only=args.inpaint_only,
target_lang=args.target_lang,
detection_size=args.detection_size,
box_threshold=args.box_threshold,
inpainting_size=args.inpainting_size,
auto_detector=args.auto_detector,
)
)
# Server endpoint tests (unless --local)
if not args.local:
asyncio.run(test_server(args.url, image_paths[0]))
else:
print("\n --local flag set, skipping server tests (4-6).")
print("\n" + "=" * 60)
print(" Done.")
print("=" * 60)
if __name__ == "__main__":
main()