| import pytest |
| import time |
| import torch |
| import urllib.error |
| import numpy as np |
| import subprocess |
|
|
| from pytest import fixture |
| from comfy_execution.graph_utils import GraphBuilder |
| from tests.execution.test_execution import ComfyClient, run_warmup |
|
|
|
|
| @pytest.mark.execution |
| class TestAsyncNodes: |
| @fixture(scope="class", autouse=True, params=[ |
| (False, 0), |
| (True, 0), |
| (True, 100), |
| ]) |
| def _server(self, args_pytest, request): |
| pargs = [ |
| 'python','main.py', |
| '--output-directory', args_pytest["output_dir"], |
| '--listen', args_pytest["listen"], |
| '--port', str(args_pytest["port"]), |
| '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', |
| '--cpu', |
| ] |
| use_lru, lru_size = request.param |
| if use_lru: |
| pargs += ['--cache-lru', str(lru_size)] |
| |
| p = subprocess.Popen(pargs) |
| yield |
| p.kill() |
| torch.cuda.empty_cache() |
|
|
| @fixture(scope="class", autouse=True) |
| def shared_client(self, args_pytest, _server): |
| client = ComfyClient() |
| n_tries = 5 |
| for i in range(n_tries): |
| time.sleep(4) |
| try: |
| client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) |
| except ConnectionRefusedError: |
| |
| pass |
| else: |
| break |
| yield client |
| del client |
| torch.cuda.empty_cache() |
|
|
| @fixture |
| def client(self, shared_client, request): |
| shared_client.set_test_name(f"async_nodes[{request.node.name}]") |
| yield shared_client |
|
|
| @fixture |
| def builder(self, request): |
| yield GraphBuilder(prefix=request.node.name) |
|
|
| |
|
|
| def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): |
| """Test that a basic async node executes correctly.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) |
| output = g.node("SaveImage", images=sleep_node.out(0)) |
|
|
| result = client.run(g) |
|
|
| |
| assert result.did_run(sleep_node), "Async sleep node should have executed" |
| assert result.did_run(output), "Output node should have executed" |
|
|
| |
| result_images = result.get_images(output) |
| assert len(result_images) == 1, "Should have 1 image" |
| assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" |
|
|
| def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): |
| """Test that multiple async nodes execute in parallel.""" |
| |
| run_warmup(client) |
|
|
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
|
|
| |
| sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) |
| sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) |
| sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) |
|
|
| |
| _output1 = g.node("PreviewImage", images=sleep1.out(0)) |
| _output2 = g.node("PreviewImage", images=sleep2.out(0)) |
| _output3 = g.node("PreviewImage", images=sleep3.out(0)) |
|
|
| start_time = time.time() |
| result = client.run(g) |
| elapsed_time = time.time() - start_time |
|
|
| |
| if not skip_timing_checks: |
| assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" |
|
|
| |
| assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) |
|
|
| def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): |
| """Test async nodes with proper dependency handling.""" |
| g = builder |
| image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
|
|
| |
| sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) |
| sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) |
|
|
| |
| average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) |
| output = g.node("SaveImage", images=average.out(0)) |
|
|
| result = client.run(g) |
|
|
| |
| assert result.did_run(sleep1) and result.did_run(sleep2) |
| assert result.did_run(average) and result.did_run(output) |
|
|
| |
| result_images = result.get_images(output) |
| avg_value = np.array(result_images[0]).mean() |
| assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" |
|
|
| def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): |
| """Test async VALIDATE_INPUTS function.""" |
| g = builder |
| |
| validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) |
| g.node("SaveImage", images=validation_node.out(0)) |
|
|
| |
| result = client.run(g) |
| assert result.did_run(validation_node) |
|
|
| |
| validation_node.inputs['threshold'] = 3.0 |
| with pytest.raises(urllib.error.HTTPError): |
| client.run(g) |
|
|
| def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): |
| """Test async nodes with lazy evaluation.""" |
| |
| run_warmup(client, prefix="warmup_lazy") |
|
|
| g = builder |
| input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
| mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) |
|
|
| |
| sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) |
| sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) |
|
|
| |
| lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) |
| g.node("SaveImage", images=lazy_mix.out(0)) |
|
|
| start_time = time.time() |
| result = client.run(g) |
| elapsed_time = time.time() - start_time |
|
|
| |
| if not skip_timing_checks: |
| assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" |
| assert result.did_run(sleep1), "Sleep1 should have executed" |
| assert not result.did_run(sleep2), "Sleep2 should have been skipped" |
|
|
| def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): |
| """Test async check_lazy_status function.""" |
| g = builder |
| |
| lazy_node = g.node("TestAsyncLazyCheck", |
| input1="value1", |
| input2="value2", |
| condition=True) |
| g.node("SaveImage", images=lazy_node.out(0)) |
|
|
| result = client.run(g) |
| assert result.did_run(lazy_node) |
|
|
| |
|
|
| def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): |
| """Test that async execution errors are properly handled.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| |
| error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) |
| g.node("SaveImage", images=error_node.out(0)) |
|
|
| try: |
| client.run(g) |
| assert False, "Should have raised an error" |
| except Exception as e: |
| assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" |
| assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" |
|
|
| def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): |
| """Test async validation error handling.""" |
| g = builder |
| |
| validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) |
| g.node("SaveImage", images=validation_node.out(0)) |
|
|
| with pytest.raises(urllib.error.HTTPError) as exc_info: |
| client.run(g) |
| |
| assert exc_info.value.code == 400 |
|
|
| def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): |
| """Test handling of async operations that timeout.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| |
| timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) |
| g.node("SaveImage", images=timeout_node.out(0)) |
|
|
| try: |
| client.run(g) |
| assert False, "Should have raised a timeout error" |
| except Exception as e: |
| assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" |
|
|
| def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): |
| """Test that workflow can recover after async errors.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
|
|
| |
| error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) |
| g.node("SaveImage", images=error_node.out(0)) |
|
|
| try: |
| client.run(g) |
| except Exception: |
| pass |
|
|
| |
| g2 = GraphBuilder(prefix="recovery_test") |
| image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
| sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) |
| g2.node("SaveImage", images=sleep_node.out(0)) |
|
|
| result = client.run(g2) |
| assert result.did_run(sleep_node), "Should be able to run after error" |
|
|
| def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): |
| """Test handling when sync node errors while async node is executing.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
|
|
| |
| sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) |
|
|
| |
| error_node = g.node("TestSyncError", value=image.out(0)) |
|
|
| |
| g.node("PreviewImage", images=sleep_node.out(0)) |
| g.node("PreviewImage", images=error_node.out(0)) |
|
|
| try: |
| client.run(g) |
| assert False, "Should have raised an error" |
| except Exception as e: |
| |
| assert 'prompt_id' in e.args[0] |
|
|
| |
|
|
| def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): |
| """Test async nodes with execution blockers.""" |
| g = builder |
| image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
|
|
| |
| sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) |
| sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) |
|
|
| |
| image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) |
|
|
| |
| int1 = g.node("StubInt", value=1) |
| int2 = g.node("StubInt", value=2) |
| block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) |
|
|
| |
| compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") |
|
|
| |
| blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) |
|
|
| output = g.node("PreviewImage", images=blocker.out(0)) |
|
|
| result = client.run(g) |
| images = result.get_images(output) |
| assert len(images) == 1, "Should have blocked second image" |
|
|
| def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): |
| """Test that async nodes are properly cached.""" |
| |
| run_warmup(client, prefix="warmup_cache") |
|
|
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) |
| g.node("SaveImage", images=sleep_node.out(0)) |
|
|
| |
| result1 = client.run(g) |
| assert result1.did_run(sleep_node), "Should run first time" |
|
|
| |
| start_time = time.time() |
| result2 = client.run(g) |
| elapsed_time = time.time() - start_time |
|
|
| assert not result2.did_run(sleep_node), "Should be cached" |
| if not skip_timing_checks: |
| assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" |
|
|
| def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): |
| """Test async nodes within dynamically generated prompts.""" |
| |
| run_warmup(client, prefix="warmup_dynamic") |
|
|
| g = builder |
| image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
|
|
| |
| dynamic_async = g.node("TestDynamicAsyncGeneration", |
| image1=image1.out(0), |
| image2=image2.out(0), |
| num_async_nodes=5, |
| sleep_duration=0.4) |
| g.node("SaveImage", images=dynamic_async.out(0)) |
|
|
| start_time = time.time() |
| result = client.run(g) |
| elapsed_time = time.time() - start_time |
|
|
| |
| if not skip_timing_checks: |
| assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s" |
| assert result.did_run(dynamic_async) |
|
|
| def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): |
| """Test that async resources are properly cleaned up.""" |
| g = builder |
| image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
|
|
| |
| resource_nodes = [] |
| for i in range(5): |
| node = g.node("TestAsyncResourceUser", |
| value=image.out(0), |
| resource_id=f"resource_{i}", |
| duration=0.1) |
| resource_nodes.append(node) |
| g.node("PreviewImage", images=node.out(0)) |
|
|
| result = client.run(g) |
|
|
| |
| for node in resource_nodes: |
| assert result.did_run(node) |
|
|
| |
| result2 = client.run(g) |
| |
| for node in resource_nodes: |
| assert not result2.did_run(node), "Should be cached" |
|
|
| def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): |
| """Test cancellation of async operations.""" |
| |
| |
| pass |
|
|
| def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): |
| """Test workflows with both sync and async nodes.""" |
| g = builder |
| image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) |
| image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) |
| mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) |
|
|
| |
| |
| sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) |
| |
| async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) |
| |
| sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) |
| |
| async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) |
|
|
| output = g.node("SaveImage", images=async_op2.out(0)) |
|
|
| result = client.run(g) |
|
|
| |
| assert result.did_run(sync_op1) |
| assert result.did_run(async_op1) |
| assert result.did_run(sync_op2) |
| assert result.did_run(async_op2) |
|
|
| |
| result_images = result.get_images(output) |
| avg_value = np.array(result_images[0]).mean() |
| assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75" |
|
|