| """Test that progress updates are properly isolated between WebSocket clients.""" |
|
|
| import json |
| import pytest |
| import time |
| import threading |
| import uuid |
| import websocket |
| from typing import List, Dict, Any |
| from comfy_execution.graph_utils import GraphBuilder |
| from tests.execution.test_execution import ComfyClient |
|
|
|
|
| class ProgressTracker: |
| """Tracks progress messages received by a WebSocket client.""" |
|
|
| def __init__(self, client_id: str): |
| self.client_id = client_id |
| self.progress_messages: List[Dict[str, Any]] = [] |
| self.lock = threading.Lock() |
|
|
| def add_message(self, message: Dict[str, Any]): |
| """Thread-safe addition of progress messages.""" |
| with self.lock: |
| self.progress_messages.append(message) |
|
|
| def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]: |
| """Get all progress messages for a specific prompt_id.""" |
| with self.lock: |
| return [ |
| msg for msg in self.progress_messages |
| if msg.get('data', {}).get('prompt_id') == prompt_id |
| ] |
|
|
| def has_cross_contamination(self, own_prompt_id: str) -> bool: |
| """Check if this client received progress for other prompts.""" |
| with self.lock: |
| for msg in self.progress_messages: |
| msg_prompt_id = msg.get('data', {}).get('prompt_id') |
| if msg_prompt_id and msg_prompt_id != own_prompt_id: |
| return True |
| return False |
|
|
|
|
| class IsolatedClient(ComfyClient): |
| """Extended ComfyClient that tracks all WebSocket messages.""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.progress_tracker = None |
| self.all_messages: List[Dict[str, Any]] = [] |
|
|
| def connect(self, listen='127.0.0.1', port=8188, client_id=None): |
| """Connect with a specific client_id and set up message tracking.""" |
| if client_id is None: |
| client_id = str(uuid.uuid4()) |
| super().connect(listen, port, client_id) |
| self.progress_tracker = ProgressTracker(client_id) |
|
|
| def listen_for_messages(self, duration: float = 5.0): |
| """Listen for WebSocket messages for a specified duration.""" |
| end_time = time.time() + duration |
| self.ws.settimeout(0.5) |
|
|
| while time.time() < end_time: |
| try: |
| out = self.ws.recv() |
| if isinstance(out, str): |
| message = json.loads(out) |
| self.all_messages.append(message) |
|
|
| |
| if message.get('type') == 'progress_state': |
| self.progress_tracker.add_message(message) |
| except websocket.WebSocketTimeoutException: |
| continue |
| except Exception: |
| |
| break |
|
|
|
|
| @pytest.mark.execution |
| class TestProgressIsolation: |
| """Test suite for verifying progress update isolation between clients.""" |
|
|
| @pytest.fixture(scope="class", autouse=True) |
| def _server(self, args_pytest): |
| """Start the ComfyUI server for testing.""" |
| import subprocess |
| pargs = [ |
| 'python', 'main.py', |
| '--output-directory', args_pytest["output_dir"], |
| '--listen', args_pytest["listen"], |
| '--port', str(args_pytest["port"]), |
| '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', |
| '--cpu', |
| ] |
| p = subprocess.Popen(pargs) |
| yield |
| p.kill() |
|
|
| def start_client_with_retry(self, listen: str, port: int, client_id: str = None): |
| """Start client with connection retries.""" |
| client = IsolatedClient() |
| |
| n_tries = 5 |
| for i in range(n_tries): |
| time.sleep(4) |
| try: |
| client.connect(listen, port, client_id) |
| return client |
| except ConnectionRefusedError as e: |
| print(e) |
| print(f"({i+1}/{n_tries}) Retrying...") |
| raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") |
|
|
| def test_progress_isolation_between_clients(self, args_pytest): |
| """Test that progress updates are isolated between different clients.""" |
| listen = args_pytest["listen"] |
| port = args_pytest["port"] |
|
|
| |
| client_a_id = "client_a_" + str(uuid.uuid4()) |
| client_b_id = "client_b_" + str(uuid.uuid4()) |
|
|
| try: |
| |
| client_a = self.start_client_with_retry(listen, port, client_a_id) |
| client_b = self.start_client_with_retry(listen, port, client_b_id) |
|
|
| |
| graph_a = GraphBuilder(prefix="client_a") |
| image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) |
| graph_a.node("PreviewImage", images=image_a.out(0)) |
|
|
| graph_b = GraphBuilder(prefix="client_b") |
| image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) |
| graph_b.node("PreviewImage", images=image_b.out(0)) |
|
|
| |
| prompt_a = graph_a.finalize() |
| prompt_b = graph_b.finalize() |
|
|
| response_a = client_a.queue_prompt(prompt_a) |
| prompt_id_a = response_a['prompt_id'] |
|
|
| response_b = client_b.queue_prompt(prompt_b) |
| prompt_id_b = response_b['prompt_id'] |
|
|
| |
| def listen_client_a(): |
| client_a.listen_for_messages(duration=10.0) |
|
|
| def listen_client_b(): |
| client_b.listen_for_messages(duration=10.0) |
|
|
| thread_a = threading.Thread(target=listen_client_a) |
| thread_b = threading.Thread(target=listen_client_b) |
|
|
| thread_a.start() |
| thread_b.start() |
|
|
| |
| thread_a.join() |
| thread_b.join() |
|
|
| |
| |
| assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \ |
| f"Client A received progress updates for other clients' workflows. " \ |
| f"Expected only {prompt_id_a}, but got messages for multiple prompts." |
|
|
| |
| assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \ |
| f"Client B received progress updates for other clients' workflows. " \ |
| f"Expected only {prompt_id_b}, but got messages for multiple prompts." |
|
|
| |
| client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a) |
| client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b) |
|
|
| assert len(client_a_messages) > 0, \ |
| "Client A did not receive any progress updates for its own workflow" |
| assert len(client_b_messages) > 0, \ |
| "Client B did not receive any progress updates for its own workflow" |
|
|
| |
| client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b) |
| client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a) |
|
|
| assert len(client_a_other) == 0, \ |
| f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow" |
| assert len(client_b_other) == 0, \ |
| f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow" |
|
|
| finally: |
| |
| if hasattr(client_a, 'ws'): |
| client_a.ws.close() |
| if hasattr(client_b, 'ws'): |
| client_b.ws.close() |
|
|
| def test_progress_with_missing_client_id(self, args_pytest): |
| """Test that progress updates handle missing client_id gracefully.""" |
| listen = args_pytest["listen"] |
| port = args_pytest["port"] |
|
|
| try: |
| |
| client = self.start_client_with_retry(listen, port) |
|
|
| |
| graph = GraphBuilder(prefix="test_missing_id") |
| image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1) |
| graph.node("PreviewImage", images=image.out(0)) |
|
|
| |
| prompt = graph.finalize() |
| response = client.queue_prompt(prompt) |
| prompt_id = response['prompt_id'] |
|
|
| |
| client.listen_for_messages(duration=5.0) |
|
|
| |
| messages = client.progress_tracker.get_messages_for_prompt(prompt_id) |
| assert len(messages) > 0, \ |
| "Client did not receive progress updates even though it initiated the workflow" |
|
|
| finally: |
| if hasattr(client, 'ws'): |
| client.ws.close() |
|
|
|
|