|
|
|
|
|
"""基于 tensor 名称的通用 ONNX 子图切分工具。 |
|
|
|
|
|
相较于 split_quant_onnx_by_subconfigs.py,本脚本额外提供: |
|
|
1. 为每个子模型执行 onnx checker 与 shape inference(可关闭)。 |
|
|
2. 支持从 .npz/.npy(包含 dict 或单数组)加载验证数据。 |
|
|
3. 可用 onnxruntime 串联执行全部子模型,既可校验精度,也可单独输出流水线结果。 |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Set |
|
|
|
|
|
import numpy as np |
|
|
import onnx |
|
|
from onnx import TensorProto, checker, helper, shape_inference, utils as onnx_utils |
|
|
|
|
|
try: |
|
|
import onnxruntime as ort |
|
|
except ImportError: |
|
|
ort = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SubGraphSpec: |
|
|
label: str |
|
|
start: List[str] |
|
|
end: List[str] |
|
|
node_names: Set[str] |
|
|
source: str |
|
|
output_path: Optional[Path] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GraphIndex: |
|
|
tensor_to_producer: Dict[str, str] |
|
|
tensor_to_consumers: Dict[str, List[str]] |
|
|
node_inputs: Dict[str, List[str]] |
|
|
node_outputs: Dict[str, List[str]] |
|
|
graph_inputs: Set[str] |
|
|
graph_outputs: Set[str] |
|
|
initializer_names: Set[str] |
|
|
node_order: List[str] |
|
|
|
|
|
|
|
|
def sanitize(name: str) -> str: |
|
|
keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] |
|
|
sanitized = "".join(keep).strip("_") |
|
|
return sanitized or "tensor" |
|
|
|
|
|
|
|
|
def build_graph_index(model: onnx.ModelProto) -> GraphIndex: |
|
|
tensor_to_producer: Dict[str, str] = {} |
|
|
tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) |
|
|
node_inputs: Dict[str, List[str]] = {} |
|
|
node_outputs: Dict[str, List[str]] = {} |
|
|
node_order: List[str] = [] |
|
|
|
|
|
used_names: Set[str] = set() |
|
|
for idx, node in enumerate(model.graph.node): |
|
|
base = node.name.strip() if node.name else "" |
|
|
candidate = base or f"node_{idx}" |
|
|
while candidate in used_names: |
|
|
candidate = f"{candidate}_{idx}" |
|
|
used_names.add(candidate) |
|
|
node_order.append(candidate) |
|
|
node_inputs[candidate] = [x for x in node.input if x] |
|
|
node_outputs[candidate] = [y for y in node.output if y] |
|
|
for out_name in node_outputs[candidate]: |
|
|
tensor_to_producer[out_name] = candidate |
|
|
for inp_name in node_inputs[candidate]: |
|
|
tensor_to_consumers[inp_name].append(candidate) |
|
|
|
|
|
graph_inputs = {vi.name for vi in model.graph.input} |
|
|
graph_outputs = {vi.name for vi in model.graph.output} |
|
|
initializer_names = {init.name for init in model.graph.initializer} |
|
|
|
|
|
return GraphIndex( |
|
|
tensor_to_producer=tensor_to_producer, |
|
|
tensor_to_consumers=tensor_to_consumers, |
|
|
node_inputs=node_inputs, |
|
|
node_outputs=node_outputs, |
|
|
graph_inputs=graph_inputs, |
|
|
graph_outputs=graph_outputs, |
|
|
initializer_names=initializer_names, |
|
|
node_order=node_order, |
|
|
) |
|
|
|
|
|
|
|
|
def ensure_value_infos(model: onnx.ModelProto, tensor_names: Iterable[str]) -> None: |
|
|
existing = {vi.name for vi in model.graph.value_info} |
|
|
source_map = {} |
|
|
for vi in list(model.graph.input) + list(model.graph.value_info) + list(model.graph.output): |
|
|
source_map[vi.name] = vi |
|
|
|
|
|
added: List[str] = [] |
|
|
for name in tensor_names: |
|
|
if name in existing: |
|
|
continue |
|
|
src = source_map.get(name) |
|
|
if src is not None: |
|
|
vi = onnx.ValueInfoProto() |
|
|
vi.CopyFrom(src) |
|
|
else: |
|
|
vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) |
|
|
model.graph.value_info.append(vi) |
|
|
existing.add(name) |
|
|
added.append(name) |
|
|
if added: |
|
|
logging.debug("已为以下 tensor 补充 ValueInfo: %s", added) |
|
|
|
|
|
|
|
|
def ensure_extractor_value_infos( |
|
|
extractor: onnx_utils.Extractor, |
|
|
tensor_names: Iterable[str], |
|
|
source_model: onnx.ModelProto, |
|
|
) -> None: |
|
|
existing_inputs = {vi.name for vi in extractor.graph.input} |
|
|
existing_outputs = {vi.name for vi in extractor.graph.output} |
|
|
existing_vi = {vi.name for vi in extractor.graph.value_info} |
|
|
source_map = {} |
|
|
for vi in ( |
|
|
list(source_model.graph.input) |
|
|
+ list(source_model.graph.value_info) |
|
|
+ list(source_model.graph.output) |
|
|
): |
|
|
source_map[vi.name] = vi |
|
|
|
|
|
added: List[str] = [] |
|
|
for name in tensor_names: |
|
|
if name in existing_inputs or name in existing_outputs or name in existing_vi: |
|
|
continue |
|
|
src = source_map.get(name) |
|
|
if src is not None: |
|
|
vi = onnx.ValueInfoProto() |
|
|
vi.CopyFrom(src) |
|
|
else: |
|
|
vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None) |
|
|
extractor.graph.value_info.append(vi) |
|
|
extractor.vimap[name] = vi |
|
|
existing_vi.add(name) |
|
|
added.append(name) |
|
|
if added: |
|
|
logging.debug("Extractor 侧补充 ValueInfo: %s", added) |
|
|
|
|
|
|
|
|
def trace_nodes_between(spec: SubGraphSpec, index: GraphIndex) -> Set[str]: |
|
|
boundary = set(spec.start) | index.graph_inputs | index.initializer_names |
|
|
visited_tensors: Set[str] = set() |
|
|
stack = list(spec.end) |
|
|
discovered_nodes: Set[str] = set() |
|
|
|
|
|
while stack: |
|
|
tensor = stack.pop() |
|
|
if tensor in visited_tensors: |
|
|
continue |
|
|
visited_tensors.add(tensor) |
|
|
if tensor in boundary: |
|
|
continue |
|
|
producer = index.tensor_to_producer.get(tensor) |
|
|
if not producer or producer in discovered_nodes: |
|
|
continue |
|
|
discovered_nodes.add(producer) |
|
|
for upstream in index.node_inputs.get(producer, []): |
|
|
if upstream and upstream not in boundary: |
|
|
stack.append(upstream) |
|
|
return discovered_nodes |
|
|
|
|
|
|
|
|
def untouched_components( |
|
|
all_nodes: Sequence[str], |
|
|
covered_nodes: Set[str], |
|
|
index: GraphIndex, |
|
|
) -> List[Set[str]]: |
|
|
remaining = [n for n in all_nodes if n not in covered_nodes] |
|
|
if not remaining: |
|
|
return [] |
|
|
adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} |
|
|
rem_set = set(remaining) |
|
|
|
|
|
for node in remaining: |
|
|
for out_name in index.node_outputs.get(node, []): |
|
|
for consumer in index.tensor_to_consumers.get(out_name, []): |
|
|
if consumer in rem_set: |
|
|
adjacency[node].add(consumer) |
|
|
adjacency[consumer].add(node) |
|
|
for inp_name in index.node_inputs.get(node, []): |
|
|
producer = index.tensor_to_producer.get(inp_name) |
|
|
if producer in rem_set: |
|
|
adjacency[node].add(producer) |
|
|
adjacency[producer].add(node) |
|
|
|
|
|
components: List[Set[str]] = [] |
|
|
visited: Set[str] = set() |
|
|
for node in remaining: |
|
|
if node in visited: |
|
|
continue |
|
|
stack = [node] |
|
|
comp: Set[str] = set() |
|
|
while stack: |
|
|
cur = stack.pop() |
|
|
if cur in visited: |
|
|
continue |
|
|
visited.add(cur) |
|
|
comp.add(cur) |
|
|
stack.extend(adjacency[cur] - visited) |
|
|
components.append(comp) |
|
|
return components |
|
|
|
|
|
|
|
|
def derive_interface(nodes: Set[str], index: GraphIndex) -> (List[str], List[str]): |
|
|
produced = set() |
|
|
for node in nodes: |
|
|
produced.update(index.node_outputs.get(node, [])) |
|
|
|
|
|
start: Set[str] = set() |
|
|
for node in nodes: |
|
|
for inp in index.node_inputs.get(node, []): |
|
|
producer = index.tensor_to_producer.get(inp) |
|
|
if producer is None and inp not in index.initializer_names: |
|
|
start.add(inp) |
|
|
elif producer not in nodes and inp not in index.initializer_names: |
|
|
start.add(inp) |
|
|
|
|
|
end: Set[str] = set() |
|
|
for node in nodes: |
|
|
for out in index.node_outputs.get(node, []): |
|
|
consumers = index.tensor_to_consumers.get(out, []) |
|
|
if not consumers: |
|
|
if out in index.graph_outputs: |
|
|
end.add(out) |
|
|
continue |
|
|
if any(consumer not in nodes for consumer in consumers): |
|
|
end.add(out) |
|
|
end.update(index.graph_outputs & produced) |
|
|
|
|
|
if not end and produced: |
|
|
end = produced.copy() |
|
|
|
|
|
return sorted(start), sorted(end) |
|
|
|
|
|
|
|
|
def ordered_specs(specs: Sequence[SubGraphSpec], index: GraphIndex) -> List[SubGraphSpec]: |
|
|
available = set(index.graph_inputs) | index.initializer_names |
|
|
pending = list(specs) |
|
|
ordered: List[SubGraphSpec] = [] |
|
|
while pending: |
|
|
progressed = False |
|
|
for spec in list(pending): |
|
|
if set(spec.start).issubset(available): |
|
|
ordered.append(spec) |
|
|
available.update(spec.end) |
|
|
pending.remove(spec) |
|
|
progressed = True |
|
|
if not progressed: |
|
|
missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} |
|
|
raise RuntimeError(f"无法解析子图拓扑,缺少以下张量: {missing}") |
|
|
return ordered |
|
|
|
|
|
|
|
|
def is_valid_onnx_model(model_path: Path) -> bool: |
|
|
"""检查 ONNX 模型文件是否有效(包含必需的 opset_import)""" |
|
|
try: |
|
|
|
|
|
if model_path.stat().st_size == 0: |
|
|
logging.warning(f"模型 {model_path.name} 是空文件") |
|
|
return False |
|
|
|
|
|
model = onnx.load(model_path.as_posix(), load_external_data=False) |
|
|
|
|
|
|
|
|
if model is None: |
|
|
logging.warning(f"模型 {model_path.name} 加载后为 None") |
|
|
return False |
|
|
|
|
|
|
|
|
if not hasattr(model, 'graph') or model.graph is None: |
|
|
logging.warning(f"模型 {model_path.name} 缺少 graph") |
|
|
return False |
|
|
|
|
|
|
|
|
if len(model.opset_import) == 0: |
|
|
logging.warning(f"模型 {model_path.name} 缺少 opset_import 信息") |
|
|
return False |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
logging.warning(f"无法加载模型 {model_path.name}: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def extract_model_file( |
|
|
source_model: onnx.ModelProto, |
|
|
spec: SubGraphSpec, |
|
|
output_dir: Path, |
|
|
suffix: str, |
|
|
run_checker: bool, |
|
|
run_shape_infer: bool, |
|
|
skip_existing: bool = True, |
|
|
) -> Path: |
|
|
head = sanitize(spec.start[0]) if spec.start else "const" |
|
|
tail = sanitize(spec.end[0]) if spec.end else "out" |
|
|
filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" |
|
|
destination = output_dir / filename |
|
|
|
|
|
|
|
|
if skip_existing and destination.exists(): |
|
|
if is_valid_onnx_model(destination): |
|
|
logging.info("跳过已存在的子图 %s (文件: %s)", spec.label, destination.name) |
|
|
return destination |
|
|
else: |
|
|
logging.warning("子图 %s 的文件无效,将重新生成", spec.label) |
|
|
|
|
|
|
|
|
|
|
|
sub_graph = helper.make_graph( |
|
|
nodes=[], |
|
|
name=f"{spec.label}_subgraph", |
|
|
inputs=[], |
|
|
outputs=[], |
|
|
initializer=[] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
node_map = {} |
|
|
used_names: Set[str] = set() |
|
|
for idx, node in enumerate(source_model.graph.node): |
|
|
base = node.name.strip() if node.name else "" |
|
|
candidate = base or f"node_{idx}" |
|
|
while candidate in used_names: |
|
|
candidate = f"{candidate}_{idx}" |
|
|
used_names.add(candidate) |
|
|
node_map[candidate] = node |
|
|
|
|
|
missing_nodes: List[str] = [] |
|
|
for node_name in spec.node_names: |
|
|
target = node_map.get(node_name) |
|
|
if target is None: |
|
|
missing_nodes.append(node_name) |
|
|
continue |
|
|
new_node = onnx.NodeProto() |
|
|
new_node.CopyFrom(target) |
|
|
sub_graph.node.append(new_node) |
|
|
|
|
|
if missing_nodes: |
|
|
logging.warning("子图 %s: 有 %d 个节点未匹配到源模型,将被跳过: %s", spec.label, len(missing_nodes), missing_nodes[:5]) |
|
|
|
|
|
|
|
|
node_inputs = set() |
|
|
node_outputs = set() |
|
|
for node in sub_graph.node: |
|
|
for inp in node.input: |
|
|
if inp: |
|
|
node_inputs.add(inp) |
|
|
for out in node.output: |
|
|
if out: |
|
|
node_outputs.add(out) |
|
|
|
|
|
|
|
|
source_value_info_map = {} |
|
|
for vi in list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output): |
|
|
source_value_info_map[vi.name] = vi |
|
|
|
|
|
|
|
|
source_init_map = {init.name: init for init in source_model.graph.initializer} |
|
|
|
|
|
|
|
|
input_tensor_names = set(spec.start) |
|
|
for tensor_name in node_inputs: |
|
|
if tensor_name not in node_outputs and tensor_name not in source_init_map: |
|
|
input_tensor_names.add(tensor_name) |
|
|
|
|
|
for tensor_name in sorted(input_tensor_names): |
|
|
if tensor_name in source_value_info_map: |
|
|
vi = onnx.ValueInfoProto() |
|
|
vi.CopyFrom(source_value_info_map[tensor_name]) |
|
|
sub_graph.input.append(vi) |
|
|
else: |
|
|
vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) |
|
|
sub_graph.input.append(vi) |
|
|
|
|
|
|
|
|
for tensor_name in node_inputs: |
|
|
if tensor_name in source_init_map: |
|
|
init = onnx.TensorProto() |
|
|
init.CopyFrom(source_init_map[tensor_name]) |
|
|
sub_graph.initializer.append(init) |
|
|
|
|
|
|
|
|
for tensor_name in spec.end: |
|
|
if tensor_name in source_value_info_map: |
|
|
vi = onnx.ValueInfoProto() |
|
|
vi.CopyFrom(source_value_info_map[tensor_name]) |
|
|
sub_graph.output.append(vi) |
|
|
else: |
|
|
vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None) |
|
|
sub_graph.output.append(vi) |
|
|
|
|
|
|
|
|
for out_name in spec.end: |
|
|
if out_name not in node_outputs: |
|
|
|
|
|
if out_name in input_tensor_names or out_name in source_init_map: |
|
|
|
|
|
identity_node = helper.make_node( |
|
|
'Identity', |
|
|
inputs=[out_name], |
|
|
outputs=[out_name], |
|
|
name=f'passthrough_{sanitize(out_name)}' |
|
|
) |
|
|
sub_graph.node.append(identity_node) |
|
|
logging.info(f"子图 {spec.label}: 为输出 {out_name} 添加 Identity 节点") |
|
|
else: |
|
|
logging.error(f"子图 {spec.label}: 输出 {out_name} 无法产生(不在输入/initializer/节点输出中)") |
|
|
|
|
|
|
|
|
sub_model = helper.make_model(sub_graph) |
|
|
|
|
|
|
|
|
sub_model.ir_version = source_model.ir_version |
|
|
sub_model.producer_name = source_model.producer_name |
|
|
sub_model.producer_version = source_model.producer_version |
|
|
sub_model.domain = source_model.domain |
|
|
sub_model.model_version = source_model.model_version |
|
|
sub_model.doc_string = source_model.doc_string |
|
|
|
|
|
|
|
|
while len(sub_model.opset_import) > 0: |
|
|
sub_model.opset_import.pop() |
|
|
|
|
|
if len(source_model.opset_import) > 0: |
|
|
for opset in source_model.opset_import: |
|
|
opset_import = sub_model.opset_import.add() |
|
|
opset_import.CopyFrom(opset) |
|
|
else: |
|
|
|
|
|
logging.warning(f"源模型缺少 opset_import,为子图 {spec.label} 添加默认 opset 17") |
|
|
opset_import = sub_model.opset_import.add() |
|
|
opset_import.domain = "" |
|
|
opset_import.version = 17 |
|
|
|
|
|
|
|
|
if len(sub_model.opset_import) == 0: |
|
|
raise RuntimeError(f"子图 {spec.label} 缺少 opset_import 信息") |
|
|
|
|
|
|
|
|
if run_shape_infer: |
|
|
try: |
|
|
sub_model = shape_inference.infer_shapes(sub_model) |
|
|
except Exception as e: |
|
|
logging.warning(f"子图 {spec.label} shape inference 失败: {e}") |
|
|
|
|
|
if run_checker: |
|
|
try: |
|
|
checker.check_model(sub_model) |
|
|
except Exception as e: |
|
|
logging.warning(f"子图 {spec.label} checker 验证失败: {e}") |
|
|
|
|
|
|
|
|
onnx.save(sub_model, destination.as_posix()) |
|
|
logging.info( |
|
|
"保存子图 %s (start=%s, end=%s, 节点数=%d, checker=%s, infer_shape=%s)", |
|
|
destination.name, |
|
|
spec.start, |
|
|
spec.end, |
|
|
len(sub_graph.node), |
|
|
bool(run_checker), |
|
|
bool(run_shape_infer), |
|
|
) |
|
|
return destination |
|
|
|
|
|
|
|
|
|
|
|
def load_numpy_inputs(path: Path, expected_inputs: Iterable[str]) -> Dict[str, np.ndarray]: |
|
|
suffix = path.suffix.lower() |
|
|
expected = list(expected_inputs) |
|
|
if suffix == ".npz": |
|
|
data = np.load(path, allow_pickle=False) |
|
|
return {key: data[key] for key in data.files} |
|
|
if suffix == ".npy": |
|
|
arr = np.load(path, allow_pickle=True) |
|
|
if isinstance(arr, np.ndarray) and arr.shape == () and isinstance(arr.item(), dict): |
|
|
return {str(k): np.asarray(v) for k, v in arr.item().items()} |
|
|
if isinstance(arr, np.ndarray) and arr.dtype.names: |
|
|
return {name: arr[name] for name in arr.dtype.names} |
|
|
if len(expected) == 1: |
|
|
return {expected[0]: np.asarray(arr)} |
|
|
raise ValueError("多输入模型需要字典格式的 .npy/.npz 数据。") |
|
|
raise ValueError("仅支持 .npz 或 .npy 输入数据。") |
|
|
|
|
|
|
|
|
def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): |
|
|
if ort is None: |
|
|
raise RuntimeError("需要 onnxruntime 才能执行验证。") |
|
|
session = ort.InferenceSession(model_path.as_posix(), providers=providers) |
|
|
outputs = session.run(None, feed_dict) |
|
|
names = [meta.name for meta in session.get_outputs()] |
|
|
return dict(zip(names, outputs)) |
|
|
|
|
|
|
|
|
def run_split_pipeline( |
|
|
ordered_subgraphs: Sequence[SubGraphSpec], |
|
|
feed_dict: Dict[str, np.ndarray], |
|
|
providers: List[str], |
|
|
) -> Dict[str, np.ndarray]: |
|
|
if ort is None: |
|
|
raise RuntimeError("需要 onnxruntime 才能执行验证。") |
|
|
tensor_store = dict(feed_dict) |
|
|
for spec in ordered_subgraphs: |
|
|
if spec.output_path is None: |
|
|
raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") |
|
|
if not spec.output_path.exists(): |
|
|
raise RuntimeError(f"子图 {spec.label} 的输出文件不存在: {spec.output_path}") |
|
|
|
|
|
logging.info("运行子图 %s (输入: %s, 输出: %s)", spec.label, spec.start, spec.end) |
|
|
session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) |
|
|
|
|
|
|
|
|
actual_inputs = [inp.name for inp in session.get_inputs()] |
|
|
logging.debug("子图 %s 实际输入: %s", spec.label, actual_inputs) |
|
|
|
|
|
fetch_inputs = {} |
|
|
for name in actual_inputs: |
|
|
if name not in tensor_store: |
|
|
|
|
|
if name in spec.start: |
|
|
logging.warning(f"子图 {spec.label} 缺少输入张量 {name},尝试从 feed_dict 查找") |
|
|
if name in feed_dict: |
|
|
tensor_store[name] = feed_dict[name] |
|
|
else: |
|
|
available = list(tensor_store.keys()) |
|
|
raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}") |
|
|
else: |
|
|
available = list(tensor_store.keys()) |
|
|
raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}") |
|
|
fetch_inputs[name] = tensor_store[name] |
|
|
|
|
|
results = session.run(None, fetch_inputs) |
|
|
for meta, value in zip(session.get_outputs(), results): |
|
|
tensor_store[meta.name] = value |
|
|
logging.debug("子图 %s 产生输出: %s (shape=%s)", spec.label, meta.name, value.shape) |
|
|
return tensor_store |
|
|
|
|
|
|
|
|
def verify( |
|
|
model_path: Path, |
|
|
ordered_subgraphs: Sequence[SubGraphSpec], |
|
|
feed_dict: Dict[str, np.ndarray], |
|
|
providers: List[str], |
|
|
rtol: float, |
|
|
atol: float, |
|
|
) -> None: |
|
|
"""验证切分后的子图流水线与原始模型输出是否一致。 |
|
|
|
|
|
验证流程: |
|
|
1. 运行完整的原始模型,获得所有输出 |
|
|
2. 按拓扑顺序依次运行所有子图(前一个子图的输出作为后一个子图的输入) |
|
|
3. 比较原始模型的最终输出与子图流水线的最终输出 |
|
|
4. 如果所有输出在指定的误差范围内一致,则验证通过 |
|
|
""" |
|
|
logging.info("开始验证:运行原始模型...") |
|
|
full_outputs = run_full_model(model_path, feed_dict, providers) |
|
|
logging.info("原始模型运行完成,产生 %d 个输出", len(full_outputs)) |
|
|
|
|
|
logging.info("开始验证:运行子图流水线...") |
|
|
split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) |
|
|
logging.info("子图流水线运行完成") |
|
|
|
|
|
logging.info("比较输出...") |
|
|
for name, ref in full_outputs.items(): |
|
|
cand = split_store.get(name) |
|
|
if cand is None: |
|
|
raise AssertionError(f"切分流水线未产生模型输出 {name}") |
|
|
if not np.allclose(ref, cand, rtol=rtol, atol=atol): |
|
|
diff = float(np.max(np.abs(ref - cand))) |
|
|
raise AssertionError(f"输出 {name} 不匹配,最大偏差 {diff:.3e}") |
|
|
logging.info("✓ 输出 %s 验证通过 (shape=%s)", name, ref.shape) |
|
|
logging.info("✓ 切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) |
|
|
|
|
|
|
|
|
def save_outputs(outputs: Dict[str, np.ndarray], destination: Optional[Path]) -> None: |
|
|
if destination is None: |
|
|
return |
|
|
destination.parent.mkdir(parents=True, exist_ok=True) |
|
|
np.savez(destination, **outputs) |
|
|
logging.info("流水线输出已保存至 %s", destination) |
|
|
|
|
|
|
|
|
def load_sub_configs(config_path: Path) -> List[dict]: |
|
|
with config_path.open("r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
sub_configs = config.get("compiler", {}).get("sub_configs") |
|
|
if not sub_configs: |
|
|
sub_configs = config.get("sub_configs") |
|
|
if not sub_configs: |
|
|
raise ValueError("配置文件中未找到 sub_configs。") |
|
|
return sub_configs |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser(description="根据 tensor 名切分 ONNX 子图") |
|
|
parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 文件") |
|
|
parser.add_argument("--config", required=True, type=Path, help="包含 sub_configs 的 JSON") |
|
|
parser.add_argument("--output-dir", default="./split-onnx", type=Path, help="子模型输出目录") |
|
|
parser.add_argument("--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime providers 顺序") |
|
|
parser.add_argument("--verify", action="store_true", help="比较原始模型与流水线输出") |
|
|
parser.add_argument("--run-pipeline", action="store_true", help="只运行切分流水线并输出结果") |
|
|
parser.add_argument("--input-data", type=Path, help=".npz/.npy 格式的输入数据") |
|
|
parser.add_argument("--pipeline-output", type=Path, help="保存流水线输出为 npz") |
|
|
parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") |
|
|
parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") |
|
|
parser.add_argument("--skip-checker", action="store_true", help="跳过 onnx checker") |
|
|
parser.add_argument("--skip-shape-infer", action="store_true", help="跳过 shape inference") |
|
|
parser.add_argument("--skip-existing", action="store_true", help="跳过已存在的子图文件") |
|
|
parser.add_argument("--log", default="INFO", help="日志等级") |
|
|
args = parser.parse_args() |
|
|
|
|
|
logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("加载完整模型(包括外部数据文件)...这可能需要一些时间和内存") |
|
|
model = onnx.load(args.model.as_posix(), load_external_data=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("跳过对原始模型的 checker(模型过大),将只对切分后的子模型进行验证") |
|
|
|
|
|
graph_index = build_graph_index(model) |
|
|
|
|
|
sub_configs = load_sub_configs(args.config) |
|
|
specs: List[SubGraphSpec] = [] |
|
|
covered_nodes: Set[str] = set() |
|
|
|
|
|
for idx, entry in enumerate(sub_configs): |
|
|
start = [name for name in entry.get("start_tensor_names", []) if name] |
|
|
end = [name for name in entry.get("end_tensor_names", []) if name] |
|
|
if not start or not end: |
|
|
raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") |
|
|
spec = SubGraphSpec( |
|
|
label=f"cfg_{idx:02d}", |
|
|
start=start, |
|
|
end=end, |
|
|
node_names=set(), |
|
|
source="config", |
|
|
) |
|
|
nodes = trace_nodes_between(spec, graph_index) |
|
|
spec.node_names = nodes |
|
|
covered_nodes.update(nodes) |
|
|
specs.append(spec) |
|
|
|
|
|
leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) |
|
|
for idx, component in enumerate(leftovers): |
|
|
start, end = derive_interface(component, graph_index) |
|
|
if not end: |
|
|
continue |
|
|
spec = SubGraphSpec( |
|
|
label=f"auto_{idx:02d}", |
|
|
start=start, |
|
|
end=end, |
|
|
node_names=component, |
|
|
source="auto", |
|
|
) |
|
|
specs.append(spec) |
|
|
logging.info( |
|
|
"自动补充子图 %s: start=%s end=%s (节点数=%d)", |
|
|
spec.label, |
|
|
spec.start, |
|
|
spec.end, |
|
|
len(component), |
|
|
) |
|
|
|
|
|
ordered = ordered_specs(specs, graph_index) |
|
|
|
|
|
required_tensors: Set[str] = set() |
|
|
for spec in ordered: |
|
|
required_tensors.update(spec.start) |
|
|
required_tensors.update(spec.end) |
|
|
ensure_value_infos(model, required_tensors) |
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
corrupted_specs = [] |
|
|
for spec in ordered: |
|
|
head = sanitize(spec.start[0]) if spec.start else "const" |
|
|
tail = sanitize(spec.end[0]) if spec.end else "out" |
|
|
filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.onnx" |
|
|
potential_path = args.output_dir / filename |
|
|
|
|
|
if potential_path.exists() and not is_valid_onnx_model(potential_path): |
|
|
corrupted_specs.append(spec.label) |
|
|
logging.warning(f"检测到损坏的子图: {spec.label}, 将重新生成") |
|
|
|
|
|
if corrupted_specs: |
|
|
logging.info(f"发现 {len(corrupted_specs)} 个损坏的子图,将重新生成: {corrupted_specs}") |
|
|
|
|
|
|
|
|
for spec in ordered: |
|
|
spec.output_path = extract_model_file( |
|
|
model, |
|
|
spec, |
|
|
args.output_dir, |
|
|
spec.source, |
|
|
run_checker=not args.skip_checker, |
|
|
run_shape_infer=not args.skip_shape_infer, |
|
|
skip_existing=args.skip_existing, |
|
|
) |
|
|
|
|
|
need_inputs = args.verify or args.run_pipeline |
|
|
if need_inputs: |
|
|
if args.input_data is None: |
|
|
raise ValueError("verify/run-pipeline 模式需要 --input-data 提供输入。") |
|
|
feed = load_numpy_inputs(args.input_data, graph_index.graph_inputs) |
|
|
missing_inputs = graph_index.graph_inputs - feed.keys() |
|
|
if missing_inputs: |
|
|
raise ValueError(f"输入数据缺少以下张量: {sorted(missing_inputs)}") |
|
|
else: |
|
|
feed = {} |
|
|
|
|
|
if args.verify: |
|
|
verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) |
|
|
elif args.run_pipeline: |
|
|
outputs = run_split_pipeline(ordered, feed, args.providers) |
|
|
save_outputs(outputs, args.pipeline_output) |
|
|
else: |
|
|
logging.info("子模型已生成,如需验证请添加 --verify 或 --run-pipeline。") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
python ./scripts/split_onnx_by_subconfig.py \ |
|
|
--model ./onnx-models/z_image_transformer_body_only_simp_slim.onnx \ |
|
|
--config ./pulsar2_configs/transformers_subgraph.json \ |
|
|
--output-dir ./transformers_body_only_split_onnx \ |
|
|
--verify \ |
|
|
--input-data ./onnx-calibration-no-controlnet/transformer_inputs_prompt000_step00.npy \ |
|
|
--providers CPUExecutionProvider |
|
|
""" |
|
|
main() |
|
|
|