| import json |
| import sys |
| import zlib |
| from collections import defaultdict, deque |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Optional |
|
|
|
|
| WD14_TAGGER_CNR_ID = "comfyui-wd14-tagger" |
| MULTISTRING_TYPE = "String Literal" |
| WD14_TAGS_TITLE = "WD14 Tags" |
| WD14_SOURCE_PROP = "wd14_source_id" |
| DEFAULT_STRING_NODE_SIZE = (600.0, 120.0) |
| PNG_SIG = b"\x89PNG\r\n\x1a\n" |
|
|
|
|
| def _decode_text(b: bytes) -> str: |
| try: |
| return b.decode("utf-8") |
| except UnicodeDecodeError: |
| return b.decode("latin-1", "replace") |
|
|
|
|
| def _read_exact(f, n: int) -> bytes: |
| b = f.read(n) |
| if len(b) != n: |
| raise EOFError("Unexpected EOF while reading PNG") |
| return b |
|
|
|
|
| def read_png_text_chunks(path: Path) -> dict[str, str]: |
| out: dict[str, str] = {} |
| with path.open("rb") as f: |
| sig = _read_exact(f, 8) |
| if sig != PNG_SIG: |
| raise ValueError("Not a PNG file") |
|
|
| while True: |
| length_b = f.read(4) |
| if not length_b: |
| break |
| if len(length_b) != 4: |
| raise EOFError("Unexpected EOF while reading PNG chunk length") |
|
|
| length = int.from_bytes(length_b, "big", signed=False) |
| ctype = _read_exact(f, 4) |
| data = _read_exact(f, length) |
| _ = _read_exact(f, 4) |
|
|
| if ctype == b"tEXt": |
| if b"\x00" not in data: |
| continue |
| k, v = data.split(b"\x00", 1) |
| out[_decode_text(k)] = _decode_text(v) |
|
|
| elif ctype == b"zTXt": |
| if b"\x00" not in data: |
| continue |
| k, rest = data.split(b"\x00", 1) |
| if not rest: |
| continue |
| comp_method = rest[0] |
| comp_data = rest[1:] |
| if comp_method != 0: |
| continue |
| try: |
| v = zlib.decompress(comp_data) |
| except zlib.error: |
| continue |
| out[_decode_text(k)] = _decode_text(v) |
|
|
| elif ctype == b"iTXt": |
| if b"\x00" not in data: |
| continue |
| k, rest = data.split(b"\x00", 1) |
| if len(rest) < 2: |
| continue |
| comp_flag = rest[0] |
| comp_method = rest[1] |
| rest = rest[2:] |
|
|
| if b"\x00" not in rest: |
| continue |
| _, rest = rest.split(b"\x00", 1) |
|
|
| if b"\x00" not in rest: |
| continue |
| _, rest = rest.split(b"\x00", 1) |
|
|
| text_bytes = rest |
| if comp_flag == 1: |
| if comp_method != 0: |
| continue |
| try: |
| text_bytes = zlib.decompress(text_bytes) |
| except zlib.error: |
| continue |
|
|
| out[_decode_text(k)] = _decode_text(text_bytes) |
|
|
| if ctype == b"IEND": |
| break |
|
|
| return out |
|
|
|
|
| def _looks_like_workflow(obj: Any) -> bool: |
| if not isinstance(obj, dict): |
| return False |
| return isinstance(obj.get("nodes"), list) and isinstance(obj.get("links"), list) |
|
|
|
|
| def extract_workflow_from_metadata(meta: dict[str, str]) -> dict[str, Any]: |
| if "workflow" in meta: |
| try: |
| wf = json.loads(meta["workflow"]) |
| if _looks_like_workflow(wf): |
| return wf |
| except json.JSONDecodeError: |
| pass |
|
|
| if "prompt" in meta: |
| try: |
| p = json.loads(meta["prompt"]) |
| except json.JSONDecodeError: |
| p = None |
|
|
| if isinstance(p, dict): |
| if _looks_like_workflow(p.get("workflow")): |
| return p["workflow"] |
| extra = p.get("extra_pnginfo") |
| if isinstance(extra, dict) and _looks_like_workflow(extra.get("workflow")): |
| return extra["workflow"] |
| if _looks_like_workflow(p): |
| return p |
|
|
| raise ValueError("Could not find a ComfyUI workflow in PNG metadata (looked for 'workflow' and 'prompt').") |
|
|
|
|
| @dataclass(frozen=True) |
| class Size: |
| w: float |
| h: float |
|
|
|
|
| def _get_size(node: dict[str, Any]) -> Size: |
| size = node.get("size") |
| if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size): |
| w, h = float(size[0]), float(size[1]) |
| else: |
| w, h = 320.0, 120.0 |
|
|
| |
| w = max(w, 260.0) |
| h = max(h, 120.0) |
| return Size(w=w, h=h) |
|
|
|
|
| def _build_graph(links: list[Any]) -> tuple[dict[int, list[int]], dict[int, int]]: |
| adj: dict[int, list[int]] = defaultdict(list) |
| indeg: dict[int, int] = defaultdict(int) |
| for link in links: |
| if not isinstance(link, list) or len(link) < 4: |
| continue |
| try: |
| src = int(link[1]) |
| dst = int(link[3]) |
| except (ValueError, TypeError): |
| continue |
| adj[src].append(dst) |
| indeg[dst] += 1 |
| indeg.setdefault(src, indeg.get(src, 0)) |
| return adj, indeg |
|
|
|
|
| def _kahn_order(node_ids: list[int], adj: dict[int, list[int]], indeg: dict[int, int]) -> list[int]: |
| q = deque([n for n in node_ids if indeg.get(n, 0) == 0]) |
| seen: set[int] = set(q) |
| out: list[int] = [] |
|
|
| indeg_local = dict(indeg) |
| while q: |
| n = q.popleft() |
| out.append(n) |
| for m in adj.get(n, []): |
| indeg_local[m] = indeg_local.get(m, 0) - 1 |
| if indeg_local[m] == 0 and m not in seen: |
| seen.add(m) |
| q.append(m) |
|
|
| remaining = [n for n in node_ids if n not in set(out)] |
| remaining.sort() |
| out.extend(remaining) |
| return out |
|
|
|
|
| def _force_uncollapse(nodes: list[dict[str, Any]]) -> None: |
| for n in nodes: |
| flags = n.get("flags") |
| if isinstance(flags, dict): |
| flags["collapsed"] = False |
| n["flags"] = flags |
| else: |
| n["flags"] = {"collapsed": False} |
|
|
|
|
| def _is_wd14_tagger_node(node: dict[str, Any]) -> bool: |
| props = node.get("properties") |
| if not isinstance(props, dict): |
| return False |
| return props.get("cnr_id") == WD14_TAGGER_CNR_ID |
|
|
|
|
| def _extract_wd14_tags(node: dict[str, Any]) -> Optional[str]: |
| widgets = node.get("widgets_values") |
| if not isinstance(widgets, list) or not widgets: |
| return None |
| last = widgets[-1] |
| if isinstance(last, str): |
| return last |
| return None |
|
|
|
|
| def _pick_multistring_template(nodes: list[dict[str, Any]]) -> Optional[dict[str, Any]]: |
| for n in nodes: |
| if not isinstance(n, dict): |
| continue |
| if n.get("type") == MULTISTRING_TYPE: |
| return n |
| return None |
|
|
|
|
| def _max_node_id(nodes: list[dict[str, Any]]) -> int: |
| max_id = 0 |
| for n in nodes: |
| if not isinstance(n, dict): |
| continue |
| nid = n.get("id") |
| if isinstance(nid, int) and nid > max_id: |
| max_id = nid |
| return max_id |
|
|
|
|
| def _string_output_from_template(template: Optional[dict[str, Any]]) -> list[dict[str, Any]]: |
| out: dict[str, Any] = {"label": "STRING", "name": "STRING", "type": "STRING", "links": []} |
| if template is not None: |
| outputs = template.get("outputs") |
| if isinstance(outputs, list) and outputs: |
| first = outputs[0] |
| if isinstance(first, dict): |
| for key in ("label", "name", "type", "shape", "slot_index"): |
| if key in first: |
| out[key] = first[key] |
| return [out] |
|
|
|
|
| def _size_from_template(template: Optional[dict[str, Any]]) -> list[float]: |
| if template is not None: |
| size = template.get("size") |
| if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size): |
| return [float(size[0]), float(size[1])] |
| return [DEFAULT_STRING_NODE_SIZE[0], DEFAULT_STRING_NODE_SIZE[1]] |
|
|
|
|
| def _widget_values_from_template(template: Optional[dict[str, Any]], text: str) -> list[Any]: |
| speak = True |
| if template is not None: |
| widgets = template.get("widgets_values") |
| if isinstance(widgets, list) and len(widgets) > 1 and isinstance(widgets[1], bool): |
| speak = widgets[1] |
| return [text, speak] |
|
|
|
|
| def _flags_from_template(template: Optional[dict[str, Any]]) -> dict[str, Any]: |
| if template is not None: |
| flags = template.get("flags") |
| if isinstance(flags, dict): |
| out_flags = dict(flags) |
| out_flags["collapsed"] = False |
| return out_flags |
| return {"collapsed": False} |
|
|
|
|
| def _properties_from_template(template: Optional[dict[str, Any]], source_id: int) -> dict[str, Any]: |
| props: dict[str, Any] = {} |
| if template is not None: |
| template_props = template.get("properties") |
| if isinstance(template_props, dict): |
| props.update(template_props) |
| props[WD14_SOURCE_PROP] = source_id |
| return props |
|
|
|
|
| def _make_multistring_node( |
| node_id: int, text: str, template: Optional[dict[str, Any]], source_id: int |
| ) -> dict[str, Any]: |
| order = 0 |
| mode = 0 |
| if template is not None: |
| if isinstance(template.get("order"), int): |
| order = template["order"] |
| if isinstance(template.get("mode"), int): |
| mode = template["mode"] |
|
|
| node: dict[str, Any] = { |
| "id": node_id, |
| "type": MULTISTRING_TYPE, |
| "pos": [0.0, 0.0], |
| "size": _size_from_template(template), |
| "flags": _flags_from_template(template), |
| "order": order, |
| "mode": mode, |
| "inputs": [], |
| "outputs": _string_output_from_template(template), |
| "title": WD14_TAGS_TITLE, |
| "properties": _properties_from_template(template, source_id), |
| "widgets_values": _widget_values_from_template(template, text), |
| } |
| if template is not None: |
| if "color" in template: |
| node["color"] = template["color"] |
| if "bgcolor" in template: |
| node["bgcolor"] = template["bgcolor"] |
| return node |
|
|
|
|
| def _existing_wd14_sources(nodes: list[dict[str, Any]]) -> set[int]: |
| sources: set[int] = set() |
| for n in nodes: |
| if not isinstance(n, dict): |
| continue |
| props = n.get("properties") |
| if not isinstance(props, dict): |
| continue |
| source_id = props.get(WD14_SOURCE_PROP) |
| if isinstance(source_id, int): |
| sources.add(source_id) |
| return sources |
|
|
|
|
| def _add_wd14_multistring_nodes(wf: dict[str, Any]) -> dict[int, int]: |
| nodes = wf.get("nodes") |
| if not isinstance(nodes, list): |
| return {} |
| existing_sources = _existing_wd14_sources(nodes) |
| template = _pick_multistring_template(nodes) |
| next_id = _max_node_id(nodes) + 1 |
| added: dict[int, int] = {} |
| for n in list(nodes): |
| if not isinstance(n, dict): |
| continue |
| if not _is_wd14_tagger_node(n): |
| continue |
| nid = n.get("id") |
| if not isinstance(nid, int): |
| continue |
| if nid in existing_sources: |
| continue |
| text = _extract_wd14_tags(n) |
| if text is None: |
| continue |
| nodes.append(_make_multistring_node(next_id, text, template, nid)) |
| added[next_id] = nid |
| next_id += 1 |
| if added: |
| last_node_id = wf.get("last_node_id") |
| new_last = next_id - 1 |
| if isinstance(last_node_id, int): |
| wf["last_node_id"] = max(last_node_id, new_last) |
| else: |
| wf["last_node_id"] = new_last |
| return added |
|
|
|
|
| def _order_with_multistrings(order: list[int], mapping: dict[int, int]) -> list[int]: |
| if not mapping: |
| return order |
| new_ids = set(mapping) |
| base = [n for n in order if n not in new_ids] |
| children: dict[int, list[int]] = defaultdict(list) |
| for new_id, source_id in mapping.items(): |
| children[source_id].append(new_id) |
| for lst in children.values(): |
| lst.sort() |
| out: list[int] = [] |
| appended: set[int] = set() |
| for n in base: |
| out.append(n) |
| appended.add(n) |
| kids = children.get(n) |
| if kids: |
| for kid in kids: |
| out.append(kid) |
| appended.add(kid) |
| if len(appended) < len(base) + len(new_ids): |
| for new_id in sorted(new_ids): |
| if new_id not in appended: |
| out.append(new_id) |
| return out |
|
|
|
|
| def relayout_workflow(wf: dict[str, Any]) -> dict[str, Any]: |
| wf2 = json.loads(json.dumps(wf)) |
| nodes = wf2.get("nodes") |
| if not isinstance(nodes, list): |
| raise ValueError("Workflow missing 'nodes' list") |
|
|
| tagger_map = _add_wd14_multistring_nodes(wf2) |
| _force_uncollapse(nodes) |
|
|
| node_by_id: dict[int, dict[str, Any]] = {} |
| node_ids: list[int] = [] |
| for n in nodes: |
| if not isinstance(n, dict): |
| continue |
| nid = n.get("id") |
| if isinstance(nid, int): |
| node_by_id[nid] = n |
| node_ids.append(nid) |
|
|
| links = wf2.get("links", []) |
| if not isinstance(links, list): |
| links = [] |
|
|
| adj, indeg = _build_graph(links) |
| order = _kahn_order(sorted(node_ids), adj, indeg) |
| if tagger_map: |
| order = _order_with_multistrings(order, tagger_map) |
|
|
| layer: dict[int, int] = {nid: 0 for nid in node_ids} |
| for n in order: |
| base = layer.get(n, 0) |
| for m in adj.get(n, []): |
| layer[m] = max(layer.get(m, 0), base + 1) |
| if tagger_map: |
| for new_id, source_id in tagger_map.items(): |
| layer[new_id] = layer.get(source_id, layer.get(new_id, 0)) |
|
|
| layers: dict[int, list[int]] = defaultdict(list) |
| for nid in order: |
| layers[layer.get(nid, 0)].append(nid) |
|
|
| |
| layer_ids = sorted(layers.keys()) |
| max_w: dict[int, float] = {} |
| for layer_id in layer_ids: |
| max_w[layer_id] = max((_get_size(node_by_id[nid]).w for nid in layers[layer_id]), default=320.0) |
|
|
| x_margin = 140.0 |
| y_margin = 60.0 |
|
|
| x_pos: dict[int, float] = {} |
| cursor = 0.0 |
| for layer_id in layer_ids: |
| x_pos[layer_id] = cursor |
| cursor += max_w[layer_id] + x_margin |
|
|
| |
| for layer_id in layer_ids: |
| y_cursor = 0.0 |
| for nid in layers[layer_id]: |
| s = _get_size(node_by_id[nid]) |
| node_by_id[nid]["pos"] = [x_pos[layer_id], y_cursor] |
| y_cursor += s.h + y_margin |
|
|
| |
| extra = wf2.get("extra") |
| if not isinstance(extra, dict): |
| extra = {} |
| wf2["extra"] = extra |
| ds = extra.get("ds") |
| if not isinstance(ds, dict): |
| ds = {} |
| extra["ds"] = ds |
| ds["scale"] = 1.0 |
| ds["offset"] = [0, 0] |
|
|
| if "groups" not in wf2: |
| wf2["groups"] = [] |
| if "config" not in wf2: |
| wf2["config"] = {} |
|
|
| return wf2 |
|
|
|
|
| def process_png(png_path: Path) -> Path: |
| meta = read_png_text_chunks(png_path) |
| wf = extract_workflow_from_metadata(meta) |
| cleaned = relayout_workflow(wf) |
|
|
| out_path = png_path.with_suffix("") |
| out_path = out_path.with_name(out_path.name + "_clean_workflow.json") |
| out_path.write_text(json.dumps(cleaned, ensure_ascii=False, indent=2), encoding="utf-8") |
| return out_path |
|
|
|
|
| def main(argv: list[str]) -> int: |
| if len(argv) < 2: |
| print("Drag-and-drop one or more .png files onto this script, or run:") |
| print(" python png_to_clean_workflow.py <image.png> [more.png ...]") |
| return 2 |
|
|
| ok = 0 |
| for a in argv[1:]: |
| p = Path(a.strip('"')).expanduser() |
| try: |
| out = process_png(p) |
| print(f"[OK] {p.name} -> {out.name}") |
| ok += 1 |
| except Exception as e: |
| print(f"[ERR] {p}: {e}") |
|
|
| return 0 if ok else 1 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main(sys.argv)) |
|
|