""" Generate images from prompts using Replicate or Kie.ai (image models). Used to turn selected ad creative prompts into actual ad images. When KIE_API_KEY is set and model is nano-banana-pro, uses Kie.ai; otherwise Replicate. """ import asyncio import os import random import re import time from typing import Any, Optional import replicate from replicate.exceptions import ReplicateError # Model keys that use Kie.ai when KIE_API_KEY is set KIE_MODELS = {"nano-banana-pro-kie", "nano-banana-2-kie"} # Models that support reference images (image-to-image / multi-reference) REFERENCE_IMAGE_MODELS = { "nano-banana", "nano-banana-2", "nano-banana-pro", "nano-banana-2-kie", "nano-banana-pro-kie", } MODEL_REGISTRY: dict[str, dict[str, Any]] = { "nano-banana": { "id": "google/nano-banana", "param_name": "aspect_ratio", "uses_dimensions": False, "label": "Nano Banana (Replicate)", }, "nano-banana-2": { "id": "google/nano-banana-2", "param_name": "aspect_ratio", "uses_dimensions": False, "label": "Nano Banana 2 (Replicate)", }, "nano-banana-pro": { "id": "google/nano-banana-pro", "param_name": "aspect_ratio", "uses_dimensions": False, "label": "Nano Banana Pro (Replicate)", }, "nano-banana-2-kie": { "id": "kie.ai/nano-banana-2", "param_name": "aspect_ratio", "uses_dimensions": False, "label": "Nano Banana 2 (Kie AI)", "provider": "kie", "kie_model": "nano-banana-2", }, "nano-banana-pro-kie": { "id": "kie.ai/nano-banana-pro", "param_name": "aspect_ratio", "uses_dimensions": False, "label": "Nano Banana Pro (Kie AI)", "provider": "kie", "kie_model": "nano-banana-pro", }, # "flux-2-max": { # "id": "black-forest-labs/flux-2-max", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "ideogram-v3": { # "id": "ideogram-ai/ideogram-v3-quality", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "photon": { # "id": "luma/photon", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "recraft-v3": { # "id": "recraft-ai/recraft-v3", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, # "z-image-turbo": { # "id": "prunaai/z-image-turbo", # "param_name": "height", # "uses_dimensions": True, # }, # "seedream-3": { # "id": "bytedance/seedream-3", # "param_name": "aspect_ratio", # "uses_dimensions": False, # }, } DEFAULT_MODEL = "nano-banana-2" TIMEOUT = 120 # Retries: 5 attempts total for 503/429 (service/rate limit); 3 for other retryable errors MAX_RETRIES = 2 # SSL/timeout/connection (3 attempts total) MAX_RETRIES_SERVER_ERROR = 4 # 503/429 get 5 attempts with longer backoff RETRY_DELAY_SEC = 5 # For 503: longer waits so we don't hammer Replicate while it's overloaded RETRY_DELAY_503_SEC = 12 # For 429 rate limit: wait at least this long (Replicate often says "resets in ~15s") RATE_LIMIT_MIN_WAIT_SEC = 18 RATE_LIMIT_MAX_WAIT_SEC = 45 # Regex to parse "resets in ~15s" or "resets in ~1m" from Replicate 429 detail RATE_LIMIT_RESET_RE = re.compile(r"resets?\s+in\s+~?\s*(\d+)\s*s", re.I) def _width_height_to_aspect_ratio(width: int, height: int) -> str: if width == height: return "1:1" if width > height: return "16:9" return "9:16" def _get_client(): token = os.environ.get("REPLICATE_API_TOKEN") if not token: raise ValueError("REPLICATE_API_TOKEN is not set") return replicate.Client(api_token=token) def _extract_image_url(output) -> Optional[str]: if output is None: return None if hasattr(output, "url"): return getattr(output, "url", None) if isinstance(output, list) and len(output) > 0: first = output[0] return getattr(first, "url", str(first) if isinstance(first, str) else None) if isinstance(output, str) and output.startswith("http"): return output return None def _error_message(e: Exception) -> str: """Get the best available message from an exception (e.g. ReplicateError.detail).""" if isinstance(e, ReplicateError): detail = getattr(e, "detail", None) if detail is not None and str(detail).strip(): return str(detail).strip() return (str(e) or "").strip() def _is_retryable_error(e: Exception) -> bool: """True for transient errors: 503/429 from Replicate, E003/high demand, SSL/timeout/connection errors.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status in (503, 429): return True msg = _error_message(e).lower() if "e003" in msg or "high demand" in msg or "unavailable due to high demand" in msg: return True if "timed out" in msg or "timeout" in msg: return True if "ssl" in msg or "handshake" in msg: return True if "connection" in msg and ("reset" in msg or "refused" in msg or "error" in msg): return True return False def _retry_delay_sec(e: Exception, attempt: int) -> float: """Seconds to wait before retry. 429: use reset hint; 503/E003/high demand: longer backoff; else exponential.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status == 429: detail = _error_message(e) match = RATE_LIMIT_RESET_RE.search(detail) if match: sec = int(match.group(1)) return min(max(sec + 2, RATE_LIMIT_MIN_WAIT_SEC), RATE_LIMIT_MAX_WAIT_SEC) return RATE_LIMIT_MIN_WAIT_SEC if status == 503: return RETRY_DELAY_503_SEC * (attempt + 1) msg = _error_message(e).lower() if "e003" in msg or "high demand" in msg or "unavailable due to high demand" in msg: return RETRY_DELAY_503_SEC * (attempt + 1) return RETRY_DELAY_SEC * (attempt + 1) def _user_friendly_error(e: Exception) -> str: """Return a clearer message for known error types.""" if isinstance(e, ReplicateError): status = getattr(e, "status", None) if status == 503: return "Replicate service is temporarily unavailable. Please try again in a moment." if status == 429: return "Rate limit reached for this model. Please wait a minute and try again, or generate fewer ads at once." msg = _error_message(e) if not msg: return "Something went wrong while generating the image. Please try again." # Normalize Replicate's E003 / high-demand message to a short, user-friendly line (no trace IDs) if "e003" in msg.lower() or "high demand" in msg.lower() or "unavailable due to high demand" in msg.lower(): return "Service is busy due to high demand. Please try again in a few minutes." if "timed out" in msg or "timeout" in msg or "ssl" in msg or "handshake" in msg: return "Connection timed out. Please try generating this image again." return msg def _use_kie(model_key: str) -> bool: """True if this model should use Kie.ai and an API key is configured.""" if model_key not in KIE_MODELS: return False return bool((os.environ.get("KIE_API_KEY") or os.environ.get("KIE_AI_API_KEY") or "").strip()) def generate_image_sync( prompt: str, model_key: str = DEFAULT_MODEL, width: int = 1024, height: int = 1024, seed: Optional[int] = None, reference_image_urls: Optional[list[str]] = None, ) -> tuple[Optional[str], Optional[str]]: """ Run image generation (blocking). Returns (image_url, error_message). Uses Kie.ai for nano-banana-pro when KIE_API_KEY is set; otherwise Replicate. reference_image_urls: optional list of reference image URLs (product, logo, etc.). Nano-banana supports multiple images and will fuse them according to the prompt. """ # Route nano-banana-pro to Kie.ai when KIE_API_KEY is set if _use_kie(model_key): from app.kie_image import generate_image_sync as kie_generate cfg = MODEL_REGISTRY.get(model_key) or {} kie_model = cfg.get("kie_model") or "nano-banana-pro" return kie_generate( prompt=prompt, model=kie_model, width=width, height=height, reference_image_urls=reference_image_urls, ) cfg = MODEL_REGISTRY.get(model_key) if not cfg: return None, f"Unknown model: {model_key}" seed = seed or random.randint(1, 2147483647) input_data = {"prompt": prompt, "seed": seed} urls = [u for u in (reference_image_urls or []) if u and isinstance(u, str) and u.strip()] if urls and model_key in REFERENCE_IMAGE_MODELS: ref_key = cfg.get("reference_key") or "image_input" if cfg.get("reference_single"): input_data[ref_key] = urls[0] else: input_data[ref_key] = urls if cfg.get("uses_dimensions"): input_data["width"] = width input_data["height"] = height else: input_data[cfg["param_name"]] = _width_height_to_aspect_ratio(width, height) last_error = None for attempt in range(MAX_RETRIES_SERVER_ERROR + 1): try: client = _get_client() output = client.run(cfg["id"], input=input_data) url = _extract_image_url(output) return url, None except Exception as e: last_error = e is_server_or_rate_limit = ( isinstance(e, ReplicateError) and getattr(e, "status", None) in (503, 429) ) or ( "e003" in _error_message(e).lower() or "high demand" in _error_message(e).lower() ) max_retries = MAX_RETRIES_SERVER_ERROR if is_server_or_rate_limit else MAX_RETRIES if attempt < max_retries and _is_retryable_error(e): time.sleep(_retry_delay_sec(e, attempt)) continue break return None, _user_friendly_error(last_error) async def generate_image( prompt: str, model_key: str = DEFAULT_MODEL, width: int = 1024, height: int = 1024, reference_image_urls: Optional[list[str]] = None, ) -> tuple[Optional[str], Optional[str]]: """Run Replicate in a thread so we don't block the event loop.""" return await asyncio.wait_for( asyncio.to_thread( generate_image_sync, prompt=prompt, model_key=model_key, width=width, height=height, reference_image_urls=reference_image_urls, ), timeout=TIMEOUT, )