Jewellery_Variation / src /streamlit_app.py
userIdc2024's picture
Update src/streamlit_app.py
f8c7098 verified
Raw
History Blame Contribute Delete
66.6 kB
import io
import json
import logging
import os
import re
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from urllib.parse import urlparse
import requests
import streamlit as st
from bs4 import BeautifulSoup
from openai import OpenAI
from PIL import Image
from design_generation import (
OUTPUT_DIR,
analyze_jewellery,
create_zip_from_images,
generate_design_directions,
generate_final_images,
generate_six_campaign_images,
save_uploaded_files,
)
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("amalfa")
MAX_KEYWORDS = 40
MAX_REFERENCES = 10
DEFAULT_PRODUCTS_PER_KEYWORD = 10
KEYWORD_OPTIONS = [
"layered necklace set",
"pearl choker necklace",
"rhinestone choker",
"herringbone necklace",
"cuban link chain necklace",
"snake chain necklace",
"paperclip chain necklace",
"name nameplate necklace",
"butterfly pendant necklace",
"lariat necklace",
"huggie hoop earrings",
"chunky gold hoop earrings",
"tassel drop earrings",
"crystal statement earrings",
"ear cuff no piercing",
"threader earrings",
"stackable rings set",
"signet ring",
"tennis bracelet",
"cuban link bracelet",
"beaded charm bracelet",
"evil eye bracelet",
"cuff bangle",
"body chain",
"waist chain belly chain",
"hand chain bracelet",
"anklet",
"rhinestone bra strap shoulder chain",
"pearl hair clip",
"rhinestone hair pin",
"claw hair clip",
"hair jewelry chain",
"brooch pin",
"saree brooch pin",
"decorative fancy buttons",
"18k gold plated jewelry",
"stainless steel waterproof jewelry",
"anti tarnish jewelry",
"925 sterling silver jewelry",
"y2k jewelry",
]
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Limit concurrent OpenAI calls so parallel keyword threads don't trigger 429s.
# Each keyword research makes 2 OpenAI calls (search + format) plus 1 for validation,
# so 6 slots β‰ˆ 2 full keywords running their AI steps simultaneously at any time.
_API_SEMAPHORE = threading.Semaphore(6)
# Cap total concurrent image downloads across all keyword threads.
_IMG_SEMAPHORE = threading.Semaphore(16)
# ── Page config (must be first Streamlit call) ────────────────────────────────
st.set_page_config(
page_title="Amalfa Β· Jewellery AI Studio",
page_icon="πŸ’Ž",
layout="wide",
initial_sidebar_state="expanded",
)
# ── Custom CSS ────────────────────────────────────────────────────────────────
GOLD = "#B8972A"
st.markdown(f"""
<style>
/* Background */
[data-testid="stAppViewContainer"] {{
background: #F9F8F6;
}}
[data-testid="stHeader"] {{
background: transparent;
}}
/* Typography */
h1 {{ color: #1A1A1A; font-weight: 700; letter-spacing: -0.5px; }}
h2, h3, h4 {{ color: #2C2C2C; font-weight: 600; }}
/* Sidebar */
[data-testid="stSidebar"] {{
background: #1C1C1C;
border-right: 1px solid #333;
}}
[data-testid="stSidebar"] p,
[data-testid="stSidebar"] label,
[data-testid="stSidebar"] span,
[data-testid="stSidebar"] div {{
color: #E8E0D0 !important;
}}
[data-testid="stSidebar"] h1,
[data-testid="stSidebar"] h2,
[data-testid="stSidebar"] h3 {{
color: #F5ECD7 !important;
}}
[data-testid="stSidebar"] textarea,
[data-testid="stSidebar"] input {{
background: #2A2A2A !important;
border-color: #444 !important;
color: #F0EAD8 !important;
}}
[data-testid="stSidebar"] textarea::placeholder,
[data-testid="stSidebar"] input::placeholder {{
color: #888888 !important;
opacity: 1 !important;
}}
/* Primary buttons */
.stButton > button[kind="primary"] {{
background: {GOLD};
color: white;
border: none;
border-radius: 6px;
font-weight: 600;
letter-spacing: 0.3px;
transition: background 0.2s;
}}
.stButton > button[kind="primary"]:hover {{
background: #9A7B22;
border: none;
}}
/* Sidebar sign-out button */
[data-testid="stSidebar"] .stButton > button {{
background: transparent !important;
border: 1px solid {GOLD} !important;
color: {GOLD} !important;
border-radius: 6px;
font-weight: 600;
width: 100%;
}}
[data-testid="stSidebar"] .stButton > button:hover {{
background: {GOLD} !important;
color: white !important;
}}
/* Step bar */
.step-bar {{
display: flex;
align-items: center;
margin: 1rem 0 1.5rem;
flex-wrap: wrap;
gap: 0;
}}
.step {{
display: flex;
align-items: center;
gap: 7px;
font-size: 0.82rem;
font-weight: 500;
color: #AAAAAA;
white-space: nowrap;
}}
.step.active {{
color: {GOLD};
font-weight: 700;
}}
.step.done {{ color: #5BA85A; }}
.step-num {{
width: 26px;
height: 26px;
border-radius: 50%;
background: #DDDDDD;
display: flex;
align-items: center;
justify-content: center;
font-size: 0.75rem;
font-weight: 700;
flex-shrink: 0;
color: #666;
}}
.step.active .step-num {{
background: {GOLD};
color: white;
}}
.step.done .step-num {{
background: #5BA85A;
color: white;
}}
.step-line {{
flex: 1;
height: 2px;
background: #DDDDDD;
margin: 0 8px;
min-width: 24px;
}}
/* Tabs */
.stTabs [data-baseweb="tab-list"] {{
border-bottom: 2px solid #E0DEDA;
gap: 0;
}}
.stTabs [data-baseweb="tab"] {{
font-size: 0.95rem;
font-weight: 600;
padding: 0.65rem 1.5rem;
color: #888;
border-bottom: 3px solid transparent;
}}
.stTabs [aria-selected="true"] {{
color: {GOLD} !important;
border-bottom-color: {GOLD} !important;
}}
/* Expander */
[data-testid="stExpander"] {{
border: 1px solid #E5E2DC;
border-radius: 10px;
background: white;
margin-bottom: 1rem;
box-shadow: 0 1px 4px rgba(0,0,0,0.05);
}}
[data-testid="stExpander"] summary {{
font-weight: 600;
font-size: 1rem;
color: #2C2C2C;
padding: 1rem 1.25rem;
}}
/* Download button */
[data-testid="stDownloadButton"] button {{
background: white;
border: 1px solid #C9A96E;
color: #B8872A;
border-radius: 6px;
font-weight: 500;
font-size: 0.82rem;
padding: 0.35rem 0.75rem;
transition: background 0.15s;
}}
[data-testid="stDownloadButton"] button:hover {{
background: #FDF8EE;
}}
/* Caption */
.stCaption {{
color: #888 !important;
font-size: 0.78rem !important;
}}
/* Login form card area */
.login-wrap {{
max-width: 380px;
margin: 0 auto;
padding: 2rem;
background: white;
border-radius: 16px;
border: 1px solid #E5E2DC;
box-shadow: 0 4px 24px rgba(0,0,0,0.08);
}}
</style>
""", unsafe_allow_html=True)
# ── Helpers ───────────────────────────────────────────────────────────────────
def slugify(value: str) -> str:
value = value.lower().strip()
value = re.sub(r"[^a-z0-9]+", "_", value)
return value.strip("_") or "item"
def show_image_with_download(img: dict, key_prefix: str):
image_path = img["path"]
image_name = Path(image_path).name
st.image(image_path, use_container_width=True)
with open(image_path, "rb") as file:
st.download_button(
label="Download",
data=file,
file_name=image_name,
mime="image/jpeg",
key=f"download_{key_prefix}_{image_name}",
use_container_width=True,
)
def render_step_bar(labels: List[str], done: List[int], current: int) -> None:
def cls(n):
if n in done:
return "done"
if n == current:
return "active"
return ""
html = '<div class="step-bar">'
for i, label in enumerate(labels, 1):
c = cls(i)
num = "βœ“" if c == "done" else str(i)
html += f'<div class="step {c}"><div class="step-num">{num}</div><span>{label}</span></div>'
if i < len(labels):
html += '<div class="step-line"></div>'
html += "</div>"
st.markdown(html, unsafe_allow_html=True)
_SIGNAL_LABELS: Dict[str, tuple] = {
"official_bestseller": ("Bestseller", "#2A7A3C"),
"retailer_bestseller": ("Top Seller", "#2A7A3C"),
"trending": ("Trending", "#B86A10"),
"selling_fast": ("Selling Fast", "#B86A10"),
"sold_out_or_restocked":("Sells Out", "#B86A10"),
"viral_social": ("Viral", "#7033A0"),
"celebrity_worn": ("Celebrity Pick", "#7033A0"),
"editorial_feature": ("Editor's Pick", "#7033A0"),
"high_review_volume": ("Highly Reviewed", "#1A62A8"),
"top_rated": ("Top Rated", "#1A62A8"),
"iconic_long_term": ("Iconic", "#555555"),
"other_verified": ("Verified", "#888888"),
}
_CONFIDENCE_DOT = {"high": "●", "medium": "◐", "low": "β—‹"}
def popularity_badge_html(signal: str, confidence: str) -> str:
label, color = _SIGNAL_LABELS.get(signal, ("Popular", "#888888"))
dot = _CONFIDENCE_DOT.get(confidence, "")
return (
f'<span style="background:{color};color:white;font-size:0.66rem;font-weight:700;'
f'padding:2px 8px;border-radius:3px;display:inline-block;">{label}</span>'
f'<span style="font-size:0.66rem;color:{color};margin-left:4px;" '
f'title="Confidence: {confidence}">{dot}</span>'
)
def download_image_from_url(image_url: str, filename_prefix: str = "product") -> str:
headers = {
"User-Agent": (
"Mozilla/5.0 AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/120.0 Safari/537.36"
)
}
if image_url.startswith("//"):
image_url = "https:" + image_url
response = requests.get(image_url, headers=headers, timeout=25)
response.raise_for_status()
try:
img = Image.open(io.BytesIO(response.content))
img.verify()
except Exception:
raise ValueError(f"URL did not return a valid image: {image_url}")
ext = Path(urlparse(image_url).path).suffix.lower()
if ext not in [".jpg", ".jpeg", ".png", ".webp"]:
content_type = response.headers.get("content-type", "")
if "png" in content_type:
ext = ".png"
elif "webp" in content_type:
ext = ".webp"
else:
ext = ".jpg"
output_path = OUTPUT_DIR / f"{filename_prefix}_{uuid.uuid4().hex}{ext}"
with open(output_path, "wb") as f:
f.write(response.content)
return str(output_path)
def fetch_product_image_from_page(product_url: str) -> str:
headers = {
"User-Agent": (
"Mozilla/5.0 AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/120.0 Safari/537.36"
)
}
response = requests.get(product_url, headers=headers, timeout=25)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
image_url = None
og_image = soup.find("meta", property="og:image")
if og_image and og_image.get("content"):
image_url = og_image["content"]
if not image_url:
twitter_image = soup.find("meta", attrs={"name": "twitter:image"})
if twitter_image and twitter_image.get("content"):
image_url = twitter_image["content"]
if not image_url:
image_meta = soup.find("meta", attrs={"itemprop": "image"})
if image_meta and image_meta.get("content"):
image_url = image_meta["content"]
if not image_url:
for img in soup.find_all("img"):
candidate = (
img.get("src")
or img.get("data-src")
or img.get("data-original")
or img.get("data-image")
)
if not candidate and img.get("srcset"):
candidate = img.get("srcset").split(",")[0].strip().split(" ")[0]
if candidate:
image_url = candidate
break
if not image_url:
raise ValueError("No product image found on product page.")
if image_url.startswith("//"):
image_url = "https:" + image_url
if image_url.startswith("/"):
parsed = urlparse(product_url)
image_url = f"{parsed.scheme}://{parsed.netloc}{image_url}"
return download_image_from_url(image_url=image_url, filename_prefix="scraped_product")
def parse_keywords(raw_keywords: str) -> List[str]:
keywords = []
for line in raw_keywords.replace(",", "\n").splitlines():
keyword = line.strip()
if keyword:
keywords.append(keyword)
seen = set()
unique_keywords = []
for keyword in keywords:
key = keyword.lower()
if key not in seen:
seen.add(key)
unique_keywords.append(keyword)
return unique_keywords[:MAX_KEYWORDS]
@st.cache_data(show_spinner=False)
def validate_product_with_ai(
keyword: str,
product_name: str,
brand: str,
product_url: str,
image_url: str,
category: str,
) -> bool:
response = client.responses.create(
model="gpt-5.5",
input=(
f"You are validating ecommerce product research for a jewellery-only app.\n\n"
f"Keyword: {keyword}\n"
f"Product name: {product_name}\n"
f"Brand: {brand}\n"
f"Product URL: {product_url}\n"
f"Image URL: {image_url}\n"
f"Category: {category}\n\n"
f"Return ONLY valid JSON:\n"
f'{{"accept": true, "reason": ""}}\n\n'
f"Accept only if:\n"
f"- The product is a jewellery product.\n"
f"- The product matches the keyword intent.\n"
f"- The product URL looks like a real product page.\n"
f"- The image URL, if present, appears related to the product.\n"
f"- It is not a shoe, bag, clothing, beauty, perfume, watch, or unrelated accessory.\n"
),
text={"format": {"type": "json_object"}},
)
try:
data = json.loads(response.output_text)
return bool(data.get("accept"))
except Exception:
return False
def batch_validate_products(
keyword: str,
products: List[Dict[str, Any]],
) -> List[bool]:
"""Validate a batch of products in a single AI call.
Products already marked high-confidence by the research step are auto-accepted."""
if not products:
return []
results: List[Optional[bool]] = [None] * len(products)
to_validate_indices: List[int] = []
for i, p in enumerate(products):
if p.get("popularity_confidence") == "high":
results[i] = True # already verified by research β€” skip re-validation
else:
to_validate_indices.append(i)
if not to_validate_indices:
return results # type: ignore
candidates = [products[i] for i in to_validate_indices]
product_list = "\n".join(
f"{j + 1}. Name: {p.get('product_name', '')} | Brand: {p.get('brand', '')} "
f"| Category: {p.get('category', '')} | Signal: {p.get('popularity_signal', '')} "
f"| Confidence: {p.get('popularity_confidence', '')} | URL: {p.get('product_url', '')}"
for j, p in enumerate(candidates)
)
response = client.responses.create(
model="gpt-5.5",
input=(
f"You are validating products for a jewellery-only app. Keyword: '{keyword}'\n\n"
f"{product_list}\n\n"
f"Return ONLY valid JSON: {{\"valid\": [1, 3, ...]}} β€” 1-based indices of products that:\n"
f"- Are jewellery (necklace, ring, earring, bracelet, pendant, anklet, etc.)\n"
f"- Match the keyword intent\n"
f"- Are NOT shoes, bags, clothing, beauty, perfume, watches, or unrelated accessories\n"
f"- Have a plausible product page URL\n"
),
text={"format": {"type": "json_object"}},
)
try:
valid_set = set(json.loads(response.output_text).get("valid", []))
for list_pos, original_idx in enumerate(to_validate_indices):
results[original_idx] = (list_pos + 1) in valid_set
except Exception:
for idx in to_validate_indices:
results[idx] = True
return results # type: ignore
@st.cache_data(show_spinner=False)
def cached_get_product_card_image(
image_url: str,
product_url: str,
product_name: str,
) -> Optional[str]:
image_url = image_url.strip() if image_url else ""
product_url = product_url.strip() if product_url else ""
product_name = product_name.strip() if product_name else "product"
if image_url:
try:
path = download_image_from_url(
image_url=image_url,
filename_prefix=slugify(product_name),
)
if path and Path(path).exists():
return path
except Exception:
pass
if product_url:
try:
path = fetch_product_image_from_page(product_url)
if path and Path(path).exists():
return path
except Exception:
pass
return None
def research_products_with_deep_research(
keyword: str,
count: int = DEFAULT_PRODUCTS_PER_KEYWORD,
exclude_names: Optional[List[str]] = None,
exclude_brands: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Research internationally popular, trending, bestselling jewellery products.
The function performs:
1. Web-grounded product research.
2. Structured JSON normalization.
3. Basic validation and deduplication.
Returns:
List of product dictionaries.
"""
keyword = keyword.strip()
if not keyword:
raise ValueError("keyword cannot be empty")
if count <= 0:
return []
exclude_names = [
str(name).strip()
for name in (exclude_names or [])
if str(name).strip()
]
exclude_brands = [
str(brand).strip()
for brand in (exclude_brands or [])
if str(brand).strip()
]
exclusion_sections: List[str] = []
if exclude_names:
exclusion_sections.append(
"Do not include any of these already collected products:\n"
+ "\n".join(f"- {name}" for name in exclude_names[:100])
)
if exclude_brands:
exclusion_sections.append(
"Do not include products from these already used brands:\n"
+ "\n".join(f"- {brand}" for brand in exclude_brands[:100])
)
exclude_text = (
"\n\n".join(exclusion_sections)
if exclusion_sections
else "No additional product or brand exclusions were provided."
)
research_prompt = f"""
You are a jewellery ecommerce product research specialist.
Search the web for up to {count} unique jewellery products relevant to the
keyword "{keyword}".
PRIMARY OBJECTIVE
Find products that have genuine current popularity signals. Prioritize products
that are currently bestselling, trending, frequently purchased, viral,
waitlisted, selling out, highly reviewed, editor-featured, celebrity-worn,
social-media popular, or prominently placed in a retailer's bestseller or
trending collection.
Do not select a product only because its brand is famous. The individual
product should have a verifiable popularity signal whenever possible.
POPULARITY EVIDENCE
For every selected product, identify at least one verifiable popularity signal,
such as:
- Listed on an official "Bestsellers", "Most Popular", "Trending",
"Most Loved", "Top Rated", or "Selling Fast" page.
- Explicitly marked as a bestseller, trending product, popular pick,
frequently purchased item, low-stock item, or selling-fast item.
- A substantial number of customer ratings or reviews.
- Recent editorial coverage describing the product as viral, cult-favorite,
trending, iconic, highly demanded, or frequently sold out.
- Recent social-media, celebrity, influencer, or fashion-trend attention.
- Prominent placement on a major retailer's jewellery bestseller page.
- Evidence of repeat sell-outs, restocks, waitlists, or strong customer demand.
Never state that a product is a bestseller, viral, trending, or highly popular
unless the available source supports that statement.
SOURCE PRIORITY
Use sources in this order:
1. Official brand product pages.
2. Official brand bestseller or trending collections.
3. Major international retailers with dedicated product pages.
4. Reputable fashion, jewellery, lifestyle, or commerce publications.
5. Reputable retailer bestseller rankings or customer-review evidence.
PRODUCT REQUIREMENTS
- Include jewellery products only.
- The product must be directly relevant to "{keyword}".
- Exclude Indian brands.
- Exclude brands founded in India, headquartered primarily in India, or
predominantly identified as Indian jewellery brands.
- Exclude watches and smart watches.
- Exclude shoes, bags, clothing, perfume, beauty products, eyewear,
hair accessories, and unrelated accessories.
- Every result must represent a distinct jewellery product.
- Every result must have a unique product-page URL.
- Do not return the same product from different retailers.
- Do not return colour, size, metal, gemstone, finish, or material variants of
the same base product as separate products.
- Prefer one product per brand.
- Only use a second product from the same brand when there are not enough
strong alternatives and the designs are materially different.
- Return a mix of brands, jewellery styles, design types, materials, and price
points.
- Prefer products that are currently purchasable.
- Do not use homepages, search pages, collection pages, category pages, or
social-media posts as the product_url.
- A popularity evidence page may be a separate URL from the product page.
DATA TO COLLECT
For every product collect:
- product_name
- brand
- product_url
- image_url
- approximate price
- currency
- category
- tags
- why the product is popular
- popularity signal type
- popularity evidence URL
- popularity evidence summary
- whether the popularity evidence is current
- confidence in the popularity claim: high, medium, or low
IMAGE RULES
- image_url must be a direct product image URL when one can be verified.
- Do not place the product-page URL in image_url.
- Do not invent an image URL.
- Leave the image URL empty when a reliable direct URL cannot be found.
PRICE RULES
- Prefer the price displayed on the selected product page.
- Preserve the displayed currency.
- Prefer the normal listed price.
- When only a sale price is available, use the sale price.
- Do not perform currency conversion.
- Do not invent missing prices.
RECENCY AND RANKING
Prioritize current demand signals over old popularity.
Prefer evidence from the last 12 months. Older evidence may be used for an
enduring iconic or consistently bestselling product, but clearly identify it as
long-term popularity rather than a current trend.
Rank the final products using this order:
1. Strong current bestseller or trending evidence.
2. Strong sales-demand evidence such as sell-outs, waitlists, or restocks.
3. High review volume or top-rated retailer evidence.
4. Recent reputable editorial, celebrity, or social-trend evidence.
5. Long-term iconic popularity.
If fewer than {count} products can be verified, return fewer products.
Never fabricate products or evidence merely to meet the requested count.
ADDITIONAL EXCLUSIONS
{exclude_text}
Return concise but sufficiently detailed research. Include the product URL and
popularity evidence URL for every product whenever available.
""".strip()
t0 = time.time()
logger.info("[%s] Starting web search (target %d products)", keyword, count)
# web_search_preview cannot be combined with JSON mode β€” two calls are required.
search_response = client.responses.create(
model="gpt-5.5",
input=research_prompt,
tools=[{"type": "web_search_preview"}],
)
search_text = search_response.output_text.strip()
logger.info("[%s] Search call completed in %.1fs", keyword, time.time() - t0)
if not search_text:
return []
t1 = time.time()
format_response = client.responses.create(
model="gpt-5.5",
input=(
f"Convert this jewellery research into a JSON object for keyword '{keyword}'.\n\n"
f"{search_text}\n\n"
f"Return ONLY valid JSON with this structure β€” no markdown, no commentary:\n"
f'{{ "keyword": "{keyword}", "products": [ {{ "product_name": "", '
f'"brand": "", "product_url": "", "image_url": "", "approx_price": "", '
f'"currency": "", "why_popular": "", "popularity_signal": "", '
f'"popularity_evidence_url": "", "popularity_evidence_summary": "", '
f'"popularity_recency": "", "popularity_confidence": "", '
f'"category": "", "tags": [] }} ] }}'
),
text={"format": {"type": "json_object"}},
)
raw_text = format_response.output_text.strip()
logger.info("[%s] Format call completed in %.1fs", keyword, time.time() - t1)
try:
data = json.loads(raw_text)
except json.JSONDecodeError as exc:
raise ValueError(f"Research response was not valid JSON:\n{raw_text}") from exc
products = data.get("products", [])
if not isinstance(products, list):
return []
# Defensive application-side deduplication.
unique_products: List[Dict[str, Any]] = []
seen_urls = set()
seen_products = set()
seen_brands = set()
for product in products:
if not isinstance(product, dict):
continue
product_name = str(product.get("product_name") or "").strip()
brand = str(product.get("brand") or "").strip()
product_url = str(product.get("product_url") or "").strip()
if not product_name or not brand or not product_url:
continue
normalized_url = product_url.lower().rstrip("/")
normalized_product = (
product_name.lower().strip(),
brand.lower().strip(),
)
normalized_brand = brand.lower().strip()
if normalized_url in seen_urls:
continue
if normalized_product in seen_products:
continue
# Keeps one product per brand for maximum diversity.
if normalized_brand in seen_brands:
continue
popularity_confidence = str(
product.get("popularity_confidence") or ""
).strip().lower()
if popularity_confidence not in {"high", "medium", "low"}:
product["popularity_confidence"] = "low"
popularity_recency = str(
product.get("popularity_recency") or ""
).strip().lower()
if popularity_recency not in {
"current",
"recent",
"long_term",
"unknown",
}:
product["popularity_recency"] = "unknown"
tags = product.get("tags", [])
if not isinstance(tags, list):
product["tags"] = []
seen_urls.add(normalized_url)
seen_products.add(normalized_product)
seen_brands.add(normalized_brand)
unique_products.append(product)
if len(unique_products) >= count:
break
return unique_products
def _research_keyword_pure(
keyword: str,
target_count: int,
on_progress: Optional[Callable[[int, int, str], None]] = None,
) -> List[Dict[str, Any]]:
"""Research one keyword in a single pass β€” no retry rounds."""
t_start = time.time()
# ── 1. fetch candidates ───────────────────────────────────────────────────
if on_progress:
on_progress(0, target_count, f"'{keyword}' β€” searching…")
try:
with _API_SEMAPHORE:
candidates = research_products_with_deep_research(
keyword=keyword,
count=max(target_count * 3, 15),
)
except Exception as e:
logger.error("[%s] research call failed: %s", keyword, e)
return []
logger.info("[%s] %d candidates returned in %.1fs", keyword, len(candidates), time.time() - t_start)
if not candidates:
return []
# ── 2. batch AI validation ────────────────────────────────────────────────
if on_progress:
on_progress(0, target_count, f"'{keyword}' β€” validating {len(candidates)} candidates…")
high_conf = sum(1 for p in candidates if p.get("popularity_confidence") == "high")
logger.info("[%s] validating %d candidates (%d high-conf auto-pass)", keyword, len(candidates), high_conf)
try:
with _API_SEMAPHORE:
valid_flags = batch_validate_products(keyword, candidates)
except Exception as e:
logger.error("[%s] validation failed: %s β€” accepting all", keyword, e)
valid_flags = [True] * len(candidates)
valid_products = [p for p, ok in zip(candidates, valid_flags) if ok]
logger.info("[%s] %d/%d passed validation", keyword, len(valid_products), len(candidates))
if not valid_products:
return []
# ── 3. parallel image downloads ───────────────────────────────────────────
if on_progress:
on_progress(0, target_count, f"'{keyword}' β€” downloading images…")
def _dl(p: Dict[str, Any]):
with _IMG_SEMAPHORE:
return p, cached_get_product_card_image(
image_url=p.get("image_url", ""),
product_url=p.get("product_url", ""),
product_name=p.get("product_name", "product"),
)
products_with_images: List[Dict[str, Any]] = []
seen_images: set = set()
# 4 workers per keyword β€” multiple keywords run in parallel so keep this
# modest to avoid overwhelming the image hosts.
with ThreadPoolExecutor(max_workers=min(len(valid_products), 4)) as pool:
dl_futures = {pool.submit(_dl, p): p for p in valid_products}
for future in as_completed(dl_futures):
if len(products_with_images) >= target_count:
break
try:
product, image_path = future.result()
if not image_path or not Path(image_path).exists():
continue
image_key = str(Path(image_path).resolve())
if image_key in seen_images:
continue
seen_images.add(image_key)
product["local_image_path"] = image_path
product["source_keyword"] = keyword
product["id"] = uuid.uuid4().hex
products_with_images.append(product)
if on_progress:
on_progress(len(products_with_images), target_count,
f"'{keyword}' β€” {len(products_with_images)}/{target_count} found")
except Exception as e:
logger.warning("[%s] image download error: %s", keyword, e)
elapsed = time.time() - t_start
logger.info("[%s] done: %d products in %.1fs", keyword, len(products_with_images), elapsed)
return products_with_images[:target_count]
def research_until_image_count(keyword: str, target_count: int) -> List[Dict[str, Any]]:
"""Single-keyword research with a Streamlit progress bar."""
progress = st.progress(0, text=f"Searching for '{keyword}'…")
def _update(found: int, target: int, msg: str) -> None:
progress.progress(min(found / target, 1.0) if target else 0.0, text=msg)
products = _research_keyword_pure(keyword, target_count, on_progress=_update)
progress.empty()
return products
def research_products_for_keywords(
keywords: List[str],
count_per_keyword: int,
) -> List[Dict[str, Any]]:
"""Research all keywords in parallel (up to 10 concurrent workers)."""
if len(keywords) == 1:
return research_until_image_count(keywords[0], count_per_keyword)
n = len(keywords)
# Up to 10 keywords researched simultaneously.
# _API_SEMAPHORE (6 slots) ensures no more than 6 OpenAI calls run at once
# even though 10 threads may be active β€” threads queue at the semaphore.
workers = min(n, 10)
logger.info("Starting parallel research: %d keywords, %d workers", n, workers)
overall = st.progress(0, text=f"Researching {n} keywords in parallel…")
status_area = st.empty()
# Track per-keyword state in the main thread (updated inside as_completed loop).
kw_state: Dict[str, str] = {kw: "pending" for kw in keywords}
kw_counts: Dict[str, int] = {}
def _render_status() -> None:
done_kws = [kw for kw in keywords if kw_state[kw] == "done"]
fail_kws = [kw for kw in keywords if kw_state[kw] == "failed"]
pend_kws = [kw for kw in keywords if kw_state[kw] == "pending"]
total_products = sum(kw_counts.values())
lines: List[str] = []
if done_kws:
badges = " Β· ".join(
f"`{kw}` ({kw_counts.get(kw, 0)})" for kw in done_kws
)
lines.append(f"βœ… **Done ({len(done_kws)}):** {badges}")
if fail_kws:
lines.append("❌ **Failed:** " + " · ".join(f"`{kw}`" for kw in fail_kws))
if pend_kws:
lines.append(f"⏳ **In progress / queued:** {len(pend_kws)} keyword(s)")
lines.append(f"**Products collected so far:** {total_products}")
status_area.markdown(" \n".join(lines))
results: Dict[str, List[Dict[str, Any]]] = {}
# as_completed iterates in the main thread β€” safe for all st.* calls.
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(_research_keyword_pure, kw, count_per_keyword): kw
for kw in keywords
}
for future in as_completed(futures):
kw = futures[future]
try:
products = future.result()
results[kw] = products
kw_state[kw] = "done"
kw_counts[kw] = len(products)
logger.info("Keyword '%s' complete: %d products", kw, len(products))
except Exception as e:
results[kw] = []
kw_state[kw] = "failed"
logger.error("Keyword '%s' failed: %s", kw, e)
done_count = sum(1 for s in kw_state.values() if s in ("done", "failed"))
overall.progress(
done_count / n,
text=f"Completed {done_count}/{n} keywords…",
)
_render_status()
overall.empty()
status_area.empty()
all_products: List[Dict[str, Any]] = []
for kw in keywords:
all_products.extend(results.get(kw, []))
logger.info("All keywords done: %d total products", len(all_products))
return all_products
# ── Authentication ─────────────────────────────────────────────────────────────
APP_USERNAME = os.getenv("APP_USERNAME")
APP_PASSWORD = os.getenv("APP_PASSWORD")
if not st.session_state.get("authenticated"):
st.markdown(
"""
<div style="text-align:center;padding-top:3rem;padding-bottom:1.5rem;">
<div style="font-size:2.5rem;">πŸ’Ž</div>
<h1 style="font-size:1.9rem;letter-spacing:3px;margin:0.25rem 0 0.5rem;">AMALFA</h1>
<p style="color:#999;font-size:0.95rem;margin:0;">Jewellery AI Studio</p>
</div>
""",
unsafe_allow_html=True,
)
_, col_c, _ = st.columns([1, 1.1, 1])
with col_c:
with st.form("login_form"):
st.markdown("#### Sign In")
username = st.text_input("Username", placeholder="Enter your username")
password = st.text_input("Password", type="password", placeholder="Enter your password")
submitted = st.form_submit_button("Sign In", use_container_width=True, type="primary")
if submitted:
if username == APP_USERNAME and password == APP_PASSWORD:
st.session_state["authenticated"] = True
st.rerun()
else:
st.error("Invalid username or password.")
st.stop()
# ── Sidebar ────────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown(
f"<div style='font-size:1.4rem;font-weight:700;letter-spacing:2px;color:#F5ECD7;margin-bottom:0.25rem;'>πŸ’Ž AMALFA</div>"
f"<div style='font-size:0.75rem;color:#888;margin-bottom:1.5rem;'>Jewellery AI Studio</div>",
unsafe_allow_html=True,
)
st.markdown("---")
st.markdown("**Research Settings**")
count_per_keyword = st.number_input(
"Products per keyword",
min_value=1,
max_value=10,
value=DEFAULT_PRODUCTS_PER_KEYWORD,
help="How many products to fetch per keyword (max 10)",
)
st.markdown("---")
st.markdown(
f"<div style='font-size:0.72rem;color:#555;margin-bottom:1rem;'>"
f"Max keywords: {MAX_KEYWORDS} Β· Max refs: {MAX_REFERENCES}"
f"</div>",
unsafe_allow_html=True,
)
if st.button("Sign Out", use_container_width=True):
st.session_state.clear()
st.rerun()
# ── Page header ────────────────────────────────────────────────────────────────
st.markdown("## Jewellery Variation Generator")
st.caption("Generate professional AI-powered jewellery variations from keywords or your own reference images.")
st.markdown("---")
# ── Main tabs ──────────────────────────────────────────────────────────────────
tab_research, tab_upload = st.tabs(["Research from Keywords", "Upload Reference Images"])
# ══════════════════════════════════════════════════════════════════════════════
# RESEARCH WORKFLOW
# ══════════════════════════════════════════════════════════════════════════════
with tab_research:
# Compute which steps are complete
r_done: List[int] = []
if "researched_products" in st.session_state:
r_done.append(1)
if "processed_product_refs" in st.session_state:
r_done.append(2)
if "design_options_per_product" in st.session_state:
r_done.append(3)
if "all_final_images" in st.session_state:
r_done.append(4)
r_current = (max(r_done) + 1) if r_done else 1
render_step_bar(
["Research Products", "Select Products", "Generate Designs", "Final Images"],
r_done,
r_current,
)
# ── Step 1: Research ──────────────────────────────────────────────────────
with st.expander("Step 1 Β· Research Products from Keywords", expanded=(r_current == 1)):
selected_keywords = st.multiselect(
"Select keywords to research",
options=KEYWORD_OPTIONS,
key="kw_dropdown",
help=f"Pick any of the {len(KEYWORD_OPTIONS)} jewellery keywords to research. "
"Selected keywords are combined with anything typed below.",
)
st.markdown("Or enter your own keywords β€” one per line (or comma-separated). Each keyword is searched independently.")
raw_keywords = st.text_area(
"Keywords",
placeholder=(
"layered necklace set\n"
"pearl choker necklace\n"
"minimal gold earrings\n"
"stackable rings\n"
"silver charm bracelet"
),
height=130,
label_visibility="collapsed",
key="kw_input",
)
combined_raw_keywords = "\n".join(selected_keywords + [raw_keywords])
col_btn, col_info = st.columns([1, 3])
with col_btn:
research_btn = st.button("Search Products", type="primary", use_container_width=True, key="btn_research")
with col_info:
if combined_raw_keywords.strip():
kws = parse_keywords(combined_raw_keywords)
total = len(kws) * count_per_keyword
st.caption(f"{len(kws)} keyword(s) Β· up to {total} products Β· {count_per_keyword}/keyword")
if research_btn:
keywords = parse_keywords(combined_raw_keywords)
if not keywords:
st.error("Please enter at least one keyword.")
else:
try:
with st.status(f"Researching {len(keywords)} keyword(s)…", expanded=True) as status_ui:
products = research_products_for_keywords(
keywords=keywords,
count_per_keyword=count_per_keyword,
)
st.session_state["researched_products"] = products
st.session_state["research_keywords_order"] = keywords
for key in ["processed_product_refs", "design_options_per_product",
"frozen_selected_campaigns", "all_final_images"]:
st.session_state.pop(key, None)
status_ui.update(label=f"Done β€” {len(products)} products found.", state="complete")
st.rerun()
except Exception as e:
st.error(f"Research failed: {e}")
# ── Step 2: Select Products ───────────────────────────────────────────────
if "researched_products" in st.session_state:
products = st.session_state["researched_products"]
step2_label = f"Step 2 Β· Select Products ({len(products)} found)"
with st.expander(step2_label, expanded=(r_current == 2)):
if not products:
st.warning("No products with usable images found. Try different keywords.")
else:
st.markdown(f"**{len(products)} products found.** Tick the ones you want to process.")
selected_product_ids: set = set()
cols_per_row = 4
# Group products by source keyword, preserving research order
kw_order = st.session_state.get("research_keywords_order", [])
kw_products: Dict[str, List] = {}
for p in products:
kw = p.get("source_keyword", "Other")
kw_products.setdefault(kw, []).append(p)
ordered_kws = [kw for kw in kw_order if kw in kw_products]
for kw in kw_products:
if kw not in ordered_kws:
ordered_kws.append(kw)
for kw in ordered_kws:
kw_prods = kw_products[kw]
st.markdown(f"**{kw}** β€” {len(kw_prods)} product(s)")
rows = [kw_prods[i : i + cols_per_row] for i in range(0, len(kw_prods), cols_per_row)]
for row in rows:
cols = st.columns(cols_per_row)
for col, product in zip(cols, row):
with col:
pid = product.get("id", "")
pname = product.get("product_name", "Unnamed")
brand = product.get("brand", "")
purl = product.get("product_url", "")
local_image_path = product.get("local_image_path")
signal = product.get("popularity_signal", "")
confidence = product.get("popularity_confidence", "")
why_popular = product.get("why_popular", "")
approx_price = product.get("approx_price", "")
currency = product.get("currency", "")
if local_image_path and Path(local_image_path).exists():
st.image(local_image_path, use_container_width=True)
checked = st.checkbox(
pname[:32] + ("…" if len(pname) > 32 else ""),
key=f"sel_{pid}",
help=f"{pname} by {brand}",
)
st.caption(brand)
if signal:
st.markdown(
popularity_badge_html(signal, confidence),
unsafe_allow_html=True,
)
if why_popular:
st.caption(why_popular[:90] + ("…" if len(why_popular) > 90 else ""))
row_parts = []
if approx_price:
row_parts.append(f"{currency} {approx_price}".strip())
if purl:
row_parts.append(f"[View β†—]({purl})")
if row_parts:
st.markdown(" Β· ".join(row_parts))
if checked:
selected_product_ids.add(pid)
st.markdown("---")
col_proc, col_cnt = st.columns([1, 3])
with col_proc:
process_btn = st.button(
f"Process {len(selected_product_ids)} Selected" if selected_product_ids else "Select products above",
type="primary",
use_container_width=True,
disabled=not selected_product_ids,
key="btn_process",
)
with col_cnt:
if selected_product_ids:
st.caption(f"{len(selected_product_ids)} product(s) selected")
if process_btn and selected_product_ids:
selected = [p for p in products if p.get("id") in selected_product_ids]
processed_refs = [
{"product": p, "image_path": p["local_image_path"]}
for p in selected
if p.get("local_image_path") and Path(p["local_image_path"]).exists()
]
st.session_state["processed_product_refs"] = processed_refs
for key in ["design_options_per_product", "frozen_selected_campaigns", "all_final_images"]:
st.session_state.pop(key, None)
st.rerun()
# ── Step 3: Generate Designs + Pick Direction ─────────────────────────────
if "processed_product_refs" in st.session_state:
processed_refs = st.session_state["processed_product_refs"]
with st.expander(
f"Step 3 Β· Generate & Choose Design Directions ({len(processed_refs)} product(s))",
expanded=(r_current == 3),
):
st.markdown(
f"**{len(processed_refs)} product(s)** ready. "
"Generate 6 design direction previews per product, then pick one direction per product."
)
if st.button("Generate Design Directions", type="primary", key="btn_gen_designs"):
design_options_per_product: Dict[str, Any] = {}
for item in processed_refs:
product = item["product"]
image_path = item["image_path"]
pname = product.get("product_name", "Product")
brand = product.get("brand", "")
pkey = slugify(f"{pname}_{brand}")
with st.status(f"Processing {pname}…", expanded=False) as s:
try:
s.update(label=f"Analysing {pname}…")
analysis = analyze_jewellery(image_path)
s.update(label=f"Building directions for {pname}…")
directions = generate_design_directions(analysis)
s.update(label=f"Rendering previews for {pname}…")
design_options = generate_six_campaign_images(
reference_images=[image_path],
analysis=analysis,
directions=directions,
)
design_options_per_product[pkey] = {
"product": product,
"image_path": image_path,
"analysis": analysis,
"design_options": design_options,
}
s.update(label=f"{pname} β€” directions ready", state="complete")
except Exception as e:
s.update(label=f"{pname} β€” failed", state="error")
st.error(f"{pname}: {e}")
st.session_state["design_options_per_product"] = design_options_per_product
for key in ["frozen_selected_campaigns", "all_final_images"]:
st.session_state.pop(key, None)
st.rerun()
if "design_options_per_product" in st.session_state:
design_opts = st.session_state["design_options_per_product"]
selected_campaigns: Dict[str, Any] = {}
for pkey, data in design_opts.items():
product = data["product"]
design_options = data["design_options"]
pname = product.get("product_name", pkey)
brand = product.get("brand", "")
st.markdown(f"#### {pname} β€” {brand}")
st.caption("Review the 6 design directions, then choose one below.")
d_cols = st.columns(3)
for idx, design in enumerate(design_options):
with d_cols[idx % 3]:
st.image(design["path"], use_container_width=True)
st.caption(design["name"])
chosen = st.radio(
"Choose direction",
options=[d["name"] for d in design_options],
key=f"radio_{pkey}",
horizontal=True,
)
selected_campaign = next(d for d in design_options if d["name"] == chosen)
selected_campaigns[pkey] = {
"campaign": selected_campaign,
"analysis": data["analysis"],
"product": product,
"image_path": data["image_path"],
}
st.divider()
st.markdown(f"**{len(selected_campaigns)} direction(s) selected.** Add any design instructions, then generate.")
user_prompt = st.text_area(
"Design instructions (optional)",
placeholder="e.g. use a rose gold finish, add sapphire accents, keep it minimal...",
help="Describe design changes you want β€” material, colour, stone, style tweaks.",
key="r_design_prompt",
)
if st.button("Generate Final Images", type="primary", key="btn_gen_final"):
st.session_state["frozen_selected_campaigns"] = {
k: dict(v) for k, v in selected_campaigns.items()
}
all_final_images: Dict[str, Any] = {}
frozen = st.session_state["frozen_selected_campaigns"]
for pkey, data in frozen.items():
product = data["product"]
campaign = data["campaign"]
analysis = data["analysis"]
pname = product.get("product_name", pkey)
with st.status(f"Generating final images for {pname}…", expanded=False) as s:
try:
final_images = generate_final_images(
reference_images=[campaign["path"]],
selected_campaign=campaign,
analysis=analysis,
user_prompt=user_prompt,
)
all_final_images[pkey] = {"product": product, "final_images": final_images}
s.update(label=f"{pname} β€” done", state="complete")
except Exception as e:
s.update(label=f"{pname} β€” failed", state="error")
st.error(f"{pname}: {e}")
st.session_state["all_final_images"] = all_final_images
st.rerun()
# ── Step 4: Final Images ──────────────────────────────────────────────────
if "all_final_images" in st.session_state:
with st.expander("Step 4 Β· Final Generated Images", expanded=True):
for pkey, data in st.session_state["all_final_images"].items():
product = data["product"]
final_images = data["final_images"]
pname = product.get("product_name", pkey)
brand = product.get("brand", "")
col_h, col_dl = st.columns([3, 1])
with col_h:
st.markdown(f"#### {pname}")
st.caption(brand)
with col_dl:
zip_path = str(OUTPUT_DIR / f"final_{pkey}.zip")
create_zip_from_images(final_images, zip_path)
with open(zip_path, "rb") as zf:
st.download_button(
"⬇ Download ZIP",
data=zf,
file_name=f"final_{pkey}.zip",
mime="application/zip",
key=f"dl_zip_{pkey}",
use_container_width=True,
)
flat_lay = [img for img in final_images if img["type"] == "Flat lay"]
model = [img for img in final_images if img["type"] == "Model"]
product_closeup = [img for img in final_images if img["type"] == "Product closeup"]
model_closeup = [img for img in final_images if img["type"] == "Model closeup"]
img_tabs = st.tabs(["Flat Lay", "Model", "Product Closeup", "Model Closeup"])
for img_tab, imgs, prefix in zip(
img_tabs,
[flat_lay, model, product_closeup, model_closeup],
["flat", "model", "pclose", "mclose"],
):
with img_tab:
if imgs:
cols = st.columns(min(len(imgs), 3))
for i, img in enumerate(imgs):
with cols[i % 3]:
show_image_with_download(img, f"{prefix}_{pkey}_{i}")
else:
st.caption("No images of this type.")
st.divider()
# ══════════════════════════════════════════════════════════════════════════════
# UPLOAD WORKFLOW
# ══════════════════════════════════════════════════════════════════════════════
with tab_upload:
u_done: List[int] = []
if "upload_ref_paths" in st.session_state:
u_done.append(1)
if "upload_campaign_options" in st.session_state:
u_done.append(2)
if "upload_final_images" in st.session_state:
u_done.append(3)
u_current = (max(u_done) + 1) if u_done else 1
render_step_bar(
["Upload Images", "Choose Direction", "Final Images"],
u_done,
u_current,
)
# ── Step 1: Upload ────────────────────────────────────────────────────────
with st.expander("Step 1 Β· Upload Reference Images", expanded=(u_current == 1)):
st.markdown(
f"Upload up to **{MAX_REFERENCES}** reference jewellery images. "
"The **first image** is used as the primary reference for AI analysis."
)
upload_files = st.file_uploader(
"Reference Images",
type=["png", "jpg", "jpeg", "webp"],
accept_multiple_files=True,
key="upload_flow2",
label_visibility="collapsed",
)
if upload_files:
upload_files = upload_files[:MAX_REFERENCES]
new_file_names = [f.name for f in upload_files]
if st.session_state.get("upload_file_names") != new_file_names:
upload_ref_paths = save_uploaded_files(upload_files)
st.session_state["upload_file_names"] = new_file_names
st.session_state["upload_ref_paths"] = upload_ref_paths
for key in ["upload_analysis", "upload_directions", "upload_campaign_options",
"upload_selected_campaign", "upload_final_images"]:
st.session_state.pop(key, None)
else:
upload_ref_paths = st.session_state.get("upload_ref_paths", [])
st.markdown(f"**{len(upload_ref_paths)} image(s) uploaded.**")
preview_cols = st.columns(min(len(upload_ref_paths), 5))
for idx, image_path in enumerate(upload_ref_paths):
with preview_cols[idx % 5]:
caption = "Primary Reference" if idx == 0 else f"Reference {idx + 1}"
st.image(image_path, caption=caption, use_container_width=True)
st.markdown("---")
if st.button("Generate Design Options", type="primary", key="btn_upload_gen"):
upload_primary = upload_ref_paths[0]
with st.status("Generating design options…", expanded=True) as s:
try:
s.update(label="Analysing jewellery image…")
analysis = analyze_jewellery(upload_primary)
st.session_state["upload_analysis"] = analysis
s.update(label="Generating design directions…")
directions = generate_design_directions(analysis)
st.session_state["upload_directions"] = directions
s.update(label="Rendering preview images…")
campaign_options = generate_six_campaign_images(
reference_images=upload_ref_paths,
analysis=analysis,
directions=directions,
)
st.session_state["upload_campaign_options"] = campaign_options
s.update(label="Design options ready!", state="complete")
st.rerun()
except Exception as e:
s.update(label="Generation failed.", state="error")
st.error(f"{e}")
# ── Step 2: Choose Direction ──────────────────────────────────────────────
if "upload_campaign_options" in st.session_state:
campaign_options = st.session_state["upload_campaign_options"]
with st.expander("Step 2 Β· Choose a Design Direction", expanded=(u_current == 2)):
st.markdown("Review the 6 design directions and select one to generate your final campaign images.")
d_cols = st.columns(3)
for idx, campaign in enumerate(campaign_options):
with d_cols[idx % 3]:
st.image(campaign["path"], use_container_width=True)
st.caption(campaign["name"])
selected_name = st.radio(
"Select direction",
options=[c["name"] for c in campaign_options],
key="upload_radio",
horizontal=True,
)
selected_campaign = next(c for c in campaign_options if c["name"] == selected_name)
st.session_state["upload_selected_campaign"] = selected_campaign
st.markdown("---")
col_prev, col_action = st.columns([1, 2])
with col_prev:
st.image(selected_campaign["path"], caption=f"Selected: {selected_campaign['name']}", use_container_width=True)
with col_action:
st.markdown(f"**Selected direction:** {selected_campaign['name']}")
upload_user_prompt = st.text_area(
"Design instructions (optional)",
placeholder="e.g. use a rose gold finish, add sapphire accents, keep it minimal...",
help="Describe design changes you want β€” material, colour, stone, style tweaks.",
key="u_design_prompt",
)
if st.button("Generate Final Images", type="primary", key="btn_upload_final"):
with st.status("Generating final images…", expanded=True) as s:
try:
final_images = generate_final_images(
reference_images=[selected_campaign["path"]],
selected_campaign=selected_campaign,
analysis=st.session_state["upload_analysis"],
user_prompt=upload_user_prompt,
)
st.session_state["upload_final_images"] = final_images
s.update(label="Final images ready!", state="complete")
st.rerun()
except Exception as e:
s.update(label="Generation failed.", state="error")
st.error(f"{e}")
# ── Step 3: Final Images ──────────────────────────────────────────────────
if "upload_final_images" in st.session_state:
final_images = st.session_state["upload_final_images"]
with st.expander("Step 3 Β· Final Generated Images", expanded=True):
col_h, col_dl = st.columns([3, 1])
with col_h:
st.markdown("#### Final Images")
st.caption("8 images across 4 photography styles.")
with col_dl:
zip_path = str(OUTPUT_DIR / "upload_final_generated_images.zip")
create_zip_from_images(final_images, zip_path)
with open(zip_path, "rb") as zf:
st.download_button(
"⬇ Download ZIP",
data=zf,
file_name="final_generated_images.zip",
mime="application/zip",
key="upload_dl_zip",
use_container_width=True,
)
flat_lay = [img for img in final_images if img["type"] == "Flat lay"]
model = [img for img in final_images if img["type"] == "Model"]
product_closeup = [img for img in final_images if img["type"] == "Product closeup"]
model_closeup = [img for img in final_images if img["type"] == "Model closeup"]
img_tabs = st.tabs(["Flat Lay", "Model", "Product Closeup", "Model Closeup"])
for img_tab, imgs, prefix in zip(
img_tabs,
[flat_lay, model, product_closeup, model_closeup],
["uflat", "umodel", "upclose", "umclose"],
):
with img_tab:
if imgs:
cols = st.columns(min(len(imgs), 3))
for i, img in enumerate(imgs):
with cols[i % 3]:
show_image_with_download(img, f"{prefix}_{i}")
else:
st.caption("No images of this type.")