userIdc2024 commited on
Commit
fd185fb
·
verified ·
1 Parent(s): 6e3f68d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +52 -33
src/streamlit_app.py CHANGED
@@ -213,7 +213,15 @@ def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: st
213
  if collection is None:
214
  return None
215
  try:
216
- doc = {"model": model_key,"aspect_ratio": aspect_ratio,"prompt": prompt,"urls": urls,"num_images": len(urls),"lob": "balraaj","created_at": datetime.utcnow()}
 
 
 
 
 
 
 
 
217
  result = collection.insert_one(doc)
218
  return str(result.inserted_id)
219
  except Exception:
@@ -225,8 +233,13 @@ def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int =
225
  if collection is None:
226
  return [], 0
227
  try:
228
- total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt},"lob": "balraaj"})
229
- cursor = collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"}).sort("created_at", -1).skip(page * LIBRARY_PAGE_SIZE).limit(LIBRARY_PAGE_SIZE)
 
 
 
 
 
230
  records = list(cursor)
231
  return records, total_count
232
  except Exception:
@@ -396,97 +409,103 @@ def check_token_cached(user_token: str) -> Tuple[bool, str]:
396
  def load_json_prompts(file) -> List[Dict[str, Any]]:
397
  raw = file.getvalue().decode("utf-8", errors="replace")
398
  data = json.loads(raw)
 
 
 
 
399
  prompts_out: List[Dict[str, Any]] = []
400
- def push(idx: int, item: Any):
401
- if isinstance(item, str):
402
- prompts_out.append({"id": f"p{idx}", "content": item})
403
- elif isinstance(item, dict) and "content" in item:
404
- obj = {"id": str(item.get("id", f"p{idx}")), "content": str(item["content"])}
405
- if "num" in item: obj["num"] = int(item["num"])
406
- if "aspect_ratio" in item: obj["aspect_ratio"] = str(item["aspect_ratio"])
407
- if "model" in item: obj["model"] = str(item["model"])
408
- prompts_out.append(obj)
409
- if isinstance(data, dict) and "prompts" in data and isinstance(data["prompts"], list):
410
- for i, item in enumerate(data["prompts"], 1):
411
- push(i, item)
412
- elif isinstance(data, list):
413
- for i, item in enumerate(data, 1):
414
- push(i, item)
415
  return prompts_out
416
 
417
  def render_json_page():
418
  st.subheader("Generate from JSON Prompts")
419
  up = st.file_uploader("Upload prompts JSON", type=["json"])
420
- col1, col2, col3 = st.columns([1,1,1])
 
421
  with col1:
422
  default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
423
  with col2:
424
  aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
425
  default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
426
- with col3:
427
- default_num = st.slider("Default Images per Prompt", 1, 20, 1, 1, key="json_default_num")
428
  debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
 
429
  if up:
430
  try:
431
  prompts_list = load_json_prompts(up)
432
  if not prompts_list:
433
  st.error("No prompts found in the JSON.")
434
  return
 
435
  with st.expander("Preview normalized prompts", expanded=False):
436
  st.json(prompts_list, expanded=False)
 
437
  if st.button("Generate for All Prompts", type="primary", use_container_width=True):
438
- handle_bulk_json_generation(prompts_list, default_model, default_aspect, default_num, debug_mode)
439
  except json.JSONDecodeError as e:
440
  st.error(f"Invalid JSON: {e}")
441
  except Exception as e:
442
  st.error(f"Failed to read prompts: {e}")
443
  else:
444
- st.caption("Your JSON can be a list of strings, a list of objects with `content`, or `{ 'prompts': [...] }`. Optional per-prompt keys: `model`, `aspect_ratio`, `num`.")
445
 
446
- def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str, default_num: int, debug_mode: bool):
447
  if not REPLICATE_API_TOKEN:
448
  st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
449
  return
 
450
  total = len(prompts)
451
  overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
452
  all_generated_urls: List[str] = []
453
  errors_total: List[str] = []
454
  start_time = time.time()
 
455
  for idx, p in enumerate(prompts, 1):
456
- model_key = str(p.get("model", default_model))
457
- aspect_ratio = str(p.get("aspect_ratio", default_aspect))
458
- num_images = int(p.get("num", default_num))
459
  prompt_text = str(p.get("content", "")).strip()
 
460
  block = st.container(border=True)
461
  with block:
462
- st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `{num_images}`")
463
  st.code(prompt_text or "(empty)", language="markdown")
 
464
  if not prompt_text:
465
  st.error("Prompt text is empty. Skipping.")
466
  overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
467
  continue
 
468
  r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
469
  rec_id = None
470
  if r2_urls:
471
  rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
 
472
  if r2_urls:
473
- st.success(f"Generated {len(r2_urls)} image(s). DB: {rec_id or 'N/A'}")
474
  display_image_gallery_optimized(r2_urls)
475
- bulk_download_button(r2_urls, filename=f"prompt_{idx}_images.zip")
476
  all_generated_urls.extend(r2_urls)
477
  elif src_urls:
478
- st.warning("Images generated but R2 upload failed. Showing originals:")
479
  display_image_gallery_optimized(src_urls)
480
- bulk_download_button(src_urls, filename=f"prompt_{idx}_images.zip")
481
  all_generated_urls.extend(src_urls)
482
  else:
483
- st.error("No images were generated for this prompt.")
 
484
  if gen_errors and debug_mode:
485
  with st.expander("Errors", expanded=False):
486
  for e in gen_errors:
487
  st.error(e)
488
  errors_total.extend(gen_errors)
 
489
  overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
 
490
  elapsed = time.time() - start_time
491
  st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
492
  if all_generated_urls:
 
213
  if collection is None:
214
  return None
215
  try:
216
+ doc = {
217
+ "model": model_key,
218
+ "aspect_ratio": aspect_ratio,
219
+ "prompt": prompt,
220
+ "urls": urls,
221
+ "num_images": len(urls),
222
+ "lob": "balraaj",
223
+ "created_at": datetime.utcnow()
224
+ }
225
  result = collection.insert_one(doc)
226
  return str(result.inserted_id)
227
  except Exception:
 
233
  if collection is None:
234
  return [], 0
235
  try:
236
+ total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
237
+ cursor = (
238
+ collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
239
+ .sort("created_at", -1)
240
+ .skip(page * LIBRARY_PAGE_SIZE)
241
+ .limit(LIBRARY_PAGE_SIZE)
242
+ )
243
  records = list(cursor)
244
  return records, total_count
245
  except Exception:
 
409
  def load_json_prompts(file) -> List[Dict[str, Any]]:
410
  raw = file.getvalue().decode("utf-8", errors="replace")
411
  data = json.loads(raw)
412
+
413
+ if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
414
+ raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
415
+
416
  prompts_out: List[Dict[str, Any]] = []
417
+ for i, item in enumerate(data["prompts"], 1):
418
+ if not isinstance(item, str) or not item.strip():
419
+ raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
420
+ prompts_out.append({"id": f"p{i}", "content": item.strip()})
421
+
 
 
 
 
 
 
 
 
 
 
422
  return prompts_out
423
 
424
  def render_json_page():
425
  st.subheader("Generate from JSON Prompts")
426
  up = st.file_uploader("Upload prompts JSON", type=["json"])
427
+
428
+ col1, col2 = st.columns([1, 1])
429
  with col1:
430
  default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
431
  with col2:
432
  aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
433
  default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
434
+
 
435
  debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
436
+
437
  if up:
438
  try:
439
  prompts_list = load_json_prompts(up)
440
  if not prompts_list:
441
  st.error("No prompts found in the JSON.")
442
  return
443
+
444
  with st.expander("Preview normalized prompts", expanded=False):
445
  st.json(prompts_list, expanded=False)
446
+
447
  if st.button("Generate for All Prompts", type="primary", use_container_width=True):
448
+ handle_bulk_json_generation(prompts_list, default_model, default_aspect, debug_mode)
449
  except json.JSONDecodeError as e:
450
  st.error(f"Invalid JSON: {e}")
451
  except Exception as e:
452
  st.error(f"Failed to read prompts: {e}")
453
  else:
454
+ st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
455
 
456
+ def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str, debug_mode: bool):
457
  if not REPLICATE_API_TOKEN:
458
  st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
459
  return
460
+
461
  total = len(prompts)
462
  overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
463
  all_generated_urls: List[str] = []
464
  errors_total: List[str] = []
465
  start_time = time.time()
466
+
467
  for idx, p in enumerate(prompts, 1):
468
+ model_key = default_model
469
+ aspect_ratio = default_aspect
470
+ num_images = 1 # exactly one image per prompt
471
  prompt_text = str(p.get("content", "")).strip()
472
+
473
  block = st.container(border=True)
474
  with block:
475
+ st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `1`")
476
  st.code(prompt_text or "(empty)", language="markdown")
477
+
478
  if not prompt_text:
479
  st.error("Prompt text is empty. Skipping.")
480
  overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
481
  continue
482
+
483
  r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
484
  rec_id = None
485
  if r2_urls:
486
  rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
487
+
488
  if r2_urls:
489
+ st.success(f"Generated 1 image. DB: {rec_id or 'N/A'}")
490
  display_image_gallery_optimized(r2_urls)
491
+ bulk_download_button(r2_urls, filename=f"prompt_{idx}_image.zip")
492
  all_generated_urls.extend(r2_urls)
493
  elif src_urls:
494
+ st.warning("Image generated but R2 upload failed. Showing original:")
495
  display_image_gallery_optimized(src_urls)
496
+ bulk_download_button(src_urls, filename=f"prompt_{idx}_image.zip")
497
  all_generated_urls.extend(src_urls)
498
  else:
499
+ st.error("No image was generated for this prompt.")
500
+
501
  if gen_errors and debug_mode:
502
  with st.expander("Errors", expanded=False):
503
  for e in gen_errors:
504
  st.error(e)
505
  errors_total.extend(gen_errors)
506
+
507
  overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
508
+
509
  elapsed = time.time() - start_time
510
  st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
511
  if all_generated_urls: