|
|
|
|
|
"""Split an ONNX graph into smaller sub-models driven by sub_config rules. |
|
|
|
|
|
This script reads a JSON config file (matching the pulsar2 sub_config layout), |
|
|
extracts the requested subgraphs, and optionally emits any leftover parts of the |
|
|
model as independent ONNX graphs. A verification utility can run the original |
|
|
model and the stitched micro-model pipeline to make sure their outputs match. |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
from collections import defaultdict, deque |
|
|
from dataclasses import dataclass, field |
|
|
from pathlib import Path |
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Set |
|
|
|
|
|
import numpy as np |
|
|
import onnx |
|
|
from onnx import utils as onnx_utils |
|
|
|
|
|
try: |
|
|
import onnxruntime as ort |
|
|
except ImportError: |
|
|
ort = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SubGraphSpec: |
|
|
"""Describes a single subgraph to extract from the full model.""" |
|
|
|
|
|
label: str |
|
|
start: List[str] |
|
|
end: List[str] |
|
|
node_names: Set[str] |
|
|
source: str |
|
|
output_path: Optional[Path] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GraphIndex: |
|
|
"""Caches helpful lookups for traversing an ONNX graph.""" |
|
|
|
|
|
tensor_to_producer: Dict[str, str] |
|
|
tensor_to_consumers: Dict[str, List[str]] |
|
|
node_inputs: Dict[str, List[str]] |
|
|
node_outputs: Dict[str, List[str]] |
|
|
graph_inputs: Set[str] |
|
|
graph_outputs: Set[str] |
|
|
initializer_names: Set[str] |
|
|
node_order: List[str] |
|
|
|
|
|
|
|
|
def sanitize(name: str) -> str: |
|
|
keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] |
|
|
sanitized = "".join(keep).strip("_") |
|
|
return sanitized or "tensor" |
|
|
|
|
|
|
|
|
def build_graph_index(model: onnx.ModelProto) -> GraphIndex: |
|
|
tensor_to_producer: Dict[str, str] = {} |
|
|
tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) |
|
|
node_inputs: Dict[str, List[str]] = {} |
|
|
node_outputs: Dict[str, List[str]] = {} |
|
|
node_order: List[str] = [] |
|
|
|
|
|
used_names: Set[str] = set() |
|
|
for idx, node in enumerate(model.graph.node): |
|
|
base = node.name.strip() if node.name else "" |
|
|
candidate = base or f"node_{idx}" |
|
|
while candidate in used_names: |
|
|
candidate = f"{candidate}_{idx}" |
|
|
used_names.add(candidate) |
|
|
node_name = candidate |
|
|
node_order.append(node_name) |
|
|
node_inputs[node_name] = [x for x in node.input if x] |
|
|
node_outputs[node_name] = [y for y in node.output if y] |
|
|
for out_name in node_outputs[node_name]: |
|
|
tensor_to_producer[out_name] = node_name |
|
|
for inp_name in node_inputs[node_name]: |
|
|
tensor_to_consumers[inp_name].append(node_name) |
|
|
|
|
|
graph_inputs = {vi.name for vi in model.graph.input} |
|
|
graph_outputs = {vi.name for vi in model.graph.output} |
|
|
initializer_names = {init.name for init in model.graph.initializer} |
|
|
|
|
|
return GraphIndex( |
|
|
tensor_to_producer=tensor_to_producer, |
|
|
tensor_to_consumers=tensor_to_consumers, |
|
|
node_inputs=node_inputs, |
|
|
node_outputs=node_outputs, |
|
|
graph_inputs=graph_inputs, |
|
|
graph_outputs=graph_outputs, |
|
|
initializer_names=initializer_names, |
|
|
node_order=node_order, |
|
|
) |
|
|
|
|
|
|
|
|
def trace_nodes_between( |
|
|
spec: SubGraphSpec, |
|
|
index: GraphIndex, |
|
|
) -> Set[str]: |
|
|
boundary = set(spec.start) | index.graph_inputs | index.initializer_names |
|
|
visited_tensors: Set[str] = set() |
|
|
stack = list(spec.end) |
|
|
discovered_nodes: Set[str] = set() |
|
|
|
|
|
while stack: |
|
|
tensor = stack.pop() |
|
|
if tensor in visited_tensors: |
|
|
continue |
|
|
visited_tensors.add(tensor) |
|
|
if tensor in boundary: |
|
|
continue |
|
|
producer = index.tensor_to_producer.get(tensor) |
|
|
if not producer: |
|
|
continue |
|
|
if producer in discovered_nodes: |
|
|
continue |
|
|
discovered_nodes.add(producer) |
|
|
for upstream in index.node_inputs.get(producer, []): |
|
|
if upstream and upstream not in boundary: |
|
|
stack.append(upstream) |
|
|
return discovered_nodes |
|
|
|
|
|
|
|
|
def untouched_components( |
|
|
all_nodes: Sequence[str], |
|
|
covered_nodes: Set[str], |
|
|
index: GraphIndex, |
|
|
) -> List[Set[str]]: |
|
|
remaining = [n for n in all_nodes if n not in covered_nodes] |
|
|
if not remaining: |
|
|
return [] |
|
|
adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} |
|
|
rem_set = set(remaining) |
|
|
for node in remaining: |
|
|
for out_name in index.node_outputs.get(node, []): |
|
|
for consumer in index.tensor_to_consumers.get(out_name, []): |
|
|
if consumer in rem_set: |
|
|
adjacency[node].add(consumer) |
|
|
adjacency[consumer].add(node) |
|
|
for inp_name in index.node_inputs.get(node, []): |
|
|
producer = index.tensor_to_producer.get(inp_name) |
|
|
if producer in rem_set: |
|
|
adjacency[node].add(producer) |
|
|
adjacency[producer].add(node) |
|
|
|
|
|
components: List[Set[str]] = [] |
|
|
visited: Set[str] = set() |
|
|
for node in remaining: |
|
|
if node in visited: |
|
|
continue |
|
|
stack = [node] |
|
|
comp: Set[str] = set() |
|
|
while stack: |
|
|
cur = stack.pop() |
|
|
if cur in visited: |
|
|
continue |
|
|
visited.add(cur) |
|
|
comp.add(cur) |
|
|
stack.extend(adjacency[cur] - visited) |
|
|
components.append(comp) |
|
|
return components |
|
|
|
|
|
|
|
|
def derive_interface( |
|
|
nodes: Set[str], |
|
|
index: GraphIndex, |
|
|
) -> (List[str], List[str]): |
|
|
produced = set() |
|
|
for node in nodes: |
|
|
produced.update(index.node_outputs.get(node, [])) |
|
|
|
|
|
start: Set[str] = set() |
|
|
for node in nodes: |
|
|
for inp in index.node_inputs.get(node, []): |
|
|
producer = index.tensor_to_producer.get(inp) |
|
|
if producer is None and inp not in index.initializer_names: |
|
|
start.add(inp) |
|
|
elif producer not in nodes and inp not in index.initializer_names: |
|
|
start.add(inp) |
|
|
|
|
|
end: Set[str] = set() |
|
|
for node in nodes: |
|
|
for out in index.node_outputs.get(node, []): |
|
|
consumers = index.tensor_to_consumers.get(out, []) |
|
|
if not consumers: |
|
|
if out in index.graph_outputs: |
|
|
end.add(out) |
|
|
continue |
|
|
if any(consumer not in nodes for consumer in consumers): |
|
|
end.add(out) |
|
|
end.update(index.graph_outputs & produced) |
|
|
|
|
|
if not end and produced: |
|
|
end = produced.copy() |
|
|
|
|
|
return sorted(start), sorted(end) |
|
|
|
|
|
|
|
|
def extract_model_file( |
|
|
model_path: Path, |
|
|
spec: SubGraphSpec, |
|
|
output_dir: Path, |
|
|
suffix: str, |
|
|
) -> Path: |
|
|
head = sanitize(spec.start[0]) if spec.start else "const" |
|
|
tail = sanitize(spec.end[0]) if spec.end else "out" |
|
|
filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" |
|
|
destination = output_dir / filename |
|
|
onnx_utils.extract_model( |
|
|
model_path.as_posix(), |
|
|
destination.as_posix(), |
|
|
input_names=spec.start, |
|
|
output_names=spec.end, |
|
|
check_model=False, |
|
|
) |
|
|
logging.info("Saved %s (start=%s, end=%s)", destination.name, spec.start, spec.end) |
|
|
return destination |
|
|
|
|
|
|
|
|
def ordered_specs( |
|
|
specs: Sequence[SubGraphSpec], |
|
|
index: GraphIndex, |
|
|
) -> List[SubGraphSpec]: |
|
|
available = set(index.graph_inputs) | index.initializer_names |
|
|
pending = list(specs) |
|
|
ordered: List[SubGraphSpec] = [] |
|
|
while pending: |
|
|
progressed = False |
|
|
for spec in list(pending): |
|
|
if set(spec.start).issubset(available): |
|
|
ordered.append(spec) |
|
|
available.update(spec.end) |
|
|
pending.remove(spec) |
|
|
progressed = True |
|
|
if not progressed: |
|
|
missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} |
|
|
raise RuntimeError( |
|
|
"无法解析子图拓扑,缺少以下张量: %s" % missing |
|
|
) |
|
|
return ordered |
|
|
|
|
|
|
|
|
def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): |
|
|
if ort is None: |
|
|
raise RuntimeError("需要 onnxruntime 才能执行验证。") |
|
|
session = ort.InferenceSession(model_path.as_posix(), providers=providers) |
|
|
outputs = session.run(None, feed_dict) |
|
|
names = [meta.name for meta in session.get_outputs()] |
|
|
return dict(zip(names, outputs)) |
|
|
|
|
|
|
|
|
def run_split_pipeline( |
|
|
ordered_subgraphs: Sequence[SubGraphSpec], |
|
|
feed_dict: Dict[str, np.ndarray], |
|
|
providers: List[str], |
|
|
) -> Dict[str, np.ndarray]: |
|
|
if ort is None: |
|
|
raise RuntimeError("需要 onnxruntime 才能执行验证。") |
|
|
tensor_store = dict(feed_dict) |
|
|
for spec in ordered_subgraphs: |
|
|
if spec.output_path is None: |
|
|
raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") |
|
|
session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) |
|
|
fetch_inputs = {} |
|
|
for name in spec.start: |
|
|
if name not in tensor_store: |
|
|
raise KeyError( |
|
|
f"子图 {spec.label} 缺少输入张量 {name},请确认切分顺序。" |
|
|
) |
|
|
fetch_inputs[name] = tensor_store[name] |
|
|
results = session.run(None, fetch_inputs) |
|
|
for meta, value in zip(session.get_outputs(), results): |
|
|
tensor_store[meta.name] = value |
|
|
return tensor_store |
|
|
|
|
|
|
|
|
def verify( |
|
|
model_path: Path, |
|
|
ordered_subgraphs: Sequence[SubGraphSpec], |
|
|
feed_dict: Dict[str, np.ndarray], |
|
|
providers: List[str], |
|
|
rtol: float, |
|
|
atol: float, |
|
|
) -> None: |
|
|
full_outputs = run_full_model(model_path, feed_dict, providers) |
|
|
split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) |
|
|
for name, ref in full_outputs.items(): |
|
|
cand = split_store.get(name) |
|
|
if cand is None: |
|
|
raise AssertionError(f"切分流水线未产生模型输出 {name}") |
|
|
if not np.allclose(ref, cand, rtol=rtol, atol=atol): |
|
|
diff = np.max(np.abs(ref - cand)) |
|
|
raise AssertionError( |
|
|
f"输出 {name} 不匹配,最大偏差 {diff:.3e}" |
|
|
) |
|
|
logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) |
|
|
|
|
|
|
|
|
def load_npz_inputs(npz_path: Path) -> Dict[str, np.ndarray]: |
|
|
data = np.load(npz_path, allow_pickle=False) |
|
|
return {key: data[key] for key in data.files} |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser(description="根据 sub_config 切分 ONNX 模型。") |
|
|
parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 路径") |
|
|
parser.add_argument("--config", required=True, type=Path, help="pulsar2 配置 JSON") |
|
|
parser.add_argument("--output-dir", required=False, default="./split-onnx", type=Path, help="保存子模型的目录") |
|
|
parser.add_argument( |
|
|
"--verify", |
|
|
action="store_true", |
|
|
help="生成后立即用 onnxruntime 校验输出是否一致", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input-npz", |
|
|
type=Path, |
|
|
help="包含模型所有输入张量的 npz 文件 (verify 模式需要)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--providers", |
|
|
nargs="*", |
|
|
default=["CPUExecutionProvider"], |
|
|
help="onnxruntime 推理后端顺序", |
|
|
) |
|
|
parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") |
|
|
parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") |
|
|
parser.add_argument("--log", default="INFO", help="日志等级") |
|
|
args = parser.parse_args() |
|
|
|
|
|
logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) |
|
|
|
|
|
model = onnx.load(args.model.as_posix()) |
|
|
graph_index = build_graph_index(model) |
|
|
|
|
|
with args.config.open("r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
sub_configs = config.get("compiler", {}).get("sub_configs", []) |
|
|
if not sub_configs: |
|
|
raise ValueError("配置文件中未找到 compiler.sub_configs。") |
|
|
|
|
|
specs: List[SubGraphSpec] = [] |
|
|
covered_nodes: Set[str] = set() |
|
|
|
|
|
for idx, entry in enumerate(sub_configs): |
|
|
start = [name for name in entry.get("start_tensor_names", []) if name] |
|
|
end = [name for name in entry.get("end_tensor_names", []) if name] |
|
|
if not start or not end: |
|
|
raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") |
|
|
spec = SubGraphSpec( |
|
|
label=f"cfg_{idx:02d}", |
|
|
start=start, |
|
|
end=end, |
|
|
node_names=set(), |
|
|
source="config", |
|
|
) |
|
|
nodes = trace_nodes_between(spec, graph_index) |
|
|
spec.node_names = nodes |
|
|
covered_nodes.update(nodes) |
|
|
specs.append(spec) |
|
|
|
|
|
leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) |
|
|
for idx, component in enumerate(leftovers): |
|
|
start, end = derive_interface(component, graph_index) |
|
|
if not end: |
|
|
logging.warning("自动发现的剩余子图 %d 没有输出,跳过。", idx) |
|
|
continue |
|
|
spec = SubGraphSpec( |
|
|
label=f"auto_{idx:02d}", |
|
|
start=start, |
|
|
end=end, |
|
|
node_names=component, |
|
|
source="auto", |
|
|
) |
|
|
specs.append(spec) |
|
|
logging.info( |
|
|
"自动补充子图 %s: start=%s end=%s (节点数=%d)", |
|
|
spec.label, |
|
|
spec.start, |
|
|
spec.end, |
|
|
len(component), |
|
|
) |
|
|
|
|
|
ordered = ordered_specs(specs, graph_index) |
|
|
|
|
|
args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
for spec in ordered: |
|
|
spec.output_path = extract_model_file(args.model, spec, args.output_dir, spec.source) |
|
|
|
|
|
if args.verify: |
|
|
if args.input_npz is None: |
|
|
raise ValueError("verify 模式需要 --input-npz 提供输入数据。") |
|
|
feed = load_npz_inputs(args.input_npz) |
|
|
missing_inputs = graph_index.graph_inputs - feed.keys() |
|
|
if missing_inputs: |
|
|
raise ValueError(f"npz 中缺少以下模型输入: {sorted(missing_inputs)}") |
|
|
verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
用法示例: |
|
|
python python/VideoX-Fun/scripts/split_onnx_by_subconfigs.py \ |
|
|
--model /path/to/full.onnx \ |
|
|
--config python/VideoX-Fun/pulsar2_configs/transformers_subgraph.json \ |
|
|
--output-dir /tmp/sliced_models \ |
|
|
--verify \ |
|
|
--input-npz /path/to/inputs.npz |
|
|
""" |
|
|
main() |
|
|
|