Amalfa_Creative_Studio / backend /app /replicate_image.py
sushilideaclan01's picture
Add Canva integration: Implement OAuth flow, token management, and asset export functionality
2f18e46
"""
Generate images from prompts using Replicate or Kie.ai (image models).
Used to turn selected ad creative prompts into actual ad images.
When KIE_API_KEY is set and model is nano-banana-pro, uses Kie.ai; otherwise Replicate.
"""
import asyncio
import os
import random
import re
import time
from typing import Any, Optional
import replicate
from replicate.exceptions import ReplicateError
# Model keys that use Kie.ai when KIE_API_KEY is set
KIE_MODELS = {"nano-banana-pro-kie", "nano-banana-2-kie"}
# Models that support reference images (image-to-image / multi-reference)
REFERENCE_IMAGE_MODELS = {
"nano-banana",
"nano-banana-2",
"nano-banana-pro",
"nano-banana-2-kie",
"nano-banana-pro-kie",
}
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"nano-banana": {
"id": "google/nano-banana",
"param_name": "aspect_ratio",
"uses_dimensions": False,
"label": "Nano Banana (Replicate)",
},
"nano-banana-2": {
"id": "google/nano-banana-2",
"param_name": "aspect_ratio",
"uses_dimensions": False,
"label": "Nano Banana 2 (Replicate)",
},
"nano-banana-pro": {
"id": "google/nano-banana-pro",
"param_name": "aspect_ratio",
"uses_dimensions": False,
"label": "Nano Banana Pro (Replicate)",
},
"nano-banana-2-kie": {
"id": "kie.ai/nano-banana-2",
"param_name": "aspect_ratio",
"uses_dimensions": False,
"label": "Nano Banana 2 (Kie AI)",
"provider": "kie",
"kie_model": "nano-banana-2",
},
"nano-banana-pro-kie": {
"id": "kie.ai/nano-banana-pro",
"param_name": "aspect_ratio",
"uses_dimensions": False,
"label": "Nano Banana Pro (Kie AI)",
"provider": "kie",
"kie_model": "nano-banana-pro",
},
# "flux-2-max": {
# "id": "black-forest-labs/flux-2-max",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "ideogram-v3": {
# "id": "ideogram-ai/ideogram-v3-quality",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "photon": {
# "id": "luma/photon",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "recraft-v3": {
# "id": "recraft-ai/recraft-v3",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
# "z-image-turbo": {
# "id": "prunaai/z-image-turbo",
# "param_name": "height",
# "uses_dimensions": True,
# },
# "seedream-3": {
# "id": "bytedance/seedream-3",
# "param_name": "aspect_ratio",
# "uses_dimensions": False,
# },
}
DEFAULT_MODEL = "nano-banana-2"
TIMEOUT = 120
# Retries: 5 attempts total for 503/429 (service/rate limit); 3 for other retryable errors
MAX_RETRIES = 2 # SSL/timeout/connection (3 attempts total)
MAX_RETRIES_SERVER_ERROR = 4 # 503/429 get 5 attempts with longer backoff
RETRY_DELAY_SEC = 5
# For 503: longer waits so we don't hammer Replicate while it's overloaded
RETRY_DELAY_503_SEC = 12
# For 429 rate limit: wait at least this long (Replicate often says "resets in ~15s")
RATE_LIMIT_MIN_WAIT_SEC = 18
RATE_LIMIT_MAX_WAIT_SEC = 45
# Regex to parse "resets in ~15s" or "resets in ~1m" from Replicate 429 detail
RATE_LIMIT_RESET_RE = re.compile(r"resets?\s+in\s+~?\s*(\d+)\s*s", re.I)
def _width_height_to_aspect_ratio(width: int, height: int) -> str:
if width == height:
return "1:1"
if width > height:
return "16:9"
return "9:16"
def _get_client():
token = os.environ.get("REPLICATE_API_TOKEN")
if not token:
raise ValueError("REPLICATE_API_TOKEN is not set")
return replicate.Client(api_token=token)
def _extract_image_url(output) -> Optional[str]:
if output is None:
return None
if hasattr(output, "url"):
return getattr(output, "url", None)
if isinstance(output, list) and len(output) > 0:
first = output[0]
return getattr(first, "url", str(first) if isinstance(first, str) else None)
if isinstance(output, str) and output.startswith("http"):
return output
return None
def _error_message(e: Exception) -> str:
"""Get the best available message from an exception (e.g. ReplicateError.detail)."""
if isinstance(e, ReplicateError):
detail = getattr(e, "detail", None)
if detail is not None and str(detail).strip():
return str(detail).strip()
return (str(e) or "").strip()
def _is_retryable_error(e: Exception) -> bool:
"""True for transient errors: 503/429 from Replicate, E003/high demand, SSL/timeout/connection errors."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status in (503, 429):
return True
msg = _error_message(e).lower()
if "e003" in msg or "high demand" in msg or "unavailable due to high demand" in msg:
return True
if "timed out" in msg or "timeout" in msg:
return True
if "ssl" in msg or "handshake" in msg:
return True
if "connection" in msg and ("reset" in msg or "refused" in msg or "error" in msg):
return True
return False
def _retry_delay_sec(e: Exception, attempt: int) -> float:
"""Seconds to wait before retry. 429: use reset hint; 503/E003/high demand: longer backoff; else exponential."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status == 429:
detail = _error_message(e)
match = RATE_LIMIT_RESET_RE.search(detail)
if match:
sec = int(match.group(1))
return min(max(sec + 2, RATE_LIMIT_MIN_WAIT_SEC), RATE_LIMIT_MAX_WAIT_SEC)
return RATE_LIMIT_MIN_WAIT_SEC
if status == 503:
return RETRY_DELAY_503_SEC * (attempt + 1)
msg = _error_message(e).lower()
if "e003" in msg or "high demand" in msg or "unavailable due to high demand" in msg:
return RETRY_DELAY_503_SEC * (attempt + 1)
return RETRY_DELAY_SEC * (attempt + 1)
def _user_friendly_error(e: Exception) -> str:
"""Return a clearer message for known error types."""
if isinstance(e, ReplicateError):
status = getattr(e, "status", None)
if status == 503:
return "Replicate service is temporarily unavailable. Please try again in a moment."
if status == 429:
return "Rate limit reached for this model. Please wait a minute and try again, or generate fewer ads at once."
msg = _error_message(e)
if not msg:
return "Something went wrong while generating the image. Please try again."
# Normalize Replicate's E003 / high-demand message to a short, user-friendly line (no trace IDs)
if "e003" in msg.lower() or "high demand" in msg.lower() or "unavailable due to high demand" in msg.lower():
return "Service is busy due to high demand. Please try again in a few minutes."
if "timed out" in msg or "timeout" in msg or "ssl" in msg or "handshake" in msg:
return "Connection timed out. Please try generating this image again."
return msg
def _use_kie(model_key: str) -> bool:
"""True if this model should use Kie.ai and an API key is configured."""
if model_key not in KIE_MODELS:
return False
return bool((os.environ.get("KIE_API_KEY") or os.environ.get("KIE_AI_API_KEY") or "").strip())
def generate_image_sync(
prompt: str,
model_key: str = DEFAULT_MODEL,
width: int = 1024,
height: int = 1024,
seed: Optional[int] = None,
reference_image_urls: Optional[list[str]] = None,
) -> tuple[Optional[str], Optional[str]]:
"""
Run image generation (blocking). Returns (image_url, error_message).
Uses Kie.ai for nano-banana-pro when KIE_API_KEY is set; otherwise Replicate.
reference_image_urls: optional list of reference image URLs (product, logo, etc.).
Nano-banana supports multiple images and will fuse them according to the prompt.
"""
# Route nano-banana-pro to Kie.ai when KIE_API_KEY is set
if _use_kie(model_key):
from app.kie_image import generate_image_sync as kie_generate
cfg = MODEL_REGISTRY.get(model_key) or {}
kie_model = cfg.get("kie_model") or "nano-banana-pro"
return kie_generate(
prompt=prompt,
model=kie_model,
width=width,
height=height,
reference_image_urls=reference_image_urls,
)
cfg = MODEL_REGISTRY.get(model_key)
if not cfg:
return None, f"Unknown model: {model_key}"
seed = seed or random.randint(1, 2147483647)
input_data = {"prompt": prompt, "seed": seed}
urls = [u for u in (reference_image_urls or []) if u and isinstance(u, str) and u.strip()]
if urls and model_key in REFERENCE_IMAGE_MODELS:
ref_key = cfg.get("reference_key") or "image_input"
if cfg.get("reference_single"):
input_data[ref_key] = urls[0]
else:
input_data[ref_key] = urls
if cfg.get("uses_dimensions"):
input_data["width"] = width
input_data["height"] = height
else:
input_data[cfg["param_name"]] = _width_height_to_aspect_ratio(width, height)
last_error = None
for attempt in range(MAX_RETRIES_SERVER_ERROR + 1):
try:
client = _get_client()
output = client.run(cfg["id"], input=input_data)
url = _extract_image_url(output)
return url, None
except Exception as e:
last_error = e
is_server_or_rate_limit = (
isinstance(e, ReplicateError) and getattr(e, "status", None) in (503, 429)
) or (
"e003" in _error_message(e).lower()
or "high demand" in _error_message(e).lower()
)
max_retries = MAX_RETRIES_SERVER_ERROR if is_server_or_rate_limit else MAX_RETRIES
if attempt < max_retries and _is_retryable_error(e):
time.sleep(_retry_delay_sec(e, attempt))
continue
break
return None, _user_friendly_error(last_error)
async def generate_image(
prompt: str,
model_key: str = DEFAULT_MODEL,
width: int = 1024,
height: int = 1024,
reference_image_urls: Optional[list[str]] = None,
) -> tuple[Optional[str], Optional[str]]:
"""Run Replicate in a thread so we don't block the event loop."""
return await asyncio.wait_for(
asyncio.to_thread(
generate_image_sync,
prompt=prompt,
model_key=model_key,
width=width,
height=height,
reference_image_urls=reference_image_urls,
),
timeout=TIMEOUT,
)