| | """execute_crops node — Gemini code_execution for agentic cropping (PoC 1 style)."""
|
| | from __future__ import annotations
|
| |
|
| | import io
|
| | import logging
|
| | import time
|
| | import uuid
|
| | from collections.abc import Callable
|
| | from concurrent.futures import ThreadPoolExecutor, as_completed
|
| |
|
| | from google import genai
|
| | from google.genai import types
|
| | from PIL import Image
|
| |
|
| | from config import CROPPER_MODEL, GOOGLE_API_KEY
|
| | from prompts.cropper import CROPPER_PROMPT_TEMPLATE
|
| | from state import CropTask, DrawingReaderState, ImageRef
|
| | from tools.crop_cache import CropCache
|
| | from tools.image_store import ImageStore
|
| | from tools.pdf_processor import get_page_image_bytes
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| |
|
| | ProgressCallback = Callable[[ImageRef, CropTask, str, int, int], None]
|
| |
|
| |
|
| | MAX_RETRIES = 3
|
| | RETRY_BASE_DELAY = 2.0
|
| |
|
| |
|
| | def _extract_last_image(response) -> Image.Image | None:
|
| | """Extract the last generated image from a Gemini code_execution response."""
|
| | last_image = None
|
| | for part in response.candidates[0].content.parts:
|
| |
|
| | try:
|
| | img_data = part.as_image()
|
| | if img_data is not None:
|
| | last_image = Image.open(io.BytesIO(img_data.image_bytes))
|
| | continue
|
| | except Exception:
|
| | pass
|
| |
|
| | try:
|
| | if hasattr(part, "inline_data") and part.inline_data is not None:
|
| | img_bytes = part.inline_data.data
|
| | last_image = Image.open(io.BytesIO(img_bytes))
|
| | except Exception:
|
| | pass
|
| | return last_image
|
| |
|
| |
|
| | def _execute_single_crop_sync(
|
| | client: genai.Client,
|
| | page_image_bytes: bytes,
|
| | crop_task: CropTask,
|
| | image_store: ImageStore,
|
| | ) -> tuple[ImageRef, bool]:
|
| | """Execute one crop via Gemini code_execution (synchronous).
|
| |
|
| | Includes retry logic for transient 503/429 errors.
|
| |
|
| | Returns
|
| | -------
|
| | (image_ref, is_fallback)
|
| | ``is_fallback`` is True when Gemini failed to produce a crop and the
|
| | full page image was returned instead. Fallbacks should NOT be cached.
|
| | """
|
| | prompt = CROPPER_PROMPT_TEMPLATE.format(
|
| | crop_instruction=crop_task["crop_instruction"],
|
| | )
|
| |
|
| | image_part = types.Part.from_bytes(data=page_image_bytes, mime_type="image/png")
|
| |
|
| |
|
| | response = None
|
| | for attempt in range(MAX_RETRIES):
|
| | try:
|
| | response = client.models.generate_content(
|
| | model=CROPPER_MODEL,
|
| | contents=[image_part, prompt],
|
| | config=types.GenerateContentConfig(
|
| | tools=[types.Tool(code_execution=types.ToolCodeExecution)]
|
| | ),
|
| | )
|
| | break
|
| | except Exception as e:
|
| | err_str = str(e)
|
| | if ("503" in err_str or "429" in err_str or "UNAVAILABLE" in err_str):
|
| | delay = RETRY_BASE_DELAY * (2 ** attempt)
|
| | logger.warning(
|
| | "Crop API error (attempt %d/%d): %s — retrying in %.1fs",
|
| | attempt + 1, MAX_RETRIES, err_str[:120], delay,
|
| | )
|
| | time.sleep(delay)
|
| | else:
|
| | raise
|
| |
|
| | is_fallback = True
|
| | if response is not None:
|
| | final_image = _extract_last_image(response)
|
| | if final_image is not None:
|
| | is_fallback = False
|
| | else:
|
| | final_image = Image.open(io.BytesIO(page_image_bytes))
|
| | else:
|
| |
|
| | final_image = Image.open(io.BytesIO(page_image_bytes))
|
| |
|
| | crop_id = f"crop_{uuid.uuid4().hex[:6]}"
|
| | ref = image_store.save_crop(
|
| | page_num=crop_task["page_num"],
|
| | crop_id=crop_id,
|
| | image=final_image,
|
| | label=crop_task["label"],
|
| | )
|
| | return ref, is_fallback
|
| |
|
| |
|
| | def execute_crops(
|
| | state: DrawingReaderState,
|
| | image_store: ImageStore,
|
| | crop_cache: CropCache | None = None,
|
| | progress_callback: ProgressCallback | None = None,
|
| | ) -> dict:
|
| | """Execute all crop tasks concurrently, reusing cached crops when possible.
|
| |
|
| | Parameters
|
| | ----------
|
| | progress_callback
|
| | Optional callback invoked on the **main thread** each time a crop
|
| | completes (or is served from cache). Called with
|
| | ``(image_ref, crop_task, source, completed_count, total_count)``
|
| | where *source* is ``"cached"``, ``"completed"``, or ``"fallback"``.
|
| | """
|
| | crop_tasks = state.get("crop_tasks", [])
|
| | page_image_dir = state["page_image_dir"]
|
| |
|
| | if not crop_tasks:
|
| | return {"status_message": ["No crop tasks to execute."]}
|
| |
|
| | total_count = len(crop_tasks)
|
| | completed_count = 0
|
| |
|
| |
|
| | image_refs: list[ImageRef] = []
|
| | tasks_to_execute: list[tuple[int, CropTask]] = []
|
| | cache_hits = 0
|
| |
|
| | for i, ct in enumerate(crop_tasks):
|
| | if crop_cache is not None:
|
| | cached_ref = crop_cache.lookup(ct["page_num"], ct["crop_instruction"])
|
| | if cached_ref is not None:
|
| | image_refs.append(cached_ref)
|
| | cache_hits += 1
|
| | completed_count += 1
|
| | logger.info(
|
| | "Reusing cached crop for '%s' (page %d)",
|
| | ct["label"], ct["page_num"],
|
| | )
|
| |
|
| | if progress_callback is not None:
|
| | progress_callback(
|
| | cached_ref, ct, "cached", completed_count, total_count,
|
| | )
|
| | continue
|
| |
|
| | tasks_to_execute.append((i, ct))
|
| |
|
| |
|
| | errors: list[str] = []
|
| |
|
| | if tasks_to_execute:
|
| | client = genai.Client(api_key=GOOGLE_API_KEY)
|
| |
|
| | with ThreadPoolExecutor(max_workers=min(len(tasks_to_execute), 4)) as pool:
|
| | future_to_idx: dict = {}
|
| | for exec_idx, (_, ct) in enumerate(tasks_to_execute):
|
| | page_bytes = get_page_image_bytes(page_image_dir, ct["page_num"])
|
| | future = pool.submit(
|
| | _execute_single_crop_sync, client, page_bytes, ct, image_store,
|
| | )
|
| | future_to_idx[future] = exec_idx
|
| |
|
| |
|
| |
|
| | for future in as_completed(future_to_idx):
|
| | exec_idx = future_to_idx[future]
|
| | orig_idx, ct = tasks_to_execute[exec_idx]
|
| | try:
|
| | ref, is_fallback = future.result()
|
| | image_refs.append(ref)
|
| | completed_count += 1
|
| |
|
| |
|
| | if crop_cache is not None:
|
| | crop_cache.register(
|
| | page_num=ct["page_num"],
|
| | crop_instruction=ct["crop_instruction"],
|
| | label=ct["label"],
|
| | image_ref=ref,
|
| | is_fallback=is_fallback,
|
| | )
|
| |
|
| |
|
| | if progress_callback is not None:
|
| | source = "fallback" if is_fallback else "completed"
|
| | progress_callback(
|
| | ref, ct, source, completed_count, total_count,
|
| | )
|
| |
|
| | except Exception as e:
|
| | completed_count += 1
|
| | errors.append(f"Crop task {orig_idx} failed: {e}")
|
| | logger.error("Crop task %d failed: %s", orig_idx, e)
|
| |
|
| |
|
| | api_count = len(tasks_to_execute) - len(errors)
|
| | parts = [f"Completed {len(image_refs)} of {total_count} crops"]
|
| | if cache_hits:
|
| | parts.append(f"({cache_hits} from cache, {api_count} new)")
|
| | if errors:
|
| | parts.append(f"Errors: {'; '.join(errors)}")
|
| | status = ". ".join(parts) + "."
|
| |
|
| | if crop_cache is not None:
|
| | logger.info(crop_cache.stats)
|
| |
|
| | return {
|
| | "image_refs": image_refs,
|
| | "status_message": [status],
|
| | }
|
| |
|