Z-Image-Turbo / VideoX-Fun /scripts /split_quant_onnx_by_subconfigs.py
yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""Split an ONNX graph into smaller sub-models driven by sub_config rules.
This script reads a JSON config file (matching the pulsar2 sub_config layout),
extracts the requested subgraphs, and optionally emits any leftover parts of the
model as independent ONNX graphs. A verification utility can run the original
model and the stitched micro-model pipeline to make sure their outputs match.
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import defaultdict, deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Set
import numpy as np
import onnx
from onnx import utils as onnx_utils
try:
import onnxruntime as ort
except ImportError: # pragma: no cover - optional dependency
ort = None
@dataclass
class SubGraphSpec:
"""Describes a single subgraph to extract from the full model."""
label: str
start: List[str]
end: List[str]
node_names: Set[str]
source: str
output_path: Optional[Path] = None
@dataclass
class GraphIndex:
"""Caches helpful lookups for traversing an ONNX graph."""
tensor_to_producer: Dict[str, str]
tensor_to_consumers: Dict[str, List[str]]
node_inputs: Dict[str, List[str]]
node_outputs: Dict[str, List[str]]
graph_inputs: Set[str]
graph_outputs: Set[str]
initializer_names: Set[str]
node_order: List[str]
def sanitize(name: str) -> str:
keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"]
sanitized = "".join(keep).strip("_")
return sanitized or "tensor"
def build_graph_index(model: onnx.ModelProto) -> GraphIndex:
tensor_to_producer: Dict[str, str] = {}
tensor_to_consumers: Dict[str, List[str]] = defaultdict(list)
node_inputs: Dict[str, List[str]] = {}
node_outputs: Dict[str, List[str]] = {}
node_order: List[str] = []
used_names: Set[str] = set()
for idx, node in enumerate(model.graph.node):
base = node.name.strip() if node.name else ""
candidate = base or f"node_{idx}"
while candidate in used_names:
candidate = f"{candidate}_{idx}"
used_names.add(candidate)
node_name = candidate
node_order.append(node_name)
node_inputs[node_name] = [x for x in node.input if x]
node_outputs[node_name] = [y for y in node.output if y]
for out_name in node_outputs[node_name]:
tensor_to_producer[out_name] = node_name
for inp_name in node_inputs[node_name]:
tensor_to_consumers[inp_name].append(node_name)
graph_inputs = {vi.name for vi in model.graph.input}
graph_outputs = {vi.name for vi in model.graph.output}
initializer_names = {init.name for init in model.graph.initializer}
return GraphIndex(
tensor_to_producer=tensor_to_producer,
tensor_to_consumers=tensor_to_consumers,
node_inputs=node_inputs,
node_outputs=node_outputs,
graph_inputs=graph_inputs,
graph_outputs=graph_outputs,
initializer_names=initializer_names,
node_order=node_order,
)
def trace_nodes_between(
spec: SubGraphSpec,
index: GraphIndex,
) -> Set[str]:
boundary = set(spec.start) | index.graph_inputs | index.initializer_names
visited_tensors: Set[str] = set()
stack = list(spec.end)
discovered_nodes: Set[str] = set()
while stack:
tensor = stack.pop()
if tensor in visited_tensors:
continue
visited_tensors.add(tensor)
if tensor in boundary:
continue
producer = index.tensor_to_producer.get(tensor)
if not producer:
continue
if producer in discovered_nodes:
continue
discovered_nodes.add(producer)
for upstream in index.node_inputs.get(producer, []):
if upstream and upstream not in boundary:
stack.append(upstream)
return discovered_nodes
def untouched_components(
all_nodes: Sequence[str],
covered_nodes: Set[str],
index: GraphIndex,
) -> List[Set[str]]:
remaining = [n for n in all_nodes if n not in covered_nodes]
if not remaining:
return []
adjacency: Dict[str, Set[str]] = {name: set() for name in remaining}
rem_set = set(remaining)
for node in remaining:
for out_name in index.node_outputs.get(node, []):
for consumer in index.tensor_to_consumers.get(out_name, []):
if consumer in rem_set:
adjacency[node].add(consumer)
adjacency[consumer].add(node)
for inp_name in index.node_inputs.get(node, []):
producer = index.tensor_to_producer.get(inp_name)
if producer in rem_set:
adjacency[node].add(producer)
adjacency[producer].add(node)
components: List[Set[str]] = []
visited: Set[str] = set()
for node in remaining:
if node in visited:
continue
stack = [node]
comp: Set[str] = set()
while stack:
cur = stack.pop()
if cur in visited:
continue
visited.add(cur)
comp.add(cur)
stack.extend(adjacency[cur] - visited)
components.append(comp)
return components
def derive_interface(
nodes: Set[str],
index: GraphIndex,
) -> (List[str], List[str]):
produced = set()
for node in nodes:
produced.update(index.node_outputs.get(node, []))
start: Set[str] = set()
for node in nodes:
for inp in index.node_inputs.get(node, []):
producer = index.tensor_to_producer.get(inp)
if producer is None and inp not in index.initializer_names:
start.add(inp)
elif producer not in nodes and inp not in index.initializer_names:
start.add(inp)
end: Set[str] = set()
for node in nodes:
for out in index.node_outputs.get(node, []):
consumers = index.tensor_to_consumers.get(out, [])
if not consumers:
if out in index.graph_outputs:
end.add(out)
continue
if any(consumer not in nodes for consumer in consumers):
end.add(out)
end.update(index.graph_outputs & produced)
if not end and produced:
end = produced.copy()
return sorted(start), sorted(end)
def extract_model_file(
model_path: Path,
spec: SubGraphSpec,
output_dir: Path,
suffix: str,
) -> Path:
head = sanitize(spec.start[0]) if spec.start else "const"
tail = sanitize(spec.end[0]) if spec.end else "out"
filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx"
destination = output_dir / filename
onnx_utils.extract_model(
model_path.as_posix(),
destination.as_posix(),
input_names=spec.start,
output_names=spec.end,
check_model=False,
)
logging.info("Saved %s (start=%s, end=%s)", destination.name, spec.start, spec.end)
return destination
def ordered_specs(
specs: Sequence[SubGraphSpec],
index: GraphIndex,
) -> List[SubGraphSpec]:
available = set(index.graph_inputs) | index.initializer_names
pending = list(specs)
ordered: List[SubGraphSpec] = []
while pending:
progressed = False
for spec in list(pending):
if set(spec.start).issubset(available):
ordered.append(spec)
available.update(spec.end)
pending.remove(spec)
progressed = True
if not progressed:
missing = {spec.label: sorted(set(spec.start) - available) for spec in pending}
raise RuntimeError(
"无法解析子图拓扑,缺少以下张量: %s" % missing
)
return ordered
def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]):
if ort is None:
raise RuntimeError("需要 onnxruntime 才能执行验证。")
session = ort.InferenceSession(model_path.as_posix(), providers=providers)
outputs = session.run(None, feed_dict)
names = [meta.name for meta in session.get_outputs()]
return dict(zip(names, outputs))
def run_split_pipeline(
ordered_subgraphs: Sequence[SubGraphSpec],
feed_dict: Dict[str, np.ndarray],
providers: List[str],
) -> Dict[str, np.ndarray]:
if ort is None:
raise RuntimeError("需要 onnxruntime 才能执行验证。")
tensor_store = dict(feed_dict)
for spec in ordered_subgraphs:
if spec.output_path is None:
raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。")
session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers)
fetch_inputs = {}
for name in spec.start:
if name not in tensor_store:
raise KeyError(
f"子图 {spec.label} 缺少输入张量 {name},请确认切分顺序。"
)
fetch_inputs[name] = tensor_store[name]
results = session.run(None, fetch_inputs)
for meta, value in zip(session.get_outputs(), results):
tensor_store[meta.name] = value
return tensor_store
def verify(
model_path: Path,
ordered_subgraphs: Sequence[SubGraphSpec],
feed_dict: Dict[str, np.ndarray],
providers: List[str],
rtol: float,
atol: float,
) -> None:
full_outputs = run_full_model(model_path, feed_dict, providers)
split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers)
for name, ref in full_outputs.items():
cand = split_store.get(name)
if cand is None:
raise AssertionError(f"切分流水线未产生模型输出 {name}")
if not np.allclose(ref, cand, rtol=rtol, atol=atol):
diff = np.max(np.abs(ref - cand))
raise AssertionError(
f"输出 {name} 不匹配,最大偏差 {diff:.3e}"
)
logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol)
def load_npz_inputs(npz_path: Path) -> Dict[str, np.ndarray]:
data = np.load(npz_path, allow_pickle=False)
return {key: data[key] for key in data.files}
def main() -> None:
parser = argparse.ArgumentParser(description="根据 sub_config 切分 ONNX 模型。")
parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 路径")
parser.add_argument("--config", required=True, type=Path, help="pulsar2 配置 JSON")
parser.add_argument("--output-dir", required=False, default="./split-onnx", type=Path, help="保存子模型的目录")
parser.add_argument(
"--verify",
action="store_true",
help="生成后立即用 onnxruntime 校验输出是否一致",
)
parser.add_argument(
"--input-npz",
type=Path,
help="包含模型所有输入张量的 npz 文件 (verify 模式需要)",
)
parser.add_argument(
"--providers",
nargs="*",
default=["CPUExecutionProvider"],
help="onnxruntime 推理后端顺序",
)
parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol")
parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol")
parser.add_argument("--log", default="INFO", help="日志等级")
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO))
model = onnx.load(args.model.as_posix())
graph_index = build_graph_index(model)
with args.config.open("r", encoding="utf-8") as f:
config = json.load(f)
sub_configs = config.get("compiler", {}).get("sub_configs", [])
if not sub_configs:
raise ValueError("配置文件中未找到 compiler.sub_configs。")
specs: List[SubGraphSpec] = []
covered_nodes: Set[str] = set()
for idx, entry in enumerate(sub_configs):
start = [name for name in entry.get("start_tensor_names", []) if name]
end = [name for name in entry.get("end_tensor_names", []) if name]
if not start or not end:
raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。")
spec = SubGraphSpec(
label=f"cfg_{idx:02d}",
start=start,
end=end,
node_names=set(),
source="config",
)
nodes = trace_nodes_between(spec, graph_index)
spec.node_names = nodes
covered_nodes.update(nodes)
specs.append(spec)
leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index)
for idx, component in enumerate(leftovers):
start, end = derive_interface(component, graph_index)
if not end:
logging.warning("自动发现的剩余子图 %d 没有输出,跳过。", idx)
continue
spec = SubGraphSpec(
label=f"auto_{idx:02d}",
start=start,
end=end,
node_names=component,
source="auto",
)
specs.append(spec)
logging.info(
"自动补充子图 %s: start=%s end=%s (节点数=%d)",
spec.label,
spec.start,
spec.end,
len(component),
)
ordered = ordered_specs(specs, graph_index)
args.output_dir.mkdir(parents=True, exist_ok=True)
for spec in ordered:
spec.output_path = extract_model_file(args.model, spec, args.output_dir, spec.source)
if args.verify:
if args.input_npz is None:
raise ValueError("verify 模式需要 --input-npz 提供输入数据。")
feed = load_npz_inputs(args.input_npz)
missing_inputs = graph_index.graph_inputs - feed.keys()
if missing_inputs:
raise ValueError(f"npz 中缺少以下模型输入: {sorted(missing_inputs)}")
verify(args.model, ordered, feed, args.providers, args.rtol, args.atol)
if __name__ == "__main__":
"""
用法示例:
python python/VideoX-Fun/scripts/split_onnx_by_subconfigs.py \
--model /path/to/full.onnx \
--config python/VideoX-Fun/pulsar2_configs/transformers_subgraph.json \
--output-dir /tmp/sliced_models \
--verify \
--input-npz /path/to/inputs.npz
"""
main()