userIdc2024 commited on
Commit
b7548f4
·
verified ·
1 Parent(s): 5557c8f

Update generator_function/image_function.py

Browse files
Files changed (1) hide show
  1. generator_function/image_function.py +40 -16
generator_function/image_function.py CHANGED
@@ -147,8 +147,8 @@ def _fetch(url: Union[str, Any]) -> Optional[bytes]:
147
  time.sleep(1)
148
  return None
149
 
150
- def _process_one(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
151
- model_key, prompt, aspect_ratio, idx = args
152
  out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None}
153
  try:
154
  urls = _generate_one(model_key, prompt, aspect_ratio)
@@ -160,22 +160,27 @@ def _process_one(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
160
  if not b:
161
  out["error"] = "Fetch failed"; return out
162
  image_with_metadata = meta_data_helper_function(b)
163
- r2 = _upload_to_r2(image_with_metadata)
164
- if r2:
165
- out["r2_url"] = r2; out["success"] = True
 
166
  else:
167
- out["error"] = "Upload to R2 failed"
 
 
 
 
168
  except Exception as e:
169
  out["error"] = str(e)
170
  return out
171
 
172
- def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
173
  if num_images == 1:
174
- res = _process_one((model_key, prompt, aspect_ratio, 0))
175
  if res["success"]:
176
  return [res["r2_url"]], [res["source_url"]], []
177
  return [], [], [res["error"] or "Generation failed"]
178
- args = [(model_key, prompt, aspect_ratio, i) for i in range(num_images)]
179
  r2, src, errs = [], [], []
180
  with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex:
181
  for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}):
@@ -194,9 +199,9 @@ def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, nu
194
 
195
 
196
 
197
- def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
198
  """Back-compat public export used by background tasks."""
199
- return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images)
200
 
201
  def handle_image_generation_optimized(
202
  *,
@@ -208,6 +213,7 @@ def handle_image_generation_optimized(
208
  category: Optional[str] = None,
209
  platform: Optional[str] = None,
210
  uid:str,
 
211
 
212
  ):
213
  """
@@ -223,7 +229,7 @@ def handle_image_generation_optimized(
223
 
224
  created_by = uid
225
 
226
- results_col = get_results_collection()
227
  db_job_id = None
228
  if results_col is not None:
229
  try:
@@ -248,8 +254,14 @@ def handle_image_generation_optimized(
248
  st.info(f"Generating {num_images} image(s)")
249
  progress.progress(10, text="Running...")
250
 
251
- r2_urls, source_urls, errors = _generate_images_parallel(model_key, aspect_ratio, prompt.strip(), num_images)
252
- urls = r2_urls or source_urls
 
 
 
 
 
 
253
 
254
  if results_col is not None and db_job_id:
255
  try:
@@ -268,7 +280,12 @@ def handle_image_generation_optimized(
268
 
269
  if urls:
270
  with status.container():
271
- st.success(f"Generated {len(urls)} image(s) in {took:.1f}s. Job ID: {db_job_id or 'N/A'}")
 
 
 
 
 
272
 
273
  cols = st.columns(min(4, len(urls)) or 1)
274
  image_bytes_list = []
@@ -276,7 +293,14 @@ def handle_image_generation_optimized(
276
  for i, u in enumerate(urls):
277
  with cols[i % len(cols)]:
278
  try:
279
- b = _fetch(u)
 
 
 
 
 
 
 
280
  if b is None:
281
  st.error("Failed to load image")
282
  continue
 
147
  time.sleep(1)
148
  return None
149
 
150
+ def _process_one(args: Tuple[str, str, str, int, bool]) -> Dict[str, Any]:
151
+ model_key, prompt, aspect_ratio, idx, private_mode = args
152
  out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None}
153
  try:
154
  urls = _generate_one(model_key, prompt, aspect_ratio)
 
160
  if not b:
161
  out["error"] = "Fetch failed"; return out
162
  image_with_metadata = meta_data_helper_function(b)
163
+ if private_mode:
164
+ data_uri = "data:image/png;base64," + base64.b64encode(image_with_metadata).decode("utf-8")
165
+ out["r2_url"] = data_uri
166
+ out["success"] = True
167
  else:
168
+ r2 = _upload_to_r2(image_with_metadata)
169
+ if r2:
170
+ out["r2_url"] = r2; out["success"] = True
171
+ else:
172
+ out["error"] = "Upload to R2 failed"
173
  except Exception as e:
174
  out["error"] = str(e)
175
  return out
176
 
177
+ def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
178
  if num_images == 1:
179
+ res = _process_one((model_key, prompt, aspect_ratio, 0, private_mode))
180
  if res["success"]:
181
  return [res["r2_url"]], [res["source_url"]], []
182
  return [], [], [res["error"] or "Generation failed"]
183
+ args = [(model_key, prompt, aspect_ratio, i, private_mode) for i in range(num_images)]
184
  r2, src, errs = [], [], []
185
  with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex:
186
  for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}):
 
199
 
200
 
201
 
202
+ def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
203
  """Back-compat public export used by background tasks."""
204
+ return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images, private_mode=private_mode)
205
 
206
  def handle_image_generation_optimized(
207
  *,
 
213
  category: Optional[str] = None,
214
  platform: Optional[str] = None,
215
  uid:str,
216
+ private_mode: bool = False,
217
 
218
  ):
219
  """
 
229
 
230
  created_by = uid
231
 
232
+ results_col = None if private_mode else get_results_collection()
233
  db_job_id = None
234
  if results_col is not None:
235
  try:
 
254
  st.info(f"Generating {num_images} image(s)")
255
  progress.progress(10, text="Running...")
256
 
257
+ r2_urls, source_urls, errors = _generate_images_parallel(
258
+ model_key,
259
+ aspect_ratio,
260
+ prompt.strip(),
261
+ num_images,
262
+ private_mode=private_mode,
263
+ )
264
+ urls = r2_urls if private_mode else (r2_urls or source_urls)
265
 
266
  if results_col is not None and db_job_id:
267
  try:
 
280
 
281
  if urls:
282
  with status.container():
283
+ message = f"Generated {len(urls)} image(s) in {took:.1f}s."
284
+ if not private_mode:
285
+ message += f" Job ID: {db_job_id or 'N/A'}"
286
+ else:
287
+ message += " Private mode: results stay local to this session."
288
+ st.success(message)
289
 
290
  cols = st.columns(min(4, len(urls)) or 1)
291
  image_bytes_list = []
 
293
  for i, u in enumerate(urls):
294
  with cols[i % len(cols)]:
295
  try:
296
+ if isinstance(u, str) and u.startswith("data:image"):
297
+ try:
298
+ _, encoded = u.split(",", 1)
299
+ b = base64.b64decode(encoded)
300
+ except Exception:
301
+ b = None
302
+ else:
303
+ b = _fetch(u)
304
  if b is None:
305
  st.error("Failed to load image")
306
  continue