Spaces:
Sleeping
Sleeping
| """ | |
| Generate images from prompts using Replicate (image models). | |
| Used to turn selected ad creative prompts into actual ad images. | |
| """ | |
| import asyncio | |
| import os | |
| import random | |
| import re | |
| import time | |
| from typing import Any, Optional | |
| import replicate | |
| from replicate.exceptions import ReplicateError | |
| # Replicate model registry: image-to-image models only (reference image support) | |
| REFERENCE_IMAGE_MODELS = {"nano-banana", "nano-banana-2", "grok-imagine"} | |
| MODEL_REGISTRY: dict[str, dict[str, Any]] = { | |
| "nano-banana": { | |
| "id": "google/nano-banana", | |
| "param_name": "aspect_ratio", | |
| "uses_dimensions": False, | |
| }, | |
| "nano-banana-2": { | |
| "id": "google/nano-banana-2", | |
| "param_name": "aspect_ratio", | |
| "uses_dimensions": False, | |
| }, | |
| "grok-imagine": { | |
| "id": "xai/grok-imagine-image", | |
| "param_name": "aspect_ratio", | |
| "uses_dimensions": False, | |
| # Replicate expects single "image" (uri) for editing; not "image_input" list | |
| "reference_key": "image", | |
| "reference_single": True, | |
| }, | |
| # "flux-2-max": { | |
| # "id": "black-forest-labs/flux-2-max", | |
| # "param_name": "aspect_ratio", | |
| # "uses_dimensions": False, | |
| # }, | |
| # "ideogram-v3": { | |
| # "id": "ideogram-ai/ideogram-v3-quality", | |
| # "param_name": "aspect_ratio", | |
| # "uses_dimensions": False, | |
| # }, | |
| # "photon": { | |
| # "id": "luma/photon", | |
| # "param_name": "aspect_ratio", | |
| # "uses_dimensions": False, | |
| # }, | |
| # "recraft-v3": { | |
| # "id": "recraft-ai/recraft-v3", | |
| # "param_name": "aspect_ratio", | |
| # "uses_dimensions": False, | |
| # }, | |
| # "z-image-turbo": { | |
| # "id": "prunaai/z-image-turbo", | |
| # "param_name": "height", | |
| # "uses_dimensions": True, | |
| # }, | |
| # "seedream-3": { | |
| # "id": "bytedance/seedream-3", | |
| # "param_name": "aspect_ratio", | |
| # "uses_dimensions": False, | |
| # }, | |
| } | |
| DEFAULT_MODEL = "nano-banana" | |
| TIMEOUT = 120 | |
| # Retries: 5 attempts total for 503/429 (service/rate limit); 3 for other retryable errors | |
| MAX_RETRIES = 2 # SSL/timeout/connection (3 attempts total) | |
| MAX_RETRIES_SERVER_ERROR = 4 # 503/429 get 5 attempts with longer backoff | |
| RETRY_DELAY_SEC = 5 | |
| # For 503: longer waits so we don't hammer Replicate while it's overloaded | |
| RETRY_DELAY_503_SEC = 12 | |
| # For 429 rate limit: wait at least this long (Replicate often says "resets in ~15s") | |
| RATE_LIMIT_MIN_WAIT_SEC = 18 | |
| RATE_LIMIT_MAX_WAIT_SEC = 45 | |
| # Regex to parse "resets in ~15s" or "resets in ~1m" from Replicate 429 detail | |
| RATE_LIMIT_RESET_RE = re.compile(r"resets?\s+in\s+~?\s*(\d+)\s*s", re.I) | |
| def _width_height_to_aspect_ratio(width: int, height: int) -> str: | |
| if width == height: | |
| return "1:1" | |
| if width > height: | |
| return "16:9" | |
| return "9:16" | |
| def _get_client(): | |
| token = os.environ.get("REPLICATE_API_TOKEN") | |
| if not token: | |
| raise ValueError("REPLICATE_API_TOKEN is not set") | |
| return replicate.Client(api_token=token) | |
| def _extract_image_url(output) -> Optional[str]: | |
| if output is None: | |
| return None | |
| if hasattr(output, "url"): | |
| return getattr(output, "url", None) | |
| if isinstance(output, list) and len(output) > 0: | |
| first = output[0] | |
| return getattr(first, "url", str(first) if isinstance(first, str) else None) | |
| if isinstance(output, str) and output.startswith("http"): | |
| return output | |
| return None | |
| def _is_retryable_error(e: Exception) -> bool: | |
| """True for transient errors: 503/429 from Replicate, SSL/timeout/connection errors.""" | |
| if isinstance(e, ReplicateError): | |
| status = getattr(e, "status", None) | |
| if status in (503, 429): | |
| return True | |
| msg = (str(e) or "").lower() | |
| if "timed out" in msg or "timeout" in msg: | |
| return True | |
| if "ssl" in msg or "handshake" in msg: | |
| return True | |
| if "connection" in msg and ("reset" in msg or "refused" in msg or "error" in msg): | |
| return True | |
| return False | |
| def _retry_delay_sec(e: Exception, attempt: int) -> float: | |
| """Seconds to wait before retry. 429: use reset hint; 503: longer backoff; else exponential.""" | |
| if isinstance(e, ReplicateError): | |
| status = getattr(e, "status", None) | |
| if status == 429: | |
| detail = (getattr(e, "detail", None) or str(e)) or "" | |
| match = RATE_LIMIT_RESET_RE.search(detail) | |
| if match: | |
| sec = int(match.group(1)) | |
| return min(max(sec + 2, RATE_LIMIT_MIN_WAIT_SEC), RATE_LIMIT_MAX_WAIT_SEC) | |
| return RATE_LIMIT_MIN_WAIT_SEC | |
| if status == 503: | |
| return RETRY_DELAY_503_SEC * (attempt + 1) | |
| return RETRY_DELAY_SEC * (attempt + 1) | |
| def _user_friendly_error(e: Exception) -> str: | |
| """Return a clearer message for known error types.""" | |
| if isinstance(e, ReplicateError): | |
| status = getattr(e, "status", None) | |
| if status == 503: | |
| return "Replicate service is temporarily unavailable. Please try again in a moment." | |
| if status == 429: | |
| return "Rate limit reached for this model. Please wait a minute and try again, or generate fewer ads at once." | |
| msg = str(e) or "Unknown error" | |
| if "timed out" in msg or "timeout" in msg or "ssl" in msg or "handshake" in msg: | |
| return "Connection timed out. Please try generating this image again." | |
| return msg | |
| def generate_image_sync( | |
| prompt: str, | |
| model_key: str = DEFAULT_MODEL, | |
| width: int = 1024, | |
| height: int = 1024, | |
| seed: Optional[int] = None, | |
| reference_image_urls: Optional[list[str]] = None, | |
| ) -> tuple[Optional[str], Optional[str]]: | |
| """ | |
| Run Replicate image generation (blocking). Returns (image_url, error_message). | |
| reference_image_urls: optional list of reference image URLs (product, logo, etc.). | |
| Nano-banana supports multiple images and will fuse them according to the prompt. | |
| """ | |
| cfg = MODEL_REGISTRY.get(model_key) | |
| if not cfg: | |
| return None, f"Unknown model: {model_key}" | |
| seed = seed or random.randint(1, 2147483647) | |
| input_data = {"prompt": prompt, "seed": seed} | |
| urls = [u for u in (reference_image_urls or []) if u and isinstance(u, str) and u.strip()] | |
| if urls and model_key in REFERENCE_IMAGE_MODELS: | |
| ref_key = cfg.get("reference_key") or "image_input" | |
| if cfg.get("reference_single"): | |
| input_data[ref_key] = urls[0] | |
| else: | |
| input_data[ref_key] = urls | |
| if cfg.get("uses_dimensions"): | |
| input_data["width"] = width | |
| input_data["height"] = height | |
| else: | |
| input_data[cfg["param_name"]] = _width_height_to_aspect_ratio(width, height) | |
| last_error = None | |
| for attempt in range(MAX_RETRIES_SERVER_ERROR + 1): | |
| try: | |
| client = _get_client() | |
| output = client.run(cfg["id"], input=input_data) | |
| url = _extract_image_url(output) | |
| return url, None | |
| except Exception as e: | |
| last_error = e | |
| max_retries = ( | |
| MAX_RETRIES_SERVER_ERROR | |
| if isinstance(e, ReplicateError) and getattr(e, "status", None) in (503, 429) | |
| else MAX_RETRIES | |
| ) | |
| if attempt < max_retries and _is_retryable_error(e): | |
| time.sleep(_retry_delay_sec(e, attempt)) | |
| continue | |
| break | |
| return None, _user_friendly_error(last_error) | |
| async def generate_image( | |
| prompt: str, | |
| model_key: str = DEFAULT_MODEL, | |
| width: int = 1024, | |
| height: int = 1024, | |
| reference_image_urls: Optional[list[str]] = None, | |
| ) -> tuple[Optional[str], Optional[str]]: | |
| """Run Replicate in a thread so we don't block the event loop.""" | |
| return await asyncio.wait_for( | |
| asyncio.to_thread( | |
| generate_image_sync, | |
| prompt=prompt, | |
| model_key=model_key, | |
| width=width, | |
| height=height, | |
| reference_image_urls=reference_image_urls, | |
| ), | |
| timeout=TIMEOUT, | |
| ) | |