p / Client /Scripts /comfy.py
q6's picture
comfy
7697d8f
import json
import sys
import zlib
from collections import defaultdict, deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
WD14_TAGGER_CNR_ID = "comfyui-wd14-tagger"
MULTISTRING_TYPE = "String Literal"
WD14_TAGS_TITLE = "WD14 Tags"
WD14_SOURCE_PROP = "wd14_source_id"
DEFAULT_STRING_NODE_SIZE = (600.0, 120.0)
PNG_SIG = b"\x89PNG\r\n\x1a\n"
def _decode_text(b: bytes) -> str:
try:
return b.decode("utf-8")
except UnicodeDecodeError:
return b.decode("latin-1", "replace")
def _read_exact(f, n: int) -> bytes:
b = f.read(n)
if len(b) != n:
raise EOFError("Unexpected EOF while reading PNG")
return b
def read_png_text_chunks(path: Path) -> dict[str, str]:
out: dict[str, str] = {}
with path.open("rb") as f:
sig = _read_exact(f, 8)
if sig != PNG_SIG:
raise ValueError("Not a PNG file")
while True:
length_b = f.read(4)
if not length_b:
break
if len(length_b) != 4:
raise EOFError("Unexpected EOF while reading PNG chunk length")
length = int.from_bytes(length_b, "big", signed=False)
ctype = _read_exact(f, 4)
data = _read_exact(f, length)
_ = _read_exact(f, 4)
if ctype == b"tEXt":
if b"\x00" not in data:
continue
k, v = data.split(b"\x00", 1)
out[_decode_text(k)] = _decode_text(v)
elif ctype == b"zTXt":
if b"\x00" not in data:
continue
k, rest = data.split(b"\x00", 1)
if not rest:
continue
comp_method = rest[0]
comp_data = rest[1:]
if comp_method != 0:
continue
try:
v = zlib.decompress(comp_data)
except zlib.error:
continue
out[_decode_text(k)] = _decode_text(v)
elif ctype == b"iTXt":
if b"\x00" not in data:
continue
k, rest = data.split(b"\x00", 1)
if len(rest) < 2:
continue
comp_flag = rest[0]
comp_method = rest[1]
rest = rest[2:]
if b"\x00" not in rest:
continue
_, rest = rest.split(b"\x00", 1)
if b"\x00" not in rest:
continue
_, rest = rest.split(b"\x00", 1)
text_bytes = rest
if comp_flag == 1:
if comp_method != 0:
continue
try:
text_bytes = zlib.decompress(text_bytes)
except zlib.error:
continue
out[_decode_text(k)] = _decode_text(text_bytes)
if ctype == b"IEND":
break
return out
def _looks_like_workflow(obj: Any) -> bool:
if not isinstance(obj, dict):
return False
return isinstance(obj.get("nodes"), list) and isinstance(obj.get("links"), list)
def extract_workflow_from_metadata(meta: dict[str, str]) -> dict[str, Any]:
if "workflow" in meta:
try:
wf = json.loads(meta["workflow"])
if _looks_like_workflow(wf):
return wf
except json.JSONDecodeError:
pass
if "prompt" in meta:
try:
p = json.loads(meta["prompt"])
except json.JSONDecodeError:
p = None
if isinstance(p, dict):
if _looks_like_workflow(p.get("workflow")):
return p["workflow"]
extra = p.get("extra_pnginfo")
if isinstance(extra, dict) and _looks_like_workflow(extra.get("workflow")):
return extra["workflow"]
if _looks_like_workflow(p):
return p
raise ValueError("Could not find a ComfyUI workflow in PNG metadata (looked for 'workflow' and 'prompt').")
@dataclass(frozen=True)
class Size:
w: float
h: float
def _get_size(node: dict[str, Any]) -> Size:
size = node.get("size")
if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size):
w, h = float(size[0]), float(size[1])
else:
w, h = 320.0, 120.0
# ComfyUI nodes can be taller than their "size" when uncollapsed; give margin.
w = max(w, 260.0)
h = max(h, 120.0)
return Size(w=w, h=h)
def _build_graph(links: list[Any]) -> tuple[dict[int, list[int]], dict[int, int]]:
adj: dict[int, list[int]] = defaultdict(list)
indeg: dict[int, int] = defaultdict(int)
for link in links:
if not isinstance(link, list) or len(link) < 4:
continue
try:
src = int(link[1])
dst = int(link[3])
except (ValueError, TypeError):
continue
adj[src].append(dst)
indeg[dst] += 1
indeg.setdefault(src, indeg.get(src, 0))
return adj, indeg
def _kahn_order(node_ids: list[int], adj: dict[int, list[int]], indeg: dict[int, int]) -> list[int]:
q = deque([n for n in node_ids if indeg.get(n, 0) == 0])
seen: set[int] = set(q)
out: list[int] = []
indeg_local = dict(indeg)
while q:
n = q.popleft()
out.append(n)
for m in adj.get(n, []):
indeg_local[m] = indeg_local.get(m, 0) - 1
if indeg_local[m] == 0 and m not in seen:
seen.add(m)
q.append(m)
remaining = [n for n in node_ids if n not in set(out)]
remaining.sort()
out.extend(remaining)
return out
def _force_uncollapse(nodes: list[dict[str, Any]]) -> None:
for n in nodes:
flags = n.get("flags")
if isinstance(flags, dict):
flags["collapsed"] = False
n["flags"] = flags
else:
n["flags"] = {"collapsed": False}
def _is_wd14_tagger_node(node: dict[str, Any]) -> bool:
props = node.get("properties")
if not isinstance(props, dict):
return False
return props.get("cnr_id") == WD14_TAGGER_CNR_ID
def _extract_wd14_tags(node: dict[str, Any]) -> Optional[str]:
widgets = node.get("widgets_values")
if not isinstance(widgets, list) or not widgets:
return None
last = widgets[-1]
if isinstance(last, str):
return last
return None
def _pick_multistring_template(nodes: list[dict[str, Any]]) -> Optional[dict[str, Any]]:
for n in nodes:
if not isinstance(n, dict):
continue
if n.get("type") == MULTISTRING_TYPE:
return n
return None
def _max_node_id(nodes: list[dict[str, Any]]) -> int:
max_id = 0
for n in nodes:
if not isinstance(n, dict):
continue
nid = n.get("id")
if isinstance(nid, int) and nid > max_id:
max_id = nid
return max_id
def _string_output_from_template(template: Optional[dict[str, Any]]) -> list[dict[str, Any]]:
out: dict[str, Any] = {"label": "STRING", "name": "STRING", "type": "STRING", "links": []}
if template is not None:
outputs = template.get("outputs")
if isinstance(outputs, list) and outputs:
first = outputs[0]
if isinstance(first, dict):
for key in ("label", "name", "type", "shape", "slot_index"):
if key in first:
out[key] = first[key]
return [out]
def _size_from_template(template: Optional[dict[str, Any]]) -> list[float]:
if template is not None:
size = template.get("size")
if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size):
return [float(size[0]), float(size[1])]
return [DEFAULT_STRING_NODE_SIZE[0], DEFAULT_STRING_NODE_SIZE[1]]
def _widget_values_from_template(template: Optional[dict[str, Any]], text: str) -> list[Any]:
speak = True
if template is not None:
widgets = template.get("widgets_values")
if isinstance(widgets, list) and len(widgets) > 1 and isinstance(widgets[1], bool):
speak = widgets[1]
return [text, speak]
def _flags_from_template(template: Optional[dict[str, Any]]) -> dict[str, Any]:
if template is not None:
flags = template.get("flags")
if isinstance(flags, dict):
out_flags = dict(flags)
out_flags["collapsed"] = False
return out_flags
return {"collapsed": False}
def _properties_from_template(template: Optional[dict[str, Any]], source_id: int) -> dict[str, Any]:
props: dict[str, Any] = {}
if template is not None:
template_props = template.get("properties")
if isinstance(template_props, dict):
props.update(template_props)
props[WD14_SOURCE_PROP] = source_id
return props
def _make_multistring_node(
node_id: int, text: str, template: Optional[dict[str, Any]], source_id: int
) -> dict[str, Any]:
order = 0
mode = 0
if template is not None:
if isinstance(template.get("order"), int):
order = template["order"]
if isinstance(template.get("mode"), int):
mode = template["mode"]
node: dict[str, Any] = {
"id": node_id,
"type": MULTISTRING_TYPE,
"pos": [0.0, 0.0],
"size": _size_from_template(template),
"flags": _flags_from_template(template),
"order": order,
"mode": mode,
"inputs": [],
"outputs": _string_output_from_template(template),
"title": WD14_TAGS_TITLE,
"properties": _properties_from_template(template, source_id),
"widgets_values": _widget_values_from_template(template, text),
}
if template is not None:
if "color" in template:
node["color"] = template["color"]
if "bgcolor" in template:
node["bgcolor"] = template["bgcolor"]
return node
def _existing_wd14_sources(nodes: list[dict[str, Any]]) -> set[int]:
sources: set[int] = set()
for n in nodes:
if not isinstance(n, dict):
continue
props = n.get("properties")
if not isinstance(props, dict):
continue
source_id = props.get(WD14_SOURCE_PROP)
if isinstance(source_id, int):
sources.add(source_id)
return sources
def _add_wd14_multistring_nodes(wf: dict[str, Any]) -> dict[int, int]:
nodes = wf.get("nodes")
if not isinstance(nodes, list):
return {}
existing_sources = _existing_wd14_sources(nodes)
template = _pick_multistring_template(nodes)
next_id = _max_node_id(nodes) + 1
added: dict[int, int] = {}
for n in list(nodes):
if not isinstance(n, dict):
continue
if not _is_wd14_tagger_node(n):
continue
nid = n.get("id")
if not isinstance(nid, int):
continue
if nid in existing_sources:
continue
text = _extract_wd14_tags(n)
if text is None:
continue
nodes.append(_make_multistring_node(next_id, text, template, nid))
added[next_id] = nid
next_id += 1
if added:
last_node_id = wf.get("last_node_id")
new_last = next_id - 1
if isinstance(last_node_id, int):
wf["last_node_id"] = max(last_node_id, new_last)
else:
wf["last_node_id"] = new_last
return added
def _order_with_multistrings(order: list[int], mapping: dict[int, int]) -> list[int]:
if not mapping:
return order
new_ids = set(mapping)
base = [n for n in order if n not in new_ids]
children: dict[int, list[int]] = defaultdict(list)
for new_id, source_id in mapping.items():
children[source_id].append(new_id)
for lst in children.values():
lst.sort()
out: list[int] = []
appended: set[int] = set()
for n in base:
out.append(n)
appended.add(n)
kids = children.get(n)
if kids:
for kid in kids:
out.append(kid)
appended.add(kid)
if len(appended) < len(base) + len(new_ids):
for new_id in sorted(new_ids):
if new_id not in appended:
out.append(new_id)
return out
def relayout_workflow(wf: dict[str, Any]) -> dict[str, Any]:
wf2 = json.loads(json.dumps(wf))
nodes = wf2.get("nodes")
if not isinstance(nodes, list):
raise ValueError("Workflow missing 'nodes' list")
tagger_map = _add_wd14_multistring_nodes(wf2)
_force_uncollapse(nodes)
node_by_id: dict[int, dict[str, Any]] = {}
node_ids: list[int] = []
for n in nodes:
if not isinstance(n, dict):
continue
nid = n.get("id")
if isinstance(nid, int):
node_by_id[nid] = n
node_ids.append(nid)
links = wf2.get("links", [])
if not isinstance(links, list):
links = []
adj, indeg = _build_graph(links)
order = _kahn_order(sorted(node_ids), adj, indeg)
if tagger_map:
order = _order_with_multistrings(order, tagger_map)
layer: dict[int, int] = {nid: 0 for nid in node_ids}
for n in order:
base = layer.get(n, 0)
for m in adj.get(n, []):
layer[m] = max(layer.get(m, 0), base + 1)
if tagger_map:
for new_id, source_id in tagger_map.items():
layer[new_id] = layer.get(source_id, layer.get(new_id, 0))
layers: dict[int, list[int]] = defaultdict(list)
for nid in order:
layers[layer.get(nid, 0)].append(nid)
# Compute per-layer max widths so x-spacing never overlaps.
layer_ids = sorted(layers.keys())
max_w: dict[int, float] = {}
for layer_id in layer_ids:
max_w[layer_id] = max((_get_size(node_by_id[nid]).w for nid in layers[layer_id]), default=320.0)
x_margin = 140.0
y_margin = 60.0
x_pos: dict[int, float] = {}
cursor = 0.0
for layer_id in layer_ids:
x_pos[layer_id] = cursor
cursor += max_w[layer_id] + x_margin
# Place nodes in each layer using cumulative heights (so tall nodes don't overlap).
for layer_id in layer_ids:
y_cursor = 0.0
for nid in layers[layer_id]:
s = _get_size(node_by_id[nid])
node_by_id[nid]["pos"] = [x_pos[layer_id], y_cursor]
y_cursor += s.h + y_margin
# Reset viewport so it opens nicely.
extra = wf2.get("extra")
if not isinstance(extra, dict):
extra = {}
wf2["extra"] = extra
ds = extra.get("ds")
if not isinstance(ds, dict):
ds = {}
extra["ds"] = ds
ds["scale"] = 1.0
ds["offset"] = [0, 0]
if "groups" not in wf2:
wf2["groups"] = []
if "config" not in wf2:
wf2["config"] = {}
return wf2
def process_png(png_path: Path) -> Path:
meta = read_png_text_chunks(png_path)
wf = extract_workflow_from_metadata(meta)
cleaned = relayout_workflow(wf)
out_path = png_path.with_suffix("")
out_path = out_path.with_name(out_path.name + "_clean_workflow.json")
out_path.write_text(json.dumps(cleaned, ensure_ascii=False, indent=2), encoding="utf-8")
return out_path
def main(argv: list[str]) -> int:
if len(argv) < 2:
print("Drag-and-drop one or more .png files onto this script, or run:")
print(" python png_to_clean_workflow.py <image.png> [more.png ...]")
return 2
ok = 0
for a in argv[1:]:
p = Path(a.strip('"')).expanduser()
try:
out = process_png(p)
print(f"[OK] {p.name} -> {out.name}")
ok += 1
except Exception as e:
print(f"[ERR] {p}: {e}")
return 0 if ok else 1
if __name__ == "__main__":
raise SystemExit(main(sys.argv))