Z-Image-Turbo / VideoX-Fun /scripts /generate_subgraph_configs.py
yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""为每个子图ONNX生成量化配置JSON文件"""
import argparse
import json
import os
from pathlib import Path
# 路径配置默认值
REPO_ROOT = Path(__file__).parent.parent
DEFAULT_TAR_LIST_FILE = REPO_ROOT / "onnx-calibration-subgraphs" / "subgraph_calibration_paths.txt"
DEFAULT_OUTPUT_CONFIG_DIR = REPO_ROOT / "pulsar2_configs" / "subgraphs"
DEFAULT_TEMPLATE_CONFIG = REPO_ROOT / "pulsar2_configs" / "transformers.json"
# JSON模板配置
CONFIG_TEMPLATE = {
"model_type": "ONNX",
"npu_mode": "NPU3",
"quant": {
"input_configs": [
{
"tensor_name": "DEFAULT",
"calibration_dataset": "", # 将被替换
"calibration_size": -1,
"calibration_format": "NumpyObject"
}
],
"calibration_method": "MinMax",
"precision_analysis": True,
"precision_analysis_method": "PerLayer",
"enable_smooth_quant": True,
"conv_bias_data_type": "FP32",
"layer_configs": [
{
"start_tensor_names": ["DEFAULT"],
"end_tensor_names": ["DEFAULT"],
"data_type": "U16"
}
]
},
"input_processors": [
{
"tensor_name": "DEFAULT"
}
],
"compiler": {
"check": 0
}
}
def load_template_config(template_path: Path) -> dict:
"""加载模板配置文件"""
if template_path.exists():
with open(template_path, 'r', encoding='utf-8') as f:
return json.load(f)
return CONFIG_TEMPLATE
def extract_subgraph_name(tar_path: str) -> str:
"""从tar文件路径提取子图名称
例如: /path/to/cfg_00.tar -> cfg_00
"""
return Path(tar_path).stem
def create_config_for_subgraph(tar_path: str, template: dict, output_dir: Path) -> Path:
"""为单个子图创建配置文件"""
subgraph_name = extract_subgraph_name(tar_path)
# 深拷贝模板
config = json.loads(json.dumps(template))
# 修改 calibration_dataset 字段
config["quant"]["input_configs"][0]["calibration_dataset"] = tar_path
# 生成输出文件路径
config_file = output_dir / f"{subgraph_name}.json"
# 保存配置文件
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
return config_file
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="为子图ONNX生成量化配置JSON文件")
parser.add_argument(
"--tar-list-file",
type=Path,
default=DEFAULT_TAR_LIST_FILE,
help="包含tar路径的列表文件,一行一个;如果同时提供 --tar,将忽略此文件",
)
parser.add_argument(
"--tar",
dest="tar_paths",
action="append",
help="直接指定tar文件路径,可重复",
)
parser.add_argument(
"--output-config-dir",
type=Path,
default=DEFAULT_OUTPUT_CONFIG_DIR,
help="生成的配置文件输出目录",
)
parser.add_argument(
"--template-config",
type=Path,
default=DEFAULT_TEMPLATE_CONFIG,
help="配置模板文件路径,默认使用 pulsar2_configs/transformers.json",
)
return parser.parse_args()
def main():
args = parse_args()
print("=" * 80)
print("为子图ONNX生成量化配置文件")
print("=" * 80)
# 读取tar文件路径
if args.tar_paths:
tar_paths = args.tar_paths
print(f"\n使用命令行提供的 {len(tar_paths)} 个tar文件路径")
else:
if not args.tar_list_file.exists():
print(f"错误: 找不到tar列表文件: {args.tar_list_file}")
print("请先运行 collect_subgraph_inputs.py 生成校准数据或使用 --tar 指定tar路径")
return
print(f"\n读取tar文件列表: {args.tar_list_file}")
with open(args.tar_list_file, 'r') as f:
tar_paths = [line.strip() for line in f if line.strip()]
print(f"找到 {len(tar_paths)} 个tar文件")
# 加载模板配置
print(f"\n加载配置模板: {args.template_config}")
template = load_template_config(args.template_config)
# 创建输出目录
args.output_config_dir.mkdir(parents=True, exist_ok=True)
print(f"输出目录: {args.output_config_dir}")
# 为每个tar文件生成配置
print(f"\n生成配置文件...")
created_configs = []
for tar_path in tar_paths:
if not os.path.exists(tar_path):
print(f" 警告: tar文件不存在: {tar_path}")
continue
try:
config_file = create_config_for_subgraph(tar_path, template, args.output_config_dir)
created_configs.append(config_file)
print(f" ✓ {config_file.name}")
except Exception as e:
print(f" ✗ 生成配置失败 ({extract_subgraph_name(tar_path)}): {e}")
# 生成一个索引文件,列出所有配置文件路径
index_file = args.output_config_dir / "subgraph_configs_list.txt"
with open(index_file, 'w') as f:
for config_file in sorted(created_configs):
f.write(str(config_file.absolute()) + '\n')
print(f"\n配置文件索引已保存: {index_file}")
print("\n" + "=" * 80)
print(f"完成! 共生成 {len(created_configs)} 个配置文件")
print(f"配置文件目录: {args.output_config_dir}")
print(f"配置文件列表: {index_file}")
print("=" * 80)
# 显示示例配置
if created_configs:
print(f"\n示例配置 ({created_configs[0].name}):")
print("-" * 80)
with open(created_configs[0], 'r') as f:
print(f.read())
if __name__ == "__main__":
main()