#!/usr/bin/env python3 """基于 tensor 名称的通用 ONNX 子图切分工具。 相较于 split_quant_onnx_by_subconfigs.py,本脚本额外提供: 1. 为每个子模型执行 onnx checker 与 shape inference(可关闭)。 2. 支持从 .npz/.npy(包含 dict 或单数组)加载验证数据。 3. 可用 onnxruntime 串联执行全部子模型,既可校验精度,也可单独输出流水线结果。 """ from __future__ import annotations import argparse import json import logging from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Set import numpy as np import onnx from onnx import TensorProto, checker, helper, shape_inference, utils as onnx_utils try: # pragma: no cover - 可选依赖 import onnxruntime as ort except ImportError: # pragma: no cover ort = None @dataclass class SubGraphSpec: label: str start: List[str] end: List[str] node_names: Set[str] source: str output_path: Optional[Path] = None @dataclass class GraphIndex: tensor_to_producer: Dict[str, str] tensor_to_consumers: Dict[str, List[str]] node_inputs: Dict[str, List[str]] node_outputs: Dict[str, List[str]] graph_inputs: Set[str] graph_outputs: Set[str] initializer_names: Set[str] node_order: List[str] def sanitize(name: str) -> str: keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] sanitized = "".join(keep).strip("_") return sanitized or "tensor" def build_graph_index(model: onnx.ModelProto) -> GraphIndex: tensor_to_producer: Dict[str, str] = {} tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) node_inputs: Dict[str, List[str]] = {} node_outputs: Dict[str, List[str]] = {} node_order: List[str] = [] used_names: Set[str] = set() for idx, node in enumerate(model.graph.node): base = node.name.strip() if node.name else "" candidate = base or f"node_{idx}" while candidate in used_names: candidate = f"{candidate}_{idx}" used_names.add(candidate) node_order.append(candidate) node_inputs[candidate] = [x for x in node.input if x] node_outputs[candidate] = [y for y in node.output if y] for out_name in node_outputs[candidate]: tensor_to_producer[out_name] = candidate for inp_name in node_inputs[candidate]: tensor_to_consumers[inp_name].append(candidate) graph_inputs = {vi.name for vi in model.graph.input} graph_outputs = {vi.name for vi in model.graph.output} initializer_names = {init.name for init in model.graph.initializer} return GraphIndex( tensor_to_producer=tensor_to_producer, tensor_to_consumers=tensor_to_consumers, node_inputs=node_inputs, node_outputs=node_outputs, graph_inputs=graph_inputs, graph_outputs=graph_outputs, initializer_names=initializer_names, node_order=node_order, ) def ensure_value_infos(model: onnx.ModelProto, tensor_names: Iterable[str]) -> None: existing = {vi.name for vi in model.graph.value_info} source_map = {} for vi in list(model.graph.input) + list(model.graph.value_info) + list(model.graph.output): source_map[vi.name] = vi added: List[str] = [] for name in tensor_names: if name in existing: continue src = source_map.get(name) if src is not None: vi = onnx.ValueInfoProto() vi.CopyFrom(src) else: vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) model.graph.value_info.append(vi) existing.add(name) added.append(name) if added: logging.debug("已为以下 tensor 补充 ValueInfo: %s", added) def ensure_extractor_value_infos( extractor: onnx_utils.Extractor, tensor_names: Iterable[str], source_model: onnx.ModelProto, ) -> None: existing_inputs = {vi.name for vi in extractor.graph.input} existing_outputs = {vi.name for vi in extractor.graph.output} existing_vi = {vi.name for vi in extractor.graph.value_info} source_map = {} for vi in ( list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output) ): source_map[vi.name] = vi added: List[str] = [] for name in tensor_names: if name in existing_inputs or name in existing_outputs or name in existing_vi: continue src = source_map.get(name) if src is not None: vi = onnx.ValueInfoProto() vi.CopyFrom(src) else: vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) extractor.graph.value_info.append(vi) extractor.vimap[name] = vi existing_vi.add(name) added.append(name) if added: logging.debug("Extractor 侧补充 ValueInfo: %s", added) def trace_nodes_between(spec: SubGraphSpec, index: GraphIndex) -> Set[str]: boundary = set(spec.start) | index.graph_inputs | index.initializer_names visited_tensors: Set[str] = set() stack = list(spec.end) discovered_nodes: Set[str] = set() while stack: tensor = stack.pop() if tensor in visited_tensors: continue visited_tensors.add(tensor) if tensor in boundary: continue producer = index.tensor_to_producer.get(tensor) if not producer or producer in discovered_nodes: continue discovered_nodes.add(producer) for upstream in index.node_inputs.get(producer, []): if upstream and upstream not in boundary: stack.append(upstream) return discovered_nodes def untouched_components( all_nodes: Sequence[str], covered_nodes: Set[str], index: GraphIndex, ) -> List[Set[str]]: remaining = [n for n in all_nodes if n not in covered_nodes] if not remaining: return [] adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} rem_set = set(remaining) for node in remaining: for out_name in index.node_outputs.get(node, []): for consumer in index.tensor_to_consumers.get(out_name, []): if consumer in rem_set: adjacency[node].add(consumer) adjacency[consumer].add(node) for inp_name in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp_name) if producer in rem_set: adjacency[node].add(producer) adjacency[producer].add(node) components: List[Set[str]] = [] visited: Set[str] = set() for node in remaining: if node in visited: continue stack = [node] comp: Set[str] = set() while stack: cur = stack.pop() if cur in visited: continue visited.add(cur) comp.add(cur) stack.extend(adjacency[cur] - visited) components.append(comp) return components def derive_interface(nodes: Set[str], index: GraphIndex) -> (List[str], List[str]): produced = set() for node in nodes: produced.update(index.node_outputs.get(node, [])) start: Set[str] = set() for node in nodes: for inp in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp) if producer is None and inp not in index.initializer_names: start.add(inp) elif producer not in nodes and inp not in index.initializer_names: start.add(inp) end: Set[str] = set() for node in nodes: for out in index.node_outputs.get(node, []): consumers = index.tensor_to_consumers.get(out, []) if not consumers: if out in index.graph_outputs: end.add(out) continue if any(consumer not in nodes for consumer in consumers): end.add(out) end.update(index.graph_outputs & produced) if not end and produced: end = produced.copy() return sorted(start), sorted(end) def ordered_specs(specs: Sequence[SubGraphSpec], index: GraphIndex) -> List[SubGraphSpec]: available = set(index.graph_inputs) | index.initializer_names pending = list(specs) ordered: List[SubGraphSpec] = [] while pending: progressed = False for spec in list(pending): if set(spec.start).issubset(available): ordered.append(spec) available.update(spec.end) pending.remove(spec) progressed = True if not progressed: missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} raise RuntimeError(f"无法解析子图拓扑,缺少以下张量: {missing}") return ordered def extract_model_file( source_model: onnx.ModelProto, spec: SubGraphSpec, output_dir: Path, suffix: str, run_checker: bool, run_shape_infer: bool, ) -> Path: head = sanitize(spec.start[0]) if spec.start else "const" tail = sanitize(spec.end[0]) if spec.end else "out" filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" destination = output_dir / filename # 手动构建子图,不使用 Extractor # 1. 创建新的空图 sub_graph = helper.make_graph( nodes=[], name=f"{spec.label}_subgraph", inputs=[], outputs=[], initializer=[] ) # 2. 从原始模型复制需要的节点 node_map = {node.name or f"node_{i}": node for i, node in enumerate(source_model.graph.node)} for node_name in spec.node_names: if node_name in node_map: new_node = onnx.NodeProto() new_node.CopyFrom(node_map[node_name]) sub_graph.node.append(new_node) # 3. 收集所有需要的张量名称 node_inputs = set() node_outputs = set() for node in sub_graph.node: for inp in node.input: if inp: node_inputs.add(inp) for out in node.output: if out: node_outputs.add(out) # 4. 从原始模型收集 value_info source_value_info_map = {} for vi in list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output): source_value_info_map[vi.name] = vi # 5. 从原始模型收集 initializers source_init_map = {init.name: init for init in source_model.graph.initializer} # 6. 添加输入:从 spec.start 和需要但不是节点输出的张量 input_tensor_names = set(spec.start) for tensor_name in node_inputs: if tensor_name not in node_outputs and tensor_name not in source_init_map: input_tensor_names.add(tensor_name) for tensor_name in sorted(input_tensor_names): if tensor_name in source_value_info_map: vi = onnx.ValueInfoProto() vi.CopyFrom(source_value_info_map[tensor_name]) sub_graph.input.append(vi) else: vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) sub_graph.input.append(vi) # 7. 添加 initializers for tensor_name in node_inputs: if tensor_name in source_init_map: init = onnx.TensorProto() init.CopyFrom(source_init_map[tensor_name]) sub_graph.initializer.append(init) # 8. 添加输出:从 spec.end for tensor_name in spec.end: if tensor_name in source_value_info_map: vi = onnx.ValueInfoProto() vi.CopyFrom(source_value_info_map[tensor_name]) sub_graph.output.append(vi) else: vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) sub_graph.output.append(vi) # 9. 检查输出是否都有对应的产生节点,如果没有则添加 Identity for out_name in spec.end: if out_name not in node_outputs: # 这个输出没有被任何节点产生 if out_name in input_tensor_names or out_name in source_init_map: # 添加 Identity 节点 identity_node = helper.make_node( 'Identity', inputs=[out_name], outputs=[out_name], name=f'passthrough_{sanitize(out_name)}' ) sub_graph.node.append(identity_node) logging.info(f"子图 {spec.label}: 为输出 {out_name} 添加 Identity 节点") else: logging.error(f"子图 {spec.label}: 输出 {out_name} 无法产生(不在输入/initializer/节点输出中)") # 10. 创建模型 sub_model = helper.make_model(sub_graph) # 11. 复制元数据 sub_model.ir_version = source_model.ir_version sub_model.producer_name = source_model.producer_name sub_model.producer_version = source_model.producer_version sub_model.domain = source_model.domain sub_model.model_version = source_model.model_version sub_model.doc_string = source_model.doc_string # 12. 复制 opset imports while len(sub_model.opset_import) > 0: sub_model.opset_import.pop() for opset in source_model.opset_import: opset_import = sub_model.opset_import.add() opset_import.CopyFrom(opset) # 13. Shape inference 和 checker if run_shape_infer: try: sub_model = shape_inference.infer_shapes(sub_model) except Exception as e: logging.warning(f"子图 {spec.label} shape inference 失败: {e}") if run_checker: try: checker.check_model(sub_model) except Exception as e: logging.warning(f"子图 {spec.label} checker 验证失败: {e}") # 14. 保存 onnx.save(sub_model, destination.as_posix()) logging.info( "保存子图 %s (start=%s, end=%s, 节点数=%d, checker=%s, infer_shape=%s)", destination.name, spec.start, spec.end, len(sub_graph.node), bool(run_checker), bool(run_shape_infer), ) return destination def load_numpy_inputs(path: Path, expected_inputs: Iterable[str]) -> Dict[str, np.ndarray]: suffix = path.suffix.lower() expected = list(expected_inputs) if suffix == ".npz": data = np.load(path, allow_pickle=False) return {key: data[key] for key in data.files} if suffix == ".npy": arr = np.load(path, allow_pickle=True) if isinstance(arr, np.ndarray) and arr.shape == () and isinstance(arr.item(), dict): return {str(k): np.asarray(v) for k, v in arr.item().items()} if isinstance(arr, np.ndarray) and arr.dtype.names: return {name: arr[name] for name in arr.dtype.names} if len(expected) == 1: return {expected[0]: np.asarray(arr)} raise ValueError("多输入模型需要字典格式的 .npy/.npz 数据。") raise ValueError("仅支持 .npz 或 .npy 输入数据。") def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") session = ort.InferenceSession(model_path.as_posix(), providers=providers) outputs = session.run(None, feed_dict) names = [meta.name for meta in session.get_outputs()] return dict(zip(names, outputs)) def run_split_pipeline( ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], ) -> Dict[str, np.ndarray]: if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") tensor_store = dict(feed_dict) for spec in ordered_subgraphs: if spec.output_path is None: raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) fetch_inputs = {} for name in spec.start: if name not in tensor_store: raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。") fetch_inputs[name] = tensor_store[name] results = session.run(None, fetch_inputs) for meta, value in zip(session.get_outputs(), results): tensor_store[meta.name] = value return tensor_store def verify( model_path: Path, ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], rtol: float, atol: float, ) -> None: full_outputs = run_full_model(model_path, feed_dict, providers) split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) for name, ref in full_outputs.items(): cand = split_store.get(name) if cand is None: raise AssertionError(f"切分流水线未产生模型输出 {name}") if not np.allclose(ref, cand, rtol=rtol, atol=atol): diff = float(np.max(np.abs(ref - cand))) raise AssertionError(f"输出 {name} 不匹配,最大偏差 {diff:.3e}") logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) def save_outputs(outputs: Dict[str, np.ndarray], destination: Optional[Path]) -> None: if destination is None: return destination.parent.mkdir(parents=True, exist_ok=True) np.savez(destination, **outputs) logging.info("流水线输出已保存至 %s", destination) def load_sub_configs(config_path: Path) -> List[dict]: with config_path.open("r", encoding="utf-8") as f: config = json.load(f) sub_configs = config.get("compiler", {}).get("sub_configs") if not sub_configs: sub_configs = config.get("sub_configs") if not sub_configs: raise ValueError("配置文件中未找到 sub_configs。") return sub_configs def main() -> None: parser = argparse.ArgumentParser(description="根据 tensor 名切分 ONNX 子图") parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 文件") parser.add_argument("--config", required=True, type=Path, help="包含 sub_configs 的 JSON") parser.add_argument("--output-dir", default="./split-onnx", type=Path, help="子模型输出目录") parser.add_argument("--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime providers 顺序") parser.add_argument("--verify", action="store_true", help="比较原始模型与流水线输出") parser.add_argument("--run-pipeline", action="store_true", help="只运行切分流水线并输出结果") parser.add_argument("--input-data", type=Path, help=".npz/.npy 格式的输入数据") parser.add_argument("--pipeline-output", type=Path, help="保存流水线输出为 npz") parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") parser.add_argument("--skip-checker", action="store_true", help="跳过 onnx checker") parser.add_argument("--skip-shape-infer", action="store_true", help="跳过 shape inference") parser.add_argument("--log", default="INFO", help="日志等级") args = parser.parse_args() logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) # 必须加载完整数据(包括外部数据文件) # 虽然模型很大,但这是必须的,因为我们需要将权重完全内嵌到子模型中 # 否则子模型会引用原始的外部数据路径,导致在新目录下无法找到数据文件 logging.info("加载完整模型(包括外部数据文件)...这可能需要一些时间和内存") model = onnx.load(args.model.as_posix(), load_external_data=True) # 跳过对原始巨大模型的 checker,只对切分后的小模型进行 checker # if not args.skip_checker: # checker.check_model(args.model.as_posix()) logging.info("跳过对原始模型的 checker(模型过大),将只对切分后的子模型进行验证") graph_index = build_graph_index(model) sub_configs = load_sub_configs(args.config) specs: List[SubGraphSpec] = [] covered_nodes: Set[str] = set() for idx, entry in enumerate(sub_configs): start = [name for name in entry.get("start_tensor_names", []) if name] end = [name for name in entry.get("end_tensor_names", []) if name] if not start or not end: raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") spec = SubGraphSpec( label=f"cfg_{idx:02d}", start=start, end=end, node_names=set(), source="config", ) nodes = trace_nodes_between(spec, graph_index) spec.node_names = nodes covered_nodes.update(nodes) specs.append(spec) leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) for idx, component in enumerate(leftovers): start, end = derive_interface(component, graph_index) if not end: continue spec = SubGraphSpec( label=f"auto_{idx:02d}", start=start, end=end, node_names=component, source="auto", ) specs.append(spec) logging.info( "自动补充子图 %s: start=%s end=%s (节点数=%d)", spec.label, spec.start, spec.end, len(component), ) ordered = ordered_specs(specs, graph_index) required_tensors: Set[str] = set() for spec in ordered: required_tensors.update(spec.start) required_tensors.update(spec.end) ensure_value_infos(model, required_tensors) args.output_dir.mkdir(parents=True, exist_ok=True) for spec in ordered: spec.output_path = extract_model_file( model, spec, args.output_dir, spec.source, run_checker=not args.skip_checker, run_shape_infer=not args.skip_shape_infer, ) need_inputs = args.verify or args.run_pipeline if need_inputs: if args.input_data is None: raise ValueError("verify/run-pipeline 模式需要 --input-data 提供输入。") feed = load_numpy_inputs(args.input_data, graph_index.graph_inputs) missing_inputs = graph_index.graph_inputs - feed.keys() if missing_inputs: raise ValueError(f"输入数据缺少以下张量: {sorted(missing_inputs)}") else: feed = {} if args.verify: verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) elif args.run_pipeline: outputs = run_split_pipeline(ordered, feed, args.providers) save_outputs(outputs, args.pipeline_output) else: logging.info("子模型已生成,如需验证请添加 --verify 或 --run-pipeline。") if __name__ == "__main__": """ python ./scripts/split_onnx_by_subconfig.py \ --model ./onnx-models/z_image_transformer_body_only_simp_slim.onnx \ --config ./pulsar2_configs/transformers_subgraph.json \ --output-dir ./transformers_body_only_split_onnx \ --verify \ --input-data ./onnx-calibration-no-controlnet/transformer_inputs_prompt000_step00.npy \ --providers CPUExecutionProvider """ main()