|
|
|
|
|
"""为每个子图ONNX生成量化配置JSON文件""" |
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
REPO_ROOT = Path(__file__).parent.parent |
|
|
DEFAULT_TAR_LIST_FILE = REPO_ROOT / "onnx-calibration-subgraphs" / "subgraph_calibration_paths.txt" |
|
|
DEFAULT_OUTPUT_CONFIG_DIR = REPO_ROOT / "pulsar2_configs" / "subgraphs" |
|
|
DEFAULT_TEMPLATE_CONFIG = REPO_ROOT / "pulsar2_configs" / "transformers.json" |
|
|
|
|
|
|
|
|
CONFIG_TEMPLATE = { |
|
|
"model_type": "ONNX", |
|
|
"npu_mode": "NPU3", |
|
|
"quant": { |
|
|
"input_configs": [ |
|
|
{ |
|
|
"tensor_name": "DEFAULT", |
|
|
"calibration_dataset": "", |
|
|
"calibration_size": -1, |
|
|
"calibration_format": "NumpyObject" |
|
|
} |
|
|
], |
|
|
"calibration_method": "MinMax", |
|
|
"precision_analysis": True, |
|
|
"precision_analysis_method": "PerLayer", |
|
|
"enable_smooth_quant": True, |
|
|
"conv_bias_data_type": "FP32", |
|
|
"layer_configs": [ |
|
|
{ |
|
|
"start_tensor_names": ["DEFAULT"], |
|
|
"end_tensor_names": ["DEFAULT"], |
|
|
"data_type": "U16" |
|
|
} |
|
|
] |
|
|
}, |
|
|
"input_processors": [ |
|
|
{ |
|
|
"tensor_name": "DEFAULT" |
|
|
} |
|
|
], |
|
|
"compiler": { |
|
|
"check": 0 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def load_template_config(template_path: Path) -> dict: |
|
|
"""加载模板配置文件""" |
|
|
if template_path.exists(): |
|
|
with open(template_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
return CONFIG_TEMPLATE |
|
|
|
|
|
|
|
|
def extract_subgraph_name(tar_path: str) -> str: |
|
|
"""从tar文件路径提取子图名称 |
|
|
例如: /path/to/cfg_00.tar -> cfg_00 |
|
|
""" |
|
|
return Path(tar_path).stem |
|
|
|
|
|
|
|
|
def create_config_for_subgraph(tar_path: str, template: dict, output_dir: Path) -> Path: |
|
|
"""为单个子图创建配置文件""" |
|
|
subgraph_name = extract_subgraph_name(tar_path) |
|
|
|
|
|
|
|
|
config = json.loads(json.dumps(template)) |
|
|
|
|
|
|
|
|
config["quant"]["input_configs"][0]["calibration_dataset"] = tar_path |
|
|
|
|
|
|
|
|
config_file = output_dir / f"{subgraph_name}.json" |
|
|
|
|
|
|
|
|
with open(config_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return config_file |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="为子图ONNX生成量化配置JSON文件") |
|
|
parser.add_argument( |
|
|
"--tar-list-file", |
|
|
type=Path, |
|
|
default=DEFAULT_TAR_LIST_FILE, |
|
|
help="包含tar路径的列表文件,一行一个;如果同时提供 --tar,将忽略此文件", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tar", |
|
|
dest="tar_paths", |
|
|
action="append", |
|
|
help="直接指定tar文件路径,可重复", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-config-dir", |
|
|
type=Path, |
|
|
default=DEFAULT_OUTPUT_CONFIG_DIR, |
|
|
help="生成的配置文件输出目录", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--template-config", |
|
|
type=Path, |
|
|
default=DEFAULT_TEMPLATE_CONFIG, |
|
|
help="配置模板文件路径,默认使用 pulsar2_configs/transformers.json", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
print("=" * 80) |
|
|
print("为子图ONNX生成量化配置文件") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if args.tar_paths: |
|
|
tar_paths = args.tar_paths |
|
|
print(f"\n使用命令行提供的 {len(tar_paths)} 个tar文件路径") |
|
|
else: |
|
|
if not args.tar_list_file.exists(): |
|
|
print(f"错误: 找不到tar列表文件: {args.tar_list_file}") |
|
|
print("请先运行 collect_subgraph_inputs.py 生成校准数据或使用 --tar 指定tar路径") |
|
|
return |
|
|
|
|
|
print(f"\n读取tar文件列表: {args.tar_list_file}") |
|
|
with open(args.tar_list_file, 'r') as f: |
|
|
tar_paths = [line.strip() for line in f if line.strip()] |
|
|
print(f"找到 {len(tar_paths)} 个tar文件") |
|
|
|
|
|
|
|
|
print(f"\n加载配置模板: {args.template_config}") |
|
|
template = load_template_config(args.template_config) |
|
|
|
|
|
|
|
|
args.output_config_dir.mkdir(parents=True, exist_ok=True) |
|
|
print(f"输出目录: {args.output_config_dir}") |
|
|
|
|
|
|
|
|
print(f"\n生成配置文件...") |
|
|
created_configs = [] |
|
|
|
|
|
for tar_path in tar_paths: |
|
|
if not os.path.exists(tar_path): |
|
|
print(f" 警告: tar文件不存在: {tar_path}") |
|
|
continue |
|
|
|
|
|
try: |
|
|
config_file = create_config_for_subgraph(tar_path, template, args.output_config_dir) |
|
|
created_configs.append(config_file) |
|
|
print(f" ✓ {config_file.name}") |
|
|
except Exception as e: |
|
|
print(f" ✗ 生成配置失败 ({extract_subgraph_name(tar_path)}): {e}") |
|
|
|
|
|
|
|
|
index_file = args.output_config_dir / "subgraph_configs_list.txt" |
|
|
with open(index_file, 'w') as f: |
|
|
for config_file in sorted(created_configs): |
|
|
f.write(str(config_file.absolute()) + '\n') |
|
|
|
|
|
print(f"\n配置文件索引已保存: {index_file}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print(f"完成! 共生成 {len(created_configs)} 个配置文件") |
|
|
print(f"配置文件目录: {args.output_config_dir}") |
|
|
print(f"配置文件列表: {index_file}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
if created_configs: |
|
|
print(f"\n示例配置 ({created_configs[0].name}):") |
|
|
print("-" * 80) |
|
|
with open(created_configs[0], 'r') as f: |
|
|
print(f.read()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|