#!/usr/bin/env python3 """Split an ONNX graph into smaller sub-models driven by sub_config rules. This script reads a JSON config file (matching the pulsar2 sub_config layout), extracts the requested subgraphs, and optionally emits any leftover parts of the model as independent ONNX graphs. A verification utility can run the original model and the stitched micro-model pipeline to make sure their outputs match. """ from __future__ import annotations import argparse import json import logging from collections import defaultdict, deque from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Set import numpy as np import onnx from onnx import utils as onnx_utils try: import onnxruntime as ort except ImportError: # pragma: no cover - optional dependency ort = None @dataclass class SubGraphSpec: """Describes a single subgraph to extract from the full model.""" label: str start: List[str] end: List[str] node_names: Set[str] source: str output_path: Optional[Path] = None @dataclass class GraphIndex: """Caches helpful lookups for traversing an ONNX graph.""" tensor_to_producer: Dict[str, str] tensor_to_consumers: Dict[str, List[str]] node_inputs: Dict[str, List[str]] node_outputs: Dict[str, List[str]] graph_inputs: Set[str] graph_outputs: Set[str] initializer_names: Set[str] node_order: List[str] def sanitize(name: str) -> str: keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] sanitized = "".join(keep).strip("_") return sanitized or "tensor" def build_graph_index(model: onnx.ModelProto) -> GraphIndex: tensor_to_producer: Dict[str, str] = {} tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) node_inputs: Dict[str, List[str]] = {} node_outputs: Dict[str, List[str]] = {} node_order: List[str] = [] used_names: Set[str] = set() for idx, node in enumerate(model.graph.node): base = node.name.strip() if node.name else "" candidate = base or f"node_{idx}" while candidate in used_names: candidate = f"{candidate}_{idx}" used_names.add(candidate) node_name = candidate node_order.append(node_name) node_inputs[node_name] = [x for x in node.input if x] node_outputs[node_name] = [y for y in node.output if y] for out_name in node_outputs[node_name]: tensor_to_producer[out_name] = node_name for inp_name in node_inputs[node_name]: tensor_to_consumers[inp_name].append(node_name) graph_inputs = {vi.name for vi in model.graph.input} graph_outputs = {vi.name for vi in model.graph.output} initializer_names = {init.name for init in model.graph.initializer} return GraphIndex( tensor_to_producer=tensor_to_producer, tensor_to_consumers=tensor_to_consumers, node_inputs=node_inputs, node_outputs=node_outputs, graph_inputs=graph_inputs, graph_outputs=graph_outputs, initializer_names=initializer_names, node_order=node_order, ) def trace_nodes_between( spec: SubGraphSpec, index: GraphIndex, ) -> Set[str]: boundary = set(spec.start) | index.graph_inputs | index.initializer_names visited_tensors: Set[str] = set() stack = list(spec.end) discovered_nodes: Set[str] = set() while stack: tensor = stack.pop() if tensor in visited_tensors: continue visited_tensors.add(tensor) if tensor in boundary: continue producer = index.tensor_to_producer.get(tensor) if not producer: continue if producer in discovered_nodes: continue discovered_nodes.add(producer) for upstream in index.node_inputs.get(producer, []): if upstream and upstream not in boundary: stack.append(upstream) return discovered_nodes def untouched_components( all_nodes: Sequence[str], covered_nodes: Set[str], index: GraphIndex, ) -> List[Set[str]]: remaining = [n for n in all_nodes if n not in covered_nodes] if not remaining: return [] adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} rem_set = set(remaining) for node in remaining: for out_name in index.node_outputs.get(node, []): for consumer in index.tensor_to_consumers.get(out_name, []): if consumer in rem_set: adjacency[node].add(consumer) adjacency[consumer].add(node) for inp_name in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp_name) if producer in rem_set: adjacency[node].add(producer) adjacency[producer].add(node) components: List[Set[str]] = [] visited: Set[str] = set() for node in remaining: if node in visited: continue stack = [node] comp: Set[str] = set() while stack: cur = stack.pop() if cur in visited: continue visited.add(cur) comp.add(cur) stack.extend(adjacency[cur] - visited) components.append(comp) return components def derive_interface( nodes: Set[str], index: GraphIndex, ) -> (List[str], List[str]): produced = set() for node in nodes: produced.update(index.node_outputs.get(node, [])) start: Set[str] = set() for node in nodes: for inp in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp) if producer is None and inp not in index.initializer_names: start.add(inp) elif producer not in nodes and inp not in index.initializer_names: start.add(inp) end: Set[str] = set() for node in nodes: for out in index.node_outputs.get(node, []): consumers = index.tensor_to_consumers.get(out, []) if not consumers: if out in index.graph_outputs: end.add(out) continue if any(consumer not in nodes for consumer in consumers): end.add(out) end.update(index.graph_outputs & produced) if not end and produced: end = produced.copy() return sorted(start), sorted(end) def extract_model_file( model_path: Path, spec: SubGraphSpec, output_dir: Path, suffix: str, ) -> Path: head = sanitize(spec.start[0]) if spec.start else "const" tail = sanitize(spec.end[0]) if spec.end else "out" filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" destination = output_dir / filename onnx_utils.extract_model( model_path.as_posix(), destination.as_posix(), input_names=spec.start, output_names=spec.end, check_model=False, ) logging.info("Saved %s (start=%s, end=%s)", destination.name, spec.start, spec.end) return destination def ordered_specs( specs: Sequence[SubGraphSpec], index: GraphIndex, ) -> List[SubGraphSpec]: available = set(index.graph_inputs) | index.initializer_names pending = list(specs) ordered: List[SubGraphSpec] = [] while pending: progressed = False for spec in list(pending): if set(spec.start).issubset(available): ordered.append(spec) available.update(spec.end) pending.remove(spec) progressed = True if not progressed: missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} raise RuntimeError( "无法解析子图拓扑,缺少以下张量: %s" % missing ) return ordered def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") session = ort.InferenceSession(model_path.as_posix(), providers=providers) outputs = session.run(None, feed_dict) names = [meta.name for meta in session.get_outputs()] return dict(zip(names, outputs)) def run_split_pipeline( ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], ) -> Dict[str, np.ndarray]: if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") tensor_store = dict(feed_dict) for spec in ordered_subgraphs: if spec.output_path is None: raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) fetch_inputs = {} for name in spec.start: if name not in tensor_store: raise KeyError( f"子图 {spec.label} 缺少输入张量 {name},请确认切分顺序。" ) fetch_inputs[name] = tensor_store[name] results = session.run(None, fetch_inputs) for meta, value in zip(session.get_outputs(), results): tensor_store[meta.name] = value return tensor_store def verify( model_path: Path, ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], rtol: float, atol: float, ) -> None: full_outputs = run_full_model(model_path, feed_dict, providers) split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) for name, ref in full_outputs.items(): cand = split_store.get(name) if cand is None: raise AssertionError(f"切分流水线未产生模型输出 {name}") if not np.allclose(ref, cand, rtol=rtol, atol=atol): diff = np.max(np.abs(ref - cand)) raise AssertionError( f"输出 {name} 不匹配,最大偏差 {diff:.3e}" ) logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) def load_npz_inputs(npz_path: Path) -> Dict[str, np.ndarray]: data = np.load(npz_path, allow_pickle=False) return {key: data[key] for key in data.files} def main() -> None: parser = argparse.ArgumentParser(description="根据 sub_config 切分 ONNX 模型。") parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 路径") parser.add_argument("--config", required=True, type=Path, help="pulsar2 配置 JSON") parser.add_argument("--output-dir", required=False, default="./split-onnx", type=Path, help="保存子模型的目录") parser.add_argument( "--verify", action="store_true", help="生成后立即用 onnxruntime 校验输出是否一致", ) parser.add_argument( "--input-npz", type=Path, help="包含模型所有输入张量的 npz 文件 (verify 模式需要)", ) parser.add_argument( "--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime 推理后端顺序", ) parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") parser.add_argument("--log", default="INFO", help="日志等级") args = parser.parse_args() logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) model = onnx.load(args.model.as_posix()) graph_index = build_graph_index(model) with args.config.open("r", encoding="utf-8") as f: config = json.load(f) sub_configs = config.get("compiler", {}).get("sub_configs", []) if not sub_configs: raise ValueError("配置文件中未找到 compiler.sub_configs。") specs: List[SubGraphSpec] = [] covered_nodes: Set[str] = set() for idx, entry in enumerate(sub_configs): start = [name for name in entry.get("start_tensor_names", []) if name] end = [name for name in entry.get("end_tensor_names", []) if name] if not start or not end: raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") spec = SubGraphSpec( label=f"cfg_{idx:02d}", start=start, end=end, node_names=set(), source="config", ) nodes = trace_nodes_between(spec, graph_index) spec.node_names = nodes covered_nodes.update(nodes) specs.append(spec) leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) for idx, component in enumerate(leftovers): start, end = derive_interface(component, graph_index) if not end: logging.warning("自动发现的剩余子图 %d 没有输出,跳过。", idx) continue spec = SubGraphSpec( label=f"auto_{idx:02d}", start=start, end=end, node_names=component, source="auto", ) specs.append(spec) logging.info( "自动补充子图 %s: start=%s end=%s (节点数=%d)", spec.label, spec.start, spec.end, len(component), ) ordered = ordered_specs(specs, graph_index) args.output_dir.mkdir(parents=True, exist_ok=True) for spec in ordered: spec.output_path = extract_model_file(args.model, spec, args.output_dir, spec.source) if args.verify: if args.input_npz is None: raise ValueError("verify 模式需要 --input-npz 提供输入数据。") feed = load_npz_inputs(args.input_npz) missing_inputs = graph_index.graph_inputs - feed.keys() if missing_inputs: raise ValueError(f"npz 中缺少以下模型输入: {sorted(missing_inputs)}") verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) if __name__ == "__main__": """ 用法示例: python python/VideoX-Fun/scripts/split_onnx_by_subconfigs.py \ --model /path/to/full.onnx \ --config python/VideoX-Fun/pulsar2_configs/transformers_subgraph.json \ --output-dir /tmp/sliced_models \ --verify \ --input-npz /path/to/inputs.npz """ main()