Spaces:
Running
Running
Update generator_function/image_function.py
Browse files
generator_function/image_function.py
CHANGED
|
@@ -147,8 +147,8 @@ def _fetch(url: Union[str, Any]) -> Optional[bytes]:
|
|
| 147 |
time.sleep(1)
|
| 148 |
return None
|
| 149 |
|
| 150 |
-
def _process_one(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
|
| 151 |
-
model_key, prompt, aspect_ratio, idx = args
|
| 152 |
out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None}
|
| 153 |
try:
|
| 154 |
urls = _generate_one(model_key, prompt, aspect_ratio)
|
|
@@ -160,22 +160,27 @@ def _process_one(args: Tuple[str, str, str, int]) -> Dict[str, Any]:
|
|
| 160 |
if not b:
|
| 161 |
out["error"] = "Fetch failed"; return out
|
| 162 |
image_with_metadata = meta_data_helper_function(b)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
out["r2_url"] =
|
|
|
|
| 166 |
else:
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
except Exception as e:
|
| 169 |
out["error"] = str(e)
|
| 170 |
return out
|
| 171 |
|
| 172 |
-
def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
|
| 173 |
if num_images == 1:
|
| 174 |
-
res = _process_one((model_key, prompt, aspect_ratio, 0))
|
| 175 |
if res["success"]:
|
| 176 |
return [res["r2_url"]], [res["source_url"]], []
|
| 177 |
return [], [], [res["error"] or "Generation failed"]
|
| 178 |
-
args = [(model_key, prompt, aspect_ratio, i) for i in range(num_images)]
|
| 179 |
r2, src, errs = [], [], []
|
| 180 |
with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex:
|
| 181 |
for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}):
|
|
@@ -194,9 +199,9 @@ def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, nu
|
|
| 194 |
|
| 195 |
|
| 196 |
|
| 197 |
-
def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int) -> Tuple[List[str], List[str], List[str]]:
|
| 198 |
"""Back-compat public export used by background tasks."""
|
| 199 |
-
return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images)
|
| 200 |
|
| 201 |
def handle_image_generation_optimized(
|
| 202 |
*,
|
|
@@ -208,6 +213,7 @@ def handle_image_generation_optimized(
|
|
| 208 |
category: Optional[str] = None,
|
| 209 |
platform: Optional[str] = None,
|
| 210 |
uid:str,
|
|
|
|
| 211 |
|
| 212 |
):
|
| 213 |
"""
|
|
@@ -223,7 +229,7 @@ def handle_image_generation_optimized(
|
|
| 223 |
|
| 224 |
created_by = uid
|
| 225 |
|
| 226 |
-
results_col = get_results_collection()
|
| 227 |
db_job_id = None
|
| 228 |
if results_col is not None:
|
| 229 |
try:
|
|
@@ -248,8 +254,14 @@ def handle_image_generation_optimized(
|
|
| 248 |
st.info(f"Generating {num_images} image(s)")
|
| 249 |
progress.progress(10, text="Running...")
|
| 250 |
|
| 251 |
-
r2_urls, source_urls, errors = _generate_images_parallel(
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
if results_col is not None and db_job_id:
|
| 255 |
try:
|
|
@@ -268,7 +280,12 @@ def handle_image_generation_optimized(
|
|
| 268 |
|
| 269 |
if urls:
|
| 270 |
with status.container():
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
cols = st.columns(min(4, len(urls)) or 1)
|
| 274 |
image_bytes_list = []
|
|
@@ -276,7 +293,14 @@ def handle_image_generation_optimized(
|
|
| 276 |
for i, u in enumerate(urls):
|
| 277 |
with cols[i % len(cols)]:
|
| 278 |
try:
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
if b is None:
|
| 281 |
st.error("Failed to load image")
|
| 282 |
continue
|
|
|
|
| 147 |
time.sleep(1)
|
| 148 |
return None
|
| 149 |
|
| 150 |
+
def _process_one(args: Tuple[str, str, str, int, bool]) -> Dict[str, Any]:
|
| 151 |
+
model_key, prompt, aspect_ratio, idx, private_mode = args
|
| 152 |
out = {"index": idx, "success": False, "source_url": None, "r2_url": None, "error": None}
|
| 153 |
try:
|
| 154 |
urls = _generate_one(model_key, prompt, aspect_ratio)
|
|
|
|
| 160 |
if not b:
|
| 161 |
out["error"] = "Fetch failed"; return out
|
| 162 |
image_with_metadata = meta_data_helper_function(b)
|
| 163 |
+
if private_mode:
|
| 164 |
+
data_uri = "data:image/png;base64," + base64.b64encode(image_with_metadata).decode("utf-8")
|
| 165 |
+
out["r2_url"] = data_uri
|
| 166 |
+
out["success"] = True
|
| 167 |
else:
|
| 168 |
+
r2 = _upload_to_r2(image_with_metadata)
|
| 169 |
+
if r2:
|
| 170 |
+
out["r2_url"] = r2; out["success"] = True
|
| 171 |
+
else:
|
| 172 |
+
out["error"] = "Upload to R2 failed"
|
| 173 |
except Exception as e:
|
| 174 |
out["error"] = str(e)
|
| 175 |
return out
|
| 176 |
|
| 177 |
+
def _generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
|
| 178 |
if num_images == 1:
|
| 179 |
+
res = _process_one((model_key, prompt, aspect_ratio, 0, private_mode))
|
| 180 |
if res["success"]:
|
| 181 |
return [res["r2_url"]], [res["source_url"]], []
|
| 182 |
return [], [], [res["error"] or "Generation failed"]
|
| 183 |
+
args = [(model_key, prompt, aspect_ratio, i, private_mode) for i in range(num_images)]
|
| 184 |
r2, src, errs = [], [], []
|
| 185 |
with ThreadPoolExecutor(max_workers=min(MAX_WORKERS, num_images)) as ex:
|
| 186 |
for fut in as_completed({ex.submit(_process_one, a): a[3] for a in args}):
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
|
| 202 |
+
def generate_images_parallel(model_key: str, aspect_ratio: str, prompt: str, num_images: int, *, private_mode: bool = False) -> Tuple[List[str], List[str], List[str]]:
|
| 203 |
"""Back-compat public export used by background tasks."""
|
| 204 |
+
return _generate_images_parallel(model_key, aspect_ratio, prompt, num_images, private_mode=private_mode)
|
| 205 |
|
| 206 |
def handle_image_generation_optimized(
|
| 207 |
*,
|
|
|
|
| 213 |
category: Optional[str] = None,
|
| 214 |
platform: Optional[str] = None,
|
| 215 |
uid:str,
|
| 216 |
+
private_mode: bool = False,
|
| 217 |
|
| 218 |
):
|
| 219 |
"""
|
|
|
|
| 229 |
|
| 230 |
created_by = uid
|
| 231 |
|
| 232 |
+
results_col = None if private_mode else get_results_collection()
|
| 233 |
db_job_id = None
|
| 234 |
if results_col is not None:
|
| 235 |
try:
|
|
|
|
| 254 |
st.info(f"Generating {num_images} image(s)")
|
| 255 |
progress.progress(10, text="Running...")
|
| 256 |
|
| 257 |
+
r2_urls, source_urls, errors = _generate_images_parallel(
|
| 258 |
+
model_key,
|
| 259 |
+
aspect_ratio,
|
| 260 |
+
prompt.strip(),
|
| 261 |
+
num_images,
|
| 262 |
+
private_mode=private_mode,
|
| 263 |
+
)
|
| 264 |
+
urls = r2_urls if private_mode else (r2_urls or source_urls)
|
| 265 |
|
| 266 |
if results_col is not None and db_job_id:
|
| 267 |
try:
|
|
|
|
| 280 |
|
| 281 |
if urls:
|
| 282 |
with status.container():
|
| 283 |
+
message = f"Generated {len(urls)} image(s) in {took:.1f}s."
|
| 284 |
+
if not private_mode:
|
| 285 |
+
message += f" Job ID: {db_job_id or 'N/A'}"
|
| 286 |
+
else:
|
| 287 |
+
message += " Private mode: results stay local to this session."
|
| 288 |
+
st.success(message)
|
| 289 |
|
| 290 |
cols = st.columns(min(4, len(urls)) or 1)
|
| 291 |
image_bytes_list = []
|
|
|
|
| 293 |
for i, u in enumerate(urls):
|
| 294 |
with cols[i % len(cols)]:
|
| 295 |
try:
|
| 296 |
+
if isinstance(u, str) and u.startswith("data:image"):
|
| 297 |
+
try:
|
| 298 |
+
_, encoded = u.split(",", 1)
|
| 299 |
+
b = base64.b64decode(encoded)
|
| 300 |
+
except Exception:
|
| 301 |
+
b = None
|
| 302 |
+
else:
|
| 303 |
+
b = _fetch(u)
|
| 304 |
if b is None:
|
| 305 |
st.error("Failed to load image")
|
| 306 |
continue
|