File-To-Images / src /streamlit_app.py
userIdc2024's picture
Update src/streamlit_app.py
7b7c637 verified
import os
import io
import zipfile
import time
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, date
from typing import Dict, Any, List, Tuple, Optional
import json
import threading
from functools import lru_cache
from urllib.parse import urlparse
import requests
import streamlit as st
from pymongo import MongoClient
import boto3
import replicate # type: ignore
from uuid import uuid4
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("imagegen_app")
# ----------------------------
# Config / Constants
# ----------------------------
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
MONGO_URI = os.getenv("MONGO_URI")
MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
REQUEST_TIMEOUT = 30
RETRY_ATTEMPTS = 3
MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
LIBRARY_PAGE_SIZE = 200 # aggregated view
# Global throttle so we don't overload Replicate / network
# This caps total in-flight image generations across all prompts.
GLOBAL_CONCURRENCY = max(4, min(16, MAX_WORKERS))
_GEN_SEMAPHORE = threading.Semaphore(GLOBAL_CONCURRENCY)
MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
"imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
"imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
"nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
"qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"},
"seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"},
"recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"},
}
_thread_local = threading.local()
# ----------------------------
# Preflight/debug helpers
# ----------------------------
def show_env_warnings():
if not REPLICATE_API_TOKEN:
st.warning("Missing **REPLICATE_API_TOKEN** β€” generation will return β€˜No URLs’.", icon="⚠️")
for v in ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME", "NEW_BASE"]:
if not os.getenv(v):
st.info(f"Optional: {v} not set β†’ images won’t be copied to R2 (source URLs will be used).", icon="ℹ️")
if not MONGO_URI:
st.info("Optional: MONGO_URI not set β†’ results won’t be saved to the Creative Library.", icon="ℹ️")
# ----------------------------
# Clients (Replicate / Mongo / S3)
# ----------------------------
def get_replicate_client():
if not hasattr(_thread_local, "replicate_client"):
_thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
return _thread_local.replicate_client
def get_mongo_collection():
if not hasattr(_thread_local, "mongo_collection"):
if not MONGO_URI:
_thread_local.mongo_collection = None
return None
try:
client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
db = client[MONGO_DB]
coll = db[MONGO_COLLECTION]
client.admin.command("ping")
_thread_local.mongo_collection = coll
except Exception as e:
logger.error(f"MongoDB connection failed: {e}")
_thread_local.mongo_collection = None
return _thread_local.mongo_collection
def get_s3_client():
if not hasattr(_thread_local, "s3_client"):
required = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME","NEW_BASE"]
if any(not os.getenv(k) for k in required):
_thread_local.s3_client = None
return None
try:
_thread_local.s3_client = boto3.client(
"s3",
endpoint_url=os.getenv("R2_ENDPOINT"),
aws_access_key_id=os.getenv("R2_ACCESS_KEY"),
aws_secret_access_key=os.getenv("R2_SECRET_KEY"),
region_name="auto",
)
except Exception as e:
logger.error(f"S3 client init failed: {e}")
_thread_local.s3_client = None
return _thread_local.s3_client
# ----------------------------
# Core ops: R2 / Generate / Fetch
# ----------------------------
def upload_to_r2(image_bytes: bytes) -> Optional[str]:
s3 = get_s3_client()
if s3 is None:
return None
try:
filename = f"{uuid4().hex}.png"
file_key = f"adgenesis_image_file/balraaj/images/{filename}"
s3.put_object(
Bucket=os.getenv("R2_BUCKET_NAME"),
Key=file_key,
Body=image_bytes,
ContentType="image/png",
)
return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
except Exception as e:
logger.error(f"S3 upload failed: {e}")
return None
def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
"""
Returns: [image_url] from Replicate (or [])
"""
client = get_replicate_client()
if client is None:
return []
model = MODEL_REGISTRY.get(model_key)
if not model:
return []
try:
output = client.run(model["id"], input={"prompt": prompt, model["param_name"]: aspect_ratio})
if isinstance(output, list) and output:
return [str(output[0])]
if isinstance(output, str):
return [output]
if hasattr(output, "url"):
return [getattr(output, "url")]
return []
except Exception as e:
logger.error(f"Replicate error: {e}")
return []
def fetch_bytes(url: str) -> Optional[bytes]:
for attempt in range(RETRY_ATTEMPTS):
try:
r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
r.raise_for_status()
return r.content
except Exception:
if attempt == RETRY_ATTEMPTS - 1:
return None
time.sleep(1)
return None
def _generate_single_image_full(prompt: str, model: str, aspect: str) -> Tuple[Optional[str], Optional[str]]:
"""
Throttled: generate -> fetch -> upload -> return final_url or src_url
Returns (final_url_or_src, error_text)
"""
with _GEN_SEMAPHORE: # throttle
urls = generate_one(model, prompt, aspect)
if not urls:
return None, "No URLs"
src = urls[0]
data = fetch_bytes(src)
if data is None:
return None, "Fetch failed"
r2 = upload_to_r2(data)
return (r2 or src), None
def process_prompt_batch(idx: int, prompt: str, model: str, aspect: str, num_images: int) -> Dict[str, Any]:
"""
Generate N images for one prompt.
Uses a small threadpool but still obeys GLOBAL throttle.
"""
n = max(1, int(num_images))
urls: List[str] = []
errs: List[str] = []
# limit inner parallelism to keep global pressure sane
inner_workers = min( min(4, n), GLOBAL_CONCURRENCY )
if n == 1:
u, e = _generate_single_image_full(prompt, model, aspect)
if u: urls.append(u)
if e: errs.append(e)
else:
with ThreadPoolExecutor(max_workers=inner_workers) as ex:
futures = [ex.submit(_generate_single_image_full, prompt, model, aspect) for _ in range(n)]
for fut in as_completed(futures):
u, e = fut.result()
if u: urls.append(u)
if e: errs.append(e)
return {"idx": idx, "urls": urls, "errors": errs}
# ----------------------------
# Persistence
# ----------------------------
def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
coll = get_mongo_collection()
if coll is None:
return None
try:
return str(coll.insert_one({
"model": model,
"aspect_ratio": aspect,
"prompt": prompt,
"urls": urls,
"num_images": len(urls),
"lob": "balraaj",
"created_at": datetime.utcnow(),
}).inserted_id)
except Exception as e:
logger.error(f"Mongo insert failed: {e}")
return None
@st.cache_data(ttl=300)
def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
coll = get_mongo_collection()
if coll is None:
return []
try:
return list(coll.find(
{"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"}
).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE))
except Exception as e:
logger.error(f"Mongo query failed: {e}")
return []
# ----------------------------
# Gallery helpers
# ----------------------------
def display_gallery(urls: List[str]):
if not urls:
return
cols = st.columns(4)
for i, u in enumerate(urls):
with cols[i % 4]:
try:
img = fetch_bytes(u)
if img:
st.image(img, use_container_width=True)
except Exception:
st.error("Failed to load image")
def bulk_zip(urls: List[str]):
if not urls:
return
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z:
for i, u in enumerate(urls, 1):
data = fetch_bytes(u)
if data:
z.writestr(f"image_{i}.png", data)
buf.seek(0)
st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
# ----------------------------
# JSON loader & batch run (parallel across prompts)
# ----------------------------
def load_json(file) -> List[str]:
data = json.loads(file.getvalue().decode("utf-8"))
if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
raise ValueError("JSON must be { 'prompts': [ '...','...' ] } (strings only).")
out = [p.strip() for p in data["prompts"] if isinstance(p, str) and p.strip()]
if not out:
raise ValueError("No valid prompts found.")
return out
def run_batch(prompts: List[str], model: str, aspect: str, num_images: int):
total = len(prompts)
rows = [st.empty() for _ in prompts]
progress = st.progress(0.0, text=f"0/{total}")
all_urls: List[str] = []
done = 0
outer_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
with st.spinner("Generating images..."):
with ThreadPoolExecutor(max_workers=outer_workers) as ex:
futs = {ex.submit(process_prompt_batch, i, p, model, aspect, num_images): i for i, p in enumerate(prompts, 1)}
for fut in as_completed(futs):
i = futs[fut]
try:
res = fut.result()
except Exception as e:
res = {"idx": i, "urls": [], "errors": [str(e)]}
if res["urls"]:
save_record(model, aspect, prompts[i-1], res["urls"])
rows[i-1].success(f"Prompt {i}/{total} βœ“ ({len(res['urls'])} images)")
all_urls.extend(res["urls"])
else:
err_msg = ", ".join(res.get("errors") or ["No images"])
rows[i-1].error(f"Prompt {i}/{total} βœ— ({err_msg})")
done += 1
progress.progress(done/total, text=f"{done}/{total}")
if all_urls:
st.subheader("Gallery")
display_gallery(all_urls)
bulk_zip(all_urls)
else:
st.info("No images to display.")
# ----------------------------
# Pages
# ----------------------------
def render_json_page():
st.subheader("Generate from JSON")
show_env_warnings()
up = st.file_uploader("Upload JSON", type=["json"])
c1, c2, c3 = st.columns([1, 1, 1])
with c1:
model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
with c2:
aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
with c3:
num_images = st.slider("Images per prompt", min_value=1, max_value=50, value=1, step=1)
if up:
try:
prompts = load_json(up)
with st.expander("Preview prompts", expanded=False):
st.json(prompts)
if st.button("Generate", type="primary", use_container_width=True):
run_batch(prompts, model, aspect, num_images)
except Exception as e:
st.error(str(e))
else:
st.caption('Expected: { "prompts": ["prompt 1", "prompt 2", ...] }')
def render_library_page():
st.subheader("Creative Library")
show_env_warnings()
today = datetime.utcnow().date()
start = st.date_input("Start date", today - timedelta(days=30))
end = st.date_input("End date", today)
start_dt = datetime.combine(start, datetime.min.time())
end_dt = datetime.combine(end + timedelta(days=1), datetime.min.time())
records = query_records(start_dt, end_dt)
all_urls: List[str] = []
for r in records:
all_urls.extend(r.get("urls", []) or [])
if all_urls:
st.caption(f"Showing {len(all_urls)} images from {len(records)} records")
display_gallery(all_urls)
bulk_zip(all_urls)
else:
st.info("No records found in the selected range.")
# ----------------------------
# Auth
# ----------------------------
@lru_cache(maxsize=1)
def check_token(tok: str) -> bool:
acc = os.getenv("ACCESS_TOKEN")
if not acc:
# No token configured β†’ allow dev access
return True
return tok == acc
def main_app():
st.set_page_config(page_title="File-to-Image β€’ Creative Library", layout="wide")
st.title("File-to-Image Generator")
page = st.sidebar.radio("Menu", ["Generate from JSON", "Creative Library"])
if page == "Generate from JSON":
render_json_page()
else:
render_library_page()
def main():
if not st.session_state.get("auth"):
st.markdown("## Access Required")
token = st.text_input("Enter Access Token", type="password")
if st.button("Unlock App"):
if check_token(token):
st.session_state["auth"] = True
st.rerun()
else:
st.error("Invalid token.")
else:
main_app()
if __name__ == "__main__":
main()