Velvessence commited on
Commit
1cc6d0e
·
verified ·
1 Parent(s): 0bc8e5a

Upload nxdify_node.py

Browse files
Files changed (1) hide show
  1. nxdify_node.py +499 -0
nxdify_node.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ import tempfile
5
+ import hashlib
6
+ import asyncio
7
+ import concurrent.futures
8
+ from typing import Tuple, Dict
9
+ from PIL import Image
10
+ import torch
11
+ import numpy as np
12
+ import fal_client as fal
13
+ from fal_client import client
14
+ import aiohttp
15
+
16
+
17
+ class NxdifyNode:
18
+ """
19
+ ComfyUI node for Nxdify image generation using FAL AI Seedream 4.5.
20
+ Takes 4 reference images (Face, Body, Breasts, Dynamic Pose) and generates variations.
21
+ """
22
+
23
+ @classmethod
24
+ def INPUT_TYPES(cls):
25
+ return {
26
+ "required": {
27
+ "face_image": ("IMAGE",),
28
+ "body_image": ("IMAGE",),
29
+ "breasts_image": ("IMAGE",),
30
+ "dynamic_pose_image": ("IMAGE",),
31
+ "prompt": ("STRING", {
32
+ "multiline": True,
33
+ "default": ""
34
+ }),
35
+ "fal_api_key": ("STRING", {
36
+ "default": "",
37
+ "password": True
38
+ }),
39
+ "quality": (["auto_4K", "auto_2K"], {
40
+ "default": "auto_4K"
41
+ }),
42
+ }
43
+ }
44
+
45
+ RETURN_TYPES = ("IMAGE",)
46
+ RETURN_NAMES = ("image",)
47
+ FUNCTION = "execute"
48
+ CATEGORY = "image/generation"
49
+
50
+ MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5MB
51
+ MAX_CONCURRENT = 8
52
+
53
+ # Class-level cache for uploaded reference image URLs (hash -> URL)
54
+ _image_url_cache: Dict[str, str] = {}
55
+
56
+ def compress_image_bytes_max(self, image_bytes: bytes, max_bytes: int) -> bytes:
57
+ """
58
+ Compress image to fit under max_bytes.
59
+ Strategy:
60
+ 1. Try reducing JPEG quality (start at 92, down to 52)
61
+ 2. If still too large, downscale image (start at 100%, down to 45%)
62
+ 3. Repeat until under limit or minimums reached
63
+ """
64
+ if len(image_bytes) <= max_bytes:
65
+ return image_bytes
66
+
67
+ # Convert to PIL Image
68
+ img = Image.open(io.BytesIO(image_bytes))
69
+ img = img.convert("RGB")
70
+ base_w, base_h = img.size
71
+
72
+ quality = 92
73
+ scale = 1.0
74
+
75
+ for _ in range(20): # Max 20 iterations
76
+ w = max(1, int(base_w * scale))
77
+ h = max(1, int(base_h * scale))
78
+
79
+ # Resize if needed
80
+ working = img if (w == base_w and h == base_h) else img.resize((w, h), Image.Resampling.LANCZOS)
81
+
82
+ # Save as JPEG
83
+ buf = io.BytesIO()
84
+ working.save(buf, format="JPEG", quality=quality, optimize=True)
85
+ data = buf.getvalue()
86
+
87
+ if len(data) <= max_bytes:
88
+ return data
89
+
90
+ # Reduce quality first
91
+ if quality > 52:
92
+ quality = max(52, quality - 10)
93
+ continue
94
+
95
+ # Then downscale
96
+ if scale > 0.45:
97
+ scale = scale * 0.85
98
+ quality = 92 # Reset quality
99
+ continue
100
+
101
+ # Can't compress further
102
+ return data
103
+
104
+ return image_bytes
105
+
106
+ def tensor_to_bytes(self, tensor: torch.Tensor) -> bytes:
107
+ """Convert ComfyUI image tensor to JPEG bytes."""
108
+ # ComfyUI IMAGE tensors are in BHWC format (batch, height, width, channels)
109
+ # Remove batch dimension to get HWC
110
+ if len(tensor.shape) == 4:
111
+ img_array = tensor[0].cpu().numpy() # Shape: (height, width, channels)
112
+ else:
113
+ img_array = tensor.cpu().numpy()
114
+
115
+ # Ensure values are in 0-255 range and convert to uint8
116
+ img_array = (np.clip(img_array, 0.0, 1.0) * 255.0).astype(np.uint8)
117
+
118
+ # Handle alpha channel if present
119
+ if img_array.shape[2] == 4:
120
+ # Convert RGBA to RGB with white background
121
+ alpha = img_array[:, :, 3:4].astype(np.float32) / 255.0
122
+ rgb = img_array[:, :, :3].astype(np.float32)
123
+ img_array = (rgb * alpha + 255 * (1 - alpha)).astype(np.uint8)
124
+ elif img_array.shape[2] == 1:
125
+ # Handle grayscale - convert to RGB
126
+ img_array = np.repeat(img_array, 3, axis=2)
127
+
128
+ # Convert to PIL Image
129
+ img = Image.fromarray(img_array)
130
+
131
+ # Convert to RGB if needed
132
+ if img.mode != "RGB":
133
+ img = img.convert("RGB")
134
+
135
+ # Save to bytes
136
+ buf = io.BytesIO()
137
+ img.save(buf, format="JPEG", quality=95, optimize=True)
138
+ return buf.getvalue()
139
+
140
+ def _compute_image_hash(self, image_bytes: bytes) -> str:
141
+ """Compute SHA256 hash of image bytes for caching."""
142
+ return hashlib.sha256(image_bytes).hexdigest()
143
+
144
+ def _upload_file_sync(self, tmp_path: str) -> str:
145
+ """Synchronous wrapper for upload_file to use with asyncio.to_thread."""
146
+ return fal.upload_file(tmp_path)
147
+
148
+ async def upload_ref_with_retry(self, image_bytes: bytes, use_cache: bool = True, max_attempts: int = 3) -> str:
149
+ """Upload image with retry on timeout. Optionally use cache to avoid re-uploading."""
150
+ upload_start = time.time()
151
+ original_size = len(image_bytes)
152
+
153
+ # Check cache first if enabled
154
+ if use_cache:
155
+ image_hash = self._compute_image_hash(image_bytes)
156
+ if image_hash in self._image_url_cache:
157
+ print(f"[Nxdify] Image found in cache (hash: {image_hash[:16]}...), skipping upload")
158
+ return self._image_url_cache[image_hash]
159
+
160
+ # Compress image first
161
+ print(f"[Nxdify] Compressing image (original: {original_size} bytes)...")
162
+ compressed = self.compress_image_bytes_max(image_bytes, self.MAX_IMAGE_SIZE)
163
+ compression_ratio = (1 - len(compressed) / original_size) * 100 if original_size > 0 else 0
164
+ print(f"[Nxdify] Compressed to {len(compressed)} bytes ({compression_ratio:.1f}% reduction)")
165
+
166
+ # Create temporary file
167
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
168
+ tmp.write(compressed)
169
+ tmp_path = tmp.name
170
+
171
+ timeout_errors = []
172
+ try:
173
+ for attempt in range(max_attempts):
174
+ try:
175
+ print(f"[Nxdify] Uploading image (attempt {attempt + 1}/{max_attempts})...")
176
+ attempt_start = time.time()
177
+
178
+ # Upload to FAL (uses FAL_KEY environment variable set in process_async)
179
+ # upload_file expects a file path, not a BytesIO object
180
+ result = await asyncio.to_thread(self._upload_file_sync, tmp_path)
181
+
182
+ attempt_elapsed = time.time() - attempt_start
183
+ print(f"[Nxdify] Upload completed in {attempt_elapsed:.2f} seconds")
184
+
185
+ if isinstance(result, dict) and "url" in result:
186
+ url = result["url"]
187
+ elif isinstance(result, str):
188
+ url = result
189
+ else:
190
+ raise ValueError(f"Unexpected upload response: {result}")
191
+
192
+ # Cache the URL if caching is enabled
193
+ if use_cache:
194
+ image_hash = self._compute_image_hash(image_bytes)
195
+ self._image_url_cache[image_hash] = url
196
+
197
+ total_upload_time = time.time() - upload_start
198
+ print(f"[Nxdify] Image upload successful (total time: {total_upload_time:.2f} seconds)")
199
+ return url
200
+
201
+ except Exception as e:
202
+ # Check if this is a 408 Request Timeout error
203
+ error_str = str(e)
204
+ error_lower = error_str.lower()
205
+ is_408_timeout = (
206
+ "408" in error_str or
207
+ "request timeout" in error_lower or
208
+ "http/1.1 408" in error_lower or
209
+ "http 408" in error_lower
210
+ )
211
+
212
+ # Check for other timeout errors
213
+ is_timeout = (
214
+ is_408_timeout or
215
+ "timeout" in error_lower or
216
+ isinstance(e, (TimeoutError, asyncio.TimeoutError)) or
217
+ (isinstance(e, aiohttp.ClientError) and "timeout" in error_lower)
218
+ )
219
+
220
+ if is_408_timeout:
221
+ timeout_errors.append(f"Attempt {attempt + 1}: HTTP 408 Request Timeout")
222
+
223
+ # If this is the last attempt and we had 408 timeouts, raise helpful exception
224
+ if attempt == max_attempts - 1:
225
+ if timeout_errors:
226
+ print(f"[Nxdify] Upload failed after {max_attempts} attempts")
227
+ raise RuntimeError(
228
+ f"Upload timed out after {max_attempts} attempts with HTTP 408 Request Timeout errors. "
229
+ f"The image may be too large. Please resize the image to a smaller resolution and try again. "
230
+ f"Errors: {'; '.join(timeout_errors)}"
231
+ )
232
+ print(f"[Nxdify] Upload failed on final attempt: {e}")
233
+ raise
234
+
235
+ # If timeout error, retry with backoff
236
+ if is_timeout:
237
+ backoff = 2 + attempt * 3 # Exponential backoff: 2s, 5s, 8s
238
+ print(f"[Nxdify] Upload timeout error (attempt {attempt + 1}): {error_str[:100]}. Retrying in {backoff} seconds...")
239
+ await asyncio.sleep(backoff)
240
+ continue
241
+
242
+ # Non-timeout error, fail immediately
243
+ print(f"[Nxdify] Upload failed with non-timeout error: {error_str[:100]}")
244
+ raise
245
+ finally:
246
+ # Clean up temp file
247
+ try:
248
+ os.unlink(tmp_path)
249
+ except OSError:
250
+ pass
251
+
252
+ def _subscribe_sync(self, endpoint: str, arguments: dict):
253
+ """Subscribe to FAL API job synchronously (handles submit + polling internally)."""
254
+ print(f"[Nxdify] Submitting job to FAL API: {endpoint}")
255
+ start_time = time.time()
256
+ result = fal.subscribe(endpoint, arguments=arguments, with_logs=False)
257
+ elapsed = time.time() - start_time
258
+ print(f"[Nxdify] FAL API job completed in {elapsed:.2f} seconds")
259
+ return result
260
+
261
+ async def generate_image(
262
+ self,
263
+ face_url: str,
264
+ body_url: str,
265
+ breasts_url: str,
266
+ dynamic_pose_url: str,
267
+ prompt: str,
268
+ quality: str
269
+ ) -> Image.Image:
270
+ """Generate image using FAL AI Seedream 4.5 API."""
271
+ print(f"[Nxdify] Starting image generation with quality: {quality}")
272
+ print(f"[Nxdify] Reference images: face={face_url[:50]}..., body={body_url[:50]}..., breasts={breasts_url[:50]}..., pose={dynamic_pose_url[:50]}...")
273
+
274
+ image_urls = [face_url, body_url, breasts_url, dynamic_pose_url]
275
+
276
+ arguments = {
277
+ "prompt": prompt,
278
+ "image_size": quality,
279
+ "num_images": 1,
280
+ "max_images": 1,
281
+ "enable_safety_checker": False,
282
+ "image_urls": image_urls
283
+ }
284
+
285
+ print(f"[Nxdify] Calling FAL API subscribe (this will poll internally)...")
286
+ subscribe_start = time.time()
287
+
288
+ # Use subscribe which handles submit + polling internally
289
+ result = await asyncio.to_thread(
290
+ self._subscribe_sync,
291
+ "fal-ai/bytedance/seedream/v4.5/edit",
292
+ arguments
293
+ )
294
+
295
+ subscribe_elapsed = time.time() - subscribe_start
296
+ print(f"[Nxdify] Subscribe call returned after {subscribe_elapsed:.2f} seconds")
297
+
298
+ if not result:
299
+ raise ValueError("No result returned from FAL AI API")
300
+
301
+ print(f"[Nxdify] Processing result (type: {type(result).__name__})...")
302
+
303
+ # Extract images from result (handle different response structures)
304
+ images = None
305
+ if isinstance(result, dict):
306
+ if "images" in result:
307
+ images = result["images"]
308
+ print(f"[Nxdify] Found {len(images)} image(s) in result['images']")
309
+ elif "output" in result and isinstance(result["output"], dict):
310
+ images = result["output"].get("images")
311
+ print(f"[Nxdify] Found {len(images) if images else 0} image(s) in result['output']['images']")
312
+
313
+ if not images or len(images) == 0:
314
+ raise ValueError("No images returned from FAL AI API")
315
+
316
+ # Handle both dict and string image URLs
317
+ if isinstance(images[0], dict):
318
+ image_url = images[0].get("url") or images[0].get("image_url")
319
+ else:
320
+ image_url = images[0]
321
+
322
+ if not image_url:
323
+ raise ValueError("No image URL in result")
324
+
325
+ print(f"[Nxdify] Image URL extracted: {image_url[:80]}...")
326
+ print(f"[Nxdify] Downloading generated image...")
327
+ download_start = time.time()
328
+
329
+ # Download image
330
+ async with aiohttp.ClientSession() as session:
331
+ async with session.get(image_url) as response:
332
+ if response.status != 200:
333
+ raise ValueError(f"Failed to download image: HTTP {response.status}")
334
+ image_bytes = await response.read()
335
+
336
+ download_elapsed = time.time() - download_start
337
+ print(f"[Nxdify] Image downloaded ({len(image_bytes)} bytes) in {download_elapsed:.2f} seconds")
338
+
339
+ # Convert to PIL Image
340
+ print(f"[Nxdify] Converting to PIL Image...")
341
+ img = Image.open(io.BytesIO(image_bytes))
342
+ final_img = img.convert("RGB")
343
+ print(f"[Nxdify] Image generation complete. Final size: {final_img.size}")
344
+ return final_img
345
+
346
+ def pil_to_tensor(self, img: Image.Image) -> torch.Tensor:
347
+ """Convert PIL Image to ComfyUI tensor format (BHWC)."""
348
+ # Convert to RGB if needed
349
+ if img.mode != "RGB":
350
+ img = img.convert("RGB")
351
+
352
+ # Convert to numpy array (HWC format: height, width, channels)
353
+ img_array = np.array(img).astype(np.float32) / 255.0
354
+
355
+ # Ensure shape is (height, width, channels)
356
+ if len(img_array.shape) == 2:
357
+ # Grayscale - add channel dimension
358
+ img_array = np.expand_dims(img_array, axis=2)
359
+ img_array = np.repeat(img_array, 3, axis=2) # Convert to RGB
360
+
361
+ # Add batch dimension to get BHWC format: (batch, height, width, channels)
362
+ tensor = torch.from_numpy(img_array)[None,]
363
+ return tensor
364
+
365
+ async def process_async(
366
+ self,
367
+ face_image: torch.Tensor,
368
+ body_image: torch.Tensor,
369
+ breasts_image: torch.Tensor,
370
+ dynamic_pose_image: torch.Tensor,
371
+ prompt: str,
372
+ fal_api_key: str,
373
+ quality: str
374
+ ) -> torch.Tensor:
375
+ """Async processing function."""
376
+ process_start = time.time()
377
+ print(f"[Nxdify] ===== Starting Nxdify image generation process =====")
378
+
379
+ if not fal_api_key:
380
+ raise ValueError("FAL API key is required")
381
+
382
+ if not prompt:
383
+ raise ValueError("Prompt is required")
384
+
385
+ print(f"[Nxdify] Converting input tensors to bytes...")
386
+ # Convert tensors to bytes
387
+ face_bytes = self.tensor_to_bytes(face_image)
388
+ body_bytes = self.tensor_to_bytes(body_image)
389
+ breasts_bytes = self.tensor_to_bytes(breasts_image)
390
+ dynamic_pose_bytes = self.tensor_to_bytes(dynamic_pose_image)
391
+ print(f"[Nxdify] Image sizes: face={len(face_bytes)} bytes, body={len(body_bytes)} bytes, breasts={len(breasts_bytes)} bytes, pose={len(dynamic_pose_bytes)} bytes")
392
+
393
+ # Set FAL API key as environment variable (FAL SDK reads from env)
394
+ os.environ["FAL_KEY"] = fal_api_key
395
+ print(f"[Nxdify] FAL API key configured")
396
+
397
+ print(f"[Nxdify] Uploading reference images...")
398
+ upload_start = time.time()
399
+
400
+ # Upload fixed references (cached - only upload if changed)
401
+ print(f"[Nxdify] Uploading face image (cached)...")
402
+ face_url = await self.upload_ref_with_retry(face_bytes, use_cache=True)
403
+ print(f"[Nxdify] Uploading body image (cached)...")
404
+ body_url = await self.upload_ref_with_retry(body_bytes, use_cache=True)
405
+ print(f"[Nxdify] Uploading breasts image (cached)...")
406
+ breasts_url = await self.upload_ref_with_retry(breasts_bytes, use_cache=True)
407
+
408
+ # Upload dynamic pose image (not cached - always upload)
409
+ print(f"[Nxdify] Uploading dynamic pose image (not cached)...")
410
+ dynamic_pose_url = await self.upload_ref_with_retry(dynamic_pose_bytes, use_cache=False)
411
+
412
+ upload_elapsed = time.time() - upload_start
413
+ print(f"[Nxdify] All images uploaded in {upload_elapsed:.2f} seconds")
414
+
415
+ # Generate image
416
+ print(f"[Nxdify] Starting image generation...")
417
+ generation_start = time.time()
418
+ generated_img = await self.generate_image(
419
+ face_url,
420
+ body_url,
421
+ breasts_url,
422
+ dynamic_pose_url,
423
+ prompt,
424
+ quality
425
+ )
426
+ generation_elapsed = time.time() - generation_start
427
+ print(f"[Nxdify] Image generation completed in {generation_elapsed:.2f} seconds")
428
+
429
+ # Convert to tensor
430
+ print(f"[Nxdify] Converting PIL image to tensor...")
431
+ result = self.pil_to_tensor(generated_img)
432
+
433
+ total_elapsed = time.time() - process_start
434
+ print(f"[Nxdify] ===== Total process time: {total_elapsed:.2f} seconds =====")
435
+ return result
436
+
437
+ def execute(
438
+ self,
439
+ face_image: torch.Tensor,
440
+ body_image: torch.Tensor,
441
+ breasts_image: torch.Tensor,
442
+ dynamic_pose_image: torch.Tensor,
443
+ prompt: str,
444
+ fal_api_key: str,
445
+ quality: str
446
+ ) -> Tuple[torch.Tensor]:
447
+ """Execute the node (synchronous wrapper for async processing)."""
448
+ # Handle event loop - check if one exists
449
+ try:
450
+ loop = asyncio.get_event_loop()
451
+ if loop.is_running():
452
+ # If loop is running, we need to use a different approach
453
+ # Create a new event loop in a thread
454
+ with concurrent.futures.ThreadPoolExecutor() as executor:
455
+ future = executor.submit(asyncio.run, self.process_async(
456
+ face_image,
457
+ body_image,
458
+ breasts_image,
459
+ dynamic_pose_image,
460
+ prompt,
461
+ fal_api_key,
462
+ quality
463
+ ))
464
+ result = future.result()
465
+ else:
466
+ # Loop exists but not running, use it
467
+ result = loop.run_until_complete(self.process_async(
468
+ face_image,
469
+ body_image,
470
+ breasts_image,
471
+ dynamic_pose_image,
472
+ prompt,
473
+ fal_api_key,
474
+ quality
475
+ ))
476
+ except RuntimeError:
477
+ # No event loop, create one
478
+ result = asyncio.run(self.process_async(
479
+ face_image,
480
+ body_image,
481
+ breasts_image,
482
+ dynamic_pose_image,
483
+ prompt,
484
+ fal_api_key,
485
+ quality
486
+ ))
487
+
488
+ return (result,)
489
+
490
+
491
+ # Node export
492
+ NODE_CLASS_MAPPINGS = {
493
+ "NxdifyNode": NxdifyNode
494
+ }
495
+
496
+ NODE_DISPLAY_NAME_MAPPINGS = {
497
+ "NxdifyNode": "Nxdify Image Generation"
498
+ }
499
+