Spaces:
Running
Running
| """ | |
| Test script for snipshot_engine — the new lean translation pipeline. | |
| Tests: | |
| 1. Import check — verify all modules load | |
| 2. Config construction — verify pydantic config works | |
| 3. In-process translate — run SnipshotTranslator directly (no server) | |
| 4. Server /health — hit the FastAPI health endpoint | |
| 5. Server /translate/raw — send an image and get translated PNG back | |
| 6. Server /translate — send an image and get Supabase URL back | |
| Usage: | |
| # Test 1-3 only (no server needed): | |
| python test.py --local | |
| # Test 1-6 (start the server first): | |
| python main.py | |
| python test.py | |
| # Specify a different image or server URL: | |
| python test.py --image 155.jpg --url http://localhost:9000 | |
| """ | |
| import argparse | |
| import asyncio | |
| import io | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| async def _run_inpaint_preview(translator, pil_image): | |
| """Run pipeline up to inpainting and return PIL image + region count.""" | |
| import numpy as np | |
| from snipshot_engine.config import Inpainter | |
| from snipshot_engine.utils import ( | |
| load_image, | |
| dump_image, | |
| is_valuable_text, | |
| sort_regions, | |
| LANGUAGE_ORIENTATION_PRESETS, | |
| ) | |
| from snipshot_engine.ocr import dispatch as dispatch_ocr | |
| from snipshot_engine.textline_merge import dispatch as dispatch_textline_merge | |
| from snipshot_engine.mask_refinement import dispatch as dispatch_mask_refinement | |
| from snipshot_engine.inpainting import dispatch as dispatch_inpainting | |
| cfg = translator.config | |
| img_rgb, img_alpha = load_image(pil_image) | |
| textlines, mask_raw, _ = await translator._detector.infer( | |
| img_rgb, | |
| cfg.detector.detection_size, | |
| cfg.detector.text_threshold, | |
| cfg.detector.box_threshold, | |
| cfg.detector.unclip_ratio, | |
| verbose=False, | |
| ) | |
| if not textlines: | |
| return dump_image(pil_image, img_rgb, img_alpha), 0 | |
| textlines = await dispatch_ocr(cfg.ocr.ocr, img_rgb, textlines, cfg.ocr, translator.device, verbose=False) | |
| textlines = [t for t in textlines if t.text.strip()] | |
| if not textlines: | |
| return dump_image(pil_image, img_rgb, img_alpha), 0 | |
| text_regions = await dispatch_textline_merge(textlines, img_rgb.shape[1], img_rgb.shape[0], verbose=False) | |
| text_regions = [ | |
| r for r in text_regions | |
| if len(r.text) >= cfg.ocr.min_text_length and is_valuable_text(r.text) | |
| ] | |
| if not text_regions: | |
| return dump_image(pil_image, img_rgb, img_alpha), 0 | |
| text_regions = sort_regions(text_regions, img_rgb.shape[1], img_rgb.shape[0]) | |
| target_lang = cfg.translator.target_lang | |
| for region in text_regions: | |
| region.target_lang = target_lang | |
| preset = LANGUAGE_ORIENTATION_PRESETS.get(target_lang) | |
| if preset: | |
| region._direction = preset | |
| mask = await dispatch_mask_refinement( | |
| text_regions, | |
| img_rgb, | |
| mask_raw if mask_raw is not None else np.zeros(img_rgb.shape[:2], dtype=np.uint8), | |
| method="fit_text", | |
| dilation_offset=cfg.mask_dilation_offset, | |
| ignore_bubble=cfg.ocr.ignore_bubble, | |
| kernel_size=cfg.kernel_size, | |
| ) | |
| if cfg.inpainter.inpainter != Inpainter.none: | |
| img_inpainted = await dispatch_inpainting( | |
| cfg.inpainter.inpainter, | |
| img_rgb, | |
| mask, | |
| cfg.inpainter, | |
| cfg.inpainter.inpainting_size, | |
| translator.device, | |
| verbose=False, | |
| ) | |
| else: | |
| img_inpainted = img_rgb.copy() | |
| return dump_image(pil_image, img_inpainted, img_alpha), len(text_regions) | |
| def _infer_dynamic_detector_params(width: int, height: int) -> dict: | |
| """Infer detector settings from image dimensions for local test sweeps.""" | |
| long_side = max(width, height) | |
| mp = (width * height) / 1_000_000.0 | |
| # Round to nearest 64 and clamp to practical range. | |
| detection_size = int(round(long_side / 64.0) * 64) | |
| detection_size = max(1024, min(3072, detection_size)) | |
| if mp < 1.0: | |
| box_threshold = 0.55 | |
| elif mp < 2.5: | |
| box_threshold = 0.60 | |
| elif mp < 4.0: | |
| box_threshold = 0.65 | |
| else: | |
| box_threshold = 0.70 | |
| return { | |
| "detection_size": detection_size, | |
| "box_threshold": box_threshold, | |
| } | |
| def _collect_image_paths(image_arg: str): | |
| """Collect input images from a file path or directory.""" | |
| p = Path(image_arg) | |
| exts = {".png", ".jpg", ".jpeg", ".webp"} | |
| if p.is_file(): | |
| return [str(p)] | |
| if p.is_dir(): | |
| out = [] | |
| for fp in sorted(p.iterdir()): | |
| if fp.suffix.lower() not in exts: | |
| continue | |
| stem = fp.stem.lower() | |
| if stem.endswith("_translated") or stem.endswith("_inpainted"): | |
| continue | |
| out.append(str(fp)) | |
| return out | |
| return [image_arg] | |
| # --------------------------------------------------------------------------- | |
| # 1. Import check | |
| # --------------------------------------------------------------------------- | |
| def test_imports(): | |
| print("\n[1] Import check...") | |
| failures = [] | |
| try: | |
| from snipshot_engine import Config, SnipshotTranslator | |
| print(" OK snipshot_engine (Config, SnipshotTranslator)") | |
| except Exception as e: | |
| failures.append(f"snipshot_engine: {e}") | |
| modules = [ | |
| ("snipshot_engine.config", None), | |
| ("snipshot_engine.utils", ["TextBlock", "Quadrilateral", "ModelWrapper", "load_image", "dump_image"]), | |
| ("snipshot_engine.detection", ["DefaultDetector"]), | |
| ("snipshot_engine.ocr", ["prepare", "dispatch", "unload"]), | |
| ("snipshot_engine.textline_merge", ["dispatch"]), | |
| ("snipshot_engine.translation", ["prepare", "dispatch"]), | |
| ("snipshot_engine.mask_refinement", ["dispatch"]), | |
| ("snipshot_engine.inpainting", ["prepare", "dispatch", "unload"]), | |
| ("snipshot_engine.rendering", ["dispatch"]), | |
| ] | |
| for mod_name, symbols in modules: | |
| try: | |
| mod = __import__(mod_name, fromlist=symbols or ["__name__"]) | |
| if symbols: | |
| missing = [s for s in symbols if not hasattr(mod, s)] | |
| if missing: | |
| failures.append(f"{mod_name}: missing {missing}") | |
| else: | |
| print(f" OK {mod_name} ({', '.join(symbols)})") | |
| else: | |
| print(f" OK {mod_name}") | |
| except Exception as e: | |
| failures.append(f"{mod_name}: {e}") | |
| if failures: | |
| for f in failures: | |
| print(f" FAIL {f}") | |
| return False | |
| print(" All imports passed.") | |
| return True | |
| # --------------------------------------------------------------------------- | |
| # 2. Config construction | |
| # --------------------------------------------------------------------------- | |
| def test_config(): | |
| print("\n[2] Config construction...") | |
| from snipshot_engine import Config | |
| # Default config | |
| cfg = Config() | |
| print(f" Default detector: {cfg.detector.detector.value}") | |
| print(f" Default OCR: {cfg.ocr.ocr.value}") | |
| print(f" Default translator: {cfg.translator.translator.value}") | |
| print(f" Default inpainter: {cfg.inpainter.inpainter.value}") | |
| print(f" Default renderer: {cfg.render.renderer.value}") | |
| # Custom config (like the frontend would send) | |
| custom = Config(**{ | |
| "detector": {"detection_size": 1024, "box_threshold": 0.5}, | |
| "translator": {"target_lang": "CHS"}, | |
| "inpainter": {"inpainter": "none"}, | |
| "render": {"direction": "horizontal"}, | |
| }) | |
| assert custom.detector.detection_size == 1024 | |
| assert custom.translator.target_lang == "CHS" | |
| assert custom.inpainter.inpainter.value == "none" | |
| assert custom.render.direction.value == "horizontal" | |
| print(" Custom config overrides work.") | |
| print(" Config test passed.") | |
| return True | |
| # --------------------------------------------------------------------------- | |
| # 3. In-process translate (local, no server) | |
| # --------------------------------------------------------------------------- | |
| async def _process_one_local_image( | |
| translator, | |
| image_path: str, | |
| save_inpainted: bool, | |
| inpaint_only: bool, | |
| auto_detector: bool, | |
| ): | |
| print(f"\n[3] In-process translate ({image_path})...") | |
| from PIL import Image | |
| try: | |
| img = Image.open(image_path).convert("RGB") | |
| print(f" Loaded {image_path} ({img.size[0]}x{img.size[1]})") | |
| except FileNotFoundError: | |
| print(f" SKIP {image_path} not found") | |
| return None | |
| print(" Running translation pipeline...") | |
| if auto_detector: | |
| dyn = _infer_dynamic_detector_params(img.size[0], img.size[1]) | |
| translator.config.detector.detection_size = dyn["detection_size"] | |
| translator.config.detector.box_threshold = dyn["box_threshold"] | |
| print( | |
| f" Dynamic detector: size={dyn['detection_size']} " | |
| f"box_threshold={dyn['box_threshold']:.2f}" | |
| ) | |
| if save_inpainted or inpaint_only: | |
| t_preview = time.time() | |
| inpainted_img, region_count = await _run_inpaint_preview(translator, img) | |
| inpaint_out = image_path.rsplit(".", 1)[0] + "_inpainted.png" | |
| inpainted_img.save(inpaint_out) | |
| print( | |
| f" Saved {inpaint_out} ({inpainted_img.size[0]}x{inpainted_img.size[1]}) " | |
| f"from {region_count} regions in {time.time() - t_preview:.1f}s" | |
| ) | |
| if inpaint_only: | |
| print(" Inpaint-only mode complete (translation/render skipped).") | |
| return True | |
| t0 = time.time() | |
| try: | |
| result = await translator.translate(img) | |
| except Exception as exc: | |
| exc_str = str(exc) | |
| if "model" in exc_str.lower() and ("not found" in exc_str.lower() or "404" in exc_str.lower()): | |
| print(f" WARN Groq model not found — update GROQ_MODEL in .env") | |
| print(f" {exc_str[:200]}") | |
| print(" (Detection + OCR + merge stages all passed before this point)") | |
| return None | |
| raise | |
| elapsed = time.time() - t0 | |
| out_path = image_path.rsplit(".", 1)[0] + "_translated.png" | |
| result.save(out_path) | |
| print(f" Saved {out_path} ({result.size[0]}x{result.size[1]}) in {elapsed:.1f}s") | |
| print(" Local translate passed.") | |
| return True | |
| async def test_local_translate( | |
| image_paths, | |
| save_inpainted: bool = False, | |
| inpaint_only: bool = False, | |
| target_lang: str = "ENG", | |
| detection_size: int = 1536, | |
| box_threshold: float = 0.7, | |
| inpainting_size: int = 2048, | |
| auto_detector: bool = False, | |
| ): | |
| from snipshot_engine import SnipshotTranslator, Config | |
| config = Config(**{ | |
| "translator": {"target_lang": target_lang}, | |
| "inpainter": {"inpainter": "lama_large", "inpainting_size": inpainting_size}, | |
| "detector": {"detection_size": detection_size, "box_threshold": box_threshold}, | |
| }) | |
| translator = SnipshotTranslator(config, device="cpu") | |
| print("\n Loading models once for all local image tests...") | |
| t0 = time.time() | |
| await translator.load_models() | |
| print(f" Models loaded in {time.time() - t0:.1f}s") | |
| successes = 0 | |
| total = len(image_paths) | |
| t_all = time.time() | |
| for image_path in image_paths: | |
| ok = await _process_one_local_image( | |
| translator, | |
| image_path, | |
| save_inpainted, | |
| inpaint_only, | |
| auto_detector, | |
| ) | |
| if ok: | |
| successes += 1 | |
| print(f"\n Local summary: {successes}/{total} succeeded in {time.time() - t_all:.1f}s") | |
| return successes == total | |
| # --------------------------------------------------------------------------- | |
| # 4-6. Server endpoint tests | |
| # --------------------------------------------------------------------------- | |
| async def test_server(api_url: str, image_path: str): | |
| import httpx | |
| print(f"\n Server: {api_url}") | |
| async with httpx.AsyncClient(timeout=180.0) as client: | |
| # 4. Health | |
| print("\n[4] GET /health...") | |
| try: | |
| resp = await client.get(f"{api_url}/health") | |
| if resp.status_code == 200: | |
| print(f" OK {resp.json()}") | |
| else: | |
| print(f" FAIL status {resp.status_code}: {resp.text}") | |
| return False | |
| except httpx.ConnectError: | |
| print(f" FAIL Cannot connect to {api_url}") | |
| print(" Is the server running? Start it with:") | |
| print(" python main.py") | |
| return False | |
| # Load image bytes | |
| try: | |
| with open(image_path, "rb") as f: | |
| image_bytes = f.read() | |
| print(f" Using {image_path} ({len(image_bytes)} bytes)") | |
| except FileNotFoundError: | |
| from PIL import Image as PILImage | |
| print(f" {image_path} not found — using blank 200x200 test image") | |
| buf = io.BytesIO() | |
| PILImage.new("RGB", (200, 200), "white").save(buf, format="PNG") | |
| image_bytes = buf.getvalue() | |
| config_json = json.dumps({ | |
| "translator": {"target_lang": "ENG"}, | |
| "detector": {"detection_size": 1536, "box_threshold": 0.7}, | |
| }) | |
| # 5. /translate/raw | |
| # print("\n[5] POST /translate/raw...") | |
| # resp = await client.post( | |
| # f"{api_url}/translate/raw", | |
| # files={"image": ("test.jpg", image_bytes, "image/jpeg")}, | |
| # data={"config": config_json}, | |
| # ) | |
| # if resp.status_code == 200 and "image/png" in resp.headers.get("content-type", ""): | |
| # out_path = "test_snipshot_raw.png" | |
| # with open(out_path, "wb") as f: | |
| # f.write(resp.content) | |
| # print(f" OK Got PNG ({len(resp.content)} bytes) → {out_path}") | |
| # else: | |
| # print(f" FAIL status {resp.status_code}: {resp.text[:200]}") | |
| # 6. /translate (Supabase upload) | |
| # print("\n[6] POST /translate (Supabase upload)...") | |
| # resp = await client.post( | |
| # f"{api_url}/translate", | |
| # files={"image": ("test.jpg", image_bytes, "image/jpeg")}, | |
| # data={"config": config_json}, | |
| # ) | |
| # if resp.status_code == 200: | |
| # result = resp.json() | |
| # print(f" OK success={result.get('success')}") | |
| # print(f" URL: {result.get('image_url', 'N/A')}") | |
| # elif resp.status_code == 502: | |
| # print(f" WARN Supabase not configured (expected if no env vars)") | |
| # print(f" {resp.json().get('detail', resp.text[:200])}") | |
| # else: | |
| # print(f" FAIL status {resp.status_code}: {resp.text[:200]}") | |
| print("\n Server tests done.") | |
| return True | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Test snipshot_engine") | |
| parser.add_argument("--local", action="store_true", help="Run only local tests (1-3), skip server tests") | |
| parser.add_argument("--save-inpainted", action="store_true", help="Also save an inpaint-only preview image (*_inpainted.png)") | |
| parser.add_argument("--inpaint-only", action="store_true", help="Run only up to inpainting and save *_inpainted.png (skip translation/render)") | |
| parser.add_argument("--image", default="test-image-easy1.jpg", help="Path to a test image or directory of images") | |
| parser.add_argument("--url", default="http://localhost:8001", help="Server URL for endpoint tests (default: http://localhost:8001)") | |
| parser.add_argument("--target-lang", default="ENG", help="Target language for local translate (default: ENG)") | |
| parser.add_argument("--detection-size", type=int, default=1536, help="Detector size for local translate (default: 1536)") | |
| parser.add_argument("--box-threshold", type=float, default=0.7, help="Detector box threshold for local translate (default: 0.7)") | |
| parser.add_argument("--inpainting-size", type=int, default=2048, help="Inpainting size for local translate (default: 2048)") | |
| parser.add_argument("--auto-detector", action="store_true", help="Auto-tune detector size/threshold per image dimensions") | |
| args = parser.parse_args() | |
| if args.inpaint_only: | |
| args.local = True | |
| print("=" * 60) | |
| print(" SnipShot Engine Test Suite") | |
| print("=" * 60) | |
| # Always run import + config tests | |
| ok = test_imports() | |
| if not ok: | |
| print("\nImport check failed — aborting.") | |
| sys.exit(1) | |
| test_config() | |
| image_paths = _collect_image_paths(args.image) | |
| if not image_paths: | |
| print(f"\nNo input images found for: {args.image}") | |
| sys.exit(1) | |
| if len(image_paths) > 1: | |
| print(f"\n Collected {len(image_paths)} images from {args.image}") | |
| # Local in-process translate | |
| asyncio.run( | |
| test_local_translate( | |
| image_paths, | |
| save_inpainted=args.save_inpainted or args.inpaint_only, | |
| inpaint_only=args.inpaint_only, | |
| target_lang=args.target_lang, | |
| detection_size=args.detection_size, | |
| box_threshold=args.box_threshold, | |
| inpainting_size=args.inpainting_size, | |
| auto_detector=args.auto_detector, | |
| ) | |
| ) | |
| # Server endpoint tests (unless --local) | |
| if not args.local: | |
| asyncio.run(test_server(args.url, image_paths[0])) | |
| else: | |
| print("\n --local flag set, skipping server tests (4-6).") | |
| print("\n" + "=" * 60) | |
| print(" Done.") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |