Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +52 -33
src/streamlit_app.py
CHANGED
|
@@ -213,7 +213,15 @@ def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: st
|
|
| 213 |
if collection is None:
|
| 214 |
return None
|
| 215 |
try:
|
| 216 |
-
doc = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
result = collection.insert_one(doc)
|
| 218 |
return str(result.inserted_id)
|
| 219 |
except Exception:
|
|
@@ -225,8 +233,13 @@ def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int =
|
|
| 225 |
if collection is None:
|
| 226 |
return [], 0
|
| 227 |
try:
|
| 228 |
-
total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt},"lob": "balraaj"})
|
| 229 |
-
cursor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
records = list(cursor)
|
| 231 |
return records, total_count
|
| 232 |
except Exception:
|
|
@@ -396,97 +409,103 @@ def check_token_cached(user_token: str) -> Tuple[bool, str]:
|
|
| 396 |
def load_json_prompts(file) -> List[Dict[str, Any]]:
|
| 397 |
raw = file.getvalue().decode("utf-8", errors="replace")
|
| 398 |
data = json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
prompts_out: List[Dict[str, Any]] = []
|
| 400 |
-
|
| 401 |
-
if isinstance(item, str):
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
if "num" in item: obj["num"] = int(item["num"])
|
| 406 |
-
if "aspect_ratio" in item: obj["aspect_ratio"] = str(item["aspect_ratio"])
|
| 407 |
-
if "model" in item: obj["model"] = str(item["model"])
|
| 408 |
-
prompts_out.append(obj)
|
| 409 |
-
if isinstance(data, dict) and "prompts" in data and isinstance(data["prompts"], list):
|
| 410 |
-
for i, item in enumerate(data["prompts"], 1):
|
| 411 |
-
push(i, item)
|
| 412 |
-
elif isinstance(data, list):
|
| 413 |
-
for i, item in enumerate(data, 1):
|
| 414 |
-
push(i, item)
|
| 415 |
return prompts_out
|
| 416 |
|
| 417 |
def render_json_page():
|
| 418 |
st.subheader("Generate from JSON Prompts")
|
| 419 |
up = st.file_uploader("Upload prompts JSON", type=["json"])
|
| 420 |
-
|
|
|
|
| 421 |
with col1:
|
| 422 |
default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
|
| 423 |
with col2:
|
| 424 |
aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
|
| 425 |
default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
|
| 426 |
-
|
| 427 |
-
default_num = st.slider("Default Images per Prompt", 1, 20, 1, 1, key="json_default_num")
|
| 428 |
debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
|
|
|
|
| 429 |
if up:
|
| 430 |
try:
|
| 431 |
prompts_list = load_json_prompts(up)
|
| 432 |
if not prompts_list:
|
| 433 |
st.error("No prompts found in the JSON.")
|
| 434 |
return
|
|
|
|
| 435 |
with st.expander("Preview normalized prompts", expanded=False):
|
| 436 |
st.json(prompts_list, expanded=False)
|
|
|
|
| 437 |
if st.button("Generate for All Prompts", type="primary", use_container_width=True):
|
| 438 |
-
handle_bulk_json_generation(prompts_list, default_model, default_aspect,
|
| 439 |
except json.JSONDecodeError as e:
|
| 440 |
st.error(f"Invalid JSON: {e}")
|
| 441 |
except Exception as e:
|
| 442 |
st.error(f"Failed to read prompts: {e}")
|
| 443 |
else:
|
| 444 |
-
st.caption(
|
| 445 |
|
| 446 |
-
def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str,
|
| 447 |
if not REPLICATE_API_TOKEN:
|
| 448 |
st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
|
| 449 |
return
|
|
|
|
| 450 |
total = len(prompts)
|
| 451 |
overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
|
| 452 |
all_generated_urls: List[str] = []
|
| 453 |
errors_total: List[str] = []
|
| 454 |
start_time = time.time()
|
|
|
|
| 455 |
for idx, p in enumerate(prompts, 1):
|
| 456 |
-
model_key =
|
| 457 |
-
aspect_ratio =
|
| 458 |
-
num_images =
|
| 459 |
prompt_text = str(p.get("content", "")).strip()
|
|
|
|
| 460 |
block = st.container(border=True)
|
| 461 |
with block:
|
| 462 |
-
st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `
|
| 463 |
st.code(prompt_text or "(empty)", language="markdown")
|
|
|
|
| 464 |
if not prompt_text:
|
| 465 |
st.error("Prompt text is empty. Skipping.")
|
| 466 |
overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
|
| 467 |
continue
|
|
|
|
| 468 |
r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
|
| 469 |
rec_id = None
|
| 470 |
if r2_urls:
|
| 471 |
rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
|
|
|
|
| 472 |
if r2_urls:
|
| 473 |
-
st.success(f"Generated
|
| 474 |
display_image_gallery_optimized(r2_urls)
|
| 475 |
-
bulk_download_button(r2_urls, filename=f"prompt_{idx}
|
| 476 |
all_generated_urls.extend(r2_urls)
|
| 477 |
elif src_urls:
|
| 478 |
-
st.warning("
|
| 479 |
display_image_gallery_optimized(src_urls)
|
| 480 |
-
bulk_download_button(src_urls, filename=f"prompt_{idx}
|
| 481 |
all_generated_urls.extend(src_urls)
|
| 482 |
else:
|
| 483 |
-
st.error("No
|
|
|
|
| 484 |
if gen_errors and debug_mode:
|
| 485 |
with st.expander("Errors", expanded=False):
|
| 486 |
for e in gen_errors:
|
| 487 |
st.error(e)
|
| 488 |
errors_total.extend(gen_errors)
|
|
|
|
| 489 |
overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
|
|
|
|
| 490 |
elapsed = time.time() - start_time
|
| 491 |
st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
|
| 492 |
if all_generated_urls:
|
|
|
|
| 213 |
if collection is None:
|
| 214 |
return None
|
| 215 |
try:
|
| 216 |
+
doc = {
|
| 217 |
+
"model": model_key,
|
| 218 |
+
"aspect_ratio": aspect_ratio,
|
| 219 |
+
"prompt": prompt,
|
| 220 |
+
"urls": urls,
|
| 221 |
+
"num_images": len(urls),
|
| 222 |
+
"lob": "balraaj",
|
| 223 |
+
"created_at": datetime.utcnow()
|
| 224 |
+
}
|
| 225 |
result = collection.insert_one(doc)
|
| 226 |
return str(result.inserted_id)
|
| 227 |
except Exception:
|
|
|
|
| 233 |
if collection is None:
|
| 234 |
return [], 0
|
| 235 |
try:
|
| 236 |
+
total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
|
| 237 |
+
cursor = (
|
| 238 |
+
collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
|
| 239 |
+
.sort("created_at", -1)
|
| 240 |
+
.skip(page * LIBRARY_PAGE_SIZE)
|
| 241 |
+
.limit(LIBRARY_PAGE_SIZE)
|
| 242 |
+
)
|
| 243 |
records = list(cursor)
|
| 244 |
return records, total_count
|
| 245 |
except Exception:
|
|
|
|
| 409 |
def load_json_prompts(file) -> List[Dict[str, Any]]:
|
| 410 |
raw = file.getvalue().decode("utf-8", errors="replace")
|
| 411 |
data = json.loads(raw)
|
| 412 |
+
|
| 413 |
+
if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
|
| 414 |
+
raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
|
| 415 |
+
|
| 416 |
prompts_out: List[Dict[str, Any]] = []
|
| 417 |
+
for i, item in enumerate(data["prompts"], 1):
|
| 418 |
+
if not isinstance(item, str) or not item.strip():
|
| 419 |
+
raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
|
| 420 |
+
prompts_out.append({"id": f"p{i}", "content": item.strip()})
|
| 421 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
return prompts_out
|
| 423 |
|
| 424 |
def render_json_page():
|
| 425 |
st.subheader("Generate from JSON Prompts")
|
| 426 |
up = st.file_uploader("Upload prompts JSON", type=["json"])
|
| 427 |
+
|
| 428 |
+
col1, col2 = st.columns([1, 1])
|
| 429 |
with col1:
|
| 430 |
default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
|
| 431 |
with col2:
|
| 432 |
aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
|
| 433 |
default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
|
| 434 |
+
|
|
|
|
| 435 |
debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
|
| 436 |
+
|
| 437 |
if up:
|
| 438 |
try:
|
| 439 |
prompts_list = load_json_prompts(up)
|
| 440 |
if not prompts_list:
|
| 441 |
st.error("No prompts found in the JSON.")
|
| 442 |
return
|
| 443 |
+
|
| 444 |
with st.expander("Preview normalized prompts", expanded=False):
|
| 445 |
st.json(prompts_list, expanded=False)
|
| 446 |
+
|
| 447 |
if st.button("Generate for All Prompts", type="primary", use_container_width=True):
|
| 448 |
+
handle_bulk_json_generation(prompts_list, default_model, default_aspect, debug_mode)
|
| 449 |
except json.JSONDecodeError as e:
|
| 450 |
st.error(f"Invalid JSON: {e}")
|
| 451 |
except Exception as e:
|
| 452 |
st.error(f"Failed to read prompts: {e}")
|
| 453 |
else:
|
| 454 |
+
st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
|
| 455 |
|
| 456 |
+
def handle_bulk_json_generation(prompts: List[Dict[str, Any]], default_model: str, default_aspect: str, debug_mode: bool):
|
| 457 |
if not REPLICATE_API_TOKEN:
|
| 458 |
st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
|
| 459 |
return
|
| 460 |
+
|
| 461 |
total = len(prompts)
|
| 462 |
overall_progress = st.progress(0, text=f"Starting batch • 0/{total}")
|
| 463 |
all_generated_urls: List[str] = []
|
| 464 |
errors_total: List[str] = []
|
| 465 |
start_time = time.time()
|
| 466 |
+
|
| 467 |
for idx, p in enumerate(prompts, 1):
|
| 468 |
+
model_key = default_model
|
| 469 |
+
aspect_ratio = default_aspect
|
| 470 |
+
num_images = 1 # exactly one image per prompt
|
| 471 |
prompt_text = str(p.get("content", "")).strip()
|
| 472 |
+
|
| 473 |
block = st.container(border=True)
|
| 474 |
with block:
|
| 475 |
+
st.markdown(f"**Prompt {idx}/{total}** — Model: `{model_key}` • Aspect: `{aspect_ratio}` • Num: `1`")
|
| 476 |
st.code(prompt_text or "(empty)", language="markdown")
|
| 477 |
+
|
| 478 |
if not prompt_text:
|
| 479 |
st.error("Prompt text is empty. Skipping.")
|
| 480 |
overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
|
| 481 |
continue
|
| 482 |
+
|
| 483 |
r2_urls, src_urls, gen_errors = generate_images_parallel(model_key, aspect_ratio, prompt_text, num_images)
|
| 484 |
rec_id = None
|
| 485 |
if r2_urls:
|
| 486 |
rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
|
| 487 |
+
|
| 488 |
if r2_urls:
|
| 489 |
+
st.success(f"Generated 1 image. DB: {rec_id or 'N/A'}")
|
| 490 |
display_image_gallery_optimized(r2_urls)
|
| 491 |
+
bulk_download_button(r2_urls, filename=f"prompt_{idx}_image.zip")
|
| 492 |
all_generated_urls.extend(r2_urls)
|
| 493 |
elif src_urls:
|
| 494 |
+
st.warning("Image generated but R2 upload failed. Showing original:")
|
| 495 |
display_image_gallery_optimized(src_urls)
|
| 496 |
+
bulk_download_button(src_urls, filename=f"prompt_{idx}_image.zip")
|
| 497 |
all_generated_urls.extend(src_urls)
|
| 498 |
else:
|
| 499 |
+
st.error("No image was generated for this prompt.")
|
| 500 |
+
|
| 501 |
if gen_errors and debug_mode:
|
| 502 |
with st.expander("Errors", expanded=False):
|
| 503 |
for e in gen_errors:
|
| 504 |
st.error(e)
|
| 505 |
errors_total.extend(gen_errors)
|
| 506 |
+
|
| 507 |
overall_progress.progress(min(idx / total, 1.0), text=f"Processed {idx}/{total}")
|
| 508 |
+
|
| 509 |
elapsed = time.time() - start_time
|
| 510 |
st.success(f"Batch complete in {elapsed:.1f}s. Total prompts: {total}.")
|
| 511 |
if all_generated_urls:
|