File size: 5,883 Bytes
ba96580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
#!/usr/bin/env python3
"""为每个子图ONNX生成量化配置JSON文件"""
import argparse
import json
import os
from pathlib import Path
# 路径配置默认值
REPO_ROOT = Path(__file__).parent.parent
DEFAULT_TAR_LIST_FILE = REPO_ROOT / "onnx-calibration-subgraphs" / "subgraph_calibration_paths.txt"
DEFAULT_OUTPUT_CONFIG_DIR = REPO_ROOT / "pulsar2_configs" / "subgraphs"
DEFAULT_TEMPLATE_CONFIG = REPO_ROOT / "pulsar2_configs" / "transformers.json"
# JSON模板配置
CONFIG_TEMPLATE = {
"model_type": "ONNX",
"npu_mode": "NPU3",
"quant": {
"input_configs": [
{
"tensor_name": "DEFAULT",
"calibration_dataset": "", # 将被替换
"calibration_size": -1,
"calibration_format": "NumpyObject"
}
],
"calibration_method": "MinMax",
"precision_analysis": True,
"precision_analysis_method": "PerLayer",
"enable_smooth_quant": True,
"conv_bias_data_type": "FP32",
"layer_configs": [
{
"start_tensor_names": ["DEFAULT"],
"end_tensor_names": ["DEFAULT"],
"data_type": "U16"
}
]
},
"input_processors": [
{
"tensor_name": "DEFAULT"
}
],
"compiler": {
"check": 0
}
}
def load_template_config(template_path: Path) -> dict:
"""加载模板配置文件"""
if template_path.exists():
with open(template_path, 'r', encoding='utf-8') as f:
return json.load(f)
return CONFIG_TEMPLATE
def extract_subgraph_name(tar_path: str) -> str:
"""从tar文件路径提取子图名称
例如: /path/to/cfg_00.tar -> cfg_00
"""
return Path(tar_path).stem
def create_config_for_subgraph(tar_path: str, template: dict, output_dir: Path) -> Path:
"""为单个子图创建配置文件"""
subgraph_name = extract_subgraph_name(tar_path)
# 深拷贝模板
config = json.loads(json.dumps(template))
# 修改 calibration_dataset 字段
config["quant"]["input_configs"][0]["calibration_dataset"] = tar_path
# 生成输出文件路径
config_file = output_dir / f"{subgraph_name}.json"
# 保存配置文件
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
return config_file
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="为子图ONNX生成量化配置JSON文件")
parser.add_argument(
"--tar-list-file",
type=Path,
default=DEFAULT_TAR_LIST_FILE,
help="包含tar路径的列表文件,一行一个;如果同时提供 --tar,将忽略此文件",
)
parser.add_argument(
"--tar",
dest="tar_paths",
action="append",
help="直接指定tar文件路径,可重复",
)
parser.add_argument(
"--output-config-dir",
type=Path,
default=DEFAULT_OUTPUT_CONFIG_DIR,
help="生成的配置文件输出目录",
)
parser.add_argument(
"--template-config",
type=Path,
default=DEFAULT_TEMPLATE_CONFIG,
help="配置模板文件路径,默认使用 pulsar2_configs/transformers.json",
)
return parser.parse_args()
def main():
args = parse_args()
print("=" * 80)
print("为子图ONNX生成量化配置文件")
print("=" * 80)
# 读取tar文件路径
if args.tar_paths:
tar_paths = args.tar_paths
print(f"\n使用命令行提供的 {len(tar_paths)} 个tar文件路径")
else:
if not args.tar_list_file.exists():
print(f"错误: 找不到tar列表文件: {args.tar_list_file}")
print("请先运行 collect_subgraph_inputs.py 生成校准数据或使用 --tar 指定tar路径")
return
print(f"\n读取tar文件列表: {args.tar_list_file}")
with open(args.tar_list_file, 'r') as f:
tar_paths = [line.strip() for line in f if line.strip()]
print(f"找到 {len(tar_paths)} 个tar文件")
# 加载模板配置
print(f"\n加载配置模板: {args.template_config}")
template = load_template_config(args.template_config)
# 创建输出目录
args.output_config_dir.mkdir(parents=True, exist_ok=True)
print(f"输出目录: {args.output_config_dir}")
# 为每个tar文件生成配置
print(f"\n生成配置文件...")
created_configs = []
for tar_path in tar_paths:
if not os.path.exists(tar_path):
print(f" 警告: tar文件不存在: {tar_path}")
continue
try:
config_file = create_config_for_subgraph(tar_path, template, args.output_config_dir)
created_configs.append(config_file)
print(f" ✓ {config_file.name}")
except Exception as e:
print(f" ✗ 生成配置失败 ({extract_subgraph_name(tar_path)}): {e}")
# 生成一个索引文件,列出所有配置文件路径
index_file = args.output_config_dir / "subgraph_configs_list.txt"
with open(index_file, 'w') as f:
for config_file in sorted(created_configs):
f.write(str(config_file.absolute()) + '\n')
print(f"\n配置文件索引已保存: {index_file}")
print("\n" + "=" * 80)
print(f"完成! 共生成 {len(created_configs)} 个配置文件")
print(f"配置文件目录: {args.output_config_dir}")
print(f"配置文件列表: {index_file}")
print("=" * 80)
# 显示示例配置
if created_configs:
print(f"\n示例配置 ({created_configs[0].name}):")
print("-" * 80)
with open(created_configs[0], 'r') as f:
print(f.read())
if __name__ == "__main__":
main()
|