| | import logging |
| | import os |
| | from io import BytesIO |
| |
|
| | |
| | try: |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| | except Exception: |
| | pass |
| |
|
| | import base64 |
| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | import google.generativeai as genai |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | |
| | DEFAULT_MODEL_ID = os.environ.get("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview") |
| | DEFAULT_PROMPT = os.environ.get( |
| | "GEMINI_IMAGE_PROMPT", |
| | ( |
| | "TASK TYPE: STRICT IMAGE INPAINTING — OBJECT REMOVAL ONLY\n\n" |
| | "You are given:\n" |
| | "1) An original image\n" |
| | "2) A binary mask image\n\n" |
| | "MASK RULE (MANDATORY):\n" |
| | "• White pixels (#FFFFFF) indicate the exact region to be REMOVED.\n" |
| | "• Black pixels (#000000) indicate regions that MUST remain completely unchanged.\n\n" |
| | "PRIMARY OBJECTIVE:\n" |
| | "Completely delete everything inside the white masked area.\n" |
| | "The object in the white region must be fully removed with no visible remnants,\n" |
| | "no partial shapes, no outlines, no shadows, and no color traces.\n\n" |
| | "INPAINTING INSTRUCTIONS:\n" |
| | "Ignore the content of the white masked area entirely.\n" |
| | "Reconstruct that region using ONLY surrounding background information.\n" |
| | "Extend nearby background textures, patterns, and structures naturally.\n" |
| | "Match lighting direction, brightness, contrast, color temperature, and noise.\n" |
| | "Continue edges, lines, and surfaces realistically across the removed area.\n" |
| | "Blend boundaries smoothly so the edit is visually undetectable.\n\n" |
| | "Return only the edited image.\n\n" |
| | "STRICT CONSTRAINTS:\n" |
| | "• Do NOT generate or keep any part of the removed object.\n" |
| | "• Do NOT invent new objects or details.\n" |
| | "• Do NOT repaint, modify, blur, or enhance any black (unmasked) area.\n" |
| | "• Do NOT change the original image composition.\n" |
| | "• Do NOT change camera angle, perspective, or scale.\n\n" |
| | "QUALITY REQUIREMENTS:\n" |
| | "• No ghosting or transparent object remains.\n" |
| | "• No edge halos or smearing.\n" |
| | "• No repeated textures or patchy fills.\n" |
| | "• Result must look like the object never existed.\n\n" |
| | "FAILURE CONDITIONS (MUST BE AVOIDED):\n" |
| | "If any object fragment, outline, shadow, or color from the removed object\n" |
| | "is still visible, the result is incorrect and must be re-generated." |
| | ), |
| | ) |
| | _GENAI_MODEL: genai.GenerativeModel | None = None |
| |
|
| |
|
| | def _resize_mask(mask: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray: |
| | """Resize mask to match the target height/width.""" |
| | target_h, target_w = target_hw |
| | if mask.shape[:2] == (target_h, target_w): |
| | return mask |
| | return cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST) |
| |
|
| |
|
| | def _binary_mask_from_rgba(mask: np.ndarray, invert_mask: bool) -> np.ndarray: |
| | """ |
| | Normalize incoming RGBA masks to a 0/255 binary mask. |
| | - Transparent alpha (0) is treated as "remove" |
| | - White/bright RGB is treated as "remove" when alpha is mostly opaque |
| | """ |
| | if mask.shape[2] == 3: |
| | alpha_channel = np.ones(mask.shape[:2], dtype=np.uint8) * 255 |
| | rgb_channels = mask |
| | else: |
| | alpha_channel = mask[:, :, 3] |
| | rgb_channels = mask[:, :, :3] |
| |
|
| | |
| | if alpha_channel.mean() < 240: |
| | mask_bw = np.where(alpha_channel < 128, 255, 0).astype(np.uint8) |
| | else: |
| | gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY) |
| | mask_bw = np.where(gray > 128, 255, 0).astype(np.uint8) |
| |
|
| | if not invert_mask: |
| | mask_bw = 255 - mask_bw |
| |
|
| | return mask_bw |
| |
|
| |
|
| | def _pil_to_png_bytes(img: Image.Image) -> bytes: |
| | """Encode a PIL image to PNG bytes for Gemini edit endpoints.""" |
| | buf = BytesIO() |
| | img.save(buf, format="PNG") |
| | buf.seek(0) |
| | return buf.getvalue() |
| |
|
| |
|
| | def _get_gemini_model() -> genai.GenerativeModel: |
| | global _GENAI_MODEL |
| | if _GENAI_MODEL is None: |
| | api_key = ( |
| | os.environ.get("GEMINI_API_KEY") |
| | or os.environ.get("GOOGLE_API_KEY") |
| | or os.environ.get("GOOGLE_GENAI_API_KEY") |
| | ) |
| | if not api_key: |
| | raise RuntimeError("Set Gemini API key via GEMINI_API_KEY / GOOGLE_API_KEY / GOOGLE_GENAI_API_KEY") |
| | genai.configure(api_key=api_key) |
| | model_id = os.environ.get("GEMINI_IMAGE_MODEL", DEFAULT_MODEL_ID) |
| | _GENAI_MODEL = genai.GenerativeModel(model_id) |
| | return _GENAI_MODEL |
| |
|
| |
|
| | def _call_gemini_edit( |
| | image_rgb: np.ndarray, |
| | mask_bw: np.ndarray, |
| | prompt: str | None, |
| | target_size: tuple[int, int], |
| | ) -> Image.Image: |
| | """ |
| | Send source image + binary mask to Gemini via API-key-only generate_content. |
| | We include both the base image and the mask as separate parts and instruct the model to remove masked regions. |
| | """ |
| | model = _get_gemini_model() |
| |
|
| | base_image = Image.fromarray(image_rgb).convert("RGB") |
| | mask_image = Image.fromarray(mask_bw).convert("L") |
| |
|
| | |
| | guidance_rgb = image_rgb.copy() |
| | guidance_rgb[mask_bw > 0] = 255 |
| | guidance_image = Image.fromarray(guidance_rgb).convert("RGB") |
| |
|
| | base_bytes = _pil_to_png_bytes(base_image) |
| | mask_bytes = _pil_to_png_bytes(mask_image) |
| | guidance_bytes = _pil_to_png_bytes(guidance_image) |
| |
|
| | |
| | effective_prompt = ( |
| | (prompt or DEFAULT_PROMPT).strip() |
| | + "\nIMAGE ORDER:\n" |
| | + "Image A: Original photo with the removal region painted white.\n" |
| | + "Image B: Binary mask (white=remove, black=keep). Use this mask to decide what to remove.\n" |
| | ) |
| | log.info( |
| | "Calling Gemini generate_content model=%s (mask-guided remove) mask_pixels=%d", |
| | model.model_name, |
| | int((mask_bw > 0).sum()), |
| | ) |
| |
|
| | |
| | content = [ |
| | effective_prompt, |
| | {"mime_type": "image/png", "data": guidance_bytes}, |
| | {"mime_type": "image/png", "data": mask_bytes}, |
| | ] |
| |
|
| | |
| | |
| | try: |
| | response = model.generate_content( |
| | content, |
| | stream=False |
| | ) |
| | except Exception as gen_err: |
| | log.error("Gemini generate_content raised exception: %s", gen_err, exc_info=True) |
| | raise RuntimeError(f"Gemini API error: {gen_err}") |
| |
|
| | output_img: Image.Image | None = None |
| |
|
| | |
| | candidates = getattr(response, "candidates", []) |
| | if not candidates: |
| | log.error("Gemini returned no candidates") |
| | raise RuntimeError("Gemini API returned no candidates. The request may have been blocked.") |
| | |
| | |
| | for idx, candidate in enumerate(candidates): |
| | finish_reason = getattr(candidate, "finish_reason", None) |
| | if finish_reason: |
| | |
| | if finish_reason == 17 or finish_reason == 2: |
| | safety_ratings = getattr(candidate, "safety_ratings", []) |
| | log.error("Gemini blocked the request. Finish reason: %s, Safety ratings: %s", finish_reason, safety_ratings) |
| | raise RuntimeError(f"Gemini API blocked the content (finish_reason={finish_reason}). The image may violate safety policies.") |
| | elif finish_reason != 0: |
| | log.warning("Gemini finished with non-zero reason: %s", finish_reason) |
| | |
| | |
| | try: |
| | log.debug("Number of candidates: %d", len(candidates)) |
| | |
| | for idx, candidate in enumerate(candidates): |
| | parts = getattr(candidate, "content", None) |
| | if not parts: |
| | log.debug("Candidate %d has no content", idx) |
| | continue |
| | response_parts = getattr(parts, "parts", None) |
| | if not response_parts: |
| | log.debug("Candidate %d content has no parts", idx) |
| | continue |
| | log.debug("Candidate %d has %d parts", idx, len(response_parts)) |
| | |
| | for part_idx, part in enumerate(response_parts): |
| | inline = getattr(part, "inline_data", None) |
| | if inline: |
| | log.debug("Part %d has inline_data, mime_type: %s", part_idx, getattr(inline, "mime_type", None)) |
| | if inline.data: |
| | data = inline.data |
| | if isinstance(data, str): |
| | data = base64.b64decode(data) |
| | output_img = Image.open(BytesIO(data)).convert("RGB") |
| | log.info("Successfully extracted image from Gemini response") |
| | break |
| | else: |
| | |
| | text = getattr(part, "text", None) |
| | if text: |
| | log.warning("Gemini returned text instead of image in part %d: %s", part_idx, text[:200]) |
| | if output_img: |
| | break |
| | except Exception as err: |
| | log.error("Failed to parse Gemini response image: %s", err, exc_info=True) |
| |
|
| | if output_img is None: |
| | |
| | try: |
| | response_text = str(response) |
| | log.error("Gemini generate_content returned no image. Full response (first 1000 chars): %s", response_text[:1000]) |
| | |
| | if hasattr(response, "prompt_feedback"): |
| | feedback = response.prompt_feedback |
| | log.error("Prompt feedback: %s", feedback) |
| | |
| | for idx, candidate in enumerate(candidates): |
| | finish_reason = getattr(candidate, "finish_reason", None) |
| | log.error("Candidate %d finish_reason: %s", idx, finish_reason) |
| | except Exception: |
| | pass |
| | raise RuntimeError("Gemini generate_content returned no image. Check logs for details.") |
| |
|
| | |
| | if output_img.size != target_size: |
| | output_img = output_img.resize(target_size, Image.Resampling.LANCZOS) |
| |
|
| | return output_img |
| |
|
| |
|
| | def process_inpaint( |
| | image: np.ndarray, |
| | mask: np.ndarray, |
| | invert_mask: bool = True, |
| | prompt: str | None = None, |
| | ) -> np.ndarray: |
| | """ |
| | Forward inpainting to Gemini edit API using source image + mask. |
| | """ |
| | image_rgba = Image.fromarray(image).convert("RGBA") |
| | image_rgb = np.array(image_rgba.convert("RGB")) |
| |
|
| | mask_rgba = np.array(Image.fromarray(mask).convert("RGBA")) |
| | mask_bw = _binary_mask_from_rgba(mask_rgba, invert_mask) |
| | mask_bw = _resize_mask(mask_bw, image_rgb.shape[:2]) |
| |
|
| | target_size = (image_rgb.shape[1], image_rgb.shape[0]) |
| | edited_image = _call_gemini_edit(image_rgb, mask_bw, prompt, target_size) |
| | return np.array(edited_image) |
| |
|