VoiceGate / scripts /validate_workflow_connections.py
YanTianlong's picture
Add full VoiceGate workflow test
359a8b1
Raw
History Blame Contribute Delete
4.87 kB
"""Validate the API workflow connections against the ComfyUI UI workflow.
ComfyUI API exports collapse some UI-only SetNode/GetNode routing pairs. This
validator resolves those pairs before comparing links so the check reflects the
actual graph semantics instead of the visual routing helpers.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
DEFAULT_API = ROOT / "workflows" / "voicegate_api.json"
DEFAULT_UI = ROOT / "VoiceGate" / "workflows" / "VoiceGate-Workflow.json"
def load_json(path: Path) -> Any:
with path.open("r", encoding="utf-8") as file:
return json.load(file)
def is_connection(value: Any) -> bool:
return (
isinstance(value, list)
and len(value) == 2
and isinstance(value[0], str)
and isinstance(value[1], int)
)
def build_expected_connections(ui_workflow: dict[str, Any]) -> dict[str, dict[str, Any]]:
nodes = ui_workflow["nodes"]
links = ui_workflow["links"]
nodes_by_id = {str(node["id"]): node for node in nodes}
links_by_id = {
link[0]: {
"from": str(link[1]),
"slot": link[2],
"to": str(link[3]),
"to_slot": link[4],
"type": link[5],
}
for link in links
}
set_by_name: dict[str, dict[str, Any]] = {}
for node in nodes:
if node.get("type") != "SetNode":
continue
values = node.get("widgets_values") or []
inputs = node.get("inputs") or []
if not values or not inputs:
continue
link_id = inputs[0].get("link")
link = links_by_id.get(link_id)
if link:
set_by_name[values[0]] = {
"from": link["from"],
"slot": link["slot"],
"via_set": str(node["id"]),
}
def resolve_source(from_node: str, slot: int) -> dict[str, Any]:
node = nodes_by_id.get(from_node)
if node and node.get("type") == "GetNode":
values = node.get("widgets_values") or []
name = values[0] if values else None
if name in set_by_name:
resolved = dict(set_by_name[name])
resolved["via_get"] = from_node
resolved["name"] = name
return resolved
return {"from": from_node, "slot": slot}
expected: dict[str, dict[str, Any]] = {}
for node in nodes:
for input_def in node.get("inputs") or []:
link_id = input_def.get("link")
if link_id is None:
continue
link = links_by_id.get(link_id)
if not link:
continue
key = f"{node['id']}.{input_def['name']}"
expected[key] = {
**resolve_source(link["from"], link["slot"]),
"type": link["type"],
}
return expected
def validate(api_workflow: dict[str, Any], ui_workflow: dict[str, Any]) -> list[dict[str, Any]]:
expected = build_expected_connections(ui_workflow)
mismatches: list[dict[str, Any]] = []
for node_id, node in api_workflow.items():
for input_name, value in (node.get("inputs") or {}).items():
if not is_connection(value):
continue
actual = {"from": value[0], "slot": value[1]}
expected_connection = expected.get(f"{node_id}.{input_name}")
if (
not expected_connection
or expected_connection["from"] != actual["from"]
or expected_connection["slot"] != actual["slot"]
):
mismatches.append(
{
"node": node_id,
"class_type": node.get("class_type"),
"input": input_name,
"actual": actual,
"expected": expected_connection,
}
)
return mismatches
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--api", type=Path, default=DEFAULT_API)
parser.add_argument("--ui", type=Path, default=DEFAULT_UI)
return parser.parse_args()
def main() -> None:
args = parse_args()
api_workflow = load_json(args.api)
ui_workflow = load_json(args.ui)
mismatches = validate(api_workflow, ui_workflow)
checked = sum(
1
for node in api_workflow.values()
for value in (node.get("inputs") or {}).values()
if is_connection(value)
)
print(f"checked_connections={checked}")
print(f"mismatches={len(mismatches)}")
if mismatches:
print(json.dumps(mismatches, ensure_ascii=False, indent=2))
raise SystemExit(1)
if __name__ == "__main__":
main()