userIdc2024 commited on
Commit
b449ae1
·
verified ·
1 Parent(s): 16b22f0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +585 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,587 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import zipfile
4
+ import replicate
5
+ import time
6
+ import logging
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from datetime import datetime, timedelta, date
9
+ from typing import Dict, Any, List, Tuple, Optional, Union
10
+ import requests
11
  import streamlit as st
12
+ from pymongo import MongoClient
13
+ import boto3
14
+ from uuid import uuid4
15
+ from dotenv import load_dotenv
16
+ from urllib.parse import urlparse
17
+ import threading
18
+ from functools import lru_cache
19
+ import json
20
 
21
+ load_dotenv()
22
+
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger("imagegen_app")
25
+
26
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
27
+ MONGO_URI = os.getenv("MONGO_URI")
28
+ MONGO_DB = os.getenv("MONGO_DB", "adgenesis_image_text")
29
+ MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
30
+ MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
31
+ REQUEST_TIMEOUT = 30
32
+ RETRY_ATTEMPTS = 3
33
+ LIBRARY_PAGE_SIZE = 20
34
+
35
+ MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
36
+ "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
37
+ "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
38
+ "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
39
+ "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3", "3:2", "2:3"],"param_name": "aspect_ratio"},
40
+ "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3", "3:2", "2:3", "21:9"],"param_name": "aspect_ratio"},
41
+ "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1", "4:3", "3:4", "3:2", "2:3", "16:9", "9:16", "1:2", "2:1", "7:5", "5:7", "4:5", "5:4", "3:5", "5:3"],"param_name": "aspect_ratio"},
42
+ "photon": {"id": "luma/photon","aspect_ratios": ["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],"param_name": "aspect_ratio"},
43
+ "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3", "3:1", "1:2", "2:1", "9:16", "16:9", "10:16", "16:10", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1"],"param_name": "aspect_ratio"},
44
+ }
45
+
46
+ _thread_local = threading.local()
47
+
48
+ def get_mongo_collection():
49
+ if not hasattr(_thread_local, 'mongo_collection'):
50
+ if not MONGO_URI:
51
+ _thread_local.mongo_collection = None
52
+ return None
53
+ try:
54
+ client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
55
+ db = client[MONGO_DB]
56
+ collection = db[MONGO_COLLECTION]
57
+ client.admin.command('ping')
58
+ _thread_local.mongo_collection = collection
59
+ except Exception as e:
60
+ logger.error(f"MongoDB connection failed: {e}")
61
+ _thread_local.mongo_collection = None
62
+ return _thread_local.mongo_collection
63
+
64
+ def get_s3_client():
65
+ if not hasattr(_thread_local, 's3_client'):
66
+ required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
67
+ missing = [var for var in required_vars if not os.getenv(var)]
68
+ if missing:
69
+ _thread_local.s3_client = None
70
+ return None
71
+ try:
72
+ _thread_local.s3_client = boto3.client(
73
+ "s3",
74
+ endpoint_url=os.getenv("R2_ENDPOINT"),
75
+ aws_access_key_id=os.getenv("R2_ACCESS_KEY"),
76
+ aws_secret_access_key=os.getenv("R2_SECRET_KEY"),
77
+ region_name="auto",
78
+ )
79
+ except Exception as e:
80
+ logger.error(f"S3 client initialization failed: {e}")
81
+ _thread_local.s3_client = None
82
+ return _thread_local.s3_client
83
+
84
+ @lru_cache(maxsize=128)
85
+ def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
86
+ return MODEL_REGISTRY.get(model_key)
87
+
88
+ def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
89
+ s3_client = get_s3_client()
90
+ if not s3_client:
91
+ return None
92
+ for attempt in range(RETRY_ATTEMPTS):
93
+ try:
94
+ filename = f"{uuid4().hex}.png"
95
+ file_key = f"adgenesis_image_file/balraaj/images/{filename}"
96
+ s3_client.put_object(
97
+ Bucket=os.getenv("R2_BUCKET_NAME"),
98
+ Key=file_key,
99
+ Body=image_bytes,
100
+ ContentType="image/png",
101
+ )
102
+ r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
103
+ return r2_url
104
+ except Exception:
105
+ if attempt == RETRY_ATTEMPTS - 1:
106
+ return None
107
+ time.sleep(2 ** attempt)
108
+ return None
109
+
110
+ def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
111
+ if not REPLICATE_API_TOKEN:
112
+ return []
113
+ config = get_model_config(model_key)
114
+ if not config:
115
+ return []
116
+ for attempt in range(RETRY_ATTEMPTS):
117
+ try:
118
+ model_id = config["id"]
119
+ ar_param = config["param_name"]
120
+ inputs = {"prompt": prompt, ar_param: aspect_ratio}
121
+ output = replicate.run(model_id, input=inputs)
122
+ urls: List[str] = []
123
+ if isinstance(output, list) and output:
124
+ first = output[0]
125
+ url = getattr(first, "url", str(first))
126
+ urls = [url]
127
+ elif isinstance(output, str):
128
+ urls = [output]
129
+ elif hasattr(output, "url"):
130
+ urls = [getattr(output, "url")]
131
+ if urls:
132
+ return urls
133
+ except Exception:
134
+ if attempt == RETRY_ATTEMPTS - 1:
135
+ return []
136
+ time.sleep(1)
137
+ return []
138
+
139
+ def fetch_image_bytes_optimized(url: Union[str, Any]) -> Optional[bytes]:
140
+ url_str = getattr(url, "url", str(url))
141
+ for attempt in range(RETRY_ATTEMPTS):
142
+ try:
143
+ response = requests.get(
144
+ url_str,
145
+ timeout=REQUEST_TIMEOUT,
146
+ headers={"Cache-Control": "no-cache","Pragma": "no-cache","User-Agent": "Mozilla/5.0 (compatible; ImageBot/1.0)"},
147
+ stream=True
148
+ )
149
+ response.raise_for_status()
150
+ content = b""
151
+ for chunk in response.iter_content(chunk_size=8192):
152
+ content += chunk
153
+ return content
154
+ except Exception:
155
+ if attempt == RETRY_ATTEMPTS - 1:
156
+ return None
157
+ time.sleep(1)
158
+ return None
159
+
160
+ def process_single_image(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
161
+ model_key, prompt, aspect_ratio, index = args
162
+ result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
163
+ try:
164
+ urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
165
+ if not urls:
166
+ result["error"] = "No URLs returned from generation"
167
+ return result
168
+ source_url = urls[0]
169
+ result["source_url"] = getattr(source_url, "url", str(source_url))
170
+ img_bytes = fetch_image_bytes_optimized(source_url)
171
+ if not img_bytes:
172
+ result["error"] = "Failed to fetch image bytes"
173
+ return result
174
+ r2_url = upload_to_r2_optimized(img_bytes)
175
+ if r2_url:
176
+ result["r2_url"] = r2_url
177
+ result["success"] = True
178
+ else:
179
+ result["error"] = "Failed to upload to R2"
180
+ except Exception as e:
181
+ result["error"] = str(e)
182
+ return result
183
+
184
+ def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
185
+ if num_images == 1:
186
+ result = process_single_image((model_key, prompt, aspect_ratio, 0))
187
+ if result["success"]:
188
+ return [result["r2_url"]], [result["source_url"]], []
189
+ else:
190
+ return [], [], [result["error"] or "Generation failed"]
191
+ args_list = [(model_key, prompt, aspect_ratio, i) for i in range(num_images)]
192
+ all_r2_urls: List[str] = []
193
+ all_source_urls: List[str] = []
194
+ generation_errors: List[str] = []
195
+ with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as executor:
196
+ future_to_index = {executor.submit(process_single_image, args): args[3] for args in args_list}
197
+ for future in as_completed(future_to_index):
198
+ try:
199
+ result = future.result()
200
+ if result["success"]:
201
+ all_r2_urls.append(result["r2_url"])
202
+ if result["source_url"]:
203
+ all_source_urls.append(result["source_url"])
204
+ else:
205
+ error_msg = result["error"] or f"Generation {result['index']} failed"
206
+ generation_errors.append(error_msg)
207
+ except Exception as e:
208
+ generation_errors.append(f"Future execution error: {str(e)}")
209
+ return all_r2_urls, all_source_urls, generation_errors
210
+
211
+ def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: str, urls: List[str]) -> Optional[str]:
212
+ collection = get_mongo_collection()
213
+ if collection is None:
214
+ return None
215
+ try:
216
+ doc = {"model": model_key,"aspect_ratio": aspect_ratio,"prompt": prompt,"urls": urls,"num_images": len(urls),"lob": "balraaj","created_at": datetime.utcnow()}
217
+ result = collection.insert_one(doc)
218
+ return str(result.inserted_id)
219
+ except Exception:
220
+ return None
221
+
222
+ @st.cache_data(ttl=300)
223
+ def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int = 0) -> Tuple[List[Dict[str, Any]], int]:
224
+ collection = get_mongo_collection()
225
+ if collection is None:
226
+ return [], 0
227
+ try:
228
+ total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt},"lob": "balraaj"})
229
+ cursor = collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"}).sort("created_at", -1).skip(page * LIBRARY_PAGE_SIZE).limit(LIBRARY_PAGE_SIZE)
230
+ records = list(cursor)
231
+ return records, total_count
232
+ except Exception:
233
+ return [], 0
234
+
235
+ @st.cache_data(ttl=3600)
236
+ def get_image_bytes_cached(url: str) -> Optional[bytes]:
237
+ return fetch_image_bytes_optimized(url)
238
+
239
+ def display_image_with_download_optimized(url: Any, button_label: str = "Download image"):
240
+ url_str = getattr(url, "url", str(url))
241
+ try:
242
+ img_bytes = get_image_bytes_cached(url_str)
243
+ if img_bytes is None:
244
+ st.error("Failed to load image")
245
+ return
246
+ st.image(img_bytes, use_container_width=True)
247
+ path = urlparse(url_str).path
248
+ base = os.path.basename(path) or "image.png"
249
+ if not os.path.splitext(base)[1]:
250
+ base = f"{base}.png"
251
+ st.download_button(label=button_label, data=img_bytes, file_name=base, mime="image/png", use_container_width=True)
252
+ except Exception as e:
253
+ st.error(f"Failed to display image: {e}")
254
+
255
+ def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
256
+ if not urls:
257
+ return
258
+ zip_buffer = io.BytesIO()
259
+ with zipfile.ZipFile(zip_buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
260
+ for idx, url in enumerate(urls):
261
+ try:
262
+ img_bytes = fetch_image_bytes_optimized(url)
263
+ if img_bytes:
264
+ path = urlparse(str(url)).path
265
+ base = os.path.basename(path) or f"image_{idx}.png"
266
+ if not os.path.splitext(base)[1]:
267
+ base = f"{base}.png"
268
+ zip_file.writestr(base, img_bytes)
269
+ except Exception:
270
+ pass
271
+ zip_buffer.seek(0)
272
+ st.download_button("Download All Images", data=zip_buffer, file_name=filename, mime="application/zip", use_container_width=True)
273
+
274
+ def main_app():
275
+ st.set_page_config(page_title="Image Generator + Creative Library", layout="wide")
276
+ st.title("Multi Model Image Generator")
277
+ with st.sidebar:
278
+ page = st.radio(" ", ["Generate Bulk Images", "Generate from JSON", "Creative Library"], index=0)
279
+ if page == "Generate Bulk Images":
280
+ render_generate_page()
281
+ elif page == "Generate from JSON":
282
+ render_json_page()
283
+ elif page == "Creative Library":
284
+ render_library_page()
285
+
286
+ def render_generate_page():
287
+ colA, colB = st.columns([1, 1])
288
+ with colA:
289
+ model_key = st.selectbox("Model", list(MODEL_REGISTRY.keys()), index=0)
290
+ aspect_options = MODEL_REGISTRY[model_key]["aspect_ratios"]
291
+ aspect_ratio = st.selectbox("Aspect Ratio", aspect_options, index=0)
292
+ num_images = st.slider("Number of images", min_value=1, max_value=50, value=1, step=1)
293
+ with colB:
294
+ prompt = st.text_area("Prompt", placeholder="Describe the image you want to generate...", height=160)
295
+ debug_mode = st.checkbox("Debug Mode")
296
+ if st.button("Generate Images", type="primary", use_container_width=True):
297
+ handle_image_generation_optimized(model_key, aspect_ratio, prompt, num_images, debug_mode)
298
+
299
+ def handle_image_generation_optimized(model_key: str, aspect_ratio: str, prompt: str, num_images: int, debug_mode: bool = False):
300
+ if not REPLICATE_API_TOKEN:
301
+ st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
302
+ return
303
+ if not prompt.strip():
304
+ st.warning("Please enter a prompt.")
305
+ return
306
+ progress = st.progress(0, text="Starting generation...")
307
+ status_container = st.empty()
308
+ start_time = time.time()
309
+ try:
310
+ with status_container.container():
311
+ st.info(f"Generating {num_images} image(s) in parallel...")
312
+ progress.progress(0.1, text="Initializing parallel generation...")
313
+ all_r2_urls, all_source_urls, generation_errors = generate_images_parallel(model_key, aspect_ratio, prompt.strip(), num_images)
314
+ progress.progress(0.8, text="Saving results...")
315
+ rec_id = None
316
+ if all_r2_urls:
317
+ rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt.strip(), all_r2_urls)
318
+ progress.progress(1.0, text="Complete!")
319
+ generation_time = time.time() - start_time
320
+ if all_r2_urls:
321
+ with status_container.container():
322
+ st.success(f"Generated {len(all_r2_urls)} image(s) in {generation_time:.1f}s. Saved to DB: {rec_id or 'N/A'}")
323
+ display_image_gallery_optimized(all_r2_urls)
324
+ bulk_download_button(all_r2_urls, filename="generated_images.zip")
325
+ elif all_source_urls:
326
+ with status_container.container():
327
+ st.warning("Images generated but R2 upload failed. Showing originals:")
328
+ display_image_gallery_optimized(all_source_urls)
329
+ bulk_download_button(all_source_urls, filename="generated_images.zip")
330
+ else:
331
+ with status_container.container():
332
+ st.error("No images were generated.")
333
+ if generation_errors and debug_mode:
334
+ with st.expander("Generation Errors", expanded=True):
335
+ for error in generation_errors:
336
+ st.error(f"{error}")
337
+ except Exception as e:
338
+ with status_container.container():
339
+ st.error(f"Generation failed: {str(e)}")
340
+
341
+ def display_image_gallery_optimized(urls: List[str]):
342
+ if not urls:
343
+ return
344
+ num_cols = min(4, len(urls)) if len(urls) > 1 else 1
345
+ cols = st.columns(num_cols)
346
+ for idx, url in enumerate(urls):
347
+ with cols[idx % num_cols]:
348
+ try:
349
+ display_image_with_download_optimized(url)
350
+ except Exception as e:
351
+ st.error(f"Failed to display image: {e}")
352
+
353
+ def render_library_page():
354
+ st.subheader("Creative Library")
355
+ if "library_page" not in st.session_state:
356
+ st.session_state.library_page = 0
357
+ today_utc = datetime.utcnow().date()
358
+ default_start = today_utc - timedelta(days=30)
359
+ c1, c2, c3 = st.columns([1, 1, 1])
360
+ with c1:
361
+ start_date: date = st.date_input("Start date", value=default_start)
362
+ with c2:
363
+ end_date: date = st.date_input("End date", value=today_utc)
364
+ with c3:
365
+ if st.button("Apply Filters", use_container_width=True):
366
+ st.session_state.library_page = 0
367
+ st.cache_data.clear()
368
+ start_dt = datetime.combine(start_date, datetime.min.time())
369
+ end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
370
+ records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
371
+ if not records and st.session_state.library_page == 0:
372
+ st.info("No creatives found for the selected dates.")
373
+ return
374
+ start_idx = st.session_state.library_page * LIBRARY_PAGE_SIZE + 1
375
+ end_idx = min(start_idx + len(records) - 1, total_count)
376
+ st.caption(f"Showing {start_idx}-{end_idx} of {total_count} items")
377
+ if total_count > LIBRARY_PAGE_SIZE:
378
+ col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
379
+ with col1:
380
+ if st.button("First", disabled=st.session_state.library_page == 0):
381
+ st.session_state.library_page = 0
382
+ st.rerun()
383
+ with col2:
384
+ if st.button("Previous", disabled=st.session_state.library_page == 0):
385
+ st.session_state.library_page = max(0, st.session_state.library_page - 1)
386
+ st.rerun()
387
+ with col3:
388
+ max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
389
+ if st.button("Next", disabled=st.session_state.library_page >= max_page):
390
+ st.session_state.library_page += 1
391
+ st.rerun()
392
+ with col4:
393
+ max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
394
+ if st.button("Last", disabled=st.session_state.library_page >= max_page):
395
+ st.session_state.library_page = max_page
396
+ st.rerun()
397
+ display_creative_grid_optimized(records)
398
+
399
+ def display_creative_grid_optimized(records: List[Dict[str, Any]]):
400
+ if not records:
401
+ return
402
+ items: List[Tuple[str, str, str, datetime]] = []
403
+ for doc in records:
404
+ created_at = doc.get("created_at", datetime.utcnow())
405
+ model = doc.get("model", "?")
406
+ aspect = doc.get("aspect_ratio", "?")
407
+ urls = doc.get("urls", []) or []
408
+ for url in urls:
409
+ items.append((url, model, aspect, created_at))
410
+ if not items:
411
+ return
412
+ if "selected_urls" not in st.session_state:
413
+ st.session_state.selected_urls = set()
414
+ colA, colB, colC = st.columns([1, 1, 2])
415
+ with colA:
416
+ if st.button("Select all on this page", use_container_width=True, disabled=not items):
417
+ for (url, _, _, _) in items:
418
+ st.session_state.selected_urls.add(getattr(url, "url", str(url)))
419
+ with colB:
420
+ if st.button("Clear selection", use_container_width=True, disabled=not st.session_state.selected_urls):
421
+ st.session_state.selected_urls = set()
422
+ with colC:
423
+ if st.session_state.selected_urls:
424
+ zip_bytes = build_zip_from_selected()
425
+ st.download_button("Download selected (.zip)", data=zip_bytes, file_name="selected_images.zip", mime="application/zip", use_container_width=True)
426
+ cols = st.columns(4)
427
+ for i, (url, model, aspect, created_at) in enumerate(items):
428
+ with cols[i % 4]:
429
+ try:
430
+ display_image_with_download_optimized(url)
431
+ st.caption(f"{model} • {aspect} • {created_at.strftime('%Y-%m-%d %H:%M UTC')}")
432
+ url_str = getattr(url, "url", str(url))
433
+ key = f"lib_sel_{hash(url_str)}_{i}"
434
+ checked = st.checkbox("Select", key=key, value=(url_str in st.session_state.selected_urls))
435
+ if checked:
436
+ st.session_state.selected_urls.add(url_str)
437
+ else:
438
+ st.session_state.selected_urls.discard(url_str)
439
+ except Exception as e:
440
+ st.error(f"Failed to display image: {e}")
441
+
442
+ def build_zip_from_selected() -> bytes:
443
+ buf = io.BytesIO()
444
+ with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
445
+ for i, u in enumerate(sorted(list(st.session_state.selected_urls)), 1):
446
+ data = get_image_bytes_cached(u)
447
+ if data:
448
+ path = urlparse(u).path
449
+ base = os.path.basename(path) or f"image_{i}.png"
450
+ if not os.path.splitext(base)[1]:
451
+ base = f"{base}.png"
452
+ zf.writestr(base, data)
453
+ buf.seek(0)
454
+ return buf.getvalue()
455
+
456
+ @lru_cache(maxsize=1)
457
+ def check_token_cached(user_token: str) -> Tuple[bool, str]:
458
+ ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
459
+ if not ACCESS_TOKEN:
460
+ return False, "Server error: Access token not configured."
461
+ if user_token == ACCESS_TOKEN:
462
+ return True, ""
463
+ return False, "Invalid token."
464
+
465
+ def load_json_prompts(file) -> List[Dict[str, Any]]:
466
+ raw = file.getvalue().decode("utf-8", errors="replace")
467
+ data = json.loads(raw)
468
+ prompts_out: List[Dict[str, Any]] = []
469
+ def push(idx: int, item: Any):
470
+ if isinstance(item, str):
471
+ prompts_out.append({"id": f"p{idx}", "content": item})
472
+ elif isinstance(item, dict) and "content" in item:
473
+ obj = {"id": str(item.get("id", f"p{idx}")), "content": str(item["content"])}
474
+ if "num" in item: obj["num"] = int(item["num"])
475
+ if "aspect_ratio" in item: obj["aspect_ratio"] = str(item["aspect_ratio"])
476
+ if "model" in item: obj["model"] = str(item["model"])
477
+ prompts_out.append(obj)
478
+ if isinstance(data, dict) and "prompts" in data and isinstance(data["prompts"], list):
479
+ for i, item in enumerate(data["prompts"], 1):
480
+ push(i, item)
481
+ elif isinstance(data, list):
482
+ for i, item in enumerate(data, 1):
483
+ push(i, item)
484
+ return prompts_out
485
+
486
+ def render_json_page():
487
+ st.subheader("Generate from JSON Prompts")
488
+ up = st.file_uploader("Upload prompts JSON", type=["json"])
489
+ col1, col2, col3 = st.columns([1,1,1])
490
+ with col1:
491
+ default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
492
+ with col2:
493
+ aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
494
+ default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
495
+ with col3:
496
+ default_num = st.slider("Default Images per Prompt", 1, 20, 1, 1, key="json_default_num")
497
+ debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
498
+ if up:
499
+ try:
500
+ prompts_list = load_json_prompts(up)
501
+ if not prompts_list:
502
+ st.error("No prompts found in the JSON.")
503
+ return
504
+ with st.expander("Preview normalized prompts", expanded=False):
505
+ st.json(prompts_list, expanded=False)
506
+ if st.button("Generate for All Prompts", type="primary", use_container_width=True):
507
+ handle_bulk_json_generation(prompts_list, default_model, default_aspect, default_num, debug_mode)
508
+ except json.JSONDecodeError as e:
509
+ st.error(f"Invalid JSON: {e}")
510
+ except Exception as e:
511
+ st.error(f"Failed to read prompts: {e}")
512
+ else:
513
+ st.caption("Your JSON can be a list of strings, a list of objects with `content`, or `{ 'prompts': [...] }`. Optional per-prompt keys: `model`, `aspect_ratio`, `num`.")
514
+
515
+ def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str, default_num: int, debug_mode: bool):
516
+ if not REPLICATE_API_TOKEN:
517
+ st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
518
+ return
519
+ total = len(prompts)
520
+ overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
521
+ all_generated_urls: List[str] = []
522
+ errors_total: List[str] = []
523
+ start_time = time.time()
524
+ for idx, p in enumerate(prompts, 1):
525
+ model_key = str(p.get("model", default_model))
526
+ aspect_ratio = str(p.get("aspect_ratio", default_aspect))
527
+ num_images = int(p.get("num", default_num))
528
+ prompt_text = str(p.get("content", "")).strip()
529
+ block = st.container(border=True)
530
+ with block:
531
+ st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `{num_images}`")
532
+ st.code(prompt_text or "(empty)", language="markdown")
533
+ if not prompt_text:
534
+ st.error("Prompt text is empty. Skipping.")
535
+ overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
536
+ continue
537
+ r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
538
+ rec_id = None
539
+ if r2_urls:
540
+ rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
541
+ if r2_urls:
542
+ st.success(f"Generated {len(r2_urls)} image(s). DB: {rec_id or 'N/A'}")
543
+ display_image_gallery_optimized(r2_urls)
544
+ bulk_download_button(r2_urls, filename=f"prompt_{idx}_images.zip")
545
+ all_generated_urls.extend(r2_urls)
546
+ elif src_urls:
547
+ st.warning("Images generated but R2 upload failed. Showing originals:")
548
+ display_image_gallery_optimized(src_urls)
549
+ bulk_download_button(src_urls, filename=f"prompt_{idx}_images.zip")
550
+ all_generated_urls.extend(src_urls)
551
+ else:
552
+ st.error("No images were generated for this prompt.")
553
+ if gen_errors and debug_mode:
554
+ with st.expander("Errors", expanded=False):
555
+ for e in gen_errors:
556
+ st.error(e)
557
+ errors_total.extend(gen_errors)
558
+ overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
559
+ elapsed = time.time() - start_time
560
+ st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
561
+ if all_generated_urls:
562
+ st.subheader("Download All Generated")
563
+ bulk_download_button(all_generated_urls, filename="all_prompts_images.zip")
564
+ if errors_total and debug_mode:
565
+ with st.expander("All Errors Summary", expanded=False):
566
+ for e in errors_total:
567
+ st.error(e)
568
+
569
+ def main():
570
+ st.set_page_config(page_title="Bulk Creative Generation", layout="wide")
571
+ if "authenticated" not in st.session_state:
572
+ st.session_state["authenticated"] = False
573
+ if not st.session_state["authenticated"]:
574
+ st.markdown("## Access Required")
575
+ token_input = st.text_input("Enter Access Token", type="password")
576
+ if st.button("Unlock App"):
577
+ ok, error_msg = check_token_cached(token_input)
578
+ if ok:
579
+ st.session_state["authenticated"] = True
580
+ st.rerun()
581
+ else:
582
+ st.error(error_msg)
583
+ else:
584
+ main_app()
585
+
586
+ if __name__ == "__main__":
587
+ main()