Amalfa_Creative_Studio / backend /app /template_flow.py
sushilideaclan01's picture
.
d5c6701
"""
Template-based creative generation flow used by the app API.
Flow:
1) Scrape product URL and pick product image
2) Send one fixed prompt + ordered references directly to image model
3) Return final generated image URLs
"""
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from app.replicate_image import generate_image_sync
from app.scraper import scrape_product
DEFAULT_MODEL_KEY = "nano-banana-2"
REPLICATE_API_KEY = os.getenv("REPLICATE_API_KEY") or os.getenv("REPLICATE_API_TOKEN") or ""
GENERATION_MAX_ATTEMPTS = 3
GENERATION_RETRY_DELAY_SEC = 4
TEMPLATE_FLOW_MAX_WORKERS = max(1, min(8, int(os.getenv("TEMPLATE_FLOW_MAX_WORKERS", "4"))))
log = logging.getLogger("uvicorn.error")
VISION_USER_PROMPT = """Use the first (left) image as the design template and layout reference. Create a high-converting advertisement for the product shown in the second (middle) image. and third (last) image is the brand logo of this jewellery brand product.
Maintain the same structure, typography style, and visual hierarchy from the template, but adapt it creatively to fit the new product.
Focus on:
Clean product placement
Short, benefit-driven copy (not generic)
Modern, premium aesthetic
Feel free to enhance colors, lighting, and composition to make the product stand out and look more desirable."""
def _is_url(value: str) -> bool:
return value.startswith("http://") or value.startswith("https://")
def scrape_product_image_url(product_url: str) -> tuple[str, dict]:
data = scrape_product(product_url)
images = [u.strip() for u in (data.get("product_images") or "").split(",") if u.strip()]
first = next((u for u in images if _is_url(u)), "")
if not first:
raise ValueError("No valid product image URL found after scraping.")
return first, data
def generate_with_nano_banana(
base_prompt: str,
reference_image_urls: list[str],
width: int,
height: int,
num_outputs: int,
model_key: str = DEFAULT_MODEL_KEY,
) -> list[str]:
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_KEY
refs = [u for i, u in enumerate(reference_image_urls) if u and u not in reference_image_urls[:i]]
log.info(
"template_flow: generation start model=%s refs=%d size=%sx%s outputs=%d",
model_key,
len(refs),
width,
height,
num_outputs,
)
urls: list[str] = []
for output_idx in range(num_outputs):
final_url = None
final_err = "Image generation failed."
for attempt in range(1, GENERATION_MAX_ATTEMPTS + 1):
log.info(
"template_flow: generation attempt output=%d/%d attempt=%d/%d",
output_idx + 1,
num_outputs,
attempt,
GENERATION_MAX_ATTEMPTS,
)
url, err = generate_image_sync(
prompt=base_prompt,
model_key=model_key,
width=width,
height=height,
reference_image_urls=refs,
)
if url and not err:
final_url = url
log.info(
"template_flow: generation success output=%d/%d attempt=%d url=%s",
output_idx + 1,
num_outputs,
attempt,
final_url,
)
break
final_err = err or "Image generation returned no URL."
log.warning(
"template_flow: generation failed output=%d/%d attempt=%d err=%s",
output_idx + 1,
num_outputs,
attempt,
final_err,
)
if attempt < GENERATION_MAX_ATTEMPTS:
time.sleep(GENERATION_RETRY_DELAY_SEC)
if not final_url:
raise RuntimeError(f"Image generation failed after {GENERATION_MAX_ATTEMPTS} attempts: {final_err}")
urls.append(final_url)
return urls
def run_template_based_creatives(
product_url: str,
template_image_url: str,
additional_template_image_urls: list[str] | None,
product_image_urls: list[str] | None,
logo_image_url: str | None,
num_outputs: int,
width: int,
height: int,
model_key: str = DEFAULT_MODEL_KEY,
) -> dict:
if not template_image_url:
raise ValueError("template_image_url is required")
log.info(
"template_flow: run start product_url=%s template_url=%s additional_templates=%d selected_product_refs=%d logo=%s",
product_url,
template_image_url,
len(additional_template_image_urls or []),
len(product_image_urls or []),
bool(logo_image_url),
)
product_image_url, product_data = scrape_product_image_url(product_url)
selected_product_refs = [u for u in (product_image_urls or []) if isinstance(u, str) and _is_url(u)]
if selected_product_refs:
product_image_url = selected_product_refs[0]
templates = [template_image_url] + [u for u in (additional_template_image_urls or []) if u]
# Preserve order while de-duplicating.
ordered_templates = [u for i, u in enumerate(templates) if u not in templates[:i]]
output_urls: list[str] = []
analyses: list[dict] = []
# Generate one image per template x product-reference combination.
product_variants = selected_product_refs if selected_product_refs else [product_image_url]
# Build all generation jobs first so we can run them concurrently.
jobs: list[tuple[int, int, str, str, list[str]]] = []
for t_idx, template_ref in enumerate(ordered_templates):
analyses.append({
"template_analysis": "direct_template_prompt_mode",
"image_generation_prompt": VISION_USER_PROMPT,
})
for p_idx, product_ref in enumerate(product_variants):
refs = [template_ref, product_ref]
if logo_image_url:
refs.append(logo_image_url)
jobs.append((t_idx, p_idx, template_ref, product_ref, refs))
def _run_variant_job(job: tuple[int, int, str, str, list[str]]) -> tuple[int, int, list[str]]:
t_idx, p_idx, template_ref, product_ref, refs = job
log.info(
"template_flow: variant generation template=%d/%d product_ref=%d/%d",
t_idx + 1,
len(ordered_templates),
p_idx + 1,
len(product_variants),
)
urls = generate_with_nano_banana(
VISION_USER_PROMPT,
refs,
width,
height,
1,
model_key=model_key,
)
return (t_idx, p_idx, urls)
# Keep deterministic output order even when worker completion order differs.
ordered_outputs: dict[tuple[int, int], list[str]] = {}
if jobs:
max_workers = min(TEMPLATE_FLOW_MAX_WORKERS, len(jobs))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [pool.submit(_run_variant_job, job) for job in jobs]
for fut in as_completed(futures):
try:
t_idx, p_idx, urls = fut.result()
ordered_outputs[(t_idx, p_idx)] = urls
except Exception:
log.exception("template_flow: one parallel generation job failed")
raise
for t_idx in range(len(ordered_templates)):
for p_idx in range(len(product_variants)):
output_urls.extend(ordered_outputs.get((t_idx, p_idx), []))
log.info(
"template_flow: run complete product=%s templates=%d product_variants=%d generated_images=%d",
product_data.get("product_name", ""),
len(ordered_templates),
len(product_variants),
len(output_urls),
)
return {
"images": output_urls,
"analysis": analyses[0] if analyses else {},
"analyses": analyses,
"meta": {
"product_url": product_url,
"selected_product_image_url": product_image_url,
"selected_product_image_urls": selected_product_refs,
"template_image_url": template_image_url,
"template_image_urls": ordered_templates,
"additional_template_image_urls": additional_template_image_urls or [],
"logo_image_url": logo_image_url,
"product_name": product_data.get("product_name", ""),
"model_key": model_key,
"num_outputs": len(output_urls),
"width": width,
"height": height,
},
"product_data": product_data,
}