Z-Image-Turbo / VideoX-Fun /scripts /split_onnx_by_subconfig.py.old
yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""基于 tensor 名称的通用 ONNX 子图切分工具。
相较于 split_quant_onnx_by_subconfigs.py,本脚本额外提供:
1. 为每个子模型执行 onnx checker 与 shape inference(可关闭)。
2. 支持从 .npz/.npy(包含 dict 或单数组)加载验证数据。
3. 可用 onnxruntime 串联执行全部子模型,既可校验精度,也可单独输出流水线结果。
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Set
import numpy as np
import onnx
from onnx import TensorProto, checker, helper, shape_inference, utils as onnx_utils
try: # pragma: no cover - 可选依赖
import onnxruntime as ort
except ImportError: # pragma: no cover
ort = None
@dataclass
class SubGraphSpec:
label: str
start: List[str]
end: List[str]
node_names: Set[str]
source: str
output_path: Optional[Path] = None
@dataclass
class GraphIndex:
tensor_to_producer: Dict[str, str]
tensor_to_consumers: Dict[str, List[str]]
node_inputs: Dict[str, List[str]]
node_outputs: Dict[str, List[str]]
graph_inputs: Set[str]
graph_outputs: Set[str]
initializer_names: Set[str]
node_order: List[str]
def sanitize(name: str) -> str:
keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"]
sanitized = "".join(keep).strip("_")
return sanitized or "tensor"
def build_graph_index(model: onnx.ModelProto) -> GraphIndex:
tensor_to_producer: Dict[str, str] = {}
tensor_to_consumers: Dict[str, List[str]] = defaultdict(list)
node_inputs: Dict[str, List[str]] = {}
node_outputs: Dict[str, List[str]] = {}
node_order: List[str] = []
used_names: Set[str] = set()
for idx, node in enumerate(model.graph.node):
base = node.name.strip() if node.name else ""
candidate = base or f"node_{idx}"
while candidate in used_names:
candidate = f"{candidate}_{idx}"
used_names.add(candidate)
node_order.append(candidate)
node_inputs[candidate] = [x for x in node.input if x]
node_outputs[candidate] = [y for y in node.output if y]
for out_name in node_outputs[candidate]:
tensor_to_producer[out_name] = candidate
for inp_name in node_inputs[candidate]:
tensor_to_consumers[inp_name].append(candidate)
graph_inputs = {vi.name for vi in model.graph.input}
graph_outputs = {vi.name for vi in model.graph.output}
initializer_names = {init.name for init in model.graph.initializer}
return GraphIndex(
tensor_to_producer=tensor_to_producer,
tensor_to_consumers=tensor_to_consumers,
node_inputs=node_inputs,
node_outputs=node_outputs,
graph_inputs=graph_inputs,
graph_outputs=graph_outputs,
initializer_names=initializer_names,
node_order=node_order,
)
def ensure_value_infos(model: onnx.ModelProto, tensor_names: Iterable[str]) -> None:
existing = {vi.name for vi in model.graph.value_info}
source_map = {}
for vi in list(model.graph.input) + list(model.graph.value_info) + list(model.graph.output):
source_map[vi.name] = vi
added: List[str] = []
for name in tensor_names:
if name in existing:
continue
src = source_map.get(name)
if src is not None:
vi = onnx.ValueInfoProto()
vi.CopyFrom(src)
else:
vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None)
model.graph.value_info.append(vi)
existing.add(name)
added.append(name)
if added:
logging.debug("已为以下 tensor 补充 ValueInfo: %s", added)
def ensure_extractor_value_infos(
extractor: onnx_utils.Extractor,
tensor_names: Iterable[str],
source_model: onnx.ModelProto,
) -> None:
existing_inputs = {vi.name for vi in extractor.graph.input}
existing_outputs = {vi.name for vi in extractor.graph.output}
existing_vi = {vi.name for vi in extractor.graph.value_info}
source_map = {}
for vi in (
list(source_model.graph.input)
+ list(source_model.graph.value_info)
+ list(source_model.graph.output)
):
source_map[vi.name] = vi
added: List[str] = []
for name in tensor_names:
if name in existing_inputs or name in existing_outputs or name in existing_vi:
continue
src = source_map.get(name)
if src is not None:
vi = onnx.ValueInfoProto()
vi.CopyFrom(src)
else:
vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None)
extractor.graph.value_info.append(vi)
extractor.vimap[name] = vi
existing_vi.add(name)
added.append(name)
if added:
logging.debug("Extractor 侧补充 ValueInfo: %s", added)
def trace_nodes_between(spec: SubGraphSpec, index: GraphIndex) -> Set[str]:
boundary = set(spec.start) | index.graph_inputs | index.initializer_names
visited_tensors: Set[str] = set()
stack = list(spec.end)
discovered_nodes: Set[str] = set()
while stack:
tensor = stack.pop()
if tensor in visited_tensors:
continue
visited_tensors.add(tensor)
if tensor in boundary:
continue
producer = index.tensor_to_producer.get(tensor)
if not producer or producer in discovered_nodes:
continue
discovered_nodes.add(producer)
for upstream in index.node_inputs.get(producer, []):
if upstream and upstream not in boundary:
stack.append(upstream)
return discovered_nodes
def untouched_components(
all_nodes: Sequence[str],
covered_nodes: Set[str],
index: GraphIndex,
) -> List[Set[str]]:
remaining = [n for n in all_nodes if n not in covered_nodes]
if not remaining:
return []
adjacency: Dict[str, Set[str]] = {name: set() for name in remaining}
rem_set = set(remaining)
for node in remaining:
for out_name in index.node_outputs.get(node, []):
for consumer in index.tensor_to_consumers.get(out_name, []):
if consumer in rem_set:
adjacency[node].add(consumer)
adjacency[consumer].add(node)
for inp_name in index.node_inputs.get(node, []):
producer = index.tensor_to_producer.get(inp_name)
if producer in rem_set:
adjacency[node].add(producer)
adjacency[producer].add(node)
components: List[Set[str]] = []
visited: Set[str] = set()
for node in remaining:
if node in visited:
continue
stack = [node]
comp: Set[str] = set()
while stack:
cur = stack.pop()
if cur in visited:
continue
visited.add(cur)
comp.add(cur)
stack.extend(adjacency[cur] - visited)
components.append(comp)
return components
def derive_interface(nodes: Set[str], index: GraphIndex) -> (List[str], List[str]):
produced = set()
for node in nodes:
produced.update(index.node_outputs.get(node, []))
start: Set[str] = set()
for node in nodes:
for inp in index.node_inputs.get(node, []):
producer = index.tensor_to_producer.get(inp)
if producer is None and inp not in index.initializer_names:
start.add(inp)
elif producer not in nodes and inp not in index.initializer_names:
start.add(inp)
end: Set[str] = set()
for node in nodes:
for out in index.node_outputs.get(node, []):
consumers = index.tensor_to_consumers.get(out, [])
if not consumers:
if out in index.graph_outputs:
end.add(out)
continue
if any(consumer not in nodes for consumer in consumers):
end.add(out)
end.update(index.graph_outputs & produced)
if not end and produced:
end = produced.copy()
return sorted(start), sorted(end)
def ordered_specs(specs: Sequence[SubGraphSpec], index: GraphIndex) -> List[SubGraphSpec]:
available = set(index.graph_inputs) | index.initializer_names
pending = list(specs)
ordered: List[SubGraphSpec] = []
while pending:
progressed = False
for spec in list(pending):
if set(spec.start).issubset(available):
ordered.append(spec)
available.update(spec.end)
pending.remove(spec)
progressed = True
if not progressed:
missing = {spec.label: sorted(set(spec.start) - available) for spec in pending}
raise RuntimeError(f"无法解析子图拓扑,缺少以下张量: {missing}")
return ordered
def extract_model_file(
source_model: onnx.ModelProto,
spec: SubGraphSpec,
output_dir: Path,
suffix: str,
run_checker: bool,
run_shape_infer: bool,
) -> Path:
head = sanitize(spec.start[0]) if spec.start else "const"
tail = sanitize(spec.end[0]) if spec.end else "out"
filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx"
destination = output_dir / filename
# 手动构建子图,不使用 Extractor
# 1. 创建新的空图
sub_graph = helper.make_graph(
nodes=[],
name=f"{spec.label}_subgraph",
inputs=[],
outputs=[],
initializer=[]
)
# 2. 从原始模型复制需要的节点
node_map = {node.name or f"node_{i}": node for i, node in enumerate(source_model.graph.node)}
for node_name in spec.node_names:
if node_name in node_map:
new_node = onnx.NodeProto()
new_node.CopyFrom(node_map[node_name])
sub_graph.node.append(new_node)
# 3. 收集所有需要的张量名称
node_inputs = set()
node_outputs = set()
for node in sub_graph.node:
for inp in node.input:
if inp:
node_inputs.add(inp)
for out in node.output:
if out:
node_outputs.add(out)
# 4. 从原始模型收集 value_info
source_value_info_map = {}
for vi in list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output):
source_value_info_map[vi.name] = vi
# 5. 从原始模型收集 initializers
source_init_map = {init.name: init for init in source_model.graph.initializer}
# 6. 添加输入:从 spec.start 和需要但不是节点输出的张量
input_tensor_names = set(spec.start)
for tensor_name in node_inputs:
if tensor_name not in node_outputs and tensor_name not in source_init_map:
input_tensor_names.add(tensor_name)
for tensor_name in sorted(input_tensor_names):
if tensor_name in source_value_info_map:
vi = onnx.ValueInfoProto()
vi.CopyFrom(source_value_info_map[tensor_name])
sub_graph.input.append(vi)
else:
vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None)
sub_graph.input.append(vi)
# 7. 添加 initializers
for tensor_name in node_inputs:
if tensor_name in source_init_map:
init = onnx.TensorProto()
init.CopyFrom(source_init_map[tensor_name])
sub_graph.initializer.append(init)
# 8. 添加输出:从 spec.end
for tensor_name in spec.end:
if tensor_name in source_value_info_map:
vi = onnx.ValueInfoProto()
vi.CopyFrom(source_value_info_map[tensor_name])
sub_graph.output.append(vi)
else:
vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None)
sub_graph.output.append(vi)
# 9. 检查输出是否都有对应的产生节点,如果没有则添加 Identity
for out_name in spec.end:
if out_name not in node_outputs:
# 这个输出没有被任何节点产生
if out_name in input_tensor_names or out_name in source_init_map:
# 添加 Identity 节点
identity_node = helper.make_node(
'Identity',
inputs=[out_name],
outputs=[out_name],
name=f'passthrough_{sanitize(out_name)}'
)
sub_graph.node.append(identity_node)
logging.info(f"子图 {spec.label}: 为输出 {out_name} 添加 Identity 节点")
else:
logging.error(f"子图 {spec.label}: 输出 {out_name} 无法产生(不在输入/initializer/节点输出中)")
# 10. 创建模型
sub_model = helper.make_model(sub_graph)
# 11. 复制元数据
sub_model.ir_version = source_model.ir_version
sub_model.producer_name = source_model.producer_name
sub_model.producer_version = source_model.producer_version
sub_model.domain = source_model.domain
sub_model.model_version = source_model.model_version
sub_model.doc_string = source_model.doc_string
# 12. 复制 opset imports
while len(sub_model.opset_import) > 0:
sub_model.opset_import.pop()
for opset in source_model.opset_import:
opset_import = sub_model.opset_import.add()
opset_import.CopyFrom(opset)
# 13. Shape inference 和 checker
if run_shape_infer:
try:
sub_model = shape_inference.infer_shapes(sub_model)
except Exception as e:
logging.warning(f"子图 {spec.label} shape inference 失败: {e}")
if run_checker:
try:
checker.check_model(sub_model)
except Exception as e:
logging.warning(f"子图 {spec.label} checker 验证失败: {e}")
# 14. 保存
onnx.save(sub_model, destination.as_posix())
logging.info(
"保存子图 %s (start=%s, end=%s, 节点数=%d, checker=%s, infer_shape=%s)",
destination.name,
spec.start,
spec.end,
len(sub_graph.node),
bool(run_checker),
bool(run_shape_infer),
)
return destination
def load_numpy_inputs(path: Path, expected_inputs: Iterable[str]) -> Dict[str, np.ndarray]:
suffix = path.suffix.lower()
expected = list(expected_inputs)
if suffix == ".npz":
data = np.load(path, allow_pickle=False)
return {key: data[key] for key in data.files}
if suffix == ".npy":
arr = np.load(path, allow_pickle=True)
if isinstance(arr, np.ndarray) and arr.shape == () and isinstance(arr.item(), dict):
return {str(k): np.asarray(v) for k, v in arr.item().items()}
if isinstance(arr, np.ndarray) and arr.dtype.names:
return {name: arr[name] for name in arr.dtype.names}
if len(expected) == 1:
return {expected[0]: np.asarray(arr)}
raise ValueError("多输入模型需要字典格式的 .npy/.npz 数据。")
raise ValueError("仅支持 .npz 或 .npy 输入数据。")
def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]):
if ort is None:
raise RuntimeError("需要 onnxruntime 才能执行验证。")
session = ort.InferenceSession(model_path.as_posix(), providers=providers)
outputs = session.run(None, feed_dict)
names = [meta.name for meta in session.get_outputs()]
return dict(zip(names, outputs))
def run_split_pipeline(
ordered_subgraphs: Sequence[SubGraphSpec],
feed_dict: Dict[str, np.ndarray],
providers: List[str],
) -> Dict[str, np.ndarray]:
if ort is None:
raise RuntimeError("需要 onnxruntime 才能执行验证。")
tensor_store = dict(feed_dict)
for spec in ordered_subgraphs:
if spec.output_path is None:
raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。")
session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers)
fetch_inputs = {}
for name in spec.start:
if name not in tensor_store:
raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。")
fetch_inputs[name] = tensor_store[name]
results = session.run(None, fetch_inputs)
for meta, value in zip(session.get_outputs(), results):
tensor_store[meta.name] = value
return tensor_store
def verify(
model_path: Path,
ordered_subgraphs: Sequence[SubGraphSpec],
feed_dict: Dict[str, np.ndarray],
providers: List[str],
rtol: float,
atol: float,
) -> None:
full_outputs = run_full_model(model_path, feed_dict, providers)
split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers)
for name, ref in full_outputs.items():
cand = split_store.get(name)
if cand is None:
raise AssertionError(f"切分流水线未产生模型输出 {name}")
if not np.allclose(ref, cand, rtol=rtol, atol=atol):
diff = float(np.max(np.abs(ref - cand)))
raise AssertionError(f"输出 {name} 不匹配,最大偏差 {diff:.3e}")
logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol)
def save_outputs(outputs: Dict[str, np.ndarray], destination: Optional[Path]) -> None:
if destination is None:
return
destination.parent.mkdir(parents=True, exist_ok=True)
np.savez(destination, **outputs)
logging.info("流水线输出已保存至 %s", destination)
def load_sub_configs(config_path: Path) -> List[dict]:
with config_path.open("r", encoding="utf-8") as f:
config = json.load(f)
sub_configs = config.get("compiler", {}).get("sub_configs")
if not sub_configs:
sub_configs = config.get("sub_configs")
if not sub_configs:
raise ValueError("配置文件中未找到 sub_configs。")
return sub_configs
def main() -> None:
parser = argparse.ArgumentParser(description="根据 tensor 名切分 ONNX 子图")
parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 文件")
parser.add_argument("--config", required=True, type=Path, help="包含 sub_configs 的 JSON")
parser.add_argument("--output-dir", default="./split-onnx", type=Path, help="子模型输出目录")
parser.add_argument("--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime providers 顺序")
parser.add_argument("--verify", action="store_true", help="比较原始模型与流水线输出")
parser.add_argument("--run-pipeline", action="store_true", help="只运行切分流水线并输出结果")
parser.add_argument("--input-data", type=Path, help=".npz/.npy 格式的输入数据")
parser.add_argument("--pipeline-output", type=Path, help="保存流水线输出为 npz")
parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol")
parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol")
parser.add_argument("--skip-checker", action="store_true", help="跳过 onnx checker")
parser.add_argument("--skip-shape-infer", action="store_true", help="跳过 shape inference")
parser.add_argument("--log", default="INFO", help="日志等级")
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO))
# 必须加载完整数据(包括外部数据文件)
# 虽然模型很大,但这是必须的,因为我们需要将权重完全内嵌到子模型中
# 否则子模型会引用原始的外部数据路径,导致在新目录下无法找到数据文件
logging.info("加载完整模型(包括外部数据文件)...这可能需要一些时间和内存")
model = onnx.load(args.model.as_posix(), load_external_data=True)
# 跳过对原始巨大模型的 checker,只对切分后的小模型进行 checker
# if not args.skip_checker:
# checker.check_model(args.model.as_posix())
logging.info("跳过对原始模型的 checker(模型过大),将只对切分后的子模型进行验证")
graph_index = build_graph_index(model)
sub_configs = load_sub_configs(args.config)
specs: List[SubGraphSpec] = []
covered_nodes: Set[str] = set()
for idx, entry in enumerate(sub_configs):
start = [name for name in entry.get("start_tensor_names", []) if name]
end = [name for name in entry.get("end_tensor_names", []) if name]
if not start or not end:
raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。")
spec = SubGraphSpec(
label=f"cfg_{idx:02d}",
start=start,
end=end,
node_names=set(),
source="config",
)
nodes = trace_nodes_between(spec, graph_index)
spec.node_names = nodes
covered_nodes.update(nodes)
specs.append(spec)
leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index)
for idx, component in enumerate(leftovers):
start, end = derive_interface(component, graph_index)
if not end:
continue
spec = SubGraphSpec(
label=f"auto_{idx:02d}",
start=start,
end=end,
node_names=component,
source="auto",
)
specs.append(spec)
logging.info(
"自动补充子图 %s: start=%s end=%s (节点数=%d)",
spec.label,
spec.start,
spec.end,
len(component),
)
ordered = ordered_specs(specs, graph_index)
required_tensors: Set[str] = set()
for spec in ordered:
required_tensors.update(spec.start)
required_tensors.update(spec.end)
ensure_value_infos(model, required_tensors)
args.output_dir.mkdir(parents=True, exist_ok=True)
for spec in ordered:
spec.output_path = extract_model_file(
model,
spec,
args.output_dir,
spec.source,
run_checker=not args.skip_checker,
run_shape_infer=not args.skip_shape_infer,
)
need_inputs = args.verify or args.run_pipeline
if need_inputs:
if args.input_data is None:
raise ValueError("verify/run-pipeline 模式需要 --input-data 提供输入。")
feed = load_numpy_inputs(args.input_data, graph_index.graph_inputs)
missing_inputs = graph_index.graph_inputs - feed.keys()
if missing_inputs:
raise ValueError(f"输入数据缺少以下张量: {sorted(missing_inputs)}")
else:
feed = {}
if args.verify:
verify(args.model, ordered, feed, args.providers, args.rtol, args.atol)
elif args.run_pipeline:
outputs = run_split_pipeline(ordered, feed, args.providers)
save_outputs(outputs, args.pipeline_output)
else:
logging.info("子模型已生成,如需验证请添加 --verify 或 --run-pipeline。")
if __name__ == "__main__":
"""
python ./scripts/split_onnx_by_subconfig.py \
--model ./onnx-models/z_image_transformer_body_only_simp_slim.onnx \
--config ./pulsar2_configs/transformers_subgraph.json \
--output-dir ./transformers_body_only_split_onnx \
--verify \
--input-data ./onnx-calibration-no-controlnet/transformer_inputs_prompt000_step00.npy \
--providers CPUExecutionProvider
"""
main()