#!/usr/bin/env python3 """基于 tensor 名称的通用 ONNX 子图切分工具。 相较于 split_quant_onnx_by_subconfigs.py,本脚本额外提供: 1. 为每个子模型执行 onnx checker 与 shape inference(可关闭)。 2. 支持从 .npz/.npy(包含 dict 或单数组)加载验证数据。 3. 可用 onnxruntime 串联执行全部子模型,既可校验精度,也可单独输出流水线结果。 """ from __future__ import annotations import argparse import json import logging from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Optional, Sequence, Set import numpy as np import onnx from onnx import TensorProto, checker, helper, shape_inference, utils as onnx_utils try: # pragma: no cover - 可选依赖 import onnxruntime as ort except ImportError: # pragma: no cover ort = None @dataclass class SubGraphSpec: label: str start: List[str] end: List[str] node_names: Set[str] source: str output_path: Optional[Path] = None @dataclass class GraphIndex: tensor_to_producer: Dict[str, str] tensor_to_consumers: Dict[str, List[str]] node_inputs: Dict[str, List[str]] node_outputs: Dict[str, List[str]] graph_inputs: Set[str] graph_outputs: Set[str] initializer_names: Set[str] node_order: List[str] def sanitize(name: str) -> str: keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] sanitized = "".join(keep).strip("_") return sanitized or "tensor" def build_graph_index(model: onnx.ModelProto) -> GraphIndex: tensor_to_producer: Dict[str, str] = {} tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) node_inputs: Dict[str, List[str]] = {} node_outputs: Dict[str, List[str]] = {} node_order: List[str] = [] used_names: Set[str] = set() for idx, node in enumerate(model.graph.node): base = node.name.strip() if node.name else "" candidate = base or f"node_{idx}" while candidate in used_names: candidate = f"{candidate}_{idx}" used_names.add(candidate) node_order.append(candidate) node_inputs[candidate] = [x for x in node.input if x] node_outputs[candidate] = [y for y in node.output if y] for out_name in node_outputs[candidate]: tensor_to_producer[out_name] = candidate for inp_name in node_inputs[candidate]: tensor_to_consumers[inp_name].append(candidate) graph_inputs = {vi.name for vi in model.graph.input} graph_outputs = {vi.name for vi in model.graph.output} initializer_names = {init.name for init in model.graph.initializer} return GraphIndex( tensor_to_producer=tensor_to_producer, tensor_to_consumers=tensor_to_consumers, node_inputs=node_inputs, node_outputs=node_outputs, graph_inputs=graph_inputs, graph_outputs=graph_outputs, initializer_names=initializer_names, node_order=node_order, ) def ensure_value_infos(model: onnx.ModelProto, tensor_names: Iterable[str]) -> None: existing = {vi.name for vi in model.graph.value_info} source_map = {} for vi in list(model.graph.input) + list(model.graph.value_info) + list(model.graph.output): source_map[vi.name] = vi added: List[str] = [] for name in tensor_names: if name in existing: continue src = source_map.get(name) if src is not None: vi = onnx.ValueInfoProto() vi.CopyFrom(src) else: vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) model.graph.value_info.append(vi) existing.add(name) added.append(name) if added: logging.debug("已为以下 tensor 补充 ValueInfo: %s", added) def ensure_extractor_value_infos( extractor: onnx_utils.Extractor, tensor_names: Iterable[str], source_model: onnx.ModelProto, ) -> None: existing_inputs = {vi.name for vi in extractor.graph.input} existing_outputs = {vi.name for vi in extractor.graph.output} existing_vi = {vi.name for vi in extractor.graph.value_info} source_map = {} for vi in ( list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output) ): source_map[vi.name] = vi added: List[str] = [] for name in tensor_names: if name in existing_inputs or name in existing_outputs or name in existing_vi: continue src = source_map.get(name) if src is not None: vi = onnx.ValueInfoProto() vi.CopyFrom(src) else: vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) extractor.graph.value_info.append(vi) extractor.vimap[name] = vi existing_vi.add(name) added.append(name) if added: logging.debug("Extractor 侧补充 ValueInfo: %s", added) def trace_nodes_between(spec: SubGraphSpec, index: GraphIndex) -> Set[str]: boundary = set(spec.start) | index.graph_inputs | index.initializer_names visited_tensors: Set[str] = set() stack = list(spec.end) discovered_nodes: Set[str] = set() while stack: tensor = stack.pop() if tensor in visited_tensors: continue visited_tensors.add(tensor) if tensor in boundary: continue producer = index.tensor_to_producer.get(tensor) if not producer or producer in discovered_nodes: continue discovered_nodes.add(producer) for upstream in index.node_inputs.get(producer, []): if upstream and upstream not in boundary: stack.append(upstream) return discovered_nodes def untouched_components( all_nodes: Sequence[str], covered_nodes: Set[str], index: GraphIndex, ) -> List[Set[str]]: remaining = [n for n in all_nodes if n not in covered_nodes] if not remaining: return [] adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} rem_set = set(remaining) for node in remaining: for out_name in index.node_outputs.get(node, []): for consumer in index.tensor_to_consumers.get(out_name, []): if consumer in rem_set: adjacency[node].add(consumer) adjacency[consumer].add(node) for inp_name in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp_name) if producer in rem_set: adjacency[node].add(producer) adjacency[producer].add(node) components: List[Set[str]] = [] visited: Set[str] = set() for node in remaining: if node in visited: continue stack = [node] comp: Set[str] = set() while stack: cur = stack.pop() if cur in visited: continue visited.add(cur) comp.add(cur) stack.extend(adjacency[cur] - visited) components.append(comp) return components def derive_interface(nodes: Set[str], index: GraphIndex) -> (List[str], List[str]): produced = set() for node in nodes: produced.update(index.node_outputs.get(node, [])) start: Set[str] = set() for node in nodes: for inp in index.node_inputs.get(node, []): producer = index.tensor_to_producer.get(inp) if producer is None and inp not in index.initializer_names: start.add(inp) elif producer not in nodes and inp not in index.initializer_names: start.add(inp) end: Set[str] = set() for node in nodes: for out in index.node_outputs.get(node, []): consumers = index.tensor_to_consumers.get(out, []) if not consumers: if out in index.graph_outputs: end.add(out) continue if any(consumer not in nodes for consumer in consumers): end.add(out) end.update(index.graph_outputs & produced) if not end and produced: end = produced.copy() return sorted(start), sorted(end) def ordered_specs(specs: Sequence[SubGraphSpec], index: GraphIndex) -> List[SubGraphSpec]: available = set(index.graph_inputs) | index.initializer_names pending = list(specs) ordered: List[SubGraphSpec] = [] while pending: progressed = False for spec in list(pending): if set(spec.start).issubset(available): ordered.append(spec) available.update(spec.end) pending.remove(spec) progressed = True if not progressed: missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} raise RuntimeError(f"无法解析子图拓扑,缺少以下张量: {missing}") return ordered def is_valid_onnx_model(model_path: Path) -> bool: """检查 ONNX 模型文件是否有效(包含必需的 opset_import)""" try: # 检查文件大小 if model_path.stat().st_size == 0: logging.warning(f"模型 {model_path.name} 是空文件") return False model = onnx.load(model_path.as_posix(), load_external_data=False) # 检查模型是否为 None if model is None: logging.warning(f"模型 {model_path.name} 加载后为 None") return False # 检查是否有 graph if not hasattr(model, 'graph') or model.graph is None: logging.warning(f"模型 {model_path.name} 缺少 graph") return False # 检查是否有 opset_import if len(model.opset_import) == 0: logging.warning(f"模型 {model_path.name} 缺少 opset_import 信息") return False return True except Exception as e: logging.warning(f"无法加载模型 {model_path.name}: {e}") return False def extract_model_file( source_model: onnx.ModelProto, spec: SubGraphSpec, output_dir: Path, suffix: str, run_checker: bool, run_shape_infer: bool, skip_existing: bool = True, ) -> Path: head = sanitize(spec.start[0]) if spec.start else "const" tail = sanitize(spec.end[0]) if spec.end else "out" filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" destination = output_dir / filename # 如果文件已存在且 skip_existing=True,检查文件是否有效 if skip_existing and destination.exists(): if is_valid_onnx_model(destination): logging.info("跳过已存在的子图 %s (文件: %s)", spec.label, destination.name) return destination else: logging.warning("子图 %s 的文件无效,将重新生成", spec.label) # 手动构建子图,不使用 Extractor # 1. 创建新的空图 sub_graph = helper.make_graph( nodes=[], name=f"{spec.label}_subgraph", inputs=[], outputs=[], initializer=[] ) # 2. 从原始模型复制需要的节点 # 注意:GraphIndex 为重复/空名字节点生成了唯一的候选名,这里必须使用同样的规则 node_map = {} used_names: Set[str] = set() for idx, node in enumerate(source_model.graph.node): base = node.name.strip() if node.name else "" candidate = base or f"node_{idx}" while candidate in used_names: candidate = f"{candidate}_{idx}" used_names.add(candidate) node_map[candidate] = node missing_nodes: List[str] = [] for node_name in spec.node_names: target = node_map.get(node_name) if target is None: missing_nodes.append(node_name) continue new_node = onnx.NodeProto() new_node.CopyFrom(target) sub_graph.node.append(new_node) if missing_nodes: logging.warning("子图 %s: 有 %d 个节点未匹配到源模型,将被跳过: %s", spec.label, len(missing_nodes), missing_nodes[:5]) # 3. 收集所有需要的张量名称 node_inputs = set() node_outputs = set() for node in sub_graph.node: for inp in node.input: if inp: node_inputs.add(inp) for out in node.output: if out: node_outputs.add(out) # 4. 从原始模型收集 value_info source_value_info_map = {} for vi in list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output): source_value_info_map[vi.name] = vi # 5. 从原始模型收集 initializers source_init_map = {init.name: init for init in source_model.graph.initializer} # 6. 添加输入:从 spec.start 和需要但不是节点输出的张量 input_tensor_names = set(spec.start) for tensor_name in node_inputs: if tensor_name not in node_outputs and tensor_name not in source_init_map: input_tensor_names.add(tensor_name) for tensor_name in sorted(input_tensor_names): if tensor_name in source_value_info_map: vi = onnx.ValueInfoProto() vi.CopyFrom(source_value_info_map[tensor_name]) sub_graph.input.append(vi) else: vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) sub_graph.input.append(vi) # 7. 添加 initializers for tensor_name in node_inputs: if tensor_name in source_init_map: init = onnx.TensorProto() init.CopyFrom(source_init_map[tensor_name]) sub_graph.initializer.append(init) # 8. 添加输出:从 spec.end for tensor_name in spec.end: if tensor_name in source_value_info_map: vi = onnx.ValueInfoProto() vi.CopyFrom(source_value_info_map[tensor_name]) sub_graph.output.append(vi) else: vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) sub_graph.output.append(vi) # 9. 检查输出是否都有对应的产生节点,如果没有则添加 Identity for out_name in spec.end: if out_name not in node_outputs: # 这个输出没有被任何节点产生 if out_name in input_tensor_names or out_name in source_init_map: # 添加 Identity 节点 identity_node = helper.make_node( 'Identity', inputs=[out_name], outputs=[out_name], name=f'passthrough_{sanitize(out_name)}' ) sub_graph.node.append(identity_node) logging.info(f"子图 {spec.label}: 为输出 {out_name} 添加 Identity 节点") else: logging.error(f"子图 {spec.label}: 输出 {out_name} 无法产生(不在输入/initializer/节点输出中)") # 10. 创建模型 sub_model = helper.make_model(sub_graph) # 11. 复制元数据 sub_model.ir_version = source_model.ir_version sub_model.producer_name = source_model.producer_name sub_model.producer_version = source_model.producer_version sub_model.domain = source_model.domain sub_model.model_version = source_model.model_version sub_model.doc_string = source_model.doc_string # 12. 复制 opset imports while len(sub_model.opset_import) > 0: sub_model.opset_import.pop() if len(source_model.opset_import) > 0: for opset in source_model.opset_import: opset_import = sub_model.opset_import.add() opset_import.CopyFrom(opset) else: # 如果源模型没有 opset_import,添加默认的 opset logging.warning(f"源模型缺少 opset_import,为子图 {spec.label} 添加默认 opset 17") opset_import = sub_model.opset_import.add() opset_import.domain = "" opset_import.version = 17 # 使用 ONNX opset 17 作为默认值 # 验证 opset_import 是否正确设置 if len(sub_model.opset_import) == 0: raise RuntimeError(f"子图 {spec.label} 缺少 opset_import 信息") # 13. Shape inference 和 checker if run_shape_infer: try: sub_model = shape_inference.infer_shapes(sub_model) except Exception as e: logging.warning(f"子图 {spec.label} shape inference 失败: {e}") if run_checker: try: checker.check_model(sub_model) except Exception as e: logging.warning(f"子图 {spec.label} checker 验证失败: {e}") # 14. 保存 onnx.save(sub_model, destination.as_posix()) logging.info( "保存子图 %s (start=%s, end=%s, 节点数=%d, checker=%s, infer_shape=%s)", destination.name, spec.start, spec.end, len(sub_graph.node), bool(run_checker), bool(run_shape_infer), ) return destination def load_numpy_inputs(path: Path, expected_inputs: Iterable[str]) -> Dict[str, np.ndarray]: suffix = path.suffix.lower() expected = list(expected_inputs) if suffix == ".npz": data = np.load(path, allow_pickle=False) return {key: data[key] for key in data.files} if suffix == ".npy": arr = np.load(path, allow_pickle=True) if isinstance(arr, np.ndarray) and arr.shape == () and isinstance(arr.item(), dict): return {str(k): np.asarray(v) for k, v in arr.item().items()} if isinstance(arr, np.ndarray) and arr.dtype.names: return {name: arr[name] for name in arr.dtype.names} if len(expected) == 1: return {expected[0]: np.asarray(arr)} raise ValueError("多输入模型需要字典格式的 .npy/.npz 数据。") raise ValueError("仅支持 .npz 或 .npy 输入数据。") def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") session = ort.InferenceSession(model_path.as_posix(), providers=providers) outputs = session.run(None, feed_dict) names = [meta.name for meta in session.get_outputs()] return dict(zip(names, outputs)) def run_split_pipeline( ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], ) -> Dict[str, np.ndarray]: if ort is None: raise RuntimeError("需要 onnxruntime 才能执行验证。") tensor_store = dict(feed_dict) for spec in ordered_subgraphs: if spec.output_path is None: raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") if not spec.output_path.exists(): raise RuntimeError(f"子图 {spec.label} 的输出文件不存在: {spec.output_path}") logging.info("运行子图 %s (输入: %s, 输出: %s)", spec.label, spec.start, spec.end) session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) # 获取实际的输入名称 actual_inputs = [inp.name for inp in session.get_inputs()] logging.debug("子图 %s 实际输入: %s", spec.label, actual_inputs) fetch_inputs = {} for name in actual_inputs: if name not in tensor_store: # 尝试从 spec.start 中查找 if name in spec.start: logging.warning(f"子图 {spec.label} 缺少输入张量 {name},尝试从 feed_dict 查找") if name in feed_dict: tensor_store[name] = feed_dict[name] else: available = list(tensor_store.keys()) raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}") else: available = list(tensor_store.keys()) raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}") fetch_inputs[name] = tensor_store[name] results = session.run(None, fetch_inputs) for meta, value in zip(session.get_outputs(), results): tensor_store[meta.name] = value logging.debug("子图 %s 产生输出: %s (shape=%s)", spec.label, meta.name, value.shape) return tensor_store def verify( model_path: Path, ordered_subgraphs: Sequence[SubGraphSpec], feed_dict: Dict[str, np.ndarray], providers: List[str], rtol: float, atol: float, ) -> None: """验证切分后的子图流水线与原始模型输出是否一致。 验证流程: 1. 运行完整的原始模型,获得所有输出 2. 按拓扑顺序依次运行所有子图(前一个子图的输出作为后一个子图的输入) 3. 比较原始模型的最终输出与子图流水线的最终输出 4. 如果所有输出在指定的误差范围内一致,则验证通过 """ logging.info("开始验证:运行原始模型...") full_outputs = run_full_model(model_path, feed_dict, providers) logging.info("原始模型运行完成,产生 %d 个输出", len(full_outputs)) logging.info("开始验证:运行子图流水线...") split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) logging.info("子图流水线运行完成") logging.info("比较输出...") for name, ref in full_outputs.items(): cand = split_store.get(name) if cand is None: raise AssertionError(f"切分流水线未产生模型输出 {name}") if not np.allclose(ref, cand, rtol=rtol, atol=atol): diff = float(np.max(np.abs(ref - cand))) raise AssertionError(f"输出 {name} 不匹配,最大偏差 {diff:.3e}") logging.info("✓ 输出 %s 验证通过 (shape=%s)", name, ref.shape) logging.info("✓ 切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) def save_outputs(outputs: Dict[str, np.ndarray], destination: Optional[Path]) -> None: if destination is None: return destination.parent.mkdir(parents=True, exist_ok=True) np.savez(destination, **outputs) logging.info("流水线输出已保存至 %s", destination) def load_sub_configs(config_path: Path) -> List[dict]: with config_path.open("r", encoding="utf-8") as f: config = json.load(f) sub_configs = config.get("compiler", {}).get("sub_configs") if not sub_configs: sub_configs = config.get("sub_configs") if not sub_configs: raise ValueError("配置文件中未找到 sub_configs。") return sub_configs def main() -> None: parser = argparse.ArgumentParser(description="根据 tensor 名切分 ONNX 子图") parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 文件") parser.add_argument("--config", required=True, type=Path, help="包含 sub_configs 的 JSON") parser.add_argument("--output-dir", default="./split-onnx", type=Path, help="子模型输出目录") parser.add_argument("--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime providers 顺序") parser.add_argument("--verify", action="store_true", help="比较原始模型与流水线输出") parser.add_argument("--run-pipeline", action="store_true", help="只运行切分流水线并输出结果") parser.add_argument("--input-data", type=Path, help=".npz/.npy 格式的输入数据") parser.add_argument("--pipeline-output", type=Path, help="保存流水线输出为 npz") parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") parser.add_argument("--skip-checker", action="store_true", help="跳过 onnx checker") parser.add_argument("--skip-shape-infer", action="store_true", help="跳过 shape inference") parser.add_argument("--skip-existing", action="store_true", help="跳过已存在的子图文件") parser.add_argument("--log", default="INFO", help="日志等级") args = parser.parse_args() logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) # 必须加载完整数据(包括外部数据文件) # 虽然模型很大,但这是必须的,因为我们需要将权重完全内嵌到子模型中 # 否则子模型会引用原始的外部数据路径,导致在新目录下无法找到数据文件 logging.info("加载完整模型(包括外部数据文件)...这可能需要一些时间和内存") model = onnx.load(args.model.as_posix(), load_external_data=True) # 跳过对原始巨大模型的 checker,只对切分后的小模型进行 checker # if not args.skip_checker: # checker.check_model(args.model.as_posix()) logging.info("跳过对原始模型的 checker(模型过大),将只对切分后的子模型进行验证") graph_index = build_graph_index(model) sub_configs = load_sub_configs(args.config) specs: List[SubGraphSpec] = [] covered_nodes: Set[str] = set() for idx, entry in enumerate(sub_configs): start = [name for name in entry.get("start_tensor_names", []) if name] end = [name for name in entry.get("end_tensor_names", []) if name] if not start or not end: raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") spec = SubGraphSpec( label=f"cfg_{idx:02d}", start=start, end=end, node_names=set(), source="config", ) nodes = trace_nodes_between(spec, graph_index) spec.node_names = nodes covered_nodes.update(nodes) specs.append(spec) leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) for idx, component in enumerate(leftovers): start, end = derive_interface(component, graph_index) if not end: continue spec = SubGraphSpec( label=f"auto_{idx:02d}", start=start, end=end, node_names=component, source="auto", ) specs.append(spec) logging.info( "自动补充子图 %s: start=%s end=%s (节点数=%d)", spec.label, spec.start, spec.end, len(component), ) ordered = ordered_specs(specs, graph_index) required_tensors: Set[str] = set() for spec in ordered: required_tensors.update(spec.start) required_tensors.update(spec.end) ensure_value_infos(model, required_tensors) args.output_dir.mkdir(parents=True, exist_ok=True) # 首先检查所有现有的子图文件,收集需要重新生成的 corrupted_specs = [] for spec in ordered: head = sanitize(spec.start[0]) if spec.start else "const" tail = sanitize(spec.end[0]) if spec.end else "out" filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.onnx" potential_path = args.output_dir / filename if potential_path.exists() and not is_valid_onnx_model(potential_path): corrupted_specs.append(spec.label) logging.warning(f"检测到损坏的子图: {spec.label}, 将重新生成") if corrupted_specs: logging.info(f"发现 {len(corrupted_specs)} 个损坏的子图,将重新生成: {corrupted_specs}") # 生成或更新子图 for spec in ordered: spec.output_path = extract_model_file( model, spec, args.output_dir, spec.source, run_checker=not args.skip_checker, run_shape_infer=not args.skip_shape_infer, skip_existing=args.skip_existing, ) need_inputs = args.verify or args.run_pipeline if need_inputs: if args.input_data is None: raise ValueError("verify/run-pipeline 模式需要 --input-data 提供输入。") feed = load_numpy_inputs(args.input_data, graph_index.graph_inputs) missing_inputs = graph_index.graph_inputs - feed.keys() if missing_inputs: raise ValueError(f"输入数据缺少以下张量: {sorted(missing_inputs)}") else: feed = {} if args.verify: verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) elif args.run_pipeline: outputs = run_split_pipeline(ordered, feed, args.providers) save_outputs(outputs, args.pipeline_output) else: logging.info("子模型已生成,如需验证请添加 --verify 或 --run-pipeline。") if __name__ == "__main__": """ python ./scripts/split_onnx_by_subconfig.py \ --model ./onnx-models/z_image_transformer_body_only_simp_slim.onnx \ --config ./pulsar2_configs/transformers_subgraph.json \ --output-dir ./transformers_body_only_split_onnx \ --verify \ --input-data ./onnx-calibration-no-controlnet/transformer_inputs_prompt000_step00.npy \ --providers CPUExecutionProvider """ main()