File size: 5,883 Bytes
ba96580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
"""为每个子图ONNX生成量化配置JSON文件"""
import argparse
import json
import os
from pathlib import Path

# 路径配置默认值
REPO_ROOT = Path(__file__).parent.parent
DEFAULT_TAR_LIST_FILE = REPO_ROOT / "onnx-calibration-subgraphs" / "subgraph_calibration_paths.txt"
DEFAULT_OUTPUT_CONFIG_DIR = REPO_ROOT / "pulsar2_configs" / "subgraphs"
DEFAULT_TEMPLATE_CONFIG = REPO_ROOT / "pulsar2_configs" / "transformers.json"

# JSON模板配置
CONFIG_TEMPLATE = {
    "model_type": "ONNX",
    "npu_mode": "NPU3",
    "quant": {
        "input_configs": [
            {
                "tensor_name": "DEFAULT",
                "calibration_dataset": "",  # 将被替换
                "calibration_size": -1,
                "calibration_format": "NumpyObject"
            }
        ],
        "calibration_method": "MinMax",
        "precision_analysis": True,
        "precision_analysis_method": "PerLayer",
        "enable_smooth_quant": True,
        "conv_bias_data_type": "FP32",
        "layer_configs": [
            {
                "start_tensor_names": ["DEFAULT"],
                "end_tensor_names": ["DEFAULT"],
                "data_type": "U16"
            }
        ]
    },
    "input_processors": [
        {
            "tensor_name": "DEFAULT"
        }
    ],
    "compiler": {
        "check": 0
    }
}


def load_template_config(template_path: Path) -> dict:
    """加载模板配置文件"""
    if template_path.exists():
        with open(template_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    return CONFIG_TEMPLATE


def extract_subgraph_name(tar_path: str) -> str:
    """从tar文件路径提取子图名称
    例如: /path/to/cfg_00.tar -> cfg_00
    """
    return Path(tar_path).stem


def create_config_for_subgraph(tar_path: str, template: dict, output_dir: Path) -> Path:
    """为单个子图创建配置文件"""
    subgraph_name = extract_subgraph_name(tar_path)
    
    # 深拷贝模板
    config = json.loads(json.dumps(template))
    
    # 修改 calibration_dataset 字段
    config["quant"]["input_configs"][0]["calibration_dataset"] = tar_path
    
    # 生成输出文件路径
    config_file = output_dir / f"{subgraph_name}.json"
    
    # 保存配置文件
    with open(config_file, 'w', encoding='utf-8') as f:
        json.dump(config, f, indent=2, ensure_ascii=False)
    
    return config_file


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="为子图ONNX生成量化配置JSON文件")
    parser.add_argument(
        "--tar-list-file",
        type=Path,
        default=DEFAULT_TAR_LIST_FILE,
        help="包含tar路径的列表文件,一行一个;如果同时提供 --tar,将忽略此文件",
    )
    parser.add_argument(
        "--tar",
        dest="tar_paths",
        action="append",
        help="直接指定tar文件路径,可重复",
    )
    parser.add_argument(
        "--output-config-dir",
        type=Path,
        default=DEFAULT_OUTPUT_CONFIG_DIR,
        help="生成的配置文件输出目录",
    )
    parser.add_argument(
        "--template-config",
        type=Path,
        default=DEFAULT_TEMPLATE_CONFIG,
        help="配置模板文件路径,默认使用 pulsar2_configs/transformers.json",
    )
    return parser.parse_args()


def main():
    args = parse_args()

    print("=" * 80)
    print("为子图ONNX生成量化配置文件")
    print("=" * 80)
    
    # 读取tar文件路径
    if args.tar_paths:
        tar_paths = args.tar_paths
        print(f"\n使用命令行提供的 {len(tar_paths)} 个tar文件路径")
    else:
        if not args.tar_list_file.exists():
            print(f"错误: 找不到tar列表文件: {args.tar_list_file}")
            print("请先运行 collect_subgraph_inputs.py 生成校准数据或使用 --tar 指定tar路径")
            return

        print(f"\n读取tar文件列表: {args.tar_list_file}")
        with open(args.tar_list_file, 'r') as f:
            tar_paths = [line.strip() for line in f if line.strip()]
        print(f"找到 {len(tar_paths)} 个tar文件")
    
    # 加载模板配置
    print(f"\n加载配置模板: {args.template_config}")
    template = load_template_config(args.template_config)
    
    # 创建输出目录
    args.output_config_dir.mkdir(parents=True, exist_ok=True)
    print(f"输出目录: {args.output_config_dir}")
    
    # 为每个tar文件生成配置
    print(f"\n生成配置文件...")
    created_configs = []
    
    for tar_path in tar_paths:
        if not os.path.exists(tar_path):
            print(f"  警告: tar文件不存在: {tar_path}")
            continue
        
        try:
            config_file = create_config_for_subgraph(tar_path, template, args.output_config_dir)
            created_configs.append(config_file)
            print(f"  ✓ {config_file.name}")
        except Exception as e:
            print(f"  ✗ 生成配置失败 ({extract_subgraph_name(tar_path)}): {e}")
    
    # 生成一个索引文件,列出所有配置文件路径
    index_file = args.output_config_dir / "subgraph_configs_list.txt"
    with open(index_file, 'w') as f:
        for config_file in sorted(created_configs):
            f.write(str(config_file.absolute()) + '\n')
    
    print(f"\n配置文件索引已保存: {index_file}")
    
    print("\n" + "=" * 80)
    print(f"完成! 共生成 {len(created_configs)} 个配置文件")
    print(f"配置文件目录: {args.output_config_dir}")
    print(f"配置文件列表: {index_file}")
    print("=" * 80)
    
    # 显示示例配置
    if created_configs:
        print(f"\n示例配置 ({created_configs[0].name}):")
        print("-" * 80)
        with open(created_configs[0], 'r') as f:
            print(f.read())


if __name__ == "__main__":
    main()