userIdc2024 commited on
Commit
ebef0e6
·
verified ·
1 Parent(s): 5559ce1

Update src/app_pages/image_generation.py

Browse files
Files changed (1) hide show
  1. src/app_pages/image_generation.py +77 -38
src/app_pages/image_generation.py CHANGED
@@ -76,8 +76,18 @@ MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {
76
  ],
77
  "param_name": "aspect_ratio",
78
  },
79
- }
 
 
 
 
 
 
 
 
80
 
 
 
81
 
82
  def retrieve_from_store(vector_store_id: str, user_query: str, category: str) -> str:
83
  client = get_openai_client()
@@ -248,50 +258,45 @@ def generate_ad_images_replicate(
248
  prompt: str,
249
  replicate_model_key: str,
250
  aspect_ratio: str,
 
 
 
251
  ) -> List[Image.Image]:
252
- if not REPLICATE_API_TOKEN:
253
- raise RuntimeError("REPLICATE_API_TOKEN is missing; cannot use Replicate models.")
254
- if replicate_model_key not in MODEL_REGISTRY:
255
- raise KeyError(f"Unknown replicate model key: {replicate_model_key}")
256
 
257
  model_cfg = MODEL_REGISTRY[replicate_model_key]
258
  model_id = model_cfg["id"]
259
- param_name = model_cfg.get("param_name", "aspect_ratio")
260
 
261
- try:
262
- output = replicate.run(
263
- model_id,
264
- input={"prompt": prompt, param_name: aspect_ratio},
265
- )
266
- except Exception as e:
267
- raise RuntimeError(f"Replicate run failed ({model_id}): {e}") from e
 
 
 
 
 
 
268
 
269
- items: List[Any] = output if isinstance(output, list) else [output]
270
 
 
271
  images: List[Image.Image] = []
 
272
  for item in items[:1]:
273
- try:
274
- if hasattr(item, "read"):
275
- data = item.read()
276
- images.append(Image.open(BytesIO(data)))
277
- continue
278
-
279
- url = str(item)
280
- if url.startswith(("http://", "https://")):
281
- with urllib.request.urlopen(url) as resp:
282
- data = resp.read()
283
- images.append(Image.open(BytesIO(data)))
284
- continue
285
-
286
- raise RuntimeError(f"Unsupported Replicate output item: {type(item)} / {item}")
287
- except Exception as e:
288
- raise RuntimeError(f"Failed to read Replicate output: {e}") from e
289
 
290
- if not images:
291
- raise RuntimeError("Replicate returned no usable image outputs.")
292
  return images
293
 
294
 
 
 
295
  def _image_task(
296
  idx: int,
297
  prompt: str,
@@ -300,13 +305,29 @@ def _image_task(
300
  size: str,
301
  quality: str,
302
  aspect_ratio: str,
 
 
 
303
  ) -> Tuple[int, str, Optional[Image.Image], Optional[bytes], Optional[str], Optional[str]]:
304
  try:
305
  if model_key in OPENAI_IMAGE_MODELS:
306
  openai_model_id = OPENAI_IMAGE_MODELS[model_key]
307
- img = generate_ad_images_openai(prompt, model_id=openai_model_id, n=1, size=size, quality=quality)[0]
 
 
 
 
 
 
308
  else:
309
- img = generate_ad_images_replicate(prompt, replicate_model_key=model_key, aspect_ratio=aspect_ratio)[0]
 
 
 
 
 
 
 
310
 
311
  buf = BytesIO()
312
  img.save(buf, format="PNG", optimize=True)
@@ -318,6 +339,7 @@ def _image_task(
318
  app_type="adgen",
319
  format="PNG",
320
  )
 
321
  if not url:
322
  return idx, prompt, img, png_bytes, None, "R2 upload failed"
323
 
@@ -339,6 +361,7 @@ def render_image_generation_page(uid: str | None = None):
339
  "Imagen 4 Ultra": "imagegen-4-ultra",
340
  "Imagen 4": "imagen-4",
341
  "gpt-image-1.5": "gpt-image-1.5",
 
342
  }
343
 
344
  try:
@@ -352,9 +375,25 @@ def render_image_generation_page(uid: str | None = None):
352
 
353
  model_label = st.selectbox("Image Model Name (required)", list(UI_MODEL_OPTIONS.keys()))
354
  model_key = UI_MODEL_OPTIONS[model_label]
355
-
356
  aspect_ratio = "1:1"
357
- if model_key not in OPENAI_IMAGE_MODELS:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  ratios = MODEL_REGISTRY[model_key]["aspect_ratios"]
359
  aspect_ratio = st.selectbox(
360
  "Aspect ratio",
@@ -442,7 +481,7 @@ def render_image_generation_page(uid: str | None = None):
442
  try:
443
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
444
  futures = [
445
- ex.submit(_image_task, i, p, category, model_key, size, quality, aspect_ratio)
446
  for i, p in enumerate(output.prompt, start=1)
447
  ]
448
  results = [f.result() for f in as_completed(futures)]
@@ -518,4 +557,4 @@ def render_image_generation_page(uid: str | None = None):
518
  file_name="ad_images.zip",
519
  mime="application/zip",
520
  key="zip_dl_last",
521
- )
 
76
  ],
77
  "param_name": "aspect_ratio",
78
  },
79
+ "z-image-turbo": {
80
+ "id": "prunaai/z-image-turbo",
81
+ "aspect_ratios": [
82
+ "match_input_image", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"
83
+ ],
84
+ "z-image-turbo": {
85
+ "id": "prunaai/z-image-turbo",
86
+ "type": "custom",
87
+ },
88
 
89
+ },
90
+ }
91
 
92
  def retrieve_from_store(vector_store_id: str, user_query: str, category: str) -> str:
93
  client = get_openai_client()
 
258
  prompt: str,
259
  replicate_model_key: str,
260
  aspect_ratio: str,
261
+ z_width: int | None = None,
262
+ z_height: int | None = None,
263
+ z_output_quality: int | None = None,
264
  ) -> List[Image.Image]:
 
 
 
 
265
 
266
  model_cfg = MODEL_REGISTRY[replicate_model_key]
267
  model_id = model_cfg["id"]
 
268
 
269
+ if replicate_model_key == "z-image-turbo":
270
+ input_payload = {
271
+ "prompt": prompt,
272
+ "width": z_width or 1024,
273
+ "height": z_height or 1024,
274
+ "output_quality": z_output_quality or 80,
275
+ }
276
+ else:
277
+ param_name = model_cfg.get("param_name", "aspect_ratio")
278
+ input_payload = {
279
+ "prompt": prompt,
280
+ param_name: aspect_ratio,
281
+ }
282
 
283
+ output = replicate.run(model_id, input=input_payload)
284
 
285
+ items = output if isinstance(output, list) else [output]
286
  images: List[Image.Image] = []
287
+
288
  for item in items[:1]:
289
+ if hasattr(item, "read"):
290
+ images.append(Image.open(BytesIO(item.read())))
291
+ else:
292
+ with urllib.request.urlopen(str(item)) as resp:
293
+ images.append(Image.open(BytesIO(resp.read())))
 
 
 
 
 
 
 
 
 
 
 
294
 
 
 
295
  return images
296
 
297
 
298
+
299
+
300
  def _image_task(
301
  idx: int,
302
  prompt: str,
 
305
  size: str,
306
  quality: str,
307
  aspect_ratio: str,
308
+ z_width: int,
309
+ z_height: int,
310
+ z_output_quality: int,
311
  ) -> Tuple[int, str, Optional[Image.Image], Optional[bytes], Optional[str], Optional[str]]:
312
  try:
313
  if model_key in OPENAI_IMAGE_MODELS:
314
  openai_model_id = OPENAI_IMAGE_MODELS[model_key]
315
+ img = generate_ad_images_openai(
316
+ prompt,
317
+ model_id=openai_model_id,
318
+ n=1,
319
+ size=size,
320
+ quality=quality
321
+ )[0]
322
  else:
323
+ img = generate_ad_images_replicate(
324
+ prompt,
325
+ replicate_model_key=model_key,
326
+ aspect_ratio=aspect_ratio,
327
+ z_width=z_width,
328
+ z_height=z_height,
329
+ z_output_quality=z_output_quality,
330
+ )[0]
331
 
332
  buf = BytesIO()
333
  img.save(buf, format="PNG", optimize=True)
 
339
  app_type="adgen",
340
  format="PNG",
341
  )
342
+
343
  if not url:
344
  return idx, prompt, img, png_bytes, None, "R2 upload failed"
345
 
 
361
  "Imagen 4 Ultra": "imagegen-4-ultra",
362
  "Imagen 4": "imagen-4",
363
  "gpt-image-1.5": "gpt-image-1.5",
364
+ "z-image-turbo":"z-image-turbo"
365
  }
366
 
367
  try:
 
375
 
376
  model_label = st.selectbox("Image Model Name (required)", list(UI_MODEL_OPTIONS.keys()))
377
  model_key = UI_MODEL_OPTIONS[model_label]
 
378
  aspect_ratio = "1:1"
379
+ z_width = 1024
380
+ z_height = 1024
381
+ z_output_quality = 80
382
+
383
+ if model_key == "z-image-turbo":
384
+ col1, col2, col3 = st.columns(3)
385
+
386
+ with col1:
387
+ z_width = st.slider("Width", 64, 1440, 1024, step=64)
388
+
389
+ with col2:
390
+ z_height = st.slider("Height", 64, 1440, 1024, step=64)
391
+
392
+ with col3:
393
+ z_output_quality = st.slider("Output Quality", 0, 100, 80)
394
+
395
+
396
+ elif model_key not in OPENAI_IMAGE_MODELS:
397
  ratios = MODEL_REGISTRY[model_key]["aspect_ratios"]
398
  aspect_ratio = st.selectbox(
399
  "Aspect ratio",
 
481
  try:
482
  with ThreadPoolExecutor(max_workers=max_workers) as ex:
483
  futures = [
484
+ ex.submit(_image_task, i, p, category, model_key, size, quality, aspect_ratio,z_width,z_height,z_output_quality)
485
  for i, p in enumerate(output.prompt, start=1)
486
  ]
487
  results = [f.result() for f in as_completed(futures)]
 
557
  file_name="ad_images.zip",
558
  mime="application/zip",
559
  key="zip_dl_last",
560
+ )