Spaces:
Running
Running
| """ | |
| E2E tests for Queue-specific Preview Method Override feature. | |
| Tests actual execution with different preview_method values. | |
| Requires a running ComfyUI server with models. | |
| Usage: | |
| COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method | |
| Note: | |
| These tests execute actual image generation and wait for completion. | |
| Tests verify preview image transmission based on preview_method setting. | |
| """ | |
| import os | |
| import json | |
| import pytest | |
| import uuid | |
| import time | |
| import random | |
| import websocket | |
| import urllib.request | |
| from pathlib import Path | |
| # Server configuration | |
| SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988") | |
| SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "") | |
| # Use existing inference graph fixture | |
| GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json" | |
| def is_server_running() -> bool: | |
| """Check if ComfyUI server is running.""" | |
| try: | |
| request = urllib.request.Request(f"{SERVER_URL}/system_stats") | |
| with urllib.request.urlopen(request, timeout=2.0): | |
| return True | |
| except Exception: | |
| return False | |
| def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict: | |
| """Prepare graph for testing: randomize seeds and reduce steps.""" | |
| adapted = json.loads(json.dumps(graph)) # Deep copy | |
| for node_id, node in adapted.items(): | |
| inputs = node.get("inputs", {}) | |
| # Handle both "seed" and "noise_seed" (used by KSamplerAdvanced) | |
| if "seed" in inputs: | |
| inputs["seed"] = random.randint(0, 2**32 - 1) | |
| if "noise_seed" in inputs: | |
| inputs["noise_seed"] = random.randint(0, 2**32 - 1) | |
| # Reduce steps for faster testing (default 20 -> 5) | |
| if "steps" in inputs: | |
| inputs["steps"] = steps | |
| return adapted | |
| # Alias for backward compatibility | |
| randomize_seed = prepare_graph_for_test | |
| class PreviewMethodClient: | |
| """Client for testing preview_method with WebSocket execution tracking.""" | |
| def __init__(self, server_address: str): | |
| self.server_address = server_address | |
| self.client_id = str(uuid.uuid4()) | |
| self.ws = None | |
| def connect(self): | |
| """Connect to WebSocket.""" | |
| self.ws = websocket.WebSocket() | |
| self.ws.settimeout(120) # 2 minute timeout for sampling | |
| self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}") | |
| def close(self): | |
| """Close WebSocket connection.""" | |
| if self.ws: | |
| self.ws.close() | |
| def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict: | |
| """Queue a prompt and return response with prompt_id.""" | |
| data = { | |
| "prompt": prompt, | |
| "client_id": self.client_id, | |
| "extra_data": extra_data or {} | |
| } | |
| req = urllib.request.Request( | |
| f"http://{self.server_address}/prompt", | |
| data=json.dumps(data).encode("utf-8"), | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| return json.loads(urllib.request.urlopen(req).read()) | |
| def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict: | |
| """ | |
| Wait for execution to complete via WebSocket. | |
| Returns: | |
| dict with keys: completed, error, preview_count, execution_time | |
| """ | |
| result = { | |
| "completed": False, | |
| "error": None, | |
| "preview_count": 0, | |
| "execution_time": 0.0 | |
| } | |
| start_time = time.time() | |
| self.ws.settimeout(timeout) | |
| try: | |
| while True: | |
| out = self.ws.recv() | |
| elapsed = time.time() - start_time | |
| if isinstance(out, str): | |
| message = json.loads(out) | |
| msg_type = message.get("type") | |
| data = message.get("data", {}) | |
| if data.get("prompt_id") != prompt_id: | |
| continue | |
| if msg_type == "executing": | |
| if data.get("node") is None: | |
| # Execution complete | |
| result["completed"] = True | |
| result["execution_time"] = elapsed | |
| break | |
| elif msg_type == "execution_error": | |
| result["error"] = data | |
| result["execution_time"] = elapsed | |
| break | |
| elif msg_type == "progress": | |
| # Progress update during sampling | |
| pass | |
| elif isinstance(out, bytes): | |
| # Binary data = preview image | |
| result["preview_count"] += 1 | |
| except websocket.WebSocketTimeoutException: | |
| result["error"] = "Timeout waiting for execution" | |
| result["execution_time"] = time.time() - start_time | |
| return result | |
| def load_graph() -> dict: | |
| """Load the SDXL graph fixture with randomized seed.""" | |
| with open(GRAPH_FILE) as f: | |
| graph = json.load(f) | |
| return randomize_seed(graph) # Avoid caching | |
| # Skip all tests if server is not running | |
| pytestmark = [ | |
| pytest.mark.skipif( | |
| not is_server_running(), | |
| reason=f"ComfyUI server not running at {SERVER_URL}" | |
| ), | |
| pytest.mark.preview_method, | |
| pytest.mark.execution, | |
| ] | |
| def client(): | |
| """Create and connect a test client.""" | |
| c = PreviewMethodClient(SERVER_HOST) | |
| c.connect() | |
| yield c | |
| c.close() | |
| def graph(): | |
| """Load the test graph.""" | |
| return load_graph() | |
| class TestPreviewMethodExecution: | |
| """Test actual execution with different preview methods.""" | |
| def test_execution_with_latent2rgb(self, client, graph): | |
| """ | |
| Execute with preview_method=latent2rgb. | |
| Should complete and potentially receive preview images. | |
| """ | |
| extra_data = {"preview_method": "latent2rgb"} | |
| response = client.queue_prompt(graph, extra_data) | |
| assert "prompt_id" in response | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| # Should complete (may error if model missing, but that's separate) | |
| assert result["completed"] or result["error"] is not None | |
| # Execution should take some time (sampling) | |
| if result["completed"]: | |
| assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run" | |
| # latent2rgb should produce previews | |
| print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 | |
| def test_execution_with_taesd(self, client, graph): | |
| """ | |
| Execute with preview_method=taesd. | |
| TAESD provides higher quality previews. | |
| """ | |
| extra_data = {"preview_method": "taesd"} | |
| response = client.queue_prompt(graph, extra_data) | |
| assert "prompt_id" in response | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| assert result["completed"] or result["error"] is not None | |
| if result["completed"]: | |
| assert result["execution_time"] > 0.5 | |
| # taesd should also produce previews | |
| print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 | |
| def test_execution_with_none_preview(self, client, graph): | |
| """ | |
| Execute with preview_method=none. | |
| No preview images should be generated. | |
| """ | |
| extra_data = {"preview_method": "none"} | |
| response = client.queue_prompt(graph, extra_data) | |
| assert "prompt_id" in response | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| assert result["completed"] or result["error"] is not None | |
| if result["completed"]: | |
| # With "none", should receive no preview images | |
| assert result["preview_count"] == 0, \ | |
| f"Expected no previews with 'none', got {result['preview_count']}" | |
| print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 | |
| def test_execution_with_default(self, client, graph): | |
| """ | |
| Execute with preview_method=default. | |
| Should use server's CLI default setting. | |
| """ | |
| extra_data = {"preview_method": "default"} | |
| response = client.queue_prompt(graph, extra_data) | |
| assert "prompt_id" in response | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| assert result["completed"] or result["error"] is not None | |
| if result["completed"]: | |
| print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 | |
| def test_execution_without_preview_method(self, client, graph): | |
| """ | |
| Execute without preview_method in extra_data. | |
| Should use server's default preview method. | |
| """ | |
| extra_data = {} # No preview_method | |
| response = client.queue_prompt(graph, extra_data) | |
| assert "prompt_id" in response | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| assert result["completed"] or result["error"] is not None | |
| if result["completed"]: | |
| print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201 | |
| class TestPreviewMethodComparison: | |
| """Compare preview behavior between different methods.""" | |
| def test_none_vs_latent2rgb_preview_count(self, client, graph): | |
| """ | |
| Compare preview counts: 'none' should have 0, others should have >0. | |
| This is the key verification that preview_method actually works. | |
| """ | |
| results = {} | |
| # Run with none (randomize seed to avoid caching) | |
| graph_none = randomize_seed(graph) | |
| extra_data_none = {"preview_method": "none"} | |
| response = client.queue_prompt(graph_none, extra_data_none) | |
| results["none"] = client.wait_for_execution(response["prompt_id"]) | |
| # Run with latent2rgb (randomize seed again) | |
| graph_rgb = randomize_seed(graph) | |
| extra_data_rgb = {"preview_method": "latent2rgb"} | |
| response = client.queue_prompt(graph_rgb, extra_data_rgb) | |
| results["latent2rgb"] = client.wait_for_execution(response["prompt_id"]) | |
| # Verify both completed | |
| assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}" | |
| assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}" | |
| # Key assertion: 'none' should have 0 previews | |
| assert results["none"]["preview_count"] == 0, \ | |
| f"'none' should have 0 previews, got {results['none']['preview_count']}" | |
| # 'latent2rgb' should have at least 1 preview (depends on steps) | |
| assert results["latent2rgb"]["preview_count"] > 0, \ | |
| f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}" | |
| print("\nPreview count comparison:") # noqa: T201 | |
| print(f" none: {results['none']['preview_count']} previews") # noqa: T201 | |
| print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201 | |
| class TestPreviewMethodSequential: | |
| """Test sequential execution with different preview methods.""" | |
| def test_sequential_different_methods(self, client, graph): | |
| """ | |
| Execute multiple prompts sequentially with different preview methods. | |
| Each should complete independently with correct preview behavior. | |
| """ | |
| methods = ["latent2rgb", "none", "default"] | |
| results = [] | |
| for method in methods: | |
| # Randomize seed for each execution to avoid caching | |
| graph_run = randomize_seed(graph) | |
| extra_data = {"preview_method": method} | |
| response = client.queue_prompt(graph_run, extra_data) | |
| result = client.wait_for_execution(response["prompt_id"]) | |
| results.append({ | |
| "method": method, | |
| "completed": result["completed"], | |
| "preview_count": result["preview_count"], | |
| "execution_time": result["execution_time"], | |
| "error": result["error"] | |
| }) | |
| # All should complete or have clear errors | |
| for r in results: | |
| assert r["completed"] or r["error"] is not None, \ | |
| f"Method {r['method']} neither completed nor errored" | |
| # "none" should have zero previews if completed | |
| none_result = next(r for r in results if r["method"] == "none") | |
| if none_result["completed"]: | |
| assert none_result["preview_count"] == 0, \ | |
| f"'none' should have 0 previews, got {none_result['preview_count']}" | |
| print("\nSequential execution results:") # noqa: T201 | |
| for r in results: | |
| status = "✓" if r["completed"] else f"✗ ({r['error']})" | |
| print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201 | |