import json import sys import zlib from collections import defaultdict, deque from dataclasses import dataclass from pathlib import Path from typing import Any, Optional WD14_TAGGER_CNR_ID = "comfyui-wd14-tagger" MULTISTRING_TYPE = "String Literal" WD14_TAGS_TITLE = "WD14 Tags" WD14_SOURCE_PROP = "wd14_source_id" DEFAULT_STRING_NODE_SIZE = (600.0, 120.0) PNG_SIG = b"\x89PNG\r\n\x1a\n" def _decode_text(b: bytes) -> str: try: return b.decode("utf-8") except UnicodeDecodeError: return b.decode("latin-1", "replace") def _read_exact(f, n: int) -> bytes: b = f.read(n) if len(b) != n: raise EOFError("Unexpected EOF while reading PNG") return b def read_png_text_chunks(path: Path) -> dict[str, str]: out: dict[str, str] = {} with path.open("rb") as f: sig = _read_exact(f, 8) if sig != PNG_SIG: raise ValueError("Not a PNG file") while True: length_b = f.read(4) if not length_b: break if len(length_b) != 4: raise EOFError("Unexpected EOF while reading PNG chunk length") length = int.from_bytes(length_b, "big", signed=False) ctype = _read_exact(f, 4) data = _read_exact(f, length) _ = _read_exact(f, 4) if ctype == b"tEXt": if b"\x00" not in data: continue k, v = data.split(b"\x00", 1) out[_decode_text(k)] = _decode_text(v) elif ctype == b"zTXt": if b"\x00" not in data: continue k, rest = data.split(b"\x00", 1) if not rest: continue comp_method = rest[0] comp_data = rest[1:] if comp_method != 0: continue try: v = zlib.decompress(comp_data) except zlib.error: continue out[_decode_text(k)] = _decode_text(v) elif ctype == b"iTXt": if b"\x00" not in data: continue k, rest = data.split(b"\x00", 1) if len(rest) < 2: continue comp_flag = rest[0] comp_method = rest[1] rest = rest[2:] if b"\x00" not in rest: continue _, rest = rest.split(b"\x00", 1) if b"\x00" not in rest: continue _, rest = rest.split(b"\x00", 1) text_bytes = rest if comp_flag == 1: if comp_method != 0: continue try: text_bytes = zlib.decompress(text_bytes) except zlib.error: continue out[_decode_text(k)] = _decode_text(text_bytes) if ctype == b"IEND": break return out def _looks_like_workflow(obj: Any) -> bool: if not isinstance(obj, dict): return False return isinstance(obj.get("nodes"), list) and isinstance(obj.get("links"), list) def extract_workflow_from_metadata(meta: dict[str, str]) -> dict[str, Any]: if "workflow" in meta: try: wf = json.loads(meta["workflow"]) if _looks_like_workflow(wf): return wf except json.JSONDecodeError: pass if "prompt" in meta: try: p = json.loads(meta["prompt"]) except json.JSONDecodeError: p = None if isinstance(p, dict): if _looks_like_workflow(p.get("workflow")): return p["workflow"] extra = p.get("extra_pnginfo") if isinstance(extra, dict) and _looks_like_workflow(extra.get("workflow")): return extra["workflow"] if _looks_like_workflow(p): return p raise ValueError("Could not find a ComfyUI workflow in PNG metadata (looked for 'workflow' and 'prompt').") @dataclass(frozen=True) class Size: w: float h: float def _get_size(node: dict[str, Any]) -> Size: size = node.get("size") if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size): w, h = float(size[0]), float(size[1]) else: w, h = 320.0, 120.0 # ComfyUI nodes can be taller than their "size" when uncollapsed; give margin. w = max(w, 260.0) h = max(h, 120.0) return Size(w=w, h=h) def _build_graph(links: list[Any]) -> tuple[dict[int, list[int]], dict[int, int]]: adj: dict[int, list[int]] = defaultdict(list) indeg: dict[int, int] = defaultdict(int) for link in links: if not isinstance(link, list) or len(link) < 4: continue try: src = int(link[1]) dst = int(link[3]) except (ValueError, TypeError): continue adj[src].append(dst) indeg[dst] += 1 indeg.setdefault(src, indeg.get(src, 0)) return adj, indeg def _kahn_order(node_ids: list[int], adj: dict[int, list[int]], indeg: dict[int, int]) -> list[int]: q = deque([n for n in node_ids if indeg.get(n, 0) == 0]) seen: set[int] = set(q) out: list[int] = [] indeg_local = dict(indeg) while q: n = q.popleft() out.append(n) for m in adj.get(n, []): indeg_local[m] = indeg_local.get(m, 0) - 1 if indeg_local[m] == 0 and m not in seen: seen.add(m) q.append(m) remaining = [n for n in node_ids if n not in set(out)] remaining.sort() out.extend(remaining) return out def _force_uncollapse(nodes: list[dict[str, Any]]) -> None: for n in nodes: flags = n.get("flags") if isinstance(flags, dict): flags["collapsed"] = False n["flags"] = flags else: n["flags"] = {"collapsed": False} def _is_wd14_tagger_node(node: dict[str, Any]) -> bool: props = node.get("properties") if not isinstance(props, dict): return False return props.get("cnr_id") == WD14_TAGGER_CNR_ID def _extract_wd14_tags(node: dict[str, Any]) -> Optional[str]: widgets = node.get("widgets_values") if not isinstance(widgets, list) or not widgets: return None last = widgets[-1] if isinstance(last, str): return last return None def _pick_multistring_template(nodes: list[dict[str, Any]]) -> Optional[dict[str, Any]]: for n in nodes: if not isinstance(n, dict): continue if n.get("type") == MULTISTRING_TYPE: return n return None def _max_node_id(nodes: list[dict[str, Any]]) -> int: max_id = 0 for n in nodes: if not isinstance(n, dict): continue nid = n.get("id") if isinstance(nid, int) and nid > max_id: max_id = nid return max_id def _string_output_from_template(template: Optional[dict[str, Any]]) -> list[dict[str, Any]]: out: dict[str, Any] = {"label": "STRING", "name": "STRING", "type": "STRING", "links": []} if template is not None: outputs = template.get("outputs") if isinstance(outputs, list) and outputs: first = outputs[0] if isinstance(first, dict): for key in ("label", "name", "type", "shape", "slot_index"): if key in first: out[key] = first[key] return [out] def _size_from_template(template: Optional[dict[str, Any]]) -> list[float]: if template is not None: size = template.get("size") if isinstance(size, list) and len(size) == 2 and all(isinstance(x, (int, float)) for x in size): return [float(size[0]), float(size[1])] return [DEFAULT_STRING_NODE_SIZE[0], DEFAULT_STRING_NODE_SIZE[1]] def _widget_values_from_template(template: Optional[dict[str, Any]], text: str) -> list[Any]: speak = True if template is not None: widgets = template.get("widgets_values") if isinstance(widgets, list) and len(widgets) > 1 and isinstance(widgets[1], bool): speak = widgets[1] return [text, speak] def _flags_from_template(template: Optional[dict[str, Any]]) -> dict[str, Any]: if template is not None: flags = template.get("flags") if isinstance(flags, dict): out_flags = dict(flags) out_flags["collapsed"] = False return out_flags return {"collapsed": False} def _properties_from_template(template: Optional[dict[str, Any]], source_id: int) -> dict[str, Any]: props: dict[str, Any] = {} if template is not None: template_props = template.get("properties") if isinstance(template_props, dict): props.update(template_props) props[WD14_SOURCE_PROP] = source_id return props def _make_multistring_node( node_id: int, text: str, template: Optional[dict[str, Any]], source_id: int ) -> dict[str, Any]: order = 0 mode = 0 if template is not None: if isinstance(template.get("order"), int): order = template["order"] if isinstance(template.get("mode"), int): mode = template["mode"] node: dict[str, Any] = { "id": node_id, "type": MULTISTRING_TYPE, "pos": [0.0, 0.0], "size": _size_from_template(template), "flags": _flags_from_template(template), "order": order, "mode": mode, "inputs": [], "outputs": _string_output_from_template(template), "title": WD14_TAGS_TITLE, "properties": _properties_from_template(template, source_id), "widgets_values": _widget_values_from_template(template, text), } if template is not None: if "color" in template: node["color"] = template["color"] if "bgcolor" in template: node["bgcolor"] = template["bgcolor"] return node def _existing_wd14_sources(nodes: list[dict[str, Any]]) -> set[int]: sources: set[int] = set() for n in nodes: if not isinstance(n, dict): continue props = n.get("properties") if not isinstance(props, dict): continue source_id = props.get(WD14_SOURCE_PROP) if isinstance(source_id, int): sources.add(source_id) return sources def _add_wd14_multistring_nodes(wf: dict[str, Any]) -> dict[int, int]: nodes = wf.get("nodes") if not isinstance(nodes, list): return {} existing_sources = _existing_wd14_sources(nodes) template = _pick_multistring_template(nodes) next_id = _max_node_id(nodes) + 1 added: dict[int, int] = {} for n in list(nodes): if not isinstance(n, dict): continue if not _is_wd14_tagger_node(n): continue nid = n.get("id") if not isinstance(nid, int): continue if nid in existing_sources: continue text = _extract_wd14_tags(n) if text is None: continue nodes.append(_make_multistring_node(next_id, text, template, nid)) added[next_id] = nid next_id += 1 if added: last_node_id = wf.get("last_node_id") new_last = next_id - 1 if isinstance(last_node_id, int): wf["last_node_id"] = max(last_node_id, new_last) else: wf["last_node_id"] = new_last return added def _order_with_multistrings(order: list[int], mapping: dict[int, int]) -> list[int]: if not mapping: return order new_ids = set(mapping) base = [n for n in order if n not in new_ids] children: dict[int, list[int]] = defaultdict(list) for new_id, source_id in mapping.items(): children[source_id].append(new_id) for lst in children.values(): lst.sort() out: list[int] = [] appended: set[int] = set() for n in base: out.append(n) appended.add(n) kids = children.get(n) if kids: for kid in kids: out.append(kid) appended.add(kid) if len(appended) < len(base) + len(new_ids): for new_id in sorted(new_ids): if new_id not in appended: out.append(new_id) return out def relayout_workflow(wf: dict[str, Any]) -> dict[str, Any]: wf2 = json.loads(json.dumps(wf)) nodes = wf2.get("nodes") if not isinstance(nodes, list): raise ValueError("Workflow missing 'nodes' list") tagger_map = _add_wd14_multistring_nodes(wf2) _force_uncollapse(nodes) node_by_id: dict[int, dict[str, Any]] = {} node_ids: list[int] = [] for n in nodes: if not isinstance(n, dict): continue nid = n.get("id") if isinstance(nid, int): node_by_id[nid] = n node_ids.append(nid) links = wf2.get("links", []) if not isinstance(links, list): links = [] adj, indeg = _build_graph(links) order = _kahn_order(sorted(node_ids), adj, indeg) if tagger_map: order = _order_with_multistrings(order, tagger_map) layer: dict[int, int] = {nid: 0 for nid in node_ids} for n in order: base = layer.get(n, 0) for m in adj.get(n, []): layer[m] = max(layer.get(m, 0), base + 1) if tagger_map: for new_id, source_id in tagger_map.items(): layer[new_id] = layer.get(source_id, layer.get(new_id, 0)) layers: dict[int, list[int]] = defaultdict(list) for nid in order: layers[layer.get(nid, 0)].append(nid) # Compute per-layer max widths so x-spacing never overlaps. layer_ids = sorted(layers.keys()) max_w: dict[int, float] = {} for layer_id in layer_ids: max_w[layer_id] = max((_get_size(node_by_id[nid]).w for nid in layers[layer_id]), default=320.0) x_margin = 140.0 y_margin = 60.0 x_pos: dict[int, float] = {} cursor = 0.0 for layer_id in layer_ids: x_pos[layer_id] = cursor cursor += max_w[layer_id] + x_margin # Place nodes in each layer using cumulative heights (so tall nodes don't overlap). for layer_id in layer_ids: y_cursor = 0.0 for nid in layers[layer_id]: s = _get_size(node_by_id[nid]) node_by_id[nid]["pos"] = [x_pos[layer_id], y_cursor] y_cursor += s.h + y_margin # Reset viewport so it opens nicely. extra = wf2.get("extra") if not isinstance(extra, dict): extra = {} wf2["extra"] = extra ds = extra.get("ds") if not isinstance(ds, dict): ds = {} extra["ds"] = ds ds["scale"] = 1.0 ds["offset"] = [0, 0] if "groups" not in wf2: wf2["groups"] = [] if "config" not in wf2: wf2["config"] = {} return wf2 def process_png(png_path: Path) -> Path: meta = read_png_text_chunks(png_path) wf = extract_workflow_from_metadata(meta) cleaned = relayout_workflow(wf) out_path = png_path.with_suffix("") out_path = out_path.with_name(out_path.name + "_clean_workflow.json") out_path.write_text(json.dumps(cleaned, ensure_ascii=False, indent=2), encoding="utf-8") return out_path def main(argv: list[str]) -> int: if len(argv) < 2: print("Drag-and-drop one or more .png files onto this script, or run:") print(" python png_to_clean_workflow.py [more.png ...]") return 2 ok = 0 for a in argv[1:]: p = Path(a.strip('"')).expanduser() try: out = process_png(p) print(f"[OK] {p.name} -> {out.name}") ok += 1 except Exception as e: print(f"[ERR] {p}: {e}") return 0 if ok else 1 if __name__ == "__main__": raise SystemExit(main(sys.argv))