File size: 4,369 Bytes
e1ced8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """annotate_crops node — nano-banana (Gemini image generation) for semantic annotation."""
from __future__ import annotations
import io
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from google.genai import types
from PIL import Image
from config import ANNOTATOR_MODEL, GOOGLE_API_KEY
from prompts.annotator import ANNOTATION_WRAPPER
from state import DrawingReaderState, ImageRef
from tools.image_store import ImageStore
def _extract_generated_image(response) -> Image.Image | None:
"""Extract the generated image from a Gemini image-generation response."""
for part in response.candidates[0].content.parts:
if part.inline_data is not None:
return Image.open(io.BytesIO(part.inline_data.data))
return None
def _annotate_single_crop_sync(
client: genai.Client,
crop_ref: ImageRef,
annotation_prompt: str,
image_store: ImageStore,
) -> ImageRef | None:
"""Annotate one crop using nano-banana (synchronous)."""
crop_bytes = image_store.load_bytes(crop_ref)
full_prompt = ANNOTATION_WRAPPER.format(annotation_prompt=annotation_prompt)
response = client.models.generate_content(
model=ANNOTATOR_MODEL,
contents=[
types.Part.from_bytes(data=crop_bytes, mime_type="image/png"),
full_prompt,
],
config=types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
),
)
annotated_image = _extract_generated_image(response)
if annotated_image is None:
return None
ref = image_store.save_annotated(crop_ref, annotated_image)
return ref
def annotate_crops(state: DrawingReaderState, image_store: ImageStore) -> dict:
"""Run nano-banana annotation on crops that need it."""
crop_tasks = state.get("crop_tasks", [])
image_refs = state.get("image_refs", [])
# Build a mapping: find crops that need annotation.
# The most recent batch of crops corresponds to the current crop_tasks.
# Take the LAST len(crop_tasks) crops from image_refs to match by position,
# so that on loop-back rounds we only match against the newest crops.
crops_needing_annotation: list[tuple[ImageRef, str]] = []
all_crops = [r for r in image_refs if r["crop_type"] == "crop"]
# Only the tail — the most recent batch produced by execute_crops
recent_crops = all_crops[-len(crop_tasks):] if crop_tasks else []
for i, task in enumerate(crop_tasks):
if task["annotate"] and task["annotation_prompt"] and i < len(recent_crops):
crops_needing_annotation.append(
(recent_crops[i], task["annotation_prompt"])
)
if not crops_needing_annotation:
return {"status_message": ["No annotation needed for these crops."]}
client = genai.Client(api_key=GOOGLE_API_KEY)
# Use a thread pool instead of asyncio to avoid event-loop conflicts
# with Streamlit's own event loop.
results: list[ImageRef | None | Exception] = [None] * len(crops_needing_annotation)
with ThreadPoolExecutor(max_workers=min(len(crops_needing_annotation), 4)) as pool:
future_to_idx = {}
for i, (ref, prompt) in enumerate(crops_needing_annotation):
future = pool.submit(
_annotate_single_crop_sync, client, ref, prompt, image_store,
)
future_to_idx[future] = i
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as e:
results[idx] = e
annotated_refs: list[ImageRef] = []
errors: list[str] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
errors.append(f"Annotation {i} failed: {result}")
elif result is not None:
annotated_refs.append(result)
else:
errors.append(f"Annotation {i} returned no image")
status = f"Annotated {len(annotated_refs)} of {len(crops_needing_annotation)} crops."
if errors:
status += f" Issues: {'; '.join(errors)}"
return {
"image_refs": annotated_refs,
"status_message": [status],
}
|