File size: 4,369 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""annotate_crops node — nano-banana (Gemini image generation) for semantic annotation."""
from __future__ import annotations

import io
from concurrent.futures import ThreadPoolExecutor, as_completed

from google import genai
from google.genai import types
from PIL import Image

from config import ANNOTATOR_MODEL, GOOGLE_API_KEY
from prompts.annotator import ANNOTATION_WRAPPER
from state import DrawingReaderState, ImageRef
from tools.image_store import ImageStore


def _extract_generated_image(response) -> Image.Image | None:
    """Extract the generated image from a Gemini image-generation response."""
    for part in response.candidates[0].content.parts:
        if part.inline_data is not None:
            return Image.open(io.BytesIO(part.inline_data.data))
    return None


def _annotate_single_crop_sync(

    client: genai.Client,

    crop_ref: ImageRef,

    annotation_prompt: str,

    image_store: ImageStore,

) -> ImageRef | None:
    """Annotate one crop using nano-banana (synchronous)."""
    crop_bytes = image_store.load_bytes(crop_ref)

    full_prompt = ANNOTATION_WRAPPER.format(annotation_prompt=annotation_prompt)

    response = client.models.generate_content(
        model=ANNOTATOR_MODEL,
        contents=[
            types.Part.from_bytes(data=crop_bytes, mime_type="image/png"),
            full_prompt,
        ],
        config=types.GenerateContentConfig(
            response_modalities=["TEXT", "IMAGE"],
        ),
    )

    annotated_image = _extract_generated_image(response)
    if annotated_image is None:
        return None

    ref = image_store.save_annotated(crop_ref, annotated_image)
    return ref


def annotate_crops(state: DrawingReaderState, image_store: ImageStore) -> dict:
    """Run nano-banana annotation on crops that need it."""
    crop_tasks = state.get("crop_tasks", [])
    image_refs = state.get("image_refs", [])

    # Build a mapping: find crops that need annotation.
    # The most recent batch of crops corresponds to the current crop_tasks.
    # Take the LAST len(crop_tasks) crops from image_refs to match by position,
    # so that on loop-back rounds we only match against the newest crops.
    crops_needing_annotation: list[tuple[ImageRef, str]] = []

    all_crops = [r for r in image_refs if r["crop_type"] == "crop"]
    # Only the tail — the most recent batch produced by execute_crops
    recent_crops = all_crops[-len(crop_tasks):] if crop_tasks else []

    for i, task in enumerate(crop_tasks):
        if task["annotate"] and task["annotation_prompt"] and i < len(recent_crops):
            crops_needing_annotation.append(
                (recent_crops[i], task["annotation_prompt"])
            )

    if not crops_needing_annotation:
        return {"status_message": ["No annotation needed for these crops."]}

    client = genai.Client(api_key=GOOGLE_API_KEY)

    # Use a thread pool instead of asyncio to avoid event-loop conflicts
    # with Streamlit's own event loop.
    results: list[ImageRef | None | Exception] = [None] * len(crops_needing_annotation)

    with ThreadPoolExecutor(max_workers=min(len(crops_needing_annotation), 4)) as pool:
        future_to_idx = {}
        for i, (ref, prompt) in enumerate(crops_needing_annotation):
            future = pool.submit(
                _annotate_single_crop_sync, client, ref, prompt, image_store,
            )
            future_to_idx[future] = i

        for future in as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                results[idx] = future.result()
            except Exception as e:
                results[idx] = e

    annotated_refs: list[ImageRef] = []
    errors: list[str] = []
    for i, result in enumerate(results):
        if isinstance(result, Exception):
            errors.append(f"Annotation {i} failed: {result}")
        elif result is not None:
            annotated_refs.append(result)
        else:
            errors.append(f"Annotation {i} returned no image")

    status = f"Annotated {len(annotated_refs)} of {len(crops_needing_annotation)} crops."
    if errors:
        status += f" Issues: {'; '.join(errors)}"

    return {
        "image_refs": annotated_refs,
        "status_message": [status],
    }