""" Generate images from prompts using Replicate (image models). Used to turn selected ad creative prompts into actual ad images. """ import asyncio import os import random import re import time from typing import Any, Optional import replicate from replicate.exceptions import ReplicateError # Replicate model registry: image-to-image models only (reference image support) REFERENCE_IMAGE_MODELS = {"nano-banana", "nano-banana-2", "grok-imagine"} MODEL_REGISTRY: dict[str, dict[str, Any]] = { "nano-banana": { "id": "google/nano-banana", "param_name": "aspect_ratio", "uses_dimensions": False, }, "nano-banana-2": { "id": "google/nano-banana-2", "param_name": "aspect_ratio", "uses_dimensions": False, }, "grok-imagine": { "id": "xai/grok-imagine-image", "param_name": "aspect_ratio", "uses_dimensions": False, # Replicate expects single "image" (uri) for editing; not "image_input" list "reference_key": "image", "reference_single": True, }, # "flux-2-max": { # "id": "black-forest-labs/flux-2-max", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "ideogram-v3": { # "id": "ideogram-ai/ideogram-v3-quality", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "photon": { # "id": "luma/photon", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "recraft-v3": { # "id": "recraft-ai/recraft-v3", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "z-image-turbo": { # "id": "prunaai/z-image-turbo", # "param_name": "height", # "uses_dimensions": True, # }, # "seedream-3": { # "id": "bytedance/seedream-3", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, } DEFAULT_MODEL = "nano-banana" TIMEOUT = 120 # Retries: 5 attempts total for 503/429 (service/rate limit); 3 for other retryable errors MAX_RETRIES = 2 # SSL/timeout/connection (3 attempts total) MAX_RETRIES_SERVER_ERROR = 4 # 503/429 get 5 attempts with longer backoff RETRY_DELAY_SEC = 5 # For 503: longer waits so we don't hammer Replicate while it's overloaded RETRY_DELAY_503_SEC = 12 # For 429 rate limit: wait at least this long (Replicate often says "resets in ~15s") RATE_LIMIT_MIN_WAIT_SEC = 18 RATE_LIMIT_MAX_WAIT_SEC = 45 # Regex to parse "resets in ~15s" or "resets in ~1m" from Replicate 429 detail RATE_LIMIT_RESET_RE = re.compile(r"resets?\s+in\s+~?\s*(\d+)\s*s", re.I) def _width_height_to_aspect_ratio(width: int, height: int) -> str: if width == height: return "1:1" if width > height: return "16:9" return "9:16" def _get_client(): token = os.environ.get("REPLICATE_API_TOKEN") if not token: raise ValueError("REPLICATE_API_TOKEN is not set") return replicate.Client(api_token=token) def _extract_image_url(output) -> Optional[str]: if output is None: return None if hasattr(output, "url"): return getattr(output, "url", None) if isinstance(output, list) and len(output) > 0: first = output[0] return getattr(first, "url", str(first) if isinstance(first, str) else None) if isinstance(output, str) and output.startswith("http"): return output return None def _is_retryable_error(e: Exception) -> bool: """True for transient errors: 503/429 from Replicate, SSL/timeout/connection errors.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status in (503, 429): return True msg = (str(e) or "").lower() if "timed out" in msg or "timeout" in msg: return True if "ssl" in msg or "handshake" in msg: return True if "connection" in msg and ("reset" in msg or "refused" in msg or "error" in msg): return True return False def _retry_delay_sec(e: Exception, attempt: int) -> float: """Seconds to wait before retry. 429: use reset hint; 503: longer backoff; else exponential.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status == 429: detail = (getattr(e, "detail", None) or str(e)) or "" match = RATE_LIMIT_RESET_RE.search(detail) if match: sec = int(match.group(1)) return min(max(sec + 2, RATE_LIMIT_MIN_WAIT_SEC), RATE_LIMIT_MAX_WAIT_SEC) return RATE_LIMIT_MIN_WAIT_SEC if status == 503: return RETRY_DELAY_503_SEC * (attempt + 1) return RETRY_DELAY_SEC * (attempt + 1) def _user_friendly_error(e: Exception) -> str: """Return a clearer message for known error types.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status == 503: return "Replicate service is temporarily unavailable. Please try again in a moment." if status == 429: return "Rate limit reached for this model. Please wait a minute and try again, or generate fewer ads at once." msg = str(e) or "Unknown error" if "timed out" in msg or "timeout" in msg or "ssl" in msg or "handshake" in msg: return "Connection timed out. Please try generating this image again." return msg def generate_image_sync( prompt: str, model_key: str = DEFAULT_MODEL, width: int = 1024, height: int = 1024, seed: Optional[int] = None, reference_image_urls: Optional[list[str]] = None, ) -> tuple[Optional[str], Optional[str]]: """ Run Replicate image generation (blocking). Returns (image_url, error_message). reference_image_urls: optional list of reference image URLs (product, logo, etc.). Nano-banana supports multiple images and will fuse them according to the prompt. """ cfg = MODEL_REGISTRY.get(model_key) if not cfg: return None, f"Unknown model: {model_key}" seed = seed or random.randint(1, 2147483647) input_data = {"prompt": prompt, "seed": seed} urls = [u for u in (reference_image_urls or []) if u and isinstance(u, str) and u.strip()] if urls and model_key in REFERENCE_IMAGE_MODELS: ref_key = cfg.get("reference_key") or "image_input" if cfg.get("reference_single"): input_data[ref_key] = urls[0] else: input_data[ref_key] = urls if cfg.get("uses_dimensions"): input_data["width"] = width input_data["height"] = height else: input_data[cfg["param_name"]] = _width_height_to_aspect_ratio(width, height) last_error = None for attempt in range(MAX_RETRIES_SERVER_ERROR + 1): try: client = _get_client() output = client.run(cfg["id"], input=input_data) url = _extract_image_url(output) return url, None except Exception as e: last_error = e max_retries = ( MAX_RETRIES_SERVER_ERROR if isinstance(e, ReplicateError) and getattr(e, "status", None) in (503, 429) else MAX_RETRIES ) if attempt < max_retries and _is_retryable_error(e): time.sleep(_retry_delay_sec(e, attempt)) continue break return None, _user_friendly_error(last_error) async def generate_image( prompt: str, model_key: str = DEFAULT_MODEL, width: int = 1024, height: int = 1024, reference_image_urls: Optional[list[str]] = None, ) -> tuple[Optional[str], Optional[str]]: """Run Replicate in a thread so we don't block the event loop.""" return await asyncio.wait_for( asyncio.to_thread( generate_image_sync, prompt=prompt, model_key=model_key, width=width, height=height, reference_image_urls=reference_image_urls, ), timeout=TIMEOUT, )