File size: 8,770 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""execute_crops node — Gemini code_execution for agentic cropping (PoC 1 style)."""
from __future__ import annotations

import io
import logging
import time
import uuid
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed

from google import genai
from google.genai import types
from PIL import Image

from config import CROPPER_MODEL, GOOGLE_API_KEY
from prompts.cropper import CROPPER_PROMPT_TEMPLATE
from state import CropTask, DrawingReaderState, ImageRef
from tools.crop_cache import CropCache
from tools.image_store import ImageStore
from tools.pdf_processor import get_page_image_bytes

logger = logging.getLogger(__name__)

# Type alias for the progress callback.
# Signature: (completed_ref, crop_task, source, completed_count, total_count)
ProgressCallback = Callable[[ImageRef, CropTask, str, int, int], None]

# Retry settings for transient API errors (429 / 503)
MAX_RETRIES = 3
RETRY_BASE_DELAY = 2.0  # seconds


def _extract_last_image(response) -> Image.Image | None:
    """Extract the last generated image from a Gemini code_execution response."""
    last_image = None
    for part in response.candidates[0].content.parts:
        # Try as_image() first
        try:
            img_data = part.as_image()
            if img_data is not None:
                last_image = Image.open(io.BytesIO(img_data.image_bytes))
                continue
        except Exception:
            pass
        # Fallback: inline_data
        try:
            if hasattr(part, "inline_data") and part.inline_data is not None:
                img_bytes = part.inline_data.data
                last_image = Image.open(io.BytesIO(img_bytes))
        except Exception:
            pass
    return last_image


def _execute_single_crop_sync(

    client: genai.Client,

    page_image_bytes: bytes,

    crop_task: CropTask,

    image_store: ImageStore,

) -> tuple[ImageRef, bool]:
    """Execute one crop via Gemini code_execution (synchronous).



    Includes retry logic for transient 503/429 errors.



    Returns

    -------

    (image_ref, is_fallback)

        ``is_fallback`` is True when Gemini failed to produce a crop and the

        full page image was returned instead.  Fallbacks should NOT be cached.

    """
    prompt = CROPPER_PROMPT_TEMPLATE.format(
        crop_instruction=crop_task["crop_instruction"],
    )

    image_part = types.Part.from_bytes(data=page_image_bytes, mime_type="image/png")

    # Retry loop for transient API errors
    response = None
    for attempt in range(MAX_RETRIES):
        try:
            response = client.models.generate_content(
                model=CROPPER_MODEL,
                contents=[image_part, prompt],
                config=types.GenerateContentConfig(
                    tools=[types.Tool(code_execution=types.ToolCodeExecution)]
                ),
            )
            break
        except Exception as e:
            err_str = str(e)
            if ("503" in err_str or "429" in err_str or "UNAVAILABLE" in err_str):
                delay = RETRY_BASE_DELAY * (2 ** attempt)
                logger.warning(
                    "Crop API error (attempt %d/%d): %s — retrying in %.1fs",
                    attempt + 1, MAX_RETRIES, err_str[:120], delay,
                )
                time.sleep(delay)
            else:
                raise

    is_fallback = True
    if response is not None:
        final_image = _extract_last_image(response)
        if final_image is not None:
            is_fallback = False
        else:
            final_image = Image.open(io.BytesIO(page_image_bytes))
    else:
        # All retries exhausted
        final_image = Image.open(io.BytesIO(page_image_bytes))

    crop_id = f"crop_{uuid.uuid4().hex[:6]}"
    ref = image_store.save_crop(
        page_num=crop_task["page_num"],
        crop_id=crop_id,
        image=final_image,
        label=crop_task["label"],
    )
    return ref, is_fallback


def execute_crops(

    state: DrawingReaderState,

    image_store: ImageStore,

    crop_cache: CropCache | None = None,

    progress_callback: ProgressCallback | None = None,

) -> dict:
    """Execute all crop tasks concurrently, reusing cached crops when possible.



    Parameters

    ----------

    progress_callback

        Optional callback invoked on the **main thread** each time a crop

        completes (or is served from cache).  Called with

        ``(image_ref, crop_task, source, completed_count, total_count)``

        where *source* is ``"cached"``, ``"completed"``, or ``"fallback"``.

    """
    crop_tasks = state.get("crop_tasks", [])
    page_image_dir = state["page_image_dir"]

    if not crop_tasks:
        return {"status_message": ["No crop tasks to execute."]}

    total_count = len(crop_tasks)
    completed_count = 0

    # ----- Phase 1: Separate cache hits from tasks that need API calls -----
    image_refs: list[ImageRef] = []         # final ordered results
    tasks_to_execute: list[tuple[int, CropTask]] = []  # (original_index, task)
    cache_hits = 0

    for i, ct in enumerate(crop_tasks):
        if crop_cache is not None:
            cached_ref = crop_cache.lookup(ct["page_num"], ct["crop_instruction"])
            if cached_ref is not None:
                image_refs.append(cached_ref)
                cache_hits += 1
                completed_count += 1
                logger.info(
                    "Reusing cached crop for '%s' (page %d)",
                    ct["label"], ct["page_num"],
                )
                # Notify the UI immediately for each cache hit
                if progress_callback is not None:
                    progress_callback(
                        cached_ref, ct, "cached", completed_count, total_count,
                    )
                continue
        # Not cached — needs an API call
        tasks_to_execute.append((i, ct))

    # ----- Phase 2: Execute uncached crops via Gemini -----
    errors: list[str] = []

    if tasks_to_execute:
        client = genai.Client(api_key=GOOGLE_API_KEY)

        with ThreadPoolExecutor(max_workers=min(len(tasks_to_execute), 4)) as pool:
            future_to_idx: dict = {}
            for exec_idx, (_, ct) in enumerate(tasks_to_execute):
                page_bytes = get_page_image_bytes(page_image_dir, ct["page_num"])
                future = pool.submit(
                    _execute_single_crop_sync, client, page_bytes, ct, image_store,
                )
                future_to_idx[future] = exec_idx

            # Process results as they arrive — this runs on the MAIN thread,
            # so we can safely invoke the Streamlit progress callback here.
            for future in as_completed(future_to_idx):
                exec_idx = future_to_idx[future]
                orig_idx, ct = tasks_to_execute[exec_idx]
                try:
                    ref, is_fallback = future.result()
                    image_refs.append(ref)
                    completed_count += 1

                    # Register in cache (only successful targeted crops)
                    if crop_cache is not None:
                        crop_cache.register(
                            page_num=ct["page_num"],
                            crop_instruction=ct["crop_instruction"],
                            label=ct["label"],
                            image_ref=ref,
                            is_fallback=is_fallback,
                        )

                    # Notify the UI as each crop completes
                    if progress_callback is not None:
                        source = "fallback" if is_fallback else "completed"
                        progress_callback(
                            ref, ct, source, completed_count, total_count,
                        )

                except Exception as e:
                    completed_count += 1
                    errors.append(f"Crop task {orig_idx} failed: {e}")
                    logger.error("Crop task %d failed: %s", orig_idx, e)

    # ----- Phase 3: Build status message -----
    api_count = len(tasks_to_execute) - len(errors)
    parts = [f"Completed {len(image_refs)} of {total_count} crops"]
    if cache_hits:
        parts.append(f"({cache_hits} from cache, {api_count} new)")
    if errors:
        parts.append(f"Errors: {'; '.join(errors)}")
    status = ". ".join(parts) + "."

    if crop_cache is not None:
        logger.info(crop_cache.stats)

    return {
        "image_refs": image_refs,
        "status_message": [status],
    }