Ryan2219's picture
Upload 70 files
e1ced8e verified
"""execute_crops node — Gemini code_execution for agentic cropping (PoC 1 style)."""
from __future__ import annotations
import io
import logging
import time
import uuid
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from google.genai import types
from PIL import Image
from config import CROPPER_MODEL, GOOGLE_API_KEY
from prompts.cropper import CROPPER_PROMPT_TEMPLATE
from state import CropTask, DrawingReaderState, ImageRef
from tools.crop_cache import CropCache
from tools.image_store import ImageStore
from tools.pdf_processor import get_page_image_bytes
logger = logging.getLogger(__name__)
# Type alias for the progress callback.
# Signature: (completed_ref, crop_task, source, completed_count, total_count)
ProgressCallback = Callable[[ImageRef, CropTask, str, int, int], None]
# Retry settings for transient API errors (429 / 503)
MAX_RETRIES = 3
RETRY_BASE_DELAY = 2.0 # seconds
def _extract_last_image(response) -> Image.Image | None:
"""Extract the last generated image from a Gemini code_execution response."""
last_image = None
for part in response.candidates[0].content.parts:
# Try as_image() first
try:
img_data = part.as_image()
if img_data is not None:
last_image = Image.open(io.BytesIO(img_data.image_bytes))
continue
except Exception:
pass
# Fallback: inline_data
try:
if hasattr(part, "inline_data") and part.inline_data is not None:
img_bytes = part.inline_data.data
last_image = Image.open(io.BytesIO(img_bytes))
except Exception:
pass
return last_image
def _execute_single_crop_sync(
client: genai.Client,
page_image_bytes: bytes,
crop_task: CropTask,
image_store: ImageStore,
) -> tuple[ImageRef, bool]:
"""Execute one crop via Gemini code_execution (synchronous).
Includes retry logic for transient 503/429 errors.
Returns
-------
(image_ref, is_fallback)
``is_fallback`` is True when Gemini failed to produce a crop and the
full page image was returned instead. Fallbacks should NOT be cached.
"""
prompt = CROPPER_PROMPT_TEMPLATE.format(
crop_instruction=crop_task["crop_instruction"],
)
image_part = types.Part.from_bytes(data=page_image_bytes, mime_type="image/png")
# Retry loop for transient API errors
response = None
for attempt in range(MAX_RETRIES):
try:
response = client.models.generate_content(
model=CROPPER_MODEL,
contents=[image_part, prompt],
config=types.GenerateContentConfig(
tools=[types.Tool(code_execution=types.ToolCodeExecution)]
),
)
break
except Exception as e:
err_str = str(e)
if ("503" in err_str or "429" in err_str or "UNAVAILABLE" in err_str):
delay = RETRY_BASE_DELAY * (2 ** attempt)
logger.warning(
"Crop API error (attempt %d/%d): %s — retrying in %.1fs",
attempt + 1, MAX_RETRIES, err_str[:120], delay,
)
time.sleep(delay)
else:
raise
is_fallback = True
if response is not None:
final_image = _extract_last_image(response)
if final_image is not None:
is_fallback = False
else:
final_image = Image.open(io.BytesIO(page_image_bytes))
else:
# All retries exhausted
final_image = Image.open(io.BytesIO(page_image_bytes))
crop_id = f"crop_{uuid.uuid4().hex[:6]}"
ref = image_store.save_crop(
page_num=crop_task["page_num"],
crop_id=crop_id,
image=final_image,
label=crop_task["label"],
)
return ref, is_fallback
def execute_crops(
state: DrawingReaderState,
image_store: ImageStore,
crop_cache: CropCache | None = None,
progress_callback: ProgressCallback | None = None,
) -> dict:
"""Execute all crop tasks concurrently, reusing cached crops when possible.
Parameters
----------
progress_callback
Optional callback invoked on the **main thread** each time a crop
completes (or is served from cache). Called with
``(image_ref, crop_task, source, completed_count, total_count)``
where *source* is ``"cached"``, ``"completed"``, or ``"fallback"``.
"""
crop_tasks = state.get("crop_tasks", [])
page_image_dir = state["page_image_dir"]
if not crop_tasks:
return {"status_message": ["No crop tasks to execute."]}
total_count = len(crop_tasks)
completed_count = 0
# ----- Phase 1: Separate cache hits from tasks that need API calls -----
image_refs: list[ImageRef] = [] # final ordered results
tasks_to_execute: list[tuple[int, CropTask]] = [] # (original_index, task)
cache_hits = 0
for i, ct in enumerate(crop_tasks):
if crop_cache is not None:
cached_ref = crop_cache.lookup(ct["page_num"], ct["crop_instruction"])
if cached_ref is not None:
image_refs.append(cached_ref)
cache_hits += 1
completed_count += 1
logger.info(
"Reusing cached crop for '%s' (page %d)",
ct["label"], ct["page_num"],
)
# Notify the UI immediately for each cache hit
if progress_callback is not None:
progress_callback(
cached_ref, ct, "cached", completed_count, total_count,
)
continue
# Not cached — needs an API call
tasks_to_execute.append((i, ct))
# ----- Phase 2: Execute uncached crops via Gemini -----
errors: list[str] = []
if tasks_to_execute:
client = genai.Client(api_key=GOOGLE_API_KEY)
with ThreadPoolExecutor(max_workers=min(len(tasks_to_execute), 4)) as pool:
future_to_idx: dict = {}
for exec_idx, (_, ct) in enumerate(tasks_to_execute):
page_bytes = get_page_image_bytes(page_image_dir, ct["page_num"])
future = pool.submit(
_execute_single_crop_sync, client, page_bytes, ct, image_store,
)
future_to_idx[future] = exec_idx
# Process results as they arrive — this runs on the MAIN thread,
# so we can safely invoke the Streamlit progress callback here.
for future in as_completed(future_to_idx):
exec_idx = future_to_idx[future]
orig_idx, ct = tasks_to_execute[exec_idx]
try:
ref, is_fallback = future.result()
image_refs.append(ref)
completed_count += 1
# Register in cache (only successful targeted crops)
if crop_cache is not None:
crop_cache.register(
page_num=ct["page_num"],
crop_instruction=ct["crop_instruction"],
label=ct["label"],
image_ref=ref,
is_fallback=is_fallback,
)
# Notify the UI as each crop completes
if progress_callback is not None:
source = "fallback" if is_fallback else "completed"
progress_callback(
ref, ct, source, completed_count, total_count,
)
except Exception as e:
completed_count += 1
errors.append(f"Crop task {orig_idx} failed: {e}")
logger.error("Crop task %d failed: %s", orig_idx, e)
# ----- Phase 3: Build status message -----
api_count = len(tasks_to_execute) - len(errors)
parts = [f"Completed {len(image_refs)} of {total_count} crops"]
if cache_hits:
parts.append(f"({cache_hits} from cache, {api_count} new)")
if errors:
parts.append(f"Errors: {'; '.join(errors)}")
status = ". ".join(parts) + "."
if crop_cache is not None:
logger.info(crop_cache.stats)
return {
"image_refs": image_refs,
"status_message": [status],
}