#!/usr/bin/env python3 """为每个子图ONNX生成量化配置JSON文件""" import argparse import json import os from pathlib import Path # 路径配置默认值 REPO_ROOT = Path(__file__).parent.parent DEFAULT_TAR_LIST_FILE = REPO_ROOT / "onnx-calibration-subgraphs" / "subgraph_calibration_paths.txt" DEFAULT_OUTPUT_CONFIG_DIR = REPO_ROOT / "pulsar2_configs" / "subgraphs" DEFAULT_TEMPLATE_CONFIG = REPO_ROOT / "pulsar2_configs" / "transformers.json" # JSON模板配置 CONFIG_TEMPLATE = { "model_type": "ONNX", "npu_mode": "NPU3", "quant": { "input_configs": [ { "tensor_name": "DEFAULT", "calibration_dataset": "", # 将被替换 "calibration_size": -1, "calibration_format": "NumpyObject" } ], "calibration_method": "MinMax", "precision_analysis": True, "precision_analysis_method": "PerLayer", "enable_smooth_quant": True, "conv_bias_data_type": "FP32", "layer_configs": [ { "start_tensor_names": ["DEFAULT"], "end_tensor_names": ["DEFAULT"], "data_type": "U16" } ] }, "input_processors": [ { "tensor_name": "DEFAULT" } ], "compiler": { "check": 0 } } def load_template_config(template_path: Path) -> dict: """加载模板配置文件""" if template_path.exists(): with open(template_path, 'r', encoding='utf-8') as f: return json.load(f) return CONFIG_TEMPLATE def extract_subgraph_name(tar_path: str) -> str: """从tar文件路径提取子图名称 例如: /path/to/cfg_00.tar -> cfg_00 """ return Path(tar_path).stem def create_config_for_subgraph(tar_path: str, template: dict, output_dir: Path) -> Path: """为单个子图创建配置文件""" subgraph_name = extract_subgraph_name(tar_path) # 深拷贝模板 config = json.loads(json.dumps(template)) # 修改 calibration_dataset 字段 config["quant"]["input_configs"][0]["calibration_dataset"] = tar_path # 生成输出文件路径 config_file = output_dir / f"{subgraph_name}.json" # 保存配置文件 with open(config_file, 'w', encoding='utf-8') as f: json.dump(config, f, indent=2, ensure_ascii=False) return config_file def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="为子图ONNX生成量化配置JSON文件") parser.add_argument( "--tar-list-file", type=Path, default=DEFAULT_TAR_LIST_FILE, help="包含tar路径的列表文件,一行一个;如果同时提供 --tar,将忽略此文件", ) parser.add_argument( "--tar", dest="tar_paths", action="append", help="直接指定tar文件路径,可重复", ) parser.add_argument( "--output-config-dir", type=Path, default=DEFAULT_OUTPUT_CONFIG_DIR, help="生成的配置文件输出目录", ) parser.add_argument( "--template-config", type=Path, default=DEFAULT_TEMPLATE_CONFIG, help="配置模板文件路径,默认使用 pulsar2_configs/transformers.json", ) return parser.parse_args() def main(): args = parse_args() print("=" * 80) print("为子图ONNX生成量化配置文件") print("=" * 80) # 读取tar文件路径 if args.tar_paths: tar_paths = args.tar_paths print(f"\n使用命令行提供的 {len(tar_paths)} 个tar文件路径") else: if not args.tar_list_file.exists(): print(f"错误: 找不到tar列表文件: {args.tar_list_file}") print("请先运行 collect_subgraph_inputs.py 生成校准数据或使用 --tar 指定tar路径") return print(f"\n读取tar文件列表: {args.tar_list_file}") with open(args.tar_list_file, 'r') as f: tar_paths = [line.strip() for line in f if line.strip()] print(f"找到 {len(tar_paths)} 个tar文件") # 加载模板配置 print(f"\n加载配置模板: {args.template_config}") template = load_template_config(args.template_config) # 创建输出目录 args.output_config_dir.mkdir(parents=True, exist_ok=True) print(f"输出目录: {args.output_config_dir}") # 为每个tar文件生成配置 print(f"\n生成配置文件...") created_configs = [] for tar_path in tar_paths: if not os.path.exists(tar_path): print(f" 警告: tar文件不存在: {tar_path}") continue try: config_file = create_config_for_subgraph(tar_path, template, args.output_config_dir) created_configs.append(config_file) print(f" ✓ {config_file.name}") except Exception as e: print(f" ✗ 生成配置失败 ({extract_subgraph_name(tar_path)}): {e}") # 生成一个索引文件,列出所有配置文件路径 index_file = args.output_config_dir / "subgraph_configs_list.txt" with open(index_file, 'w') as f: for config_file in sorted(created_configs): f.write(str(config_file.absolute()) + '\n') print(f"\n配置文件索引已保存: {index_file}") print("\n" + "=" * 80) print(f"完成! 共生成 {len(created_configs)} 个配置文件") print(f"配置文件目录: {args.output_config_dir}") print(f"配置文件列表: {index_file}") print("=" * 80) # 显示示例配置 if created_configs: print(f"\n示例配置 ({created_configs[0].name}):") print("-" * 80) with open(created_configs[0], 'r') as f: print(f.read()) if __name__ == "__main__": main()