userIdc2024 commited on
Commit
6e3f68d
·
verified ·
1 Parent(s): d0f8f72

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +5 -74
src/streamlit_app.py CHANGED
@@ -272,84 +272,15 @@ def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
272
  st.download_button("Download All Images", data=zip_buffer, file_name=filename, mime="application/zip", use_container_width=True)
273
 
274
  def main_app():
275
- st.set_page_config(page_title="Image Generator + Creative Library", layout="wide")
276
- st.title("Multi Model Image Generator")
277
  with st.sidebar:
278
- page = st.radio(" ", ["Generate Bulk Images", "Generate from JSON", "Creative Library"], index=0)
279
- if page == "Generate Bulk Images":
280
- render_generate_page()
281
- elif page == "Generate from JSON":
282
  render_json_page()
283
  elif page == "Creative Library":
284
  render_library_page()
285
 
286
- def render_generate_page():
287
- colA, colB = st.columns([1, 1])
288
- with colA:
289
- model_key = st.selectbox("Model", list(MODEL_REGISTRY.keys()), index=0)
290
- aspect_options = MODEL_REGISTRY[model_key]["aspect_ratios"]
291
- aspect_ratio = st.selectbox("Aspect Ratio", aspect_options, index=0)
292
- num_images = st.slider("Number of images", min_value=1, max_value=50, value=1, step=1)
293
- with colB:
294
- prompt = st.text_area("Prompt", placeholder="Describe the image you want to generate...", height=160)
295
- debug_mode = st.checkbox("Debug Mode")
296
- if st.button("Generate Images", type="primary", use_container_width=True):
297
- handle_image_generation_optimized(model_key, aspect_ratio, prompt, num_images, debug_mode)
298
-
299
- def handle_image_generation_optimized(model_key: str, aspect_ratio: str, prompt: str, num_images: int, debug_mode: bool = False):
300
- if not REPLICATE_API_TOKEN:
301
- st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
302
- return
303
- if not prompt.strip():
304
- st.warning("Please enter a prompt.")
305
- return
306
- progress = st.progress(0, text="Starting generation...")
307
- status_container = st.empty()
308
- start_time = time.time()
309
- try:
310
- with status_container.container():
311
- st.info(f"Generating {num_images} image(s) in parallel...")
312
- progress.progress(0.1, text="Initializing parallel generation...")
313
- all_r2_urls, all_source_urls, generation_errors = generate_images_parallel(model_key, aspect_ratio, prompt.strip(), num_images)
314
- progress.progress(0.8, text="Saving results...")
315
- rec_id = None
316
- if all_r2_urls:
317
- rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt.strip(), all_r2_urls)
318
- progress.progress(1.0, text="Complete!")
319
- generation_time = time.time() - start_time
320
- if all_r2_urls:
321
- with status_container.container():
322
- st.success(f"Generated {len(all_r2_urls)} image(s) in {generation_time:.1f}s. Saved to DB: {rec_id or 'N/A'}")
323
- display_image_gallery_optimized(all_r2_urls)
324
- bulk_download_button(all_r2_urls, filename="generated_images.zip")
325
- elif all_source_urls:
326
- with status_container.container():
327
- st.warning("Images generated but R2 upload failed. Showing originals:")
328
- display_image_gallery_optimized(all_source_urls)
329
- bulk_download_button(all_source_urls, filename="generated_images.zip")
330
- else:
331
- with status_container.container():
332
- st.error("No images were generated.")
333
- if generation_errors and debug_mode:
334
- with st.expander("Generation Errors", expanded=True):
335
- for error in generation_errors:
336
- st.error(f"{error}")
337
- except Exception as e:
338
- with status_container.container():
339
- st.error(f"Generation failed: {str(e)}")
340
-
341
- def display_image_gallery_optimized(urls: List[str]):
342
- if not urls:
343
- return
344
- num_cols = min(4, len(urls)) if len(urls) > 1 else 1
345
- cols = st.columns(num_cols)
346
- for idx, url in enumerate(urls):
347
- with cols[idx % num_cols]:
348
- try:
349
- display_image_with_download_optimized(url)
350
- except Exception as e:
351
- st.error(f"Failed to display image: {e}")
352
-
353
  def render_library_page():
354
  st.subheader("Creative Library")
355
  if "library_page" not in st.session_state:
@@ -567,7 +498,7 @@ def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: st
567
  st.error(e)
568
 
569
  def main():
570
- st.set_page_config(page_title="Bulk Creative Generation", layout="wide")
571
  if "authenticated" not in st.session_state:
572
  st.session_state["authenticated"] = False
573
  if not st.session_state["authenticated"]:
 
272
  st.download_button("Download All Images", data=zip_buffer, file_name=filename, mime="application/zip", use_container_width=True)
273
 
274
  def main_app():
275
+ st.set_page_config(page_title="File-to-Image Creative Library", layout="wide")
276
+ st.title("File-to-Image Generator")
277
  with st.sidebar:
278
+ page = st.radio(" ", ["Generate from JSON", "Creative Library"], index=0)
279
+ if page == "Generate from JSON":
 
 
280
  render_json_page()
281
  elif page == "Creative Library":
282
  render_library_page()
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def render_library_page():
285
  st.subheader("Creative Library")
286
  if "library_page" not in st.session_state:
 
498
  st.error(e)
499
 
500
  def main():
501
+ st.set_page_config(page_title="File-to-Image Creative Library", layout="wide")
502
  if "authenticated" not in st.session_state:
503
  st.session_state["authenticated"] = False
504
  if not st.session_state["authenticated"]: