userIdc2024 commited on
Commit
c66ef08
·
verified ·
1 Parent(s): ef9af6f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +154 -400
src/streamlit_app.py CHANGED
@@ -33,20 +33,19 @@ RETRY_ATTEMPTS = 3
33
  LIBRARY_PAGE_SIZE = 20
34
 
35
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
36
- "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
37
- "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
38
- "nano-banana": {"id": "google/nano-banana","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3"],"param_name": "aspect_ratio"},
39
- "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3", "3:2", "2:3"],"param_name": "aspect_ratio"},
40
- "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1", "16:9", "9:16", "3:4", "4:3", "3:2", "2:3", "21:9"],"param_name": "aspect_ratio"},
41
- "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1", "4:3", "3:4", "3:2", "2:3", "16:9", "9:16", "1:2", "2:1", "7:5", "5:7", "4:5", "5:4", "3:5", "5:3"],"param_name": "aspect_ratio"},
42
- "photon": {"id": "luma/photon","aspect_ratios": ["1:1", "3:4", "4:3", "9:16", "16:9", "9:21", "21:9"],"param_name": "aspect_ratio"},
43
- "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:3", "3:1", "1:2", "2:1", "9:16", "16:9", "10:16", "16:10", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "1:1"],"param_name": "aspect_ratio"},
44
  }
45
 
46
  _thread_local = threading.local()
47
 
48
  def get_mongo_collection():
49
- if not hasattr(_thread_local, 'mongo_collection'):
50
  if not MONGO_URI:
51
  _thread_local.mongo_collection = None
52
  return None
@@ -54,7 +53,7 @@ def get_mongo_collection():
54
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
55
  db = client[MONGO_DB]
56
  collection = db[MONGO_COLLECTION]
57
- client.admin.command('ping')
58
  _thread_local.mongo_collection = collection
59
  except Exception as e:
60
  logger.error(f"MongoDB connection failed: {e}")
@@ -62,8 +61,8 @@ def get_mongo_collection():
62
  return _thread_local.mongo_collection
63
 
64
  def get_s3_client():
65
- if not hasattr(_thread_local, 's3_client'):
66
- required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
67
  missing = [var for var in required_vars if not os.getenv(var)]
68
  if missing:
69
  _thread_local.s3_client = None
@@ -89,23 +88,19 @@ def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
89
  s3_client = get_s3_client()
90
  if not s3_client:
91
  return None
92
- for attempt in range(RETRY_ATTEMPTS):
93
- try:
94
- filename = f"{uuid4().hex}.png"
95
- file_key = f"adgenesis_image_file/balraaj/images/{filename}"
96
- s3_client.put_object(
97
- Bucket=os.getenv("R2_BUCKET_NAME"),
98
- Key=file_key,
99
- Body=image_bytes,
100
- ContentType="image/png",
101
- )
102
- r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
103
- return r2_url
104
- except Exception:
105
- if attempt == RETRY_ATTEMPTS - 1:
106
- return None
107
- time.sleep(2 ** attempt)
108
- return None
109
 
110
  def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
111
  if not REPLICATE_API_TOKEN:
@@ -113,425 +108,184 @@ def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str)
113
  config = get_model_config(model_key)
114
  if not config:
115
  return []
116
- for attempt in range(RETRY_ATTEMPTS):
117
- try:
118
- model_id = config["id"]
119
- ar_param = config["param_name"]
120
- inputs = {"prompt": prompt, ar_param: aspect_ratio}
121
- output = replicate.run(model_id, input=inputs)
122
- urls: List[str] = []
123
- if isinstance(output, list) and output:
124
- first = output[0]
125
- url = getattr(first, "url", str(first))
126
- urls = [url]
127
- elif isinstance(output, str):
128
- urls = [output]
129
- elif hasattr(output, "url"):
130
- urls = [getattr(output, "url")]
131
- if urls:
132
- return urls
133
- except Exception:
134
- if attempt == RETRY_ATTEMPTS - 1:
135
- return []
136
- time.sleep(1)
137
  return []
138
 
139
- def fetch_image_bytes_optimized(url: Union[str, Any]) -> Optional[bytes]:
140
- url_str = getattr(url, "url", str(url))
141
  for attempt in range(RETRY_ATTEMPTS):
142
  try:
143
- response = requests.get(
144
- url_str,
145
- timeout=REQUEST_TIMEOUT,
146
- headers={"Cache-Control": "no-cache","Pragma": "no-cache","User-Agent": "Mozilla/5.0 (compatible; ImageBot/1.0)"},
147
- stream=True
148
- )
149
  response.raise_for_status()
150
- content = b""
151
- for chunk in response.iter_content(chunk_size=8192):
152
- content += chunk
153
- return content
154
  except Exception:
155
  if attempt == RETRY_ATTEMPTS - 1:
156
  return None
157
  time.sleep(1)
158
  return None
159
 
160
- def process_single_image(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
161
  model_key, prompt, aspect_ratio, index = args
162
  result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
163
- try:
164
- urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
165
- if not urls:
166
- result["error"] = "No URLs returned from generation"
167
- return result
168
- source_url = urls[0]
169
- result["source_url"] = getattr(source_url, "url", str(source_url))
170
- img_bytes = fetch_image_bytes_optimized(source_url)
171
- if not img_bytes:
172
- result["error"] = "Failed to fetch image bytes"
173
- return result
174
- r2_url = upload_to_r2_optimized(img_bytes)
175
- if r2_url:
176
- result["r2_url"] = r2_url
177
- result["success"] = True
178
- else:
179
- result["error"] = "Failed to upload to R2"
180
- except Exception as e:
181
- result["error"] = str(e)
182
  return result
183
 
184
- def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
185
- if num_images == 1:
186
- result = process_single_image((model_key, prompt, aspect_ratio, 0))
187
- if result["success"]:
188
- return [result["r2_url"]], [result["source_url"]], []
189
- else:
190
- return [], [], [result["error"] or "Generation failed"]
191
- args_list = [(model_key, prompt, aspect_ratio, i) for i in range(num_images)]
192
- all_r2_urls: List[str] = []
193
- all_source_urls: List[str] = []
194
- generation_errors: List[str] = []
195
- with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as executor:
196
- future_to_index = {executor.submit(process_single_image, args): args[3] for args in args_list}
197
- for future in as_completed(future_to_index):
198
- try:
199
- result = future.result()
200
- if result["success"]:
201
- all_r2_urls.append(result["r2_url"])
202
- if result["source_url"]:
203
- all_source_urls.append(result["source_url"])
204
- else:
205
- error_msg = result["error"] or f"Generation {result['index']} failed"
206
- generation_errors.append(error_msg)
207
- except Exception as e:
208
- generation_errors.append(f"Future execution error: {str(e)}")
209
- return all_r2_urls, all_source_urls, generation_errors
210
 
211
- def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: str, urls: List[str]) -> Optional[str]:
212
  collection = get_mongo_collection()
213
  if collection is None:
214
  return None
215
  try:
216
- doc = {
217
- "model": model_key,
218
- "aspect_ratio": aspect_ratio,
219
- "prompt": prompt,
220
- "urls": urls,
221
- "num_images": len(urls),
222
- "lob": "balraaj",
223
- "created_at": datetime.utcnow()
224
- }
225
- result = collection.insert_one(doc)
226
- return str(result.inserted_id)
227
- except Exception:
228
  return None
229
 
230
  @st.cache_data(ttl=300)
231
- def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int = 0) -> Tuple[List[Dict[str, Any]], int]:
232
  collection = get_mongo_collection()
233
  if collection is None:
234
- return [], 0
235
  try:
236
- total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
237
- cursor = (
238
- collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
239
- .sort("created_at", -1)
240
- .skip(page * LIBRARY_PAGE_SIZE)
241
- .limit(LIBRARY_PAGE_SIZE)
242
- )
243
- records = list(cursor)
244
- return records, total_count
245
  except Exception:
246
- return [], 0
247
-
248
- @st.cache_data(ttl=3600)
249
- def get_image_bytes_cached(url: str) -> Optional[bytes]:
250
- return fetch_image_bytes_optimized(url)
251
 
252
- def display_image_with_download_optimized(url: Any, button_label: str = "Download image"):
253
- url_str = getattr(url, "url", str(url))
254
  try:
255
- img_bytes = get_image_bytes_cached(url_str)
256
- if img_bytes is None:
257
  st.error("Failed to load image")
258
  return
259
  st.image(img_bytes, use_container_width=True)
260
- path = urlparse(url_str).path
261
- base = os.path.basename(path) or "image.png"
262
  if not os.path.splitext(base)[1]:
263
- base = f"{base}.png"
264
- st.download_button(label=button_label, data=img_bytes, file_name=base, mime="image/png", use_container_width=True)
265
  except Exception as e:
266
  st.error(f"Failed to display image: {e}")
267
 
268
- def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
269
- if not urls:
270
- return
271
- zip_buffer = io.BytesIO()
272
- with zipfile.ZipFile(zip_buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
273
- for idx, url in enumerate(urls):
274
- try:
275
- img_bytes = fetch_image_bytes_optimized(url)
276
- if img_bytes:
277
- path = urlparse(str(url)).path
278
- base = os.path.basename(path) or f"image_{idx}.png"
279
- if not os.path.splitext(base)[1]:
280
- base = f"{base}.png"
281
- zip_file.writestr(base, img_bytes)
282
- except Exception:
283
- pass
284
- zip_buffer.seek(0)
285
- st.download_button("Download All Images", data=zip_buffer, file_name=filename, mime="application/zip", use_container_width=True)
286
 
287
- def main_app():
288
- st.set_page_config(page_title="File-to-Image Creative Library", layout="wide")
289
- st.title("File-to-Image Generator")
290
- with st.sidebar:
291
- page = st.radio(" ", ["Generate from JSON", "Creative Library"], index=0)
292
- if page == "Generate from JSON":
293
- render_json_page()
294
- elif page == "Creative Library":
295
- render_library_page()
296
-
297
- def render_library_page():
298
- st.subheader("Creative Library")
299
- if "library_page" not in st.session_state:
300
- st.session_state.library_page = 0
301
- today_utc = datetime.utcnow().date()
302
- default_start = today_utc - timedelta(days=30)
303
- c1, c2, c3 = st.columns([1, 1, 1])
304
- with c1:
305
- start_date: date = st.date_input("Start date", value=default_start)
306
- with c2:
307
- end_date: date = st.date_input("End date", value=today_utc)
308
- with c3:
309
- if st.button("Apply Filters", use_container_width=True):
310
- st.session_state.library_page = 0
311
- st.cache_data.clear()
312
- start_dt = datetime.combine(start_date, datetime.min.time())
313
- end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
314
- records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
315
- if not records and st.session_state.library_page == 0:
316
- st.info("No creatives found for the selected dates.")
317
- return
318
- start_idx = st.session_state.library_page * LIBRARY_PAGE_SIZE + 1
319
- end_idx = min(start_idx + len(records) - 1, total_count)
320
- st.caption(f"Showing {start_idx}-{end_idx} of {total_count} items")
321
- if total_count > LIBRARY_PAGE_SIZE:
322
- col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
323
- with col1:
324
- if st.button("First", disabled=st.session_state.library_page == 0):
325
- st.session_state.library_page = 0
326
- st.rerun()
327
- with col2:
328
- if st.button("Previous", disabled=st.session_state.library_page == 0):
329
- st.session_state.library_page = max(0, st.session_state.library_page - 1)
330
- st.rerun()
331
- with col3:
332
- max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
333
- if st.button("Next", disabled=st.session_state.library_page >= max_page):
334
- st.session_state.library_page += 1
335
- st.rerun()
336
- with col4:
337
- max_page = (total_count - 1) // LIBRARY_PAGE_SIZE
338
- if st.button("Last", disabled=st.session_state.library_page >= max_page):
339
- st.session_state.library_page = max_page
340
- st.rerun()
341
- display_creative_grid_optimized(records)
342
-
343
- def display_creative_grid_optimized(records: List[Dict[str, Any]]):
344
- if not records:
345
- return
346
- items: List[Tuple[str, str, str, datetime]] = []
347
- for doc in records:
348
- created_at = doc.get("created_at", datetime.utcnow())
349
- model = doc.get("model", "?")
350
- aspect = doc.get("aspect_ratio", "?")
351
- urls = doc.get("urls", []) or []
352
- for url in urls:
353
- items.append((url, model, aspect, created_at))
354
- if not items:
355
- return
356
- if "selected_urls" not in st.session_state:
357
- st.session_state.selected_urls = set()
358
- colA, colB, colC = st.columns([1, 1, 2])
359
- with colA:
360
- if st.button("Select all on this page", use_container_width=True, disabled=not items):
361
- for (url, _, _, _) in items:
362
- st.session_state.selected_urls.add(getattr(url, "url", str(url)))
363
- with colB:
364
- if st.button("Clear selection", use_container_width=True, disabled=not st.session_state.selected_urls):
365
- st.session_state.selected_urls = set()
366
- with colC:
367
- if st.session_state.selected_urls:
368
- zip_bytes = build_zip_from_selected()
369
- st.download_button("Download selected (.zip)", data=zip_bytes, file_name="selected_images.zip", mime="application/zip", use_container_width=True)
370
- cols = st.columns(4)
371
- for i, (url, model, aspect, created_at) in enumerate(items):
372
- with cols[i % 4]:
373
- try:
374
- display_image_with_download_optimized(url)
375
- st.caption(f"{model} • {aspect} • {created_at.strftime('%Y-%m-%d %H:%M UTC')}")
376
- url_str = getattr(url, "url", str(url))
377
- key = f"lib_sel_{hash(url_str)}_{i}"
378
- checked = st.checkbox("Select", key=key, value=(url_str in st.session_state.selected_urls))
379
- if checked:
380
- st.session_state.selected_urls.add(url_str)
381
- else:
382
- st.session_state.selected_urls.discard(url_str)
383
- except Exception as e:
384
- st.error(f"Failed to display image: {e}")
385
-
386
- def build_zip_from_selected() -> bytes:
387
- buf = io.BytesIO()
388
- with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
389
- for i, u in enumerate(sorted(list(st.session_state.selected_urls)), 1):
390
- data = get_image_bytes_cached(u)
391
- if data:
392
- path = urlparse(u).path
393
- base = os.path.basename(path) or f"image_{i}.png"
394
  if not os.path.splitext(base)[1]:
395
- base = f"{base}.png"
396
- zf.writestr(base, data)
397
- buf.seek(0)
398
- return buf.getvalue()
399
-
400
- @lru_cache(maxsize=1)
401
- def check_token_cached(user_token: str) -> Tuple[bool, str]:
402
- ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
403
- if not ACCESS_TOKEN:
404
- return False, "Server error: Access token not configured."
405
- if user_token == ACCESS_TOKEN:
406
- return True, ""
407
- return False, "Invalid token."
408
-
409
- def load_json_prompts(file) -> List[Dict[str, Any]]:
410
- raw = file.getvalue().decode("utf-8", errors="replace")
411
- data = json.loads(raw)
412
-
413
- if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
414
- raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
415
-
416
- prompts_out: List[Dict[str, Any]] = []
417
- for i, item in enumerate(data["prompts"], 1):
418
- if not isinstance(item, str) or not item.strip():
419
- raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
420
- prompts_out.append({"id": f"p{i}", "content": item.strip()})
421
 
422
- return prompts_out
 
 
 
 
 
423
 
424
  def render_json_page():
425
  st.subheader("Generate from JSON Prompts")
426
- up = st.file_uploader("Upload prompts JSON", type=["json"])
427
-
428
- col1, col2 = st.columns([1, 1])
429
- with col1:
430
- default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
431
- with col2:
432
- aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
433
- default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
434
-
435
- debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
436
-
437
  if up:
438
  try:
439
- prompts_list = load_json_prompts(up)
440
- if not prompts_list:
441
- st.error("No prompts found in the JSON.")
442
- return
443
-
444
- with st.expander("Preview normalized prompts", expanded=False):
445
- st.json(prompts_list, expanded=False)
446
-
447
- if st.button("Generate for All Prompts", type="primary", use_container_width=True):
448
- handle_bulk_json_generation(prompts_list, default_model, default_aspect, debug_mode)
449
- except json.JSONDecodeError as e:
450
- st.error(f"Invalid JSON: {e}")
451
  except Exception as e:
452
- st.error(f"Failed to read prompts: {e}")
453
- else:
454
- st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
455
-
456
- def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str, debug_mode: bool):
457
- if not REPLICATE_API_TOKEN:
458
- st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
459
- return
460
-
461
- total = len(prompts)
462
- overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
463
- all_generated_urls: List[str] = []
464
- errors_total: List[str] = []
465
- start_time = time.time()
466
-
467
- for idx, p in enumerate(prompts, 1):
468
- model_key = default_model
469
- aspect_ratio = default_aspect
470
- num_images = 1 # exactly one image per prompt
471
- prompt_text = str(p.get("content", "")).strip()
472
-
473
- block = st.container(border=True)
474
- with block:
475
- st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `1`")
476
- st.code(prompt_text or "(empty)", language="markdown")
477
-
478
- if not prompt_text:
479
- st.error("Prompt text is empty. Skipping.")
480
- overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
481
- continue
482
-
483
- r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
484
- rec_id = None
485
- if r2_urls:
486
- rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
487
-
488
- if r2_urls:
489
- st.success(f"Generated 1 image. DB: {rec_id or 'N/A'}")
490
- display_image_gallery_optimized(r2_urls)
491
- bulk_download_button(r2_urls, filename=f"prompt_{idx}_image.zip")
492
- all_generated_urls.extend(r2_urls)
493
- elif src_urls:
494
- st.warning("Image generated but R2 upload failed. Showing original:")
495
- display_image_gallery_optimized(src_urls)
496
- bulk_download_button(src_urls, filename=f"prompt_{idx}_image.zip")
497
- all_generated_urls.extend(src_urls)
498
- else:
499
- st.error("No image was generated for this prompt.")
500
-
501
- if gen_errors and debug_mode:
502
- with st.expander("Errors", expanded=False):
503
- for e in gen_errors:
504
- st.error(e)
505
- errors_total.extend(gen_errors)
506
 
507
- overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
 
 
 
 
 
 
 
 
 
 
 
508
 
509
- elapsed = time.time() - start_time
510
- st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
511
- if all_generated_urls:
512
- st.subheader("Download All Generated")
513
- bulk_download_button(all_generated_urls, filename="all_prompts_images.zip")
514
- if errors_total and debug_mode:
515
- with st.expander("All Errors Summary", expanded=False):
516
- for e in errors_total:
517
- st.error(e)
518
 
519
  def main():
520
- st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide")
521
- if "authenticated" not in st.session_state:
522
- st.session_state["authenticated"] = False
523
- if not st.session_state["authenticated"]:
524
- st.markdown("## Access Required")
525
- token_input = st.text_input("Enter Access Token", type="password")
526
- if st.button("Unlock App"):
527
- ok, error_msg = check_token_cached(token_input)
528
- if ok:
529
- st.session_state["authenticated"] = True
530
- st.rerun()
531
- else:
532
- st.error(error_msg)
533
- else:
534
- main_app()
535
 
536
- if __name__ == "__main__":
537
  main()
 
33
  LIBRARY_PAGE_SIZE = 20
34
 
35
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
36
+ "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
37
+ "imagen-4": {"id": "google/imagen-4","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
38
+ "qwen": {"id": "qwen/qwen-image","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"],"param_name": "aspect_ratio"},
39
+ "seedream-3": {"id": "bytedance/seedream-3","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"],"param_name": "aspect_ratio"},
40
+ "recraft-v3": {"id": "recraft-ai/recraft-v3","aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"],"param_name": "aspect_ratio"},
41
+ "photon": {"id": "luma/photon","aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"],"param_name": "aspect_ratio"},
42
+ "ideogram-v3-quality": {"id": "ideogram-ai/ideogram-v3-quality","aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"],"param_name": "aspect_ratio"},
 
43
  }
44
 
45
  _thread_local = threading.local()
46
 
47
  def get_mongo_collection():
48
+ if not hasattr(_thread_local, "mongo_collection"):
49
  if not MONGO_URI:
50
  _thread_local.mongo_collection = None
51
  return None
 
53
  client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=3000)
54
  db = client[MONGO_DB]
55
  collection = db[MONGO_COLLECTION]
56
+ client.admin.command("ping")
57
  _thread_local.mongo_collection = collection
58
  except Exception as e:
59
  logger.error(f"MongoDB connection failed: {e}")
 
61
  return _thread_local.mongo_collection
62
 
63
  def get_s3_client():
64
+ if not hasattr(_thread_local, "s3_client"):
65
+ required_vars = ["R2_ENDPOINT","R2_ACCESS_KEY","R2_SECRET_KEY","R2_BUCKET_NAME"]
66
  missing = [var for var in required_vars if not os.getenv(var)]
67
  if missing:
68
  _thread_local.s3_client = None
 
88
  s3_client = get_s3_client()
89
  if not s3_client:
90
  return None
91
+ try:
92
+ filename = f"{uuid4().hex}.png"
93
+ file_key = f"adgenesis_image_file/json/images/{filename}"
94
+ s3_client.put_object(
95
+ Bucket=os.getenv("R2_BUCKET_NAME"),
96
+ Key=file_key,
97
+ Body=image_bytes,
98
+ ContentType="image/png",
99
+ )
100
+ return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
101
+ except Exception as e:
102
+ logger.error(f"S3 upload failed: {e}")
103
+ return None
 
 
 
 
104
 
105
  def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
106
  if not REPLICATE_API_TOKEN:
 
108
  config = get_model_config(model_key)
109
  if not config:
110
  return []
111
+ try:
112
+ model_id = config["id"]
113
+ ar_param = config["param_name"]
114
+ inputs = {"prompt": prompt, ar_param: aspect_ratio}
115
+ output = replicate.run(model_id, input=inputs)
116
+ if isinstance(output, list) and output:
117
+ return [str(output[0])]
118
+ elif isinstance(output, str):
119
+ return [output]
120
+ elif hasattr(output, "url"):
121
+ return [getattr(output, "url")]
122
+ except Exception as e:
123
+ logger.error(f"Replicate error: {e}")
 
 
 
 
 
 
 
 
124
  return []
125
 
126
+ def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
 
127
  for attempt in range(RETRY_ATTEMPTS):
128
  try:
129
+ response = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
 
 
 
 
 
130
  response.raise_for_status()
131
+ return response.content
 
 
 
132
  except Exception:
133
  if attempt == RETRY_ATTEMPTS - 1:
134
  return None
135
  time.sleep(1)
136
  return None
137
 
138
+ def process_single_image(args: Tuple[str,str,str,int]) -> Dict[str,Any]:
139
  model_key, prompt, aspect_ratio, index = args
140
  result = {"index": index,"success": False,"source_url": None,"r2_url": None,"error": None}
141
+ urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
142
+ if not urls:
143
+ result["error"] = "No URLs returned"
144
+ return result
145
+ source_url = urls[0]
146
+ result["source_url"] = source_url
147
+ img_bytes = fetch_image_bytes_optimized(source_url)
148
+ if not img_bytes:
149
+ result["error"] = "Failed to fetch image bytes"
150
+ return result
151
+ r2_url = upload_to_r2_optimized(img_bytes)
152
+ if r2_url:
153
+ result["r2_url"] = r2_url
154
+ result["success"] = True
155
+ else:
156
+ result["error"] = "Failed to upload to R2"
 
 
 
157
  return result
158
 
159
+ def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str],List[str],List[str]]:
160
+ result = process_single_image((model_key, prompt, aspect_ratio, 0))
161
+ if result["success"]:
162
+ return [result["r2_url"]], [result["source_url"]], []
163
+ else:
164
+ return [], [], [result["error"] or "Generation failed"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def save_creative_record_optimized(model_key:str,aspect_ratio:str,prompt:str,urls:List[str]) -> Optional[str]:
167
  collection = get_mongo_collection()
168
  if collection is None:
169
  return None
170
  try:
171
+ doc = {"model": model_key,"aspect_ratio": aspect_ratio,"prompt": prompt,"urls": urls,"num_images": len(urls),"lob": "json_batch","created_at": datetime.utcnow()}
172
+ return str(collection.insert_one(doc).inserted_id)
173
+ except Exception as e:
174
+ logger.error(f"Mongo insert failed: {e}")
 
 
 
 
 
 
 
 
175
  return None
176
 
177
  @st.cache_data(ttl=300)
178
+ def query_creatives_optimized(start_dt:datetime,end_dt:datetime,page:int=0)->Tuple[List[Dict[str,Any]],int]:
179
  collection = get_mongo_collection()
180
  if collection is None:
181
+ return [],0
182
  try:
183
+ total_count = collection.count_documents({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"})
184
+ cursor = collection.find({"created_at":{"$gte":start_dt,"$lt":end_dt},"lob":"json_batch"}).sort("created_at",-1).skip(page*LIBRARY_PAGE_SIZE).limit(LIBRARY_PAGE_SIZE)
185
+ return list(cursor), total_count
 
 
 
 
 
 
186
  except Exception:
187
+ return [],0
 
 
 
 
188
 
189
+ def display_image_with_download_optimized(url:str):
 
190
  try:
191
+ img_bytes = fetch_image_bytes_optimized(url)
192
+ if not img_bytes:
193
  st.error("Failed to load image")
194
  return
195
  st.image(img_bytes, use_container_width=True)
196
+ base = os.path.basename(urlparse(url).path) or "image.png"
 
197
  if not os.path.splitext(base)[1]:
198
+ base += ".png"
199
+ st.download_button("Download image", data=img_bytes, file_name=base, mime="image/png", use_container_width=True)
200
  except Exception as e:
201
  st.error(f"Failed to display image: {e}")
202
 
203
+ def display_image_gallery_optimized(urls: List[str]):
204
+ if not urls: return
205
+ cols = st.columns(min(4,len(urls)))
206
+ for i,url in enumerate(urls):
207
+ with cols[i % len(cols)]:
208
+ display_image_with_download_optimized(url)
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ def bulk_download_button(urls: List[str], filename="images_bundle.zip"):
211
+ if not urls: return
212
+ zip_buffer = io.BytesIO()
213
+ with zipfile.ZipFile(zip_buffer,"w",compression=zipfile.ZIP_DEFLATED) as zipf:
214
+ for i,url in enumerate(urls,1):
215
+ img = fetch_image_bytes_optimized(url)
216
+ if img:
217
+ base = os.path.basename(urlparse(url).path) or f"img_{i}.png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  if not os.path.splitext(base)[1]:
219
+ base += ".png"
220
+ zipf.writestr(base,img)
221
+ zip_buffer.seek(0)
222
+ st.download_button("Download All Images",data=zip_buffer,file_name=filename,mime="application/zip",use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ def load_json_prompts(file)->List[Dict[str,str]]:
225
+ raw=file.getvalue().decode("utf-8")
226
+ data=json.loads(raw)
227
+ if not isinstance(data,dict) or "prompts" not in data or not isinstance(data["prompts"],list):
228
+ raise ValueError("JSON must be { 'prompts': [ 'string', ... ] }")
229
+ return [{"id":f"p{i}","content":p.strip()} for i,p in enumerate(data["prompts"],1) if isinstance(p,str) and p.strip()]
230
 
231
  def render_json_page():
232
  st.subheader("Generate from JSON Prompts")
233
+ up=st.file_uploader("Upload prompts JSON",type=["json"])
234
+ col1,col2=st.columns([1,1])
235
+ with col1: default_model=st.selectbox("Default Model",list(MODEL_REGISTRY.keys()),0)
236
+ with col2: default_aspect=st.selectbox("Default Aspect Ratio",MODEL_REGISTRY[default_model]["aspect_ratios"],0)
237
+ debug=st.checkbox("Debug Mode",False)
 
 
 
 
 
 
238
  if up:
239
  try:
240
+ prompts=load_json_prompts(up)
241
+ st.json(prompts)
242
+ if st.button("Generate for All Prompts",type="primary",use_container_width=True):
243
+ handle_bulk_json_generation(prompts,default_model,default_aspect,debug)
 
 
 
 
 
 
 
 
244
  except Exception as e:
245
+ st.error(str(e))
246
+
247
+ def handle_bulk_json_generation(prompts:List[Dict[str,str]],default_model:str,default_aspect:str,debug:bool):
248
+ total=len(prompts)
249
+ all_urls=[]
250
+ for i,p in enumerate(prompts,1):
251
+ st.markdown(f"**Prompt {i}/{total}** {p['content']}")
252
+ r2,src,errs=generate_images_parallel(default_model,default_aspect,p["content"])
253
+ if r2:
254
+ save_creative_record_optimized(default_model,default_aspect,p["content"],r2)
255
+ display_image_gallery_optimized(r2)
256
+ all_urls.extend(r2)
257
+ elif src:
258
+ display_image_gallery_optimized(src)
259
+ all_urls.extend(src)
260
+ else:
261
+ st.error("No image generated")
262
+ if errs and debug:
263
+ st.error(errs)
264
+ if all_urls:
265
+ st.subheader("Download All")
266
+ bulk_download_button(all_urls,"all_prompts_images.zip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ def render_library_page():
269
+ st.subheader("Creative Library")
270
+ today=datetime.utcnow().date()
271
+ start_date=st.date_input("Start date",today-timedelta(days=30))
272
+ end_date=st.date_input("End date",today)
273
+ start_dt=datetime.combine(start_date,datetime.min.time())
274
+ end_dt=datetime.combine(end_date+timedelta(days=1),datetime.min.time())
275
+ records,total=query_creatives_optimized(start_dt,end_dt,0)
276
+ st.caption(f"{total} items")
277
+ for rec in records:
278
+ urls=rec.get("urls",[])
279
+ if urls: display_image_gallery_optimized(urls)
280
 
281
+ def main_app():
282
+ st.title("File-to-Image Generator")
283
+ page=st.sidebar.radio("Navigation",["Generate from JSON","Creative Library"])
284
+ if page=="Generate from JSON": render_json_page()
285
+ else: render_library_page()
 
 
 
 
286
 
287
  def main():
288
+ main_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ if __name__=="__main__":
291
  main()