userIdc2024 commited on
Commit
aded900
·
verified ·
1 Parent(s): 6fb57f4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +131 -326
src/streamlit_app.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  import logging
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from datetime import datetime, timedelta, date
10
- from typing import Dict, Any, List, Tuple, Optional, Union
11
  import requests
12
  import streamlit as st
13
  from pymongo import MongoClient
@@ -37,22 +37,18 @@ REQUEST_TIMEOUT = 30
37
  RETRY_ATTEMPTS = 3
38
  LIBRARY_PAGE_SIZE = 20
39
 
40
- # Model registry (subset with common ARs)
41
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
42
  "imagegen-4-ultra": {"id": "google/imagen-4-ultra", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
43
  "imagen-4": {"id": "google/imagen-4", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
44
- "nano-banana": {"id": "google/nano-banana", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
45
  "qwen": {"id": "qwen/qwen-image", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"], "param_name": "aspect_ratio"},
46
  "seedream-3": {"id": "bytedance/seedream-3", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"], "param_name": "aspect_ratio"},
47
  "recraft-v3": {"id": "recraft-ai/recraft-v3", "aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"], "param_name": "aspect_ratio"},
48
- "photon": {"id": "luma/photon", "aspect_ratios": ["1:1","3:4","4:3","9:16","16:9","21:9"], "param_name": "aspect_ratio"},
49
- "ideogram-v3-quality":{"id": "ideogram-ai/ideogram-v3-quality", "aspect_ratios": ["1:1","16:9","9:16","2:3","3:2","4:5","5:4"], "param_name": "aspect_ratio"},
50
  }
51
 
52
  _thread_local = threading.local()
53
 
54
  # ----------------------------
55
- # Infra helpers (Mongo / S3)
56
  # ----------------------------
57
  def get_mongo_collection():
58
  if not hasattr(_thread_local, 'mongo_collection'):
@@ -73,8 +69,7 @@ def get_mongo_collection():
73
  def get_s3_client():
74
  if not hasattr(_thread_local, 's3_client'):
75
  required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
76
- missing = [var for var in required_vars if not os.getenv(var)]
77
- if missing:
78
  _thread_local.s3_client = None
79
  return None
80
  try:
@@ -95,403 +90,213 @@ def get_model_config(model_key: str) -> Optional[Dict[str, Any]]:
95
  return MODEL_REGISTRY.get(model_key)
96
 
97
  # ----------------------------
98
- # R2 upload
99
  # ----------------------------
100
- def upload_to_r2_optimized(image_bytes: bytes) -> Optional[str]:
101
- s3_client = get_s3_client()
102
- if not s3_client:
103
  return None
104
  try:
105
  filename = f"{uuid4().hex}.png"
106
  file_key = f"adgenesis_image_file/balraaj/images/{filename}"
107
- s3_client.put_object(
108
  Bucket=os.getenv("R2_BUCKET_NAME"),
109
  Key=file_key,
110
  Body=image_bytes,
111
  ContentType="image/png",
112
  )
113
- r2_url = f'{os.getenv("NEW_BASE").rstrip("/")}/{file_key}'
114
- return r2_url
115
  except Exception as e:
116
- logger.error(f"S3 upload failed: {e}")
117
  return None
118
 
119
- # ----------------------------
120
- # Generation & fetching
121
- # ----------------------------
122
- def generate_one_image_optimized(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
123
  if not REPLICATE_API_TOKEN:
124
  return []
125
  config = get_model_config(model_key)
126
  if not config:
127
  return []
128
  try:
129
- model_id = config["id"]
130
- ar_param = config["param_name"]
131
- inputs = {"prompt": prompt, ar_param: aspect_ratio}
132
- output = replicate.run(model_id, input=inputs)
133
- # Normalize to list[str]
134
  if isinstance(output, list) and output:
135
- first = output[0]
136
- return [getattr(first, "url", str(first))]
137
  elif isinstance(output, str):
138
  return [output]
139
- elif hasattr(output, "url"):
140
- return [getattr(output, "url")]
141
  return []
142
  except Exception as e:
143
  logger.error(f"Replicate error: {e}")
144
  return []
145
 
146
- def fetch_image_bytes_optimized(url: str) -> Optional[bytes]:
147
- for attempt in range(RETRY_ATTEMPTS):
148
  try:
149
- response = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
150
- response.raise_for_status()
151
- return response.content
152
  except Exception:
153
- if attempt == RETRY_ATTEMPTS - 1:
154
- return None
155
  time.sleep(1)
156
  return None
157
 
158
- def process_single_image(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
159
- model_key, prompt, aspect_ratio, index = args
160
- result = {"index": index, "success": False, "source_url": None, "r2_url": None, "error": None}
161
- urls = generate_one_image_optimized(model_key, prompt, aspect_ratio)
162
  if not urls:
163
- result["error"] = "No URLs returned from generation"
164
- return result
165
- source_url = urls[0]
166
- result["source_url"] = source_url
167
- img_bytes = fetch_image_bytes_optimized(source_url)
168
  if not img_bytes:
169
- result["error"] = "Failed to fetch image bytes"
170
- return result
171
- r2_url = upload_to_r2_optimized(img_bytes)
172
- if r2_url:
173
- result["r2_url"] = r2_url
174
- result["success"] = True
175
- else:
176
- result["error"] = "Failed to upload to R2"
177
- return result
178
-
179
- def generate_one_per_prompt(model_key: str, aspect_ratio: str, prompt: str) -> Tuple[List[str], List[str], List[str]]:
180
- """One image per prompt (no parallel within a prompt)."""
181
- res = process_single_image((model_key, prompt, aspect_ratio, 0))
182
- if res["success"]:
183
- return [res["r2_url"]], [res["source_url"]], []
184
- else:
185
- return [], [], [res["error"] or "Generation failed"]
186
 
187
  # ----------------------------
188
  # Persistence
189
  # ----------------------------
190
- def save_creative_record_optimized(model_key: str, aspect_ratio: str, prompt: str, urls: List[str]) -> Optional[str]:
191
- collection = get_mongo_collection()
192
- if collection is None:
193
  return None
194
  try:
195
- doc = {
196
- "model": model_key,
197
- "aspect_ratio": aspect_ratio,
198
  "prompt": prompt,
199
  "urls": urls,
200
- "num_images": len(urls),
201
  "lob": "balraaj",
202
  "created_at": datetime.utcnow()
203
- }
204
- ins = collection.insert_one(doc)
205
- return str(ins.inserted_id)
206
  except Exception as e:
207
  logger.error(f"Mongo insert failed: {e}")
208
  return None
209
 
210
  @st.cache_data(ttl=300)
211
- def query_creatives_optimized(start_dt: datetime, end_dt: datetime, page: int = 0) -> Tuple[List[Dict[str, Any]], int]:
212
- collection = get_mongo_collection()
213
- if collection is None:
214
- return [], 0
215
  try:
216
- total_count = collection.count_documents({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
217
- cursor = (
218
- collection.find({"created_at": {"$gte": start_dt, "$lt": end_dt}, "lob": "balraaj"})
219
- .sort("created_at", -1)
220
- .skip(page * LIBRARY_PAGE_SIZE)
221
- .limit(LIBRARY_PAGE_SIZE)
222
- )
223
- return list(cursor), total_count
224
  except Exception:
225
- return [], 0
226
 
227
  # ----------------------------
228
- # UI helpers: images
229
  # ----------------------------
230
- @st.cache_data(ttl=3600)
231
- def get_image_bytes_cached(url: str) -> Optional[bytes]:
232
- return fetch_image_bytes_optimized(url)
233
-
234
- def display_image_with_download_optimized(url: str):
235
- try:
236
- img_bytes = get_image_bytes_cached(url)
237
- if not img_bytes:
238
- st.error("Failed to load image")
239
- return
240
- st.image(img_bytes, use_container_width=True)
241
- base = os.path.basename(urlparse(url).path) or "image.png"
242
- if not os.path.splitext(base)[1]:
243
- base = f"{base}.png"
244
- st.download_button(
245
- label="Download image",
246
- data=img_bytes,
247
- file_name=base,
248
- mime="image/png",
249
- use_container_width=True
250
- )
251
- except Exception as e:
252
- st.error(f"Failed to display image: {e}")
253
-
254
- def display_image_gallery_optimized(urls: List[str]):
255
- if not urls:
256
- return
257
- num_cols = min(4, max(1, len(urls)))
258
- cols = st.columns(num_cols)
259
- for i, url in enumerate(urls):
260
- with cols[i % num_cols]:
261
- display_image_with_download_optimized(url)
262
-
263
- def bulk_download_button(urls: List[str], filename: str = "images_bundle.zip"):
264
- if not urls:
265
- return
266
- zip_buffer = io.BytesIO()
267
- with zipfile.ZipFile(zip_buffer, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
268
- for idx, url in enumerate(urls, 1):
269
  try:
270
- img_bytes = fetch_image_bytes_optimized(url)
271
- if img_bytes:
272
- path = urlparse(url).path
273
- base = os.path.basename(path) or f"image_{idx}.png"
274
- if not os.path.splitext(base)[1]:
275
- base = f"{base}.png"
276
- zip_file.writestr(base, img_bytes)
277
- except Exception:
278
- pass
279
- zip_buffer.seek(0)
280
- st.download_button(
281
- "Download All Images",
282
- data=zip_buffer,
283
- file_name=filename,
284
- mime="application/zip",
285
- use_container_width=True
286
- )
287
-
288
- # ----------------------------
289
- # JSON loader (STRICT)
290
- # ----------------------------
291
- def load_json_prompts(file) -> List[Dict[str, Any]]:
292
- raw = file.getvalue().decode("utf-8", errors="replace")
293
- data = json.loads(raw)
294
- if not isinstance(data, dict) or "prompts" not in data or not isinstance(data["prompts"], list):
295
- raise ValueError("Invalid JSON. Expected an object with a 'prompts' array of strings.")
296
- prompts_out: List[Dict[str, Any]] = []
297
- for i, item in enumerate(data["prompts"], 1):
298
- if not isinstance(item, str) or not item.strip():
299
- raise ValueError(f"'prompts[{i-1}]' must be a non-empty string.")
300
- prompts_out.append({"id": f"p{i}", "content": item.strip()})
301
- return prompts_out
302
 
303
  # ----------------------------
304
- # JSON page (parallel across prompts)
305
  # ----------------------------
306
- def _run_single_prompt(idx: int, prompt_text: str, model_key: str, aspect_ratio: str):
307
- r2_urls, src_urls, gen_errors = generate_one_per_prompt(model_key, aspect_ratio, prompt_text)
308
- rec_id = None
309
- if r2_urls:
310
- rec_id = save_creative_record_optimized(model_key, aspect_ratio, prompt_text, r2_urls)
311
- return {
312
- "idx": idx,
313
- "prompt": prompt_text,
314
- "r2_urls": r2_urls,
315
- "src_urls": src_urls,
316
- "errors": gen_errors,
317
- "rec_id": rec_id,
318
- }
319
-
320
- def render_json_page():
321
- st.subheader("Generate from JSON Prompts")
322
- up = st.file_uploader("Upload prompts JSON", type=["json"])
323
-
324
- col1, col2 = st.columns([1, 1])
325
- with col1:
326
- default_model = st.selectbox("Default Model", list(MODEL_REGISTRY.keys()), index=0)
327
- with col2:
328
- aspect_options = MODEL_REGISTRY[default_model]["aspect_ratios"]
329
- default_aspect = st.selectbox("Default Aspect Ratio", aspect_options, index=0, key="json_default_ar")
330
-
331
- debug_mode = st.checkbox("Debug Mode", value=False, key="json_debug")
332
-
333
- if up:
334
- try:
335
- prompts_list = load_json_prompts(up)
336
- with st.expander("Preview normalized prompts", expanded=False):
337
- st.json(prompts_list, expanded=False)
338
-
339
- if st.button("Generate for All Prompts", type="primary", use_container_width=True):
340
- handle_bulk_json_generation_parallel(prompts_list, default_model, default_aspect, debug_mode)
341
- except json.JSONDecodeError as e:
342
- st.error(f"Invalid JSON: {e}")
343
- except Exception as e:
344
- st.error(f"Failed to read prompts: {e}")
345
- else:
346
- st.caption('Expected format: { "prompts": ["prompt 1", "prompt 2", ...] }')
347
 
348
- def handle_bulk_json_generation_parallel(prompts: List[Dict[str, str]], default_model: str, default_aspect: str, debug: bool):
349
- if not REPLICATE_API_TOKEN:
350
- st.error("Missing REPLICATE_API_TOKEN. Set it as an environment variable.")
351
- return
352
  total = len(prompts)
353
- if total == 0:
354
- st.info("No prompts to process.")
355
- return
356
-
357
- # Placeholders for stable on-page order
358
- blocks = [st.container(border=True) for _ in range(total)]
359
- progress = st.progress(0, text=f"Starting batch • 0/{total}")
360
-
361
- all_urls: List[str] = []
362
- completed = 0
363
-
364
- max_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
365
-
366
- with st.spinner("Generating images..."):
367
- futures = {}
368
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
369
- for i, p in enumerate(prompts, 1):
370
- prompt_text = p.get("content", "").strip()
371
- if not prompt_text:
372
- # render immediately as invalid
373
- with blocks[i-1]:
374
- st.markdown(f"**Prompt {i}/{total}** — (empty)")
375
- st.error("Prompt text is empty. Skipping.")
376
- completed += 1
377
- progress.progress(completed / total, text=f"Processed {completed}/{total}")
378
- continue
379
- futures[ex.submit(_run_single_prompt, i, prompt_text, default_model, default_aspect)] = i
380
-
381
- for fut in as_completed(futures):
382
- i = futures[fut]
383
- try:
384
- res = fut.result()
385
- except Exception as e:
386
- res = {"idx": i, "prompt": "", "r2_urls": [], "src_urls": [], "errors": [str(e)], "rec_id": None}
387
-
388
- with blocks[i-1]:
389
- st.markdown(f"**Prompt {i}/{total}** — Model: `{default_model}` • Aspect: `{default_aspect}` • Num: `1`")
390
- st.code(res.get("prompt") or "(empty)", language="markdown")
391
-
392
- if res["r2_urls"]:
393
- st.success(f"Generated 1 image. DB: {res['rec_id'] or 'N/A'}")
394
- display_image_gallery_optimized(res["r2_urls"])
395
- bulk_download_button(res["r2_urls"], filename=f"prompt_{i}_image.zip")
396
- all_urls.extend(res["r2_urls"])
397
- elif res["src_urls"]:
398
- st.warning("Image generated but R2 upload failed. Showing original:")
399
- display_image_gallery_optimized(res["src_urls"])
400
- bulk_download_button(res["src_urls"], filename=f"prompt_{i}_image.zip")
401
- all_urls.extend(res["src_urls"])
402
- else:
403
- st.error("No image was generated for this prompt.")
404
-
405
- if res.get("errors") and debug:
406
- for e in res["errors"]:
407
- st.error(e)
408
-
409
- completed += 1
410
- progress.progress(completed / total, text=f"Processed {completed}/{total}")
411
 
412
- # Final all-images gallery & ZIP
413
  if all_urls:
414
- st.subheader("All Images Gallery")
415
- display_image_gallery_optimized(all_urls)
416
- st.subheader("Download All Generated")
417
- bulk_download_button(all_urls, filename="all_prompts_images.zip")
418
 
419
  # ----------------------------
420
- # Creative Library page
421
  # ----------------------------
422
- def render_library_page():
423
- st.subheader("Creative Library")
424
- if "library_page" not in st.session_state:
425
- st.session_state.library_page = 0
426
-
427
- today_utc = datetime.utcnow().date()
428
- default_start = today_utc - timedelta(days=30)
429
-
430
- c1, c2, c3 = st.columns([1, 1, 1])
431
- with c1:
432
- start_date: date = st.date_input("Start date", value=default_start)
433
- with c2:
434
- end_date: date = st.date_input("End date", value=today_utc)
435
- with c3:
436
- if st.button("Apply Filters", use_container_width=True):
437
- st.session_state.library_page = 0
438
- st.cache_data.clear()
439
-
440
- start_dt = datetime.combine(start_date, datetime.min.time())
441
- end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
442
 
443
- records, total_count = query_creatives_optimized(start_dt, end_dt, st.session_state.library_page)
444
- if not records and st.session_state.library_page == 0:
445
- st.info("No creatives found for the selected dates.")
446
- return
 
447
 
448
- st.caption(f"Total items: {total_count}")
449
- # simple gallery by record
450
- for rec in records:
451
- urls = rec.get("urls", []) or []
452
- if urls:
453
- display_image_gallery_optimized(urls)
 
 
 
 
 
 
 
 
 
454
 
455
  # ----------------------------
456
  # Auth
457
  # ----------------------------
458
  @lru_cache(maxsize=1)
459
- def check_token_cached(user_token: str) -> Tuple[bool, str]:
460
- ACCESS_TOKEN = os.getenv("ACCESS_TOKEN")
461
- if not ACCESS_TOKEN:
462
- return False, "Server error: Access token not configured."
463
- if user_token == ACCESS_TOKEN:
464
- return True, ""
465
- return False, "Invalid token."
466
 
467
- # ----------------------------
468
- # App shell
469
- # ----------------------------
470
  def main_app():
471
- st.set_page_config(page_title="File-to-Image • Creative Library", layout="wide")
472
  st.title("File-to-Image Generator")
473
- with st.sidebar:
474
- page = st.radio("Navigation", ["Generate from JSON", "Creative Library"], index=0)
475
- if page == "Generate from JSON":
476
- render_json_page()
477
- else:
478
- render_library_page()
479
 
480
  def main():
481
- if "authenticated" not in st.session_state:
482
- st.session_state["authenticated"] = False
483
- if not st.session_state["authenticated"]:
484
  st.markdown("## Access Required")
485
- token_input = st.text_input("Enter Access Token", type="password")
486
- if st.button("Unlock App"):
487
- ok, error_msg = check_token_cached(token_input)
488
- if ok:
489
- st.session_state["authenticated"] = True
490
- st.rerun()
491
- else:
492
- st.error(error_msg)
493
  else:
494
  main_app()
495
 
496
- if __name__ == "__main__":
497
  main()
 
7
  import logging
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
  from datetime import datetime, timedelta, date
10
+ from typing import Dict, Any, List, Tuple, Optional
11
  import requests
12
  import streamlit as st
13
  from pymongo import MongoClient
 
37
  RETRY_ATTEMPTS = 3
38
  LIBRARY_PAGE_SIZE = 20
39
 
 
40
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
41
  "imagegen-4-ultra": {"id": "google/imagen-4-ultra", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
42
  "imagen-4": {"id": "google/imagen-4", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"], "param_name": "aspect_ratio"},
 
43
  "qwen": {"id": "qwen/qwen-image", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3"], "param_name": "aspect_ratio"},
44
  "seedream-3": {"id": "bytedance/seedream-3", "aspect_ratios": ["1:1","16:9","9:16","3:4","4:3","3:2","2:3","21:9"], "param_name": "aspect_ratio"},
45
  "recraft-v3": {"id": "recraft-ai/recraft-v3", "aspect_ratios": ["1:1","4:3","3:4","3:2","2:3","16:9","9:16"], "param_name": "aspect_ratio"},
 
 
46
  }
47
 
48
  _thread_local = threading.local()
49
 
50
  # ----------------------------
51
+ # Infra helpers
52
  # ----------------------------
53
  def get_mongo_collection():
54
  if not hasattr(_thread_local, 'mongo_collection'):
 
69
  def get_s3_client():
70
  if not hasattr(_thread_local, 's3_client'):
71
  required_vars = ["R2_ENDPOINT", "R2_ACCESS_KEY", "R2_SECRET_KEY", "R2_BUCKET_NAME"]
72
+ if any(not os.getenv(v) for v in required_vars):
 
73
  _thread_local.s3_client = None
74
  return None
75
  try:
 
90
  return MODEL_REGISTRY.get(model_key)
91
 
92
  # ----------------------------
93
+ # Upload & generation
94
  # ----------------------------
95
+ def upload_to_r2(image_bytes: bytes) -> Optional[str]:
96
+ s3 = get_s3_client()
97
+ if not s3:
98
  return None
99
  try:
100
  filename = f"{uuid4().hex}.png"
101
  file_key = f"adgenesis_image_file/balraaj/images/{filename}"
102
+ s3.put_object(
103
  Bucket=os.getenv("R2_BUCKET_NAME"),
104
  Key=file_key,
105
  Body=image_bytes,
106
  ContentType="image/png",
107
  )
108
+ return f"{os.getenv('NEW_BASE').rstrip('/')}/{file_key}"
 
109
  except Exception as e:
110
+ logger.error(f"Upload failed: {e}")
111
  return None
112
 
113
+ def generate_one(model_key: str, prompt: str, aspect_ratio: str) -> List[str]:
 
 
 
114
  if not REPLICATE_API_TOKEN:
115
  return []
116
  config = get_model_config(model_key)
117
  if not config:
118
  return []
119
  try:
120
+ output = replicate.run(config["id"], input={"prompt": prompt, config["param_name"]: aspect_ratio})
 
 
 
 
121
  if isinstance(output, list) and output:
122
+ return [str(output[0])]
 
123
  elif isinstance(output, str):
124
  return [output]
 
 
125
  return []
126
  except Exception as e:
127
  logger.error(f"Replicate error: {e}")
128
  return []
129
 
130
+ def fetch_bytes(url: str) -> Optional[bytes]:
131
+ for _ in range(RETRY_ATTEMPTS):
132
  try:
133
+ r = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
134
+ r.raise_for_status()
135
+ return r.content
136
  except Exception:
 
 
137
  time.sleep(1)
138
  return None
139
 
140
+ def process_prompt(i: int, text: str, model: str, aspect: str):
141
+ urls = generate_one(model, text, aspect)
 
 
142
  if not urls:
143
+ return {"idx": i, "urls": [], "error": "No URLs"}
144
+ img_bytes = fetch_bytes(urls[0])
 
 
 
145
  if not img_bytes:
146
+ return {"idx": i, "urls": [], "error": "Fetch failed"}
147
+ r2 = upload_to_r2(img_bytes)
148
+ return {"idx": i, "urls": [r2 or urls[0]], "error": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  # ----------------------------
151
  # Persistence
152
  # ----------------------------
153
+ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
154
+ coll = get_mongo_collection()
155
+ if not coll:
156
  return None
157
  try:
158
+ return str(coll.insert_one({
159
+ "model": model,
160
+ "aspect_ratio": aspect,
161
  "prompt": prompt,
162
  "urls": urls,
 
163
  "lob": "balraaj",
164
  "created_at": datetime.utcnow()
165
+ }).inserted_id)
 
 
166
  except Exception as e:
167
  logger.error(f"Mongo insert failed: {e}")
168
  return None
169
 
170
  @st.cache_data(ttl=300)
171
+ def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
172
+ coll = get_mongo_collection()
173
+ if not coll:
174
+ return []
175
  try:
176
+ return list(coll.find(
177
+ {"created_at": {"$gte": start, "$lt": end}, "lob": "balraaj"}
178
+ ).sort("created_at", -1).limit(LIBRARY_PAGE_SIZE))
 
 
 
 
 
179
  except Exception:
180
+ return []
181
 
182
  # ----------------------------
183
+ # Gallery helpers
184
  # ----------------------------
185
+ def display_gallery(urls: List[str]):
186
+ if not urls: return
187
+ cols = st.columns(4)
188
+ for i, u in enumerate(urls):
189
+ with cols[i % 4]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  try:
191
+ img = fetch_bytes(u)
192
+ if img:
193
+ st.image(img, use_container_width=True)
194
+ except:
195
+ st.error("Failed")
196
+
197
+ def bulk_zip(urls: List[str]):
198
+ buf = io.BytesIO()
199
+ with zipfile.ZipFile(buf, "w") as z:
200
+ for i, u in enumerate(urls, 1):
201
+ data = fetch_bytes(u)
202
+ if data:
203
+ name = f"image_{i}.png"
204
+ z.writestr(name, data)
205
+ buf.seek(0)
206
+ st.download_button("Download All", buf, "images.zip", "application/zip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # ----------------------------
209
+ # JSON loader & runner
210
  # ----------------------------
211
+ def load_json(file) -> List[str]:
212
+ data = json.loads(file.getvalue().decode("utf-8"))
213
+ if not isinstance(data, dict) or "prompts" not in data:
214
+ raise ValueError("JSON must be { 'prompts': [ ... ] }")
215
+ return [p for p in data["prompts"] if isinstance(p, str) and p.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ def run_batch(prompts: List[str], model: str, aspect: str):
 
 
 
218
  total = len(prompts)
219
+ status = [st.empty() for _ in prompts]
220
+ progress = st.progress(0, f"0/{total}")
221
+
222
+ all_urls = []
223
+ with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, total)) as ex:
224
+ futs = {ex.submit(process_prompt, i, p, model, aspect): i for i,p in enumerate(prompts,1)}
225
+ done = 0
226
+ for f in as_completed(futs):
227
+ i = futs[f]
228
+ try:
229
+ res = f.result()
230
+ except Exception as e:
231
+ res = {"idx": i, "urls": [], "error": str(e)}
232
+ if res["urls"]:
233
+ save_record(model, aspect, prompts[i-1], res["urls"])
234
+ status[i-1].success(f"Prompt {i}/{total} ✓")
235
+ all_urls.extend(res["urls"])
236
+ else:
237
+ status[i-1].error(f"Prompt {i}/{total} ✗ ({res['error']})")
238
+ done += 1
239
+ progress.progress(done/total, f"{done}/{total}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
 
241
  if all_urls:
242
+ st.subheader("Gallery")
243
+ display_gallery(all_urls)
244
+ bulk_zip(all_urls)
 
245
 
246
  # ----------------------------
247
+ # Pages
248
  # ----------------------------
249
+ def render_json_page():
250
+ st.subheader("Generate from JSON")
251
+ up = st.file_uploader("Upload JSON", type=["json"])
252
+ col1,col2 = st.columns([1,1])
253
+ with col1: model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
254
+ with col2: aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ if up:
257
+ prompts = load_json(up)
258
+ st.json(prompts)
259
+ if st.button("Generate", type="primary", use_container_width=True):
260
+ run_batch(prompts, model, aspect)
261
 
262
+ def render_library_page():
263
+ st.subheader("Creative Library")
264
+ today = datetime.utcnow().date()
265
+ start = st.date_input("Start", today - timedelta(days=30))
266
+ end = st.date_input("End", today)
267
+ records = query_records(datetime.combine(start, datetime.min.time()),
268
+ datetime.combine(end+timedelta(days=1), datetime.min.time()))
269
+ all_urls = []
270
+ for r in records:
271
+ all_urls.extend(r.get("urls", []))
272
+ if all_urls:
273
+ display_gallery(all_urls)
274
+ bulk_zip(all_urls)
275
+ else:
276
+ st.info("No records found.")
277
 
278
  # ----------------------------
279
  # Auth
280
  # ----------------------------
281
  @lru_cache(maxsize=1)
282
+ def check_token(tok: str):
283
+ return tok == os.getenv("ACCESS_TOKEN")
 
 
 
 
 
284
 
 
 
 
285
  def main_app():
 
286
  st.title("File-to-Image Generator")
287
+ page = st.sidebar.radio("Menu", ["Generate from JSON","Creative Library"])
288
+ if page=="Generate from JSON": render_json_page()
289
+ else: render_library_page()
 
 
 
290
 
291
  def main():
292
+ if not st.session_state.get("auth"):
 
 
293
  st.markdown("## Access Required")
294
+ t = st.text_input("Token", type="password")
295
+ if st.button("Unlock"):
296
+ if check_token(t): st.session_state["auth"]=True; st.rerun()
297
+ else: st.error("Invalid token")
 
 
 
 
298
  else:
299
  main_app()
300
 
301
+ if __name__=="__main__":
302
  main()