File size: 4,867 Bytes
359a8b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Validate the API workflow connections against the ComfyUI UI workflow.

ComfyUI API exports collapse some UI-only SetNode/GetNode routing pairs. This
validator resolves those pairs before comparing links so the check reflects the
actual graph semantics instead of the visual routing helpers.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[1]
DEFAULT_API = ROOT / "workflows" / "voicegate_api.json"
DEFAULT_UI = ROOT / "VoiceGate" / "workflows" / "VoiceGate-Workflow.json"


def load_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as file:
        return json.load(file)


def is_connection(value: Any) -> bool:
    return (
        isinstance(value, list)
        and len(value) == 2
        and isinstance(value[0], str)
        and isinstance(value[1], int)
    )


def build_expected_connections(ui_workflow: dict[str, Any]) -> dict[str, dict[str, Any]]:
    nodes = ui_workflow["nodes"]
    links = ui_workflow["links"]
    nodes_by_id = {str(node["id"]): node for node in nodes}
    links_by_id = {
        link[0]: {
            "from": str(link[1]),
            "slot": link[2],
            "to": str(link[3]),
            "to_slot": link[4],
            "type": link[5],
        }
        for link in links
    }

    set_by_name: dict[str, dict[str, Any]] = {}
    for node in nodes:
        if node.get("type") != "SetNode":
            continue
        values = node.get("widgets_values") or []
        inputs = node.get("inputs") or []
        if not values or not inputs:
            continue
        link_id = inputs[0].get("link")
        link = links_by_id.get(link_id)
        if link:
            set_by_name[values[0]] = {
                "from": link["from"],
                "slot": link["slot"],
                "via_set": str(node["id"]),
            }

    def resolve_source(from_node: str, slot: int) -> dict[str, Any]:
        node = nodes_by_id.get(from_node)
        if node and node.get("type") == "GetNode":
            values = node.get("widgets_values") or []
            name = values[0] if values else None
            if name in set_by_name:
                resolved = dict(set_by_name[name])
                resolved["via_get"] = from_node
                resolved["name"] = name
                return resolved
        return {"from": from_node, "slot": slot}

    expected: dict[str, dict[str, Any]] = {}
    for node in nodes:
        for input_def in node.get("inputs") or []:
            link_id = input_def.get("link")
            if link_id is None:
                continue
            link = links_by_id.get(link_id)
            if not link:
                continue
            key = f"{node['id']}.{input_def['name']}"
            expected[key] = {
                **resolve_source(link["from"], link["slot"]),
                "type": link["type"],
            }
    return expected


def validate(api_workflow: dict[str, Any], ui_workflow: dict[str, Any]) -> list[dict[str, Any]]:
    expected = build_expected_connections(ui_workflow)
    mismatches: list[dict[str, Any]] = []
    for node_id, node in api_workflow.items():
        for input_name, value in (node.get("inputs") or {}).items():
            if not is_connection(value):
                continue
            actual = {"from": value[0], "slot": value[1]}
            expected_connection = expected.get(f"{node_id}.{input_name}")
            if (
                not expected_connection
                or expected_connection["from"] != actual["from"]
                or expected_connection["slot"] != actual["slot"]
            ):
                mismatches.append(
                    {
                        "node": node_id,
                        "class_type": node.get("class_type"),
                        "input": input_name,
                        "actual": actual,
                        "expected": expected_connection,
                    }
                )
    return mismatches


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--api", type=Path, default=DEFAULT_API)
    parser.add_argument("--ui", type=Path, default=DEFAULT_UI)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    api_workflow = load_json(args.api)
    ui_workflow = load_json(args.ui)
    mismatches = validate(api_workflow, ui_workflow)
    checked = sum(
        1
        for node in api_workflow.values()
        for value in (node.get("inputs") or {}).values()
        if is_connection(value)
    )
    print(f"checked_connections={checked}")
    print(f"mismatches={len(mismatches)}")
    if mismatches:
        print(json.dumps(mismatches, ensure_ascii=False, indent=2))
        raise SystemExit(1)


if __name__ == "__main__":
    main()