File size: 3,512 Bytes
d65b6b1
3ea399a
d65b6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81f5ac0
 
fee06dd
81f5ac0
 
fee06dd
 
 
 
81f5ac0
 
 
fee06dd
81f5ac0
 
 
fee06dd
 
 
81f5ac0
 
fee06dd
 
 
 
 
 
 
 
 
 
 
 
 
 
81f5ac0
 
fee06dd
 
 
 
 
 
 
 
 
 
 
 
 
81f5ac0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Pure functions over LTX 2.3 mode workflow JSON templates."""

from __future__ import annotations

import copy
import json
import pathlib
from typing import Any

WORKFLOWS_DIR = pathlib.Path(__file__).parent / "workflows"

VALID_MODES: tuple[str, ...] = ("t2v", "a2v", "i2v", "lipsync", "keyframe", "style")


def load_template(mode: str) -> dict[str, Any]:
    """Load a fresh, independent copy of the named mode's workflow template."""
    if mode not in VALID_MODES:
        raise ValueError(f"unknown mode {mode!r}; expected one of {VALID_MODES}")
    path = WORKFLOWS_DIR / f"{mode}.json"
    return copy.deepcopy(json.loads(path.read_text()))


def set_input(workflow: dict[str, Any], node_id: int, widget_index: int | str, value: Any) -> None:
    """Patch a node's widgets_values in place.

    Supports both list-style widgets_values (most ComfyUI nodes — patch by integer index,
    auto-extending with None) and dict-style widgets_values (VHS_LoadAudioUpload and
    similar — patch by string key, raising KeyError if the key doesn't exist).

    Args:
        workflow: A workflow dict (must have a "nodes" list).
        node_id: The id of the node to patch.
        widget_index: Integer index (for list widgets) or string key (for dict widgets).
        value: New value.

    Raises:
        KeyError: If no node with the given id exists, or for dict widgets, if the key
            doesn't already exist on the target dict (we don't add new keys).
        TypeError: If widget_index type doesn't match the node's widgets_values type.
    """
    for node in workflow["nodes"]:
        if node.get("id") != node_id:
            continue
        widgets = node.get("widgets_values")
        if isinstance(widgets, dict):
            if not isinstance(widget_index, str):
                raise TypeError(
                    f"node {node_id} has dict widgets_values; widget_index must be str, "
                    f"got {type(widget_index).__name__}"
                )
            if widget_index not in widgets:
                raise KeyError(
                    f"node {node_id} dict widgets_values has no key {widget_index!r}; "
                    f"available keys: {list(widgets.keys())}"
                )
            widgets[widget_index] = value
            return
        # List/None case — preserve existing list-extension behavior.
        if not isinstance(widget_index, int):
            raise TypeError(
                f"node {node_id} has list widgets_values; widget_index must be int, "
                f"got {type(widget_index).__name__}"
            )
        if widgets is None:
            widgets = []
            node["widgets_values"] = widgets
        while len(widgets) <= widget_index:
            widgets.append(None)
        widgets[widget_index] = value
        return
    raise KeyError(f"node id {node_id} not found in workflow")


def validate(workflow: dict[str, Any]) -> None:
    """Static schema validation. Raises ValueError on the first problem found."""
    nodes = workflow.get("nodes")
    if not isinstance(nodes, list) or len(nodes) == 0:
        raise ValueError("workflow has no nodes")

    node_ids = {n.get("id") for n in nodes if "id" in n}
    for link in workflow.get("links", []):
        if not isinstance(link, list) or len(link) < 6:
            raise ValueError(f"malformed link {link}")
        _, src, _, dst, _, _ = link
        if src not in node_ids or dst not in node_ids:
            raise ValueError(f"orphan link {link}")