Spaces:
Running
Running
| import os | |
| import io | |
| import zipfile | |
| import time | |
| import logging | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from datetime import datetime, timedelta, date | |
| from typing import Dict, Any, List, Tuple, Optional | |
| import json | |
| import threading | |
| from functools import lru_cache | |
| from urllib.parse import urlparse | |
| import requests | |
| import streamlit as st | |
| from pymongo import MongoClient | |
| import boto3 | |
| import replicate # type: ignore | |
| from uuid import uuid4 | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("imagegen_app") | |
| # ---------------------------- | |
| # Config / Constants | |
| # ---------------------------- | |
| REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text") | |
| MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives") | |
| REQUEST_TIMEOUT = 30 | |
| RETRY_ATTEMPTS = 3 | |
| MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4) | |
| LIBRARY_PAGE_SIZE = 200 # aggregated view | |
| # Global throttle so we don't overload Replicate / network | |
| # This caps total in-flight image generations across all prompts. | |
| GLOBAL_CONCURRENCY = max(4, min(16, MAX_WORKERS)) | |
| _GEN_SEMAPHORE = threading.Semaphore(GLOBAL_CONCURRENCY) | |
| MODEL_REGISTRY: Dict[str, Dict[str, Any]] = { | |
| "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, | |
| "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, | |
| "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"}, | |
| "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"}, | |
| "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"}, | |
| "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"}, | |
| } | |
| _thread_local = threading.local() | |
| # ---------------------------- | |
| # Preflight/debug helpers | |
| # ---------------------------- | |
| def show_env_warnings(): | |
| if not REPLICATE_API_TOKEN: | |
| st.warning("Missing **REPLICATE_API_TOKEN** β generation will return βNo URLsβ.", icon="β οΈ") | |
| for v in ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME", "NEW_BASE"]: | |
| if not os.getenv(v): | |
| st.info(f"Optional: {v} not set β images wonβt be copied to R2 (source URLs will be used).", icon="βΉοΈ") | |
| if not MONGO_URI: | |
| st.info("Optional: MONGO_URI not set β results wonβt be saved to the Creative Library.", icon="βΉοΈ") | |
| # ---------------------------- | |
| # Clients (Replicate / Mongo / S3) | |
| # ---------------------------- | |
| def get_replicate_client(): | |
| if not hasattr(_thread_local, "replicate_client"): | |
| _thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None | |
| return _thread_local.replicate_client | |
| def get_mongo_collection(): | |
| if not hasattr(_thread_local, "mongo_collection"): | |
| if not MONGO_URI: | |
| _thread_local.mongo_collection = None | |
| return None | |
| try: | |
| client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000) | |
| db = client[MONGO_DB] | |
| coll = db[MONGO_COLLECTION] | |
| client.admin.command("ping") | |
| _thread_local.mongo_collection = coll | |
| except Exception as e: | |
| logger.error(f"MongoDB connection failed: {e}") | |
| _thread_local.mongo_collection = None | |
| return _thread_local.mongo_collection | |
| def get_s3_client(): | |
| if not hasattr(_thread_local, "s3_client"): | |
| required = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"] | |
| if any(not os.getenv(k) for k in required): | |
| _thread_local.s3_client = None | |
| return None | |
| try: | |
| _thread_local.s3_client = boto3.client( | |
| "s3", | |
| endpoint_url=os.getenv("R2_ENDPOINT"), | |
| aws_access_key_id=os.getenv("R2_ACCESS_KEY"), | |
| aws_secret_access_key=os.getenv("R2_SECRET_KEY"), | |
| region_name="auto", | |
| ) | |
| except Exception as e: | |
| logger.error(f"S3 client init failed: {e}") | |
| _thread_local.s3_client = None | |
| return _thread_local.s3_client | |
| # ---------------------------- | |
| # Core ops: R2 / Generate / Fetch | |
| # ---------------------------- | |
| def upload_to_r2(image_bytes: bytes) -> Optional[str]: | |
| s3 = get_s3_client() | |
| if s3 is None: | |
| return None | |
| try: | |
| filename = f"{uuid4().hex}.png" | |
| file_key = f"adgenesis_image_file/balraaj/images/{filename}" | |
| s3.put_object( | |
| Bucket=os.getenv("R2_BUCKET_NAME"), | |
| Key=file_key, | |
| Body=image_bytes, | |
| ContentType="image/png", | |
| ) | |
| return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}" | |
| except Exception as e: | |
| logger.error(f"S3 upload failed: {e}") | |
| return None | |
| def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]: | |
| """ | |
| Returns: [image_url] from Replicate (or []) | |
| """ | |
| client = get_replicate_client() | |
| if client is None: | |
| return [] | |
| model = MODEL_REGISTRY.get(model_key) | |
| if not model: | |
| return [] | |
| try: | |
| output = client.run(model["id"], input={"prompt": prompt, model["param_name"]: aspect_ratio}) | |
| if isinstance(output, list) and output: | |
| return [str(output[0])] | |
| if isinstance(output, str): | |
| return [output] | |
| if hasattr(output, "url"): | |
| return [getattr(output, "url")] | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Replicate error: {e}") | |
| return [] | |
| def fetch_bytes(url: str) -> Optional[bytes]: | |
| for attempt in range(RETRY_ATTEMPTS): | |
| try: | |
| r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True) | |
| r.raise_for_status() | |
| return r.content | |
| except Exception: | |
| if attempt == RETRY_ATTEMPTS - 1: | |
| return None | |
| time.sleep(1) | |
| return None | |
| def _generate_single_image_full(prompt: str, model: str, aspect: str) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Throttled: generate -> fetch -> upload -> return final_url or src_url | |
| Returns (final_url_or_src, error_text) | |
| """ | |
| with _GEN_SEMAPHORE: # throttle | |
| urls = generate_one(model, prompt, aspect) | |
| if not urls: | |
| return None, "No URLs" | |
| src = urls[0] | |
| data = fetch_bytes(src) | |
| if data is None: | |
| return None, "Fetch failed" | |
| r2 = upload_to_r2(data) | |
| return (r2 or src), None | |
| def process_prompt_batch(idx: int, prompt: str, model: str, aspect: str, num_images: int) -> Dict[str, Any]: | |
| """ | |
| Generate N images for one prompt. | |
| Uses a small threadpool but still obeys GLOBAL throttle. | |
| """ | |
| n = max(1, int(num_images)) | |
| urls: List[str] = [] | |
| errs: List[str] = [] | |
| # limit inner parallelism to keep global pressure sane | |
| inner_workers = min( min(4, n), GLOBAL_CONCURRENCY ) | |
| if n == 1: | |
| u, e = _generate_single_image_full(prompt, model, aspect) | |
| if u: urls.append(u) | |
| if e: errs.append(e) | |
| else: | |
| with ThreadPoolExecutor(max_workers=inner_workers) as ex: | |
| futures = [ex.submit(_generate_single_image_full, prompt, model, aspect) for _ in range(n)] | |
| for fut in as_completed(futures): | |
| u, e = fut.result() | |
| if u: urls.append(u) | |
| if e: errs.append(e) | |
| return {"idx": idx, "urls": urls, "errors": errs} | |
| # ---------------------------- | |
| # Persistence | |
| # ---------------------------- | |
| def save_record(model: str, aspect: str, prompt: str, urls: List[str]): | |
| coll = get_mongo_collection() | |
| if coll is None: | |
| return None | |
| try: | |
| return str(coll.insert_one({ | |
| "model": model, | |
| "aspect_ratio": aspect, | |
| "prompt": prompt, | |
| "urls": urls, | |
| "num_images": len(urls), | |
| "lob": "balraaj", | |
| "created_at": datetime.utcnow(), | |
| }).inserted_id) | |
| except Exception as e: | |
| logger.error(f"Mongo insert failed: {e}") | |
| return None | |
| def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]: | |
| coll = get_mongo_collection() | |
| if coll is None: | |
| return [] | |
| try: | |
| return list(coll.find( | |
| {"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"} | |
| ).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE)) | |
| except Exception as e: | |
| logger.error(f"Mongo query failed: {e}") | |
| return [] | |
| # ---------------------------- | |
| # Gallery helpers | |
| # ---------------------------- | |
| def display_gallery(urls: List[str]): | |
| if not urls: | |
| return | |
| cols = st.columns(4) | |
| for i, u in enumerate(urls): | |
| with cols[i % 4]: | |
| try: | |
| img = fetch_bytes(u) | |
| if img: | |
| st.image(img, use_container_width=True) | |
| except Exception: | |
| st.error("Failed to load image") | |
| def bulk_zip(urls: List[str]): | |
| if not urls: | |
| return | |
| buf = io.BytesIO() | |
| with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z: | |
| for i, u in enumerate(urls, 1): | |
| data = fetch_bytes(u) | |
| if data: | |
| z.writestr(f"image_{i}.png", data) | |
| buf.seek(0) | |
| st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True) | |
| # ---------------------------- | |
| # JSON loader & batch run (parallel across prompts) | |
| # ---------------------------- | |
| def load_json(file) -> List[str]: | |
| data = json.loads(file.getvalue().decode("utf-8")) | |
| if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list): | |
| raise ValueError("JSON must be { 'prompts': [ '...','...' ] } (strings only).") | |
| out = [p.strip() for p in data["prompts"] if isinstance(p, str) and p.strip()] | |
| if not out: | |
| raise ValueError("No valid prompts found.") | |
| return out | |
| def run_batch(prompts: List[str], model: str, aspect: str, num_images: int): | |
| total = len(prompts) | |
| rows = [st.empty() for _ in prompts] | |
| progress = st.progress(0.0, text=f"0/{total}") | |
| all_urls: List[str] = [] | |
| done = 0 | |
| outer_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2))) | |
| with st.spinner("Generating images..."): | |
| with ThreadPoolExecutor(max_workers=outer_workers) as ex: | |
| futs = {ex.submit(process_prompt_batch, i, p, model, aspect, num_images): i for i, p in enumerate(prompts, 1)} | |
| for fut in as_completed(futs): | |
| i = futs[fut] | |
| try: | |
| res = fut.result() | |
| except Exception as e: | |
| res = {"idx": i, "urls": [], "errors": [str(e)]} | |
| if res["urls"]: | |
| save_record(model, aspect, prompts[i-1], res["urls"]) | |
| rows[i-1].success(f"Prompt {i}/{total} β ({len(res['urls'])} images)") | |
| all_urls.extend(res["urls"]) | |
| else: | |
| err_msg = ", ".join(res.get("errors") or ["No images"]) | |
| rows[i-1].error(f"Prompt {i}/{total} β ({err_msg})") | |
| done += 1 | |
| progress.progress(done/total, text=f"{done}/{total}") | |
| if all_urls: | |
| st.subheader("Gallery") | |
| display_gallery(all_urls) | |
| bulk_zip(all_urls) | |
| else: | |
| st.info("No images to display.") | |
| # ---------------------------- | |
| # Pages | |
| # ---------------------------- | |
| def render_json_page(): | |
| st.subheader("Generate from JSON") | |
| show_env_warnings() | |
| up = st.file_uploader("Upload JSON", type=["json"]) | |
| c1, c2, c3 = st.columns([1, 1, 1]) | |
| with c1: | |
| model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0) | |
| with c2: | |
| aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0) | |
| with c3: | |
| num_images = st.slider("Images per prompt", min_value=1, max_value=50, value=1, step=1) | |
| if up: | |
| try: | |
| prompts = load_json(up) | |
| with st.expander("Preview prompts", expanded=False): | |
| st.json(prompts) | |
| if st.button("Generate", type="primary", use_container_width=True): | |
| run_batch(prompts, model, aspect, num_images) | |
| except Exception as e: | |
| st.error(str(e)) | |
| else: | |
| st.caption('Expected: { "prompts": ["prompt 1", "prompt 2", ...] }') | |
| def render_library_page(): | |
| st.subheader("Creative Library") | |
| show_env_warnings() | |
| today = datetime.utcnow().date() | |
| start = st.date_input("Start date", today - timedelta(days=30)) | |
| end = st.date_input("End date", today) | |
| start_dt = datetime.combine(start, datetime.min.time()) | |
| end_dt = datetime.combine(end + timedelta(days=1), datetime.min.time()) | |
| records = query_records(start_dt, end_dt) | |
| all_urls: List[str] = [] | |
| for r in records: | |
| all_urls.extend(r.get("urls", []) or []) | |
| if all_urls: | |
| st.caption(f"Showing {len(all_urls)} images from {len(records)} records") | |
| display_gallery(all_urls) | |
| bulk_zip(all_urls) | |
| else: | |
| st.info("No records found in the selected range.") | |
| # ---------------------------- | |
| # Auth | |
| # ---------------------------- | |
| def check_token(tok: str) -> bool: | |
| acc = os.getenv("ACCESS_TOKEN") | |
| if not acc: | |
| # No token configured β allow dev access | |
| return True | |
| return tok == acc | |
| def main_app(): | |
| st.set_page_config(page_title="File-to-Image β’ Creative Library", layout="wide") | |
| st.title("File-to-Image Generator") | |
| page = st.sidebar.radio("Menu", ["Generate from JSON", "Creative Library"]) | |
| if page == "Generate from JSON": | |
| render_json_page() | |
| else: | |
| render_library_page() | |
| def main(): | |
| if not st.session_state.get("auth"): | |
| st.markdown("## Access Required") | |
| token = st.text_input("Enter Access Token", type="password") | |
| if st.button("Unlock App"): | |
| if check_token(token): | |
| st.session_state["auth"] = True | |
| st.rerun() | |
| else: | |
| st.error("Invalid token.") | |
| else: | |
| main_app() | |
| if __name__ == "__main__": | |
| main() | |