|
|
""" |
|
|
E2E tests for Queue-specific Preview Method Override feature. |
|
|
|
|
|
Tests actual execution with different preview_method values. |
|
|
Requires a running ComfyUI server with models. |
|
|
|
|
|
Usage: |
|
|
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method |
|
|
|
|
|
Note: |
|
|
These tests execute actual image generation and wait for completion. |
|
|
Tests verify preview image transmission based on preview_method setting. |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import pytest |
|
|
import uuid |
|
|
import time |
|
|
import random |
|
|
import websocket |
|
|
import urllib.request |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988") |
|
|
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "") |
|
|
|
|
|
|
|
|
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json" |
|
|
|
|
|
|
|
|
def is_server_running() -> bool: |
|
|
"""Check if ComfyUI server is running.""" |
|
|
try: |
|
|
request = urllib.request.Request(f"{SERVER_URL}/system_stats") |
|
|
with urllib.request.urlopen(request, timeout=2.0): |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict: |
|
|
"""Prepare graph for testing: randomize seeds and reduce steps.""" |
|
|
adapted = json.loads(json.dumps(graph)) |
|
|
for node_id, node in adapted.items(): |
|
|
inputs = node.get("inputs", {}) |
|
|
|
|
|
if "seed" in inputs: |
|
|
inputs["seed"] = random.randint(0, 2**32 - 1) |
|
|
if "noise_seed" in inputs: |
|
|
inputs["noise_seed"] = random.randint(0, 2**32 - 1) |
|
|
|
|
|
if "steps" in inputs: |
|
|
inputs["steps"] = steps |
|
|
return adapted |
|
|
|
|
|
|
|
|
|
|
|
randomize_seed = prepare_graph_for_test |
|
|
|
|
|
|
|
|
class PreviewMethodClient: |
|
|
"""Client for testing preview_method with WebSocket execution tracking.""" |
|
|
|
|
|
def __init__(self, server_address: str): |
|
|
self.server_address = server_address |
|
|
self.client_id = str(uuid.uuid4()) |
|
|
self.ws = None |
|
|
|
|
|
def connect(self): |
|
|
"""Connect to WebSocket.""" |
|
|
self.ws = websocket.WebSocket() |
|
|
self.ws.settimeout(120) |
|
|
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}") |
|
|
|
|
|
def close(self): |
|
|
"""Close WebSocket connection.""" |
|
|
if self.ws: |
|
|
self.ws.close() |
|
|
|
|
|
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict: |
|
|
"""Queue a prompt and return response with prompt_id.""" |
|
|
data = { |
|
|
"prompt": prompt, |
|
|
"client_id": self.client_id, |
|
|
"extra_data": extra_data or {} |
|
|
} |
|
|
req = urllib.request.Request( |
|
|
f"http://{self.server_address}/prompt", |
|
|
data=json.dumps(data).encode("utf-8"), |
|
|
headers={"Content-Type": "application/json"} |
|
|
) |
|
|
return json.loads(urllib.request.urlopen(req).read()) |
|
|
|
|
|
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict: |
|
|
""" |
|
|
Wait for execution to complete via WebSocket. |
|
|
|
|
|
Returns: |
|
|
dict with keys: completed, error, preview_count, execution_time |
|
|
""" |
|
|
result = { |
|
|
"completed": False, |
|
|
"error": None, |
|
|
"preview_count": 0, |
|
|
"execution_time": 0.0 |
|
|
} |
|
|
|
|
|
start_time = time.time() |
|
|
self.ws.settimeout(timeout) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
out = self.ws.recv() |
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
if isinstance(out, str): |
|
|
message = json.loads(out) |
|
|
msg_type = message.get("type") |
|
|
data = message.get("data", {}) |
|
|
|
|
|
if data.get("prompt_id") != prompt_id: |
|
|
continue |
|
|
|
|
|
if msg_type == "executing": |
|
|
if data.get("node") is None: |
|
|
|
|
|
result["completed"] = True |
|
|
result["execution_time"] = elapsed |
|
|
break |
|
|
|
|
|
elif msg_type == "execution_error": |
|
|
result["error"] = data |
|
|
result["execution_time"] = elapsed |
|
|
break |
|
|
|
|
|
elif msg_type == "progress": |
|
|
|
|
|
pass |
|
|
|
|
|
elif isinstance(out, bytes): |
|
|
|
|
|
result["preview_count"] += 1 |
|
|
|
|
|
except websocket.WebSocketTimeoutException: |
|
|
result["error"] = "Timeout waiting for execution" |
|
|
result["execution_time"] = time.time() - start_time |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def load_graph() -> dict: |
|
|
"""Load the SDXL graph fixture with randomized seed.""" |
|
|
with open(GRAPH_FILE) as f: |
|
|
graph = json.load(f) |
|
|
return randomize_seed(graph) |
|
|
|
|
|
|
|
|
|
|
|
pytestmark = [ |
|
|
pytest.mark.skipif( |
|
|
not is_server_running(), |
|
|
reason=f"ComfyUI server not running at {SERVER_URL}" |
|
|
), |
|
|
pytest.mark.preview_method, |
|
|
pytest.mark.execution, |
|
|
] |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def client(): |
|
|
"""Create and connect a test client.""" |
|
|
c = PreviewMethodClient(SERVER_HOST) |
|
|
c.connect() |
|
|
yield c |
|
|
c.close() |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def graph(): |
|
|
"""Load the test graph.""" |
|
|
return load_graph() |
|
|
|
|
|
|
|
|
class TestPreviewMethodExecution: |
|
|
"""Test actual execution with different preview methods.""" |
|
|
|
|
|
def test_execution_with_latent2rgb(self, client, graph): |
|
|
""" |
|
|
Execute with preview_method=latent2rgb. |
|
|
Should complete and potentially receive preview images. |
|
|
""" |
|
|
extra_data = {"preview_method": "latent2rgb"} |
|
|
|
|
|
response = client.queue_prompt(graph, extra_data) |
|
|
assert "prompt_id" in response |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
|
|
|
assert result["completed"] or result["error"] is not None |
|
|
|
|
|
if result["completed"]: |
|
|
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run" |
|
|
|
|
|
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") |
|
|
|
|
|
def test_execution_with_taesd(self, client, graph): |
|
|
""" |
|
|
Execute with preview_method=taesd. |
|
|
TAESD provides higher quality previews. |
|
|
""" |
|
|
extra_data = {"preview_method": "taesd"} |
|
|
|
|
|
response = client.queue_prompt(graph, extra_data) |
|
|
assert "prompt_id" in response |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
assert result["completed"] or result["error"] is not None |
|
|
if result["completed"]: |
|
|
assert result["execution_time"] > 0.5 |
|
|
|
|
|
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") |
|
|
|
|
|
def test_execution_with_none_preview(self, client, graph): |
|
|
""" |
|
|
Execute with preview_method=none. |
|
|
No preview images should be generated. |
|
|
""" |
|
|
extra_data = {"preview_method": "none"} |
|
|
|
|
|
response = client.queue_prompt(graph, extra_data) |
|
|
assert "prompt_id" in response |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
assert result["completed"] or result["error"] is not None |
|
|
if result["completed"]: |
|
|
|
|
|
assert result["preview_count"] == 0, \ |
|
|
f"Expected no previews with 'none', got {result['preview_count']}" |
|
|
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") |
|
|
|
|
|
def test_execution_with_default(self, client, graph): |
|
|
""" |
|
|
Execute with preview_method=default. |
|
|
Should use server's CLI default setting. |
|
|
""" |
|
|
extra_data = {"preview_method": "default"} |
|
|
|
|
|
response = client.queue_prompt(graph, extra_data) |
|
|
assert "prompt_id" in response |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
assert result["completed"] or result["error"] is not None |
|
|
if result["completed"]: |
|
|
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") |
|
|
|
|
|
def test_execution_without_preview_method(self, client, graph): |
|
|
""" |
|
|
Execute without preview_method in extra_data. |
|
|
Should use server's default preview method. |
|
|
""" |
|
|
extra_data = {} |
|
|
|
|
|
response = client.queue_prompt(graph, extra_data) |
|
|
assert "prompt_id" in response |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
assert result["completed"] or result["error"] is not None |
|
|
if result["completed"]: |
|
|
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") |
|
|
|
|
|
|
|
|
class TestPreviewMethodComparison: |
|
|
"""Compare preview behavior between different methods.""" |
|
|
|
|
|
def test_none_vs_latent2rgb_preview_count(self, client, graph): |
|
|
""" |
|
|
Compare preview counts: 'none' should have 0, others should have >0. |
|
|
This is the key verification that preview_method actually works. |
|
|
""" |
|
|
results = {} |
|
|
|
|
|
|
|
|
graph_none = randomize_seed(graph) |
|
|
extra_data_none = {"preview_method": "none"} |
|
|
response = client.queue_prompt(graph_none, extra_data_none) |
|
|
results["none"] = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
|
|
|
graph_rgb = randomize_seed(graph) |
|
|
extra_data_rgb = {"preview_method": "latent2rgb"} |
|
|
response = client.queue_prompt(graph_rgb, extra_data_rgb) |
|
|
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"]) |
|
|
|
|
|
|
|
|
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}" |
|
|
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}" |
|
|
|
|
|
|
|
|
assert results["none"]["preview_count"] == 0, \ |
|
|
f"'none' should have 0 previews, got {results['none']['preview_count']}" |
|
|
|
|
|
|
|
|
assert results["latent2rgb"]["preview_count"] > 0, \ |
|
|
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}" |
|
|
|
|
|
print("\nPreview count comparison:") |
|
|
print(f" none: {results['none']['preview_count']} previews") |
|
|
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") |
|
|
|
|
|
|
|
|
class TestPreviewMethodSequential: |
|
|
"""Test sequential execution with different preview methods.""" |
|
|
|
|
|
def test_sequential_different_methods(self, client, graph): |
|
|
""" |
|
|
Execute multiple prompts sequentially with different preview methods. |
|
|
Each should complete independently with correct preview behavior. |
|
|
""" |
|
|
methods = ["latent2rgb", "none", "default"] |
|
|
results = [] |
|
|
|
|
|
for method in methods: |
|
|
|
|
|
graph_run = randomize_seed(graph) |
|
|
extra_data = {"preview_method": method} |
|
|
response = client.queue_prompt(graph_run, extra_data) |
|
|
|
|
|
result = client.wait_for_execution(response["prompt_id"]) |
|
|
results.append({ |
|
|
"method": method, |
|
|
"completed": result["completed"], |
|
|
"preview_count": result["preview_count"], |
|
|
"execution_time": result["execution_time"], |
|
|
"error": result["error"] |
|
|
}) |
|
|
|
|
|
|
|
|
for r in results: |
|
|
assert r["completed"] or r["error"] is not None, \ |
|
|
f"Method {r['method']} neither completed nor errored" |
|
|
|
|
|
|
|
|
none_result = next(r for r in results if r["method"] == "none") |
|
|
if none_result["completed"]: |
|
|
assert none_result["preview_count"] == 0, \ |
|
|
f"'none' should have 0 previews, got {none_result['preview_count']}" |
|
|
|
|
|
print("\nSequential execution results:") |
|
|
for r in results: |
|
|
status = "✓" if r["completed"] else f"✗ ({r['error']})" |
|
|
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") |
|
|
|