userIdc2024 commited on
Commit
5628bc4
Β·
verified Β·
1 Parent(s): 0428d2d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +59 -25
src/streamlit_app.py CHANGED
@@ -35,7 +35,12 @@ MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "creatives")
35
  REQUEST_TIMEOUT = 30
36
  RETRY_ATTEMPTS = 3
37
  MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
38
- LIBRARY_PAGE_SIZE = 200 # aggregate many for one gallery view
 
 
 
 
 
39
 
40
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
41
  "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
@@ -65,7 +70,6 @@ def show_env_warnings():
65
  # ----------------------------
66
  def get_replicate_client():
67
  if not hasattr(_thread_local, "replicate_client"):
68
- # Explicit client avoids env-specific issues with module-level run()
69
  _thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
70
  return _thread_local.replicate_client
71
 
@@ -160,22 +164,46 @@ def fetch_bytes(url: str) -> Optional[bytes]:
160
  time.sleep(1)
161
  return None
162
 
163
- def process_prompt(i: int, text: str, model: str, aspect: str) -> Dict[str, Any]:
164
  """
165
- One image per prompt:
166
- - generate via Replicate
167
- - try to upload to R2
168
- - fallback to source url if R2 not available
169
  """
170
- urls = generate_one(model, text, aspect)
 
171
  if not urls:
172
- return {"idx": i, "urls": [], "error": "No URLs"}
173
  src = urls[0]
174
  data = fetch_bytes(src)
175
  if data is None:
176
- return {"idx": i, "urls": [], "error": "Fetch failed"}
177
  r2 = upload_to_r2(data)
178
- return {"idx": i, "urls": [r2 or src], "error": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  # ----------------------------
181
  # Persistence
@@ -190,6 +218,7 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
190
  "aspect_ratio": aspect,
191
  "prompt": prompt,
192
  "urls": urls,
 
193
  "lob": "balraaj",
194
  "created_at": datetime.utcnow(),
195
  }).inserted_id)
@@ -200,7 +229,6 @@ def save_record(model: str, aspect: str, prompt: str, urls: List[str]):
200
  @st.cache_data(ttl=300)
201
  def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
202
  coll = get_mongo_collection()
203
- # FIX: compare with None explicitly (avoids NotImplementedError)
204
  if coll is None:
205
  return []
206
  try:
@@ -240,7 +268,7 @@ def bulk_zip(urls: List[str]):
240
  st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
241
 
242
  # ----------------------------
243
- # JSON loader & batch run (parallel)
244
  # ----------------------------
245
  def load_json(file) -> List[str]:
246
  data = json.loads(file.getvalue().decode("utf-8"))
@@ -251,7 +279,7 @@ def load_json(file) -> List[str]:
251
  raise ValueError("No valid prompts found.")
252
  return out
253
 
254
- def run_batch(prompts: List[str], model: str, aspect: str):
255
  total = len(prompts)
256
  rows = [st.empty() for _ in prompts]
257
  progress = st.progress(0.0, text=f"0/{total}")
@@ -259,22 +287,25 @@ def run_batch(prompts: List[str], model: str, aspect: str):
259
  all_urls: List[str] = []
260
  done = 0
261
 
262
- max_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
263
  with st.spinner("Generating images..."):
264
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
265
- futs = {ex.submit(process_prompt, i, p, model, aspect): i for i, p in enumerate(prompts, 1)}
266
  for fut in as_completed(futs):
267
  i = futs[fut]
268
  try:
269
  res = fut.result()
270
  except Exception as e:
271
- res = {"idx": i, "urls": [], "error": str(e)}
 
272
  if res["urls"]:
273
  save_record(model, aspect, prompts[i-1], res["urls"])
274
- rows[i-1].success(f"Prompt {i}/{total} βœ“")
275
  all_urls.extend(res["urls"])
276
  else:
277
- rows[i-1].error(f"Prompt {i}/{total} βœ— ({res['error']})")
 
 
278
  done += 1
279
  progress.progress(done/total, text=f"{done}/{total}")
280
 
@@ -291,12 +322,15 @@ def run_batch(prompts: List[str], model: str, aspect: str):
291
  def render_json_page():
292
  st.subheader("Generate from JSON")
293
  show_env_warnings()
 
294
  up = st.file_uploader("Upload JSON", type=["json"])
295
- col1, col2 = st.columns([1, 1])
296
- with col1:
297
  model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
298
- with col2:
299
  aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
 
 
300
 
301
  if up:
302
  try:
@@ -304,7 +338,7 @@ def render_json_page():
304
  with st.expander("Preview prompts", expanded=False):
305
  st.json(prompts)
306
  if st.button("Generate", type="primary", use_container_width=True):
307
- run_batch(prompts, model, aspect)
308
  except Exception as e:
309
  st.error(str(e))
310
  else:
@@ -339,7 +373,7 @@ def render_library_page():
339
  def check_token(tok: str) -> bool:
340
  acc = os.getenv("ACCESS_TOKEN")
341
  if not acc:
342
- # If ACCESS_TOKEN is not configured, allow through to avoid lockout in dev.
343
  return True
344
  return tok == acc
345
 
 
35
  REQUEST_TIMEOUT = 30
36
  RETRY_ATTEMPTS = 3
37
  MAX_WORKERS = min(32, (os.cpu_count() or 1) + 4)
38
+ LIBRARY_PAGE_SIZE = 200 # aggregated view
39
+
40
+ # Global throttle so we don't overload Replicate / network
41
+ # This caps total in-flight image generations across all prompts.
42
+ GLOBAL_CONCURRENCY = max(4, min(16, MAX_WORKERS))
43
+ _GEN_SEMAPHORE = threading.Semaphore(GLOBAL_CONCURRENCY)
44
 
45
  MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
46
  "imagegen-4-ultra": {"id": "google/imagen-4-ultra","aspect_ratios": ["1:1","16:9","9:16","3:4","4:3"],"param_name": "aspect_ratio"},
 
70
  # ----------------------------
71
  def get_replicate_client():
72
  if not hasattr(_thread_local, "replicate_client"):
 
73
  _thread_local.replicate_client = replicate.Client(api_token=REPLICATE_API_TOKEN) if REPLICATE_API_TOKEN else None
74
  return _thread_local.replicate_client
75
 
 
164
  time.sleep(1)
165
  return None
166
 
167
+ def _generate_single_image_full(prompt: str, model: str, aspect: str) -> Tuple[Optional[str], Optional[str]]:
168
  """
169
+ Throttled: generate -> fetch -> upload -> return final_url or src_url
170
+ Returns (final_url_or_src, error_text)
 
 
171
  """
172
+ with _GEN_SEMAPHORE: # throttle
173
+ urls = generate_one(model, prompt, aspect)
174
  if not urls:
175
+ return None, "No URLs"
176
  src = urls[0]
177
  data = fetch_bytes(src)
178
  if data is None:
179
+ return None, "Fetch failed"
180
  r2 = upload_to_r2(data)
181
+ return (r2 or src), None
182
+
183
+ def process_prompt_batch(idx: int, prompt: str, model: str, aspect: str, num_images: int) -> Dict[str, Any]:
184
+ """
185
+ Generate N images for one prompt.
186
+ Uses a small threadpool but still obeys GLOBAL throttle.
187
+ """
188
+ n = max(1, int(num_images))
189
+ urls: List[str] = []
190
+ errs: List[str] = []
191
+
192
+ # limit inner parallelism to keep global pressure sane
193
+ inner_workers = min( min(4, n), GLOBAL_CONCURRENCY )
194
+ if n == 1:
195
+ u, e = _generate_single_image_full(prompt, model, aspect)
196
+ if u: urls.append(u)
197
+ if e: errs.append(e)
198
+ else:
199
+ with ThreadPoolExecutor(max_workers=inner_workers) as ex:
200
+ futures = [ex.submit(_generate_single_image_full, prompt, model, aspect) for _ in range(n)]
201
+ for fut in as_completed(futures):
202
+ u, e = fut.result()
203
+ if u: urls.append(u)
204
+ if e: errs.append(e)
205
+
206
+ return {"idx": idx, "urls": urls, "errors": errs}
207
 
208
  # ----------------------------
209
  # Persistence
 
218
  "aspect_ratio": aspect,
219
  "prompt": prompt,
220
  "urls": urls,
221
+ "num_images": len(urls),
222
  "lob": "balraaj",
223
  "created_at": datetime.utcnow(),
224
  }).inserted_id)
 
229
  @st.cache_data(ttl=300)
230
  def query_records(start: datetime, end: datetime) -> List[Dict[str, Any]]:
231
  coll = get_mongo_collection()
 
232
  if coll is None:
233
  return []
234
  try:
 
268
  st.download_button("Download All", buf, "images.zip", "application/zip", use_container_width=True)
269
 
270
  # ----------------------------
271
+ # JSON loader & batch run (parallel across prompts)
272
  # ----------------------------
273
  def load_json(file) -> List[str]:
274
  data = json.loads(file.getvalue().decode("utf-8"))
 
279
  raise ValueError("No valid prompts found.")
280
  return out
281
 
282
+ def run_batch(prompts: List[str], model: str, aspect: str, num_images: int):
283
  total = len(prompts)
284
  rows = [st.empty() for _ in prompts]
285
  progress = st.progress(0.0, text=f"0/{total}")
 
287
  all_urls: List[str] = []
288
  done = 0
289
 
290
+ outer_workers = min(MAX_WORKERS, max(2, (os.cpu_count() or 2)))
291
  with st.spinner("Generating images..."):
292
+ with ThreadPoolExecutor(max_workers=outer_workers) as ex:
293
+ futs = {ex.submit(process_prompt_batch, i, p, model, aspect, num_images): i for i, p in enumerate(prompts, 1)}
294
  for fut in as_completed(futs):
295
  i = futs[fut]
296
  try:
297
  res = fut.result()
298
  except Exception as e:
299
+ res = {"idx": i, "urls": [], "errors": [str(e)]}
300
+
301
  if res["urls"]:
302
  save_record(model, aspect, prompts[i-1], res["urls"])
303
+ rows[i-1].success(f"Prompt {i}/{total} βœ“ ({len(res['urls'])} images)")
304
  all_urls.extend(res["urls"])
305
  else:
306
+ err_msg = ", ".join(res.get("errors") or ["No images"])
307
+ rows[i-1].error(f"Prompt {i}/{total} βœ— ({err_msg})")
308
+
309
  done += 1
310
  progress.progress(done/total, text=f"{done}/{total}")
311
 
 
322
  def render_json_page():
323
  st.subheader("Generate from JSON")
324
  show_env_warnings()
325
+
326
  up = st.file_uploader("Upload JSON", type=["json"])
327
+ c1, c2, c3 = st.columns([1, 1, 1])
328
+ with c1:
329
  model = st.selectbox("Model", list(MODEL_REGISTRY.keys()), 0)
330
+ with c2:
331
  aspect = st.selectbox("Aspect", MODEL_REGISTRY[model]["aspect_ratios"], 0)
332
+ with c3:
333
+ num_images = st.slider("Images per prompt", min_value=1, max_value=12, value=1, step=1)
334
 
335
  if up:
336
  try:
 
338
  with st.expander("Preview prompts", expanded=False):
339
  st.json(prompts)
340
  if st.button("Generate", type="primary", use_container_width=True):
341
+ run_batch(prompts, model, aspect, num_images)
342
  except Exception as e:
343
  st.error(str(e))
344
  else:
 
373
  def check_token(tok: str) -> bool:
374
  acc = os.getenv("ACCESS_TOKEN")
375
  if not acc:
376
+ # No token configured β†’ allow dev access
377
  return True
378
  return tok == acc
379