Hawbeez_Studio / backend /app /replicate_image.py
sushilideaclan01's picture
Update image model options
33b222b
"""
Generate images from prompts using Replicate (image models).
Used to turn selected ad creative prompts into actual ad images.
"""
import asyncio
import os
import random
import re
import time
from typing import Any, Optional
import replicate
from replicate.exceptions import ReplicateError
# Replicate model registry: image-to-image models only (reference image support)
REFERENCE_IMAGE_MODELS = {"nano-banana", "nano-banana-2", "grok-imagine"}
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"nano-banana": {
"id": "google/nano-banana",
"param_name": "aspect_ratio",
"uses_dimensions": False,
},
"nano-banana-2": {
"id": "google/nano-banana-2",
"param_name": "aspect_ratio",
"uses_dimensions": False,
},
"grok-imagine": {
"id": "xai/grok-imagine-image",
"param_name": "aspect_ratio",
"uses_dimensions": False,
# Replicate expects single "image" (uri) for editing; not "image_input" list
"reference_key": "image",
"reference_single": True,
},
# "flux-2-max": {
# "id": "black-forest-labs/flux-2-max",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "ideogram-v3": {
# "id": "ideogram-ai/ideogram-v3-quality",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "photon": {
# "id": "luma/photon",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "recraft-v3": {
# "id": "recraft-ai/recraft-v3",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "z-image-turbo": {
# "id": "prunaai/z-image-turbo",
# "param_name": "height",
# "uses_dimensions": True,
# },
# "seedream-3": {
# "id": "bytedance/seedream-3",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
}
DEFAULT_MODEL = "nano-banana"
TIMEOUT = 120
# Retries: 5 attempts total for 503/429 (service/rate limit); 3 for other retryable errors
MAX_RETRIES = 2 # SSL/timeout/connection (3 attempts total)
MAX_RETRIES_SERVER_ERROR = 4 # 503/429 get 5 attempts with longer backoff
RETRY_DELAY_SEC = 5
# For 503: longer waits so we don't hammer Replicate while it's overloaded
RETRY_DELAY_503_SEC = 12
# For 429 rate limit: wait at least this long (Replicate often says "resets in ~15s")
RATE_LIMIT_MIN_WAIT_SEC = 18
RATE_LIMIT_MAX_WAIT_SEC = 45
# Regex to parse "resets in ~15s" or "resets in ~1m" from Replicate 429 detail
RATE_LIMIT_RESET_RE = re.compile(r"resets?\s+in\s+~?\s*(\d+)\s*s", re.I)
def _width_height_to_aspect_ratio(width: int, height: int) -> str:
if width == height:
return "1:1"
if width > height:
return "16:9"
return "9:16"
def _get_client():
token = os.environ.get("REPLICATE_API_TOKEN")
if not token:
raise ValueError("REPLICATE_API_TOKEN is not set")
return replicate.Client(api_token=token)
def _extract_image_url(output) -> Optional[str]:
if output is None:
return None
if hasattr(output, "url"):
return getattr(output, "url", None)
if isinstance(output, list) and len(output) > 0:
first = output[0]
return getattr(first, "url", str(first) if isinstance(first, str) else None)
if isinstance(output, str) and output.startswith("http"):
return output
return None
def _is_retryable_error(e: Exception) -> bool:
"""True for transient errors: 503/429 from Replicate, SSL/timeout/connection errors."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status in (503, 429):
return True
msg = (str(e) or "").lower()
if "timed out" in msg or "timeout" in msg:
return True
if "ssl" in msg or "handshake" in msg:
return True
if "connection" in msg and ("reset" in msg or "refused" in msg or "error" in msg):
return True
return False
def _retry_delay_sec(e: Exception, attempt: int) -> float:
"""Seconds to wait before retry. 429: use reset hint; 503: longer backoff; else exponential."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status == 429:
detail = (getattr(e, "detail", None) or str(e)) or ""
match = RATE_LIMIT_RESET_RE.search(detail)
if match:
sec = int(match.group(1))
return min(max(sec + 2, RATE_LIMIT_MIN_WAIT_SEC), RATE_LIMIT_MAX_WAIT_SEC)
return RATE_LIMIT_MIN_WAIT_SEC
if status == 503:
return RETRY_DELAY_503_SEC * (attempt + 1)
return RETRY_DELAY_SEC * (attempt + 1)
def _user_friendly_error(e: Exception) -> str:
"""Return a clearer message for known error types."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status == 503:
return "Replicate service is temporarily unavailable. Please try again in a moment."
if status == 429:
return "Rate limit reached for this model. Please wait a minute and try again, or generate fewer ads at once."
msg = str(e) or "Unknown error"
if "timed out" in msg or "timeout" in msg or "ssl" in msg or "handshake" in msg:
return "Connection timed out. Please try generating this image again."
return msg
def generate_image_sync(
prompt: str,
model_key: str = DEFAULT_MODEL,
width: int = 1024,
height: int = 1024,
seed: Optional[int] = None,
reference_image_urls: Optional[list[str]] = None,
) -> tuple[Optional[str], Optional[str]]:
"""
Run Replicate image generation (blocking). Returns (image_url, error_message).
reference_image_urls: optional list of reference image URLs (product, logo, etc.).
Nano-banana supports multiple images and will fuse them according to the prompt.
"""
cfg = MODEL_REGISTRY.get(model_key)
if not cfg:
return None, f"Unknown model: {model_key}"
seed = seed or random.randint(1, 2147483647)
input_data = {"prompt": prompt, "seed": seed}
urls = [u for u in (reference_image_urls or []) if u and isinstance(u, str) and u.strip()]
if urls and model_key in REFERENCE_IMAGE_MODELS:
ref_key = cfg.get("reference_key") or "image_input"
if cfg.get("reference_single"):
input_data[ref_key] = urls[0]
else:
input_data[ref_key] = urls
if cfg.get("uses_dimensions"):
input_data["width"] = width
input_data["height"] = height
else:
input_data[cfg["param_name"]] = _width_height_to_aspect_ratio(width, height)
last_error = None
for attempt in range(MAX_RETRIES_SERVER_ERROR + 1):
try:
client = _get_client()
output = client.run(cfg["id"], input=input_data)
url = _extract_image_url(output)
return url, None
except Exception as e:
last_error = e
max_retries = (
MAX_RETRIES_SERVER_ERROR
if isinstance(e, ReplicateError) and getattr(e, "status", None) in (503, 429)
else MAX_RETRIES
)
if attempt < max_retries and _is_retryable_error(e):
time.sleep(_retry_delay_sec(e, attempt))
continue
break
return None, _user_friendly_error(last_error)
async def generate_image(
prompt: str,
model_key: str = DEFAULT_MODEL,
width: int = 1024,
height: int = 1024,
reference_image_urls: Optional[list[str]] = None,
) -> tuple[Optional[str], Optional[str]]:
"""Run Replicate in a thread so we don't block the event loop."""
return await asyncio.wait_for(
asyncio.to_thread(
generate_image_sync,
prompt=prompt,
model_key=model_key,
width=width,
height=height,
reference_image_urls=reference_image_urls,
),
timeout=TIMEOUT,
)