import os import io import zipfile import time import logging from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, date from typing import Dict, Any, List, Tuple, Optional import json import threading from functools import lru_cache from urllib.parse import urlparse import requests import streamlit as st from pymongo import MongoClient import boto3 import replicate # type: ignore from uuid import uuid4 from dotenv import load_dotenv load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger("imagegen_app") # ---------------------------- # Config / Constants # ---------------------------- REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") MONGO_URI = os.getenv("MONGO_URI") MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text") MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives") REQUEST_TIMEOUT = 30 RETRY_ATTEMPTS = 3 MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4) LIBRARY_PAGE_SIZE = 200 # aggregated view # Global throttle so we don't overload Replicate / network # This caps total in-flight image generations across all prompts. GLOBAL_CONCURRENCY = max(4, min(16, MAX_WORKERS)) _GEN_SEMAPHORE = threading.Semaphore(GLOBAL_CONCURRENCY) MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"}, "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"}, "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"}, } _thread_local = threading.local() # ---------------------------- # Preflight/debug helpers # ---------------------------- def show_env_warnings(): if not REPLICATE_API_TOKEN: st.warning("Missing **REPLICATE_API_TOKEN** — generation will return ‘No URLs’.", icon="⚠️") for v in ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME", "NEW_BASE"]: if not os.getenv(v): st.info(f"Optional: {v} not set → images won’t be copied to R2 (source URLs will be used).", icon="ℹ️") if not MONGO_URI: st.info("Optional: MONGO_URI not set → results won’t be saved to the Creative Library.", icon="ℹ️") # ---------------------------- # Clients (Replicate / Mongo / S3) # ---------------------------- def get_replicate_client(): if not hasattr(_thread_local, "replicate_client"): _thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None return _thread_local.replicate_client def get_mongo_collection(): if not hasattr(_thread_local, "mongo_collection"): if not MONGO_URI: _thread_local.mongo_collection = None return None try: client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000) db = client[MONGO_DB] coll = db[MONGO_COLLECTION] client.admin.command("ping") _thread_local.mongo_collection = coll except Exception as e: logger.error(f"MongoDB connection failed: {e}") _thread_local.mongo_collection = None return _thread_local.mongo_collection def get_s3_client(): if not hasattr(_thread_local, "s3_client"): required = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"] if any(not os.getenv(k) for k in required): _thread_local.s3_client = None return None try: _thread_local.s3_client = boto3.client( "s3", endpoint_url=os.getenv("R2_ENDPOINT"), aws_access_key_id=os.getenv("R2_ACCESS_KEY"), aws_secret_access_key=os.getenv("R2_SECRET_KEY"), region_name="auto", ) except Exception as e: logger.error(f"S3 client init failed: {e}") _thread_local.s3_client = None return _thread_local.s3_client # ---------------------------- # Core ops: R2 / Generate / Fetch # ---------------------------- def upload_to_r2(image_bytes: bytes) -> Optional[str]: s3 = get_s3_client() if s3 is None: return None try: filename = f"{uuid4().hex}.png" file_key = f"adgenesis_image_file/balraaj/images/{filename}" s3.put_object( Bucket=os.getenv("R2_BUCKET_NAME"), Key=file_key, Body=image_bytes, ContentType="image/png", ) return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}" except Exception as e: logger.error(f"S3 upload failed: {e}") return None def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]: """ Returns: [image_url] from Replicate (or []) """ client = get_replicate_client() if client is None: return [] model = MODEL_REGISTRY.get(model_key) if not model: return [] try: output = client.run(model["id"], input={"prompt": prompt, model["param_name"]: aspect_ratio}) if isinstance(output, list) and output: return [str(output[0])] if isinstance(output, str): return [output] if hasattr(output, "url"): return [getattr(output, "url")] return [] except Exception as e: logger.error(f"Replicate error: {e}") return [] def fetch_bytes(url: str) -> Optional[bytes]: for attempt in range(RETRY_ATTEMPTS): try: r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True) r.raise_for_status() return r.content except Exception: if attempt == RETRY_ATTEMPTS - 1: return None time.sleep(1) return None def _generate_single_image_full(prompt: str, model: str, aspect: str) -> Tuple[Optional[str], Optional[str]]: """ Throttled: generate -> fetch -> upload -> return final_url or src_url Returns (final_url_or_src, error_text) """ with _GEN_SEMAPHORE: # throttle urls = generate_one(model, prompt, aspect) if not urls: return None, "No URLs" src = urls[0] data = fetch_bytes(src) if data is None: return None, "Fetch failed" r2 = upload_to_r2(data) return (r2 or src), None def process_prompt_batch(idx: int, prompt: str, model: str, aspect: str, num_images: int) -> Dict[str, Any]: """ Generate N images for one prompt. Uses a small threadpool but still obeys GLOBAL throttle. """ n = max(1, int(num_images)) urls: List[str] = [] errs: List[str] = [] # limit inner parallelism to keep global pressure sane inner_workers = min( min(4, n), GLOBAL_CONCURRENCY ) if n == 1: u, e = _generate_single_image_full(prompt, model, aspect) if u: urls.append(u) if e: errs.append(e) else: with ThreadPoolExecutor(max_workers=inner_workers) as ex: futures = [ex.submit(_generate_single_image_full, prompt, model, aspect) for _ in range(n)] for fut in as_completed(futures): u, e = fut.result() if u: urls.append(u) if e: errs.append(e) return {"idx": idx, "urls": urls, "errors": errs} # ---------------------------- # Persistence # ---------------------------- def save_record(model: str, aspect: str, prompt: str, urls: List[str]): coll = get_mongo_collection() if coll is None: return None try: return str(coll.insert_one({ "model": model, "aspect_ratio": aspect, "prompt": prompt, "urls": urls, "num_images": len(urls), "lob": "balraaj", "created_at": datetime.utcnow(), }).inserted_id) except Exception as e: logger.error(f"Mongo insert failed: {e}") return None @st.cache_data(ttl=300) def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]: coll = get_mongo_collection() if coll is None: return [] try: return list(coll.find( {"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"} ).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE)) except Exception as e: logger.error(f"Mongo query failed: {e}") return [] # ---------------------------- # Gallery helpers # ---------------------------- def display_gallery(urls: List[str]): if not urls: return cols = st.columns(4) for i, u in enumerate(urls): with cols[i % 4]: try: img = fetch_bytes(u) if img: st.image(img, use_container_width=True) except Exception: st.error("Failed to load image") def bulk_zip(urls: List[str]): if not urls: return buf = io.BytesIO() with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z: for i, u in enumerate(urls, 1): data = fetch_bytes(u) if data: z.writestr(f"image_{i}.png", data) buf.seek(0) st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True) # ---------------------------- # JSON loader & batch run (parallel across prompts) # ---------------------------- def load_json(file) -> List[str]: data = json.loads(file.getvalue().decode("utf-8")) if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list): raise ValueError("JSON must be { 'prompts': [ '...','...' ] } (strings only).") out = [p.strip() for p in data["prompts"] if isinstance(p, str) and p.strip()] if not out: raise ValueError("No valid prompts found.") return out def run_batch(prompts: List[str], model: str, aspect: str, num_images: int): total = len(prompts) rows = [st.empty() for _ in prompts] progress = st.progress(0.0, text=f"0/{total}") all_urls: List[str] = [] done = 0 outer_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2))) with st.spinner("Generating images..."): with ThreadPoolExecutor(max_workers=outer_workers) as ex: futs = {ex.submit(process_prompt_batch, i, p, model, aspect, num_images): i for i, p in enumerate(prompts, 1)} for fut in as_completed(futs): i = futs[fut] try: res = fut.result() except Exception as e: res = {"idx": i, "urls": [], "errors": [str(e)]} if res["urls"]: save_record(model, aspect, prompts[i-1], res["urls"]) rows[i-1].success(f"Prompt {i}/{total} ✓ ({len(res['urls'])} images)") all_urls.extend(res["urls"]) else: err_msg = ", ".join(res.get("errors") or ["No images"]) rows[i-1].error(f"Prompt {i}/{total} ✗ ({err_msg})") done += 1 progress.progress(done/total, text=f"{done}/{total}") if all_urls: st.subheader("Gallery") display_gallery(all_urls) bulk_zip(all_urls) else: st.info("No images to display.") # ---------------------------- # Pages # ---------------------------- def render_json_page(): st.subheader("Generate from JSON") show_env_warnings() up = st.file_uploader("Upload JSON", type=["json"]) c1, c2, c3 = st.columns([1, 1, 1]) with c1: model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0) with c2: aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0) with c3: num_images = st.slider("Images per prompt", min_value=1, max_value=50, value=1, step=1) if up: try: prompts = load_json(up) with st.expander("Preview prompts", expanded=False): st.json(prompts) if st.button("Generate", type="primary", use_container_width=True): run_batch(prompts, model, aspect, num_images) except Exception as e: st.error(str(e)) else: st.caption('Expected: { "prompts": ["prompt 1", "prompt 2", ...] }') def render_library_page(): st.subheader("Creative Library") show_env_warnings() today = datetime.utcnow().date() start = st.date_input("Start date", today - timedelta(days=30)) end = st.date_input("End date", today) start_dt = datetime.combine(start, datetime.min.time()) end_dt = datetime.combine(end + timedelta(days=1), datetime.min.time()) records = query_records(start_dt, end_dt) all_urls: List[str] = [] for r in records: all_urls.extend(r.get("urls", []) or []) if all_urls: st.caption(f"Showing {len(all_urls)} images from {len(records)} records") display_gallery(all_urls) bulk_zip(all_urls) else: st.info("No records found in the selected range.") # ---------------------------- # Auth # ---------------------------- @lru_cache(maxsize=1) def check_token(tok: str) -> bool: acc = os.getenv("ACCESS_TOKEN") if not acc: # No token configured → allow dev access return True return tok == acc def main_app(): st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide") st.title("File-to-Image Generator") page = st.sidebar.radio("Menu", ["Generate from JSON", "Creative Library"]) if page == "Generate from JSON": render_json_page() else: render_library_page() def main(): if not st.session_state.get("auth"): st.markdown("## Access Required") token = st.text_input("Enter Access Token", type="password") if st.button("Unlock App"): if check_token(token): st.session_state["auth"] = True st.rerun() else: st.error("Invalid token.") else: main_app() if __name__ == "__main__": main()