File size: 8,770 Bytes
e1ced8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | """execute_crops node — Gemini code_execution for agentic cropping (PoC 1 style)."""
from __future__ import annotations
import io
import logging
import time
import uuid
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from google.genai import types
from PIL import Image
from config import CROPPER_MODEL, GOOGLE_API_KEY
from prompts.cropper import CROPPER_PROMPT_TEMPLATE
from state import CropTask, DrawingReaderState, ImageRef
from tools.crop_cache import CropCache
from tools.image_store import ImageStore
from tools.pdf_processor import get_page_image_bytes
logger = logging.getLogger(__name__)
# Type alias for the progress callback.
# Signature: (completed_ref, crop_task, source, completed_count, total_count)
ProgressCallback = Callable[[ImageRef, CropTask, str, int, int], None]
# Retry settings for transient API errors (429 / 503)
MAX_RETRIES = 3
RETRY_BASE_DELAY = 2.0 # seconds
def _extract_last_image(response) -> Image.Image | None:
"""Extract the last generated image from a Gemini code_execution response."""
last_image = None
for part in response.candidates[0].content.parts:
# Try as_image() first
try:
img_data = part.as_image()
if img_data is not None:
last_image = Image.open(io.BytesIO(img_data.image_bytes))
continue
except Exception:
pass
# Fallback: inline_data
try:
if hasattr(part, "inline_data") and part.inline_data is not None:
img_bytes = part.inline_data.data
last_image = Image.open(io.BytesIO(img_bytes))
except Exception:
pass
return last_image
def _execute_single_crop_sync(
client: genai.Client,
page_image_bytes: bytes,
crop_task: CropTask,
image_store: ImageStore,
) -> tuple[ImageRef, bool]:
"""Execute one crop via Gemini code_execution (synchronous).
Includes retry logic for transient 503/429 errors.
Returns
-------
(image_ref, is_fallback)
``is_fallback`` is True when Gemini failed to produce a crop and the
full page image was returned instead. Fallbacks should NOT be cached.
"""
prompt = CROPPER_PROMPT_TEMPLATE.format(
crop_instruction=crop_task["crop_instruction"],
)
image_part = types.Part.from_bytes(data=page_image_bytes, mime_type="image/png")
# Retry loop for transient API errors
response = None
for attempt in range(MAX_RETRIES):
try:
response = client.models.generate_content(
model=CROPPER_MODEL,
contents=[image_part, prompt],
config=types.GenerateContentConfig(
tools=[types.Tool(code_execution=types.ToolCodeExecution)]
),
)
break
except Exception as e:
err_str = str(e)
if ("503" in err_str or "429" in err_str or "UNAVAILABLE" in err_str):
delay = RETRY_BASE_DELAY * (2 ** attempt)
logger.warning(
"Crop API error (attempt %d/%d): %s — retrying in %.1fs",
attempt + 1, MAX_RETRIES, err_str[:120], delay,
)
time.sleep(delay)
else:
raise
is_fallback = True
if response is not None:
final_image = _extract_last_image(response)
if final_image is not None:
is_fallback = False
else:
final_image = Image.open(io.BytesIO(page_image_bytes))
else:
# All retries exhausted
final_image = Image.open(io.BytesIO(page_image_bytes))
crop_id = f"crop_{uuid.uuid4().hex[:6]}"
ref = image_store.save_crop(
page_num=crop_task["page_num"],
crop_id=crop_id,
image=final_image,
label=crop_task["label"],
)
return ref, is_fallback
def execute_crops(
state: DrawingReaderState,
image_store: ImageStore,
crop_cache: CropCache | None = None,
progress_callback: ProgressCallback | None = None,
) -> dict:
"""Execute all crop tasks concurrently, reusing cached crops when possible.
Parameters
----------
progress_callback
Optional callback invoked on the **main thread** each time a crop
completes (or is served from cache). Called with
``(image_ref, crop_task, source, completed_count, total_count)``
where *source* is ``"cached"``, ``"completed"``, or ``"fallback"``.
"""
crop_tasks = state.get("crop_tasks", [])
page_image_dir = state["page_image_dir"]
if not crop_tasks:
return {"status_message": ["No crop tasks to execute."]}
total_count = len(crop_tasks)
completed_count = 0
# ----- Phase 1: Separate cache hits from tasks that need API calls -----
image_refs: list[ImageRef] = [] # final ordered results
tasks_to_execute: list[tuple[int, CropTask]] = [] # (original_index, task)
cache_hits = 0
for i, ct in enumerate(crop_tasks):
if crop_cache is not None:
cached_ref = crop_cache.lookup(ct["page_num"], ct["crop_instruction"])
if cached_ref is not None:
image_refs.append(cached_ref)
cache_hits += 1
completed_count += 1
logger.info(
"Reusing cached crop for '%s' (page %d)",
ct["label"], ct["page_num"],
)
# Notify the UI immediately for each cache hit
if progress_callback is not None:
progress_callback(
cached_ref, ct, "cached", completed_count, total_count,
)
continue
# Not cached — needs an API call
tasks_to_execute.append((i, ct))
# ----- Phase 2: Execute uncached crops via Gemini -----
errors: list[str] = []
if tasks_to_execute:
client = genai.Client(api_key=GOOGLE_API_KEY)
with ThreadPoolExecutor(max_workers=min(len(tasks_to_execute), 4)) as pool:
future_to_idx: dict = {}
for exec_idx, (_, ct) in enumerate(tasks_to_execute):
page_bytes = get_page_image_bytes(page_image_dir, ct["page_num"])
future = pool.submit(
_execute_single_crop_sync, client, page_bytes, ct, image_store,
)
future_to_idx[future] = exec_idx
# Process results as they arrive — this runs on the MAIN thread,
# so we can safely invoke the Streamlit progress callback here.
for future in as_completed(future_to_idx):
exec_idx = future_to_idx[future]
orig_idx, ct = tasks_to_execute[exec_idx]
try:
ref, is_fallback = future.result()
image_refs.append(ref)
completed_count += 1
# Register in cache (only successful targeted crops)
if crop_cache is not None:
crop_cache.register(
page_num=ct["page_num"],
crop_instruction=ct["crop_instruction"],
label=ct["label"],
image_ref=ref,
is_fallback=is_fallback,
)
# Notify the UI as each crop completes
if progress_callback is not None:
source = "fallback" if is_fallback else "completed"
progress_callback(
ref, ct, source, completed_count, total_count,
)
except Exception as e:
completed_count += 1
errors.append(f"Crop task {orig_idx} failed: {e}")
logger.error("Crop task %d failed: %s", orig_idx, e)
# ----- Phase 3: Build status message -----
api_count = len(tasks_to_execute) - len(errors)
parts = [f"Completed {len(image_refs)} of {total_count} crops"]
if cache_hits:
parts.append(f"({cache_hits} from cache, {api_count} new)")
if errors:
parts.append(f"Errors: {'; '.join(errors)}")
status = ". ".join(parts) + "."
if crop_cache is not None:
logger.info(crop_cache.stats)
return {
"image_refs": image_refs,
"status_message": [status],
}
|