"""annotate_crops node — nano-banana (Gemini image generation) for semantic annotation.""" from __future__ import annotations import io from concurrent.futures import ThreadPoolExecutor, as_completed from google import genai from google.genai import types from PIL import Image from config import ANNOTATOR_MODEL, GOOGLE_API_KEY from prompts.annotator import ANNOTATION_WRAPPER from state import DrawingReaderState, ImageRef from tools.image_store import ImageStore def _extract_generated_image(response) -> Image.Image | None: """Extract the generated image from a Gemini image-generation response.""" for part in response.candidates[0].content.parts: if part.inline_data is not None: return Image.open(io.BytesIO(part.inline_data.data)) return None def _annotate_single_crop_sync( client: genai.Client, crop_ref: ImageRef, annotation_prompt: str, image_store: ImageStore, ) -> ImageRef | None: """Annotate one crop using nano-banana (synchronous).""" crop_bytes = image_store.load_bytes(crop_ref) full_prompt = ANNOTATION_WRAPPER.format(annotation_prompt=annotation_prompt) response = client.models.generate_content( model=ANNOTATOR_MODEL, contents=[ types.Part.from_bytes(data=crop_bytes, mime_type="image/png"), full_prompt, ], config=types.GenerateContentConfig( response_modalities=["TEXT", "IMAGE"], ), ) annotated_image = _extract_generated_image(response) if annotated_image is None: return None ref = image_store.save_annotated(crop_ref, annotated_image) return ref def annotate_crops(state: DrawingReaderState, image_store: ImageStore) -> dict: """Run nano-banana annotation on crops that need it.""" crop_tasks = state.get("crop_tasks", []) image_refs = state.get("image_refs", []) # Build a mapping: find crops that need annotation. # The most recent batch of crops corresponds to the current crop_tasks. # Take the LAST len(crop_tasks) crops from image_refs to match by position, # so that on loop-back rounds we only match against the newest crops. crops_needing_annotation: list[tuple[ImageRef, str]] = [] all_crops = [r for r in image_refs if r["crop_type"] == "crop"] # Only the tail — the most recent batch produced by execute_crops recent_crops = all_crops[-len(crop_tasks):] if crop_tasks else [] for i, task in enumerate(crop_tasks): if task["annotate"] and task["annotation_prompt"] and i < len(recent_crops): crops_needing_annotation.append( (recent_crops[i], task["annotation_prompt"]) ) if not crops_needing_annotation: return {"status_message": ["No annotation needed for these crops."]} client = genai.Client(api_key=GOOGLE_API_KEY) # Use a thread pool instead of asyncio to avoid event-loop conflicts # with Streamlit's own event loop. results: list[ImageRef | None | Exception] = [None] * len(crops_needing_annotation) with ThreadPoolExecutor(max_workers=min(len(crops_needing_annotation), 4)) as pool: future_to_idx = {} for i, (ref, prompt) in enumerate(crops_needing_annotation): future = pool.submit( _annotate_single_crop_sync, client, ref, prompt, image_store, ) future_to_idx[future] = i for future in as_completed(future_to_idx): idx = future_to_idx[future] try: results[idx] = future.result() except Exception as e: results[idx] = e annotated_refs: list[ImageRef] = [] errors: list[str] = [] for i, result in enumerate(results): if isinstance(result, Exception): errors.append(f"Annotation {i} failed: {result}") elif result is not None: annotated_refs.append(result) else: errors.append(f"Annotation {i} returned no image") status = f"Annotated {len(annotated_refs)} of {len(crops_needing_annotation)} crops." if errors: status += f" Issues: {'; '.join(errors)}" return { "image_refs": annotated_refs, "status_message": [status], }