| import torch |
| import asyncio |
| from typing import Dict |
| from comfy.utils import ProgressBar |
| from comfy_execution.graph_utils import GraphBuilder |
| from comfy.comfy_types.node_typing import ComfyNodeABC |
| from comfy.comfy_types import IO |
|
|
|
|
| class TestAsyncValidation(ComfyNodeABC): |
| """Test node with async VALIDATE_INPUTS.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": ("FLOAT", {"default": 5.0}), |
| "threshold": ("FLOAT", {"default": 10.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "process" |
| CATEGORY = "_for_testing/async" |
|
|
| @classmethod |
| async def VALIDATE_INPUTS(cls, value, threshold): |
| |
| await asyncio.sleep(0.05) |
|
|
| if value > threshold: |
| return f"Value {value} exceeds threshold {threshold}" |
| return True |
|
|
| def process(self, value, threshold): |
| |
| intensity = value / 10.0 |
| image = torch.ones([1, 512, 512, 3]) * intensity |
| return (image,) |
|
|
|
|
| class TestAsyncError(ComfyNodeABC): |
| """Test node that errors during async execution.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": (IO.ANY, {}), |
| "error_after": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = (IO.ANY,) |
| FUNCTION = "error_execution" |
| CATEGORY = "_for_testing/async" |
|
|
| async def error_execution(self, value, error_after): |
| await asyncio.sleep(error_after) |
| raise RuntimeError("Intentional async execution error for testing") |
|
|
|
|
| class TestAsyncValidationError(ComfyNodeABC): |
| """Test node with async validation that always fails.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": ("FLOAT", {"default": 5.0}), |
| "max_value": ("FLOAT", {"default": 10.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "process" |
| CATEGORY = "_for_testing/async" |
|
|
| @classmethod |
| async def VALIDATE_INPUTS(cls, value, max_value): |
| await asyncio.sleep(0.05) |
| |
| if value > max_value: |
| return f"Async validation failed: {value} > {max_value}" |
| return True |
|
|
| def process(self, value, max_value): |
| |
| image = torch.ones([1, 512, 512, 3]) * (value / max_value) |
| return (image,) |
|
|
|
|
| class TestAsyncTimeout(ComfyNodeABC): |
| """Test node that simulates timeout scenarios.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": (IO.ANY, {}), |
| "timeout": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0}), |
| "operation_time": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 10.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = (IO.ANY,) |
| FUNCTION = "timeout_execution" |
| CATEGORY = "_for_testing/async" |
|
|
| async def timeout_execution(self, value, timeout, operation_time): |
| try: |
| |
| await asyncio.wait_for(asyncio.sleep(operation_time), timeout=timeout) |
| return (value,) |
| except asyncio.TimeoutError: |
| raise RuntimeError(f"Operation timed out after {timeout} seconds") |
|
|
|
|
| class TestSyncError(ComfyNodeABC): |
| """Test node that errors synchronously (for mixed sync/async testing).""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": (IO.ANY, {}), |
| }, |
| } |
|
|
| RETURN_TYPES = (IO.ANY,) |
| FUNCTION = "sync_error" |
| CATEGORY = "_for_testing/async" |
|
|
| def sync_error(self, value): |
| raise RuntimeError("Intentional sync execution error for testing") |
|
|
|
|
| class TestAsyncLazyCheck(ComfyNodeABC): |
| """Test node with async check_lazy_status.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "input1": (IO.ANY, {"lazy": True}), |
| "input2": (IO.ANY, {"lazy": True}), |
| "condition": ("BOOLEAN", {"default": True}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "process" |
| CATEGORY = "_for_testing/async" |
|
|
| async def check_lazy_status(self, condition, input1, input2): |
| |
| await asyncio.sleep(0.05) |
|
|
| needed = [] |
| if condition and input1 is None: |
| needed.append("input1") |
| if not condition and input2 is None: |
| needed.append("input2") |
| return needed |
|
|
| def process(self, input1, input2, condition): |
| |
| return (torch.ones([1, 512, 512, 3]),) |
|
|
|
|
| class TestDynamicAsyncGeneration(ComfyNodeABC): |
| """Test node that dynamically generates async nodes.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "image1": ("IMAGE",), |
| "image2": ("IMAGE",), |
| "num_async_nodes": ("INT", {"default": 3, "min": 1, "max": 10}), |
| "sleep_duration": ("FLOAT", {"default": 0.2, "min": 0.1, "max": 1.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "generate_async_workflow" |
| CATEGORY = "_for_testing/async" |
|
|
| def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): |
| g = GraphBuilder() |
|
|
| |
| sleep_nodes = [] |
| for i in range(num_async_nodes): |
| image = image1 if i % 2 == 0 else image2 |
| sleep_node = g.node("TestSleep", value=image, seconds=sleep_duration) |
| sleep_nodes.append(sleep_node) |
|
|
| |
| if len(sleep_nodes) == 1: |
| final_node = sleep_nodes[0] |
| else: |
| avg_inputs = {"input1": sleep_nodes[0].out(0)} |
| for i, node in enumerate(sleep_nodes[1:], 2): |
| avg_inputs[f"input{i}"] = node.out(0) |
| final_node = g.node("TestVariadicAverage", **avg_inputs) |
|
|
| return { |
| "result": (final_node.out(0),), |
| "expand": g.finalize(), |
| } |
|
|
|
|
| class TestAsyncResourceUser(ComfyNodeABC): |
| """Test node that uses resources during async execution.""" |
|
|
| |
| _active_resources: Dict[str, bool] = {} |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": (IO.ANY, {}), |
| "resource_id": ("STRING", {"default": "resource_0"}), |
| "duration": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}), |
| }, |
| } |
|
|
| RETURN_TYPES = (IO.ANY,) |
| FUNCTION = "use_resource" |
| CATEGORY = "_for_testing/async" |
|
|
| async def use_resource(self, value, resource_id, duration): |
| |
| if self._active_resources.get(resource_id, False): |
| raise RuntimeError(f"Resource {resource_id} is already in use!") |
|
|
| |
| self._active_resources[resource_id] = True |
|
|
| try: |
| |
| await asyncio.sleep(duration) |
| return (value,) |
| finally: |
| |
| self._active_resources[resource_id] = False |
|
|
|
|
| class TestAsyncBatchProcessing(ComfyNodeABC): |
| """Test async processing of batched inputs.""" |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "images": ("IMAGE",), |
| "process_time_per_item": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 1.0}), |
| }, |
| "hidden": { |
| "unique_id": "UNIQUE_ID", |
| }, |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| FUNCTION = "process_batch" |
| CATEGORY = "_for_testing/async" |
|
|
| async def process_batch(self, images, process_time_per_item, unique_id): |
| batch_size = images.shape[0] |
| pbar = ProgressBar(batch_size, node_id=unique_id) |
|
|
| |
| processed = [] |
| for i in range(batch_size): |
| |
| await asyncio.sleep(process_time_per_item) |
|
|
| |
| processed_image = 1.0 - images[i:i+1] |
| processed.append(processed_image) |
|
|
| pbar.update(1) |
|
|
| |
| result = torch.cat(processed, dim=0) |
| return (result,) |
|
|
|
|
| class TestAsyncConcurrentLimit(ComfyNodeABC): |
| """Test concurrent execution limits for async nodes.""" |
|
|
| _semaphore = asyncio.Semaphore(2) |
|
|
| @classmethod |
| def INPUT_TYPES(cls): |
| return { |
| "required": { |
| "value": (IO.ANY, {}), |
| "duration": ("FLOAT", {"default": 0.5, "min": 0.1, "max": 2.0}), |
| "node_id": ("INT", {"default": 0}), |
| }, |
| } |
|
|
| RETURN_TYPES = (IO.ANY,) |
| FUNCTION = "limited_execution" |
| CATEGORY = "_for_testing/async" |
|
|
| async def limited_execution(self, value, duration, node_id): |
| async with self._semaphore: |
| |
| await asyncio.sleep(duration) |
| |
| return (value,) |
|
|
|
|
| |
| ASYNC_TEST_NODE_CLASS_MAPPINGS = { |
| "TestAsyncValidation": TestAsyncValidation, |
| "TestAsyncError": TestAsyncError, |
| "TestAsyncValidationError": TestAsyncValidationError, |
| "TestAsyncTimeout": TestAsyncTimeout, |
| "TestSyncError": TestSyncError, |
| "TestAsyncLazyCheck": TestAsyncLazyCheck, |
| "TestDynamicAsyncGeneration": TestDynamicAsyncGeneration, |
| "TestAsyncResourceUser": TestAsyncResourceUser, |
| "TestAsyncBatchProcessing": TestAsyncBatchProcessing, |
| "TestAsyncConcurrentLimit": TestAsyncConcurrentLimit, |
| } |
|
|
| ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS = { |
| "TestAsyncValidation": "Test Async Validation", |
| "TestAsyncError": "Test Async Error", |
| "TestAsyncValidationError": "Test Async Validation Error", |
| "TestAsyncTimeout": "Test Async Timeout", |
| "TestSyncError": "Test Sync Error", |
| "TestAsyncLazyCheck": "Test Async Lazy Check", |
| "TestDynamicAsyncGeneration": "Test Dynamic Async Generation", |
| "TestAsyncResourceUser": "Test Async Resource User", |
| "TestAsyncBatchProcessing": "Test Async Batch Processing", |
| "TestAsyncConcurrentLimit": "Test Async Concurrent Limit", |
| } |
|
|