File size: 29,209 Bytes
ba96580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
#!/usr/bin/env python3
"""基于 tensor 名称的通用 ONNX 子图切分工具。

相较于 split_quant_onnx_by_subconfigs.py,本脚本额外提供:
1. 为每个子模型执行 onnx checker 与 shape inference(可关闭)。
2. 支持从 .npz/.npy(包含 dict 或单数组)加载验证数据。
3. 可用 onnxruntime 串联执行全部子模型,既可校验精度,也可单独输出流水线结果。
"""
from __future__ import annotations

import argparse
import json
import logging
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Set

import numpy as np
import onnx
from onnx import TensorProto, checker, helper, shape_inference, utils as onnx_utils

try:  # pragma: no cover - 可选依赖
    import onnxruntime as ort
except ImportError:  # pragma: no cover
    ort = None


@dataclass
class SubGraphSpec:
    label: str
    start: List[str]
    end: List[str]
    node_names: Set[str]
    source: str
    output_path: Optional[Path] = None


@dataclass
class GraphIndex:
    tensor_to_producer: Dict[str, str]
    tensor_to_consumers: Dict[str, List[str]]
    node_inputs: Dict[str, List[str]]
    node_outputs: Dict[str, List[str]]
    graph_inputs: Set[str]
    graph_outputs: Set[str]
    initializer_names: Set[str]
    node_order: List[str]


def sanitize(name: str) -> str:
    keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"]
    sanitized = "".join(keep).strip("_")
    return sanitized or "tensor"


def build_graph_index(model: onnx.ModelProto) -> GraphIndex:
    tensor_to_producer: Dict[str, str] = {}
    tensor_to_consumers: Dict[str, List[str]] = defaultdict(list)
    node_inputs: Dict[str, List[str]] = {}
    node_outputs: Dict[str, List[str]] = {}
    node_order: List[str] = []

    used_names: Set[str] = set()
    for idx, node in enumerate(model.graph.node):
        base = node.name.strip() if node.name else ""
        candidate = base or f"node_{idx}"
        while candidate in used_names:
            candidate = f"{candidate}_{idx}"
        used_names.add(candidate)
        node_order.append(candidate)
        node_inputs[candidate] = [x for x in node.input if x]
        node_outputs[candidate] = [y for y in node.output if y]
        for out_name in node_outputs[candidate]:
            tensor_to_producer[out_name] = candidate
        for inp_name in node_inputs[candidate]:
            tensor_to_consumers[inp_name].append(candidate)

    graph_inputs = {vi.name for vi in model.graph.input}
    graph_outputs = {vi.name for vi in model.graph.output}
    initializer_names = {init.name for init in model.graph.initializer}

    return GraphIndex(
        tensor_to_producer=tensor_to_producer,
        tensor_to_consumers=tensor_to_consumers,
        node_inputs=node_inputs,
        node_outputs=node_outputs,
        graph_inputs=graph_inputs,
        graph_outputs=graph_outputs,
        initializer_names=initializer_names,
        node_order=node_order,
    )


def ensure_value_infos(model: onnx.ModelProto, tensor_names: Iterable[str]) -> None:
    existing = {vi.name for vi in model.graph.value_info}
    source_map = {}
    for vi in list(model.graph.input) + list(model.graph.value_info) + list(model.graph.output):
        source_map[vi.name] = vi

    added: List[str] = []
    for name in tensor_names:
        if name in existing:
            continue
        src = source_map.get(name)
        if src is not None:
            vi = onnx.ValueInfoProto()
            vi.CopyFrom(src)
        else:
            vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None)
        model.graph.value_info.append(vi)
        existing.add(name)
        added.append(name)
    if added:
        logging.debug("已为以下 tensor 补充 ValueInfo: %s", added)


def ensure_extractor_value_infos(
    extractor: onnx_utils.Extractor,
    tensor_names: Iterable[str],
    source_model: onnx.ModelProto,
) -> None:
    existing_inputs = {vi.name for vi in extractor.graph.input}
    existing_outputs = {vi.name for vi in extractor.graph.output}
    existing_vi = {vi.name for vi in extractor.graph.value_info}
    source_map = {}
    for vi in (
        list(source_model.graph.input)
        + list(source_model.graph.value_info)
        + list(source_model.graph.output)
    ):
        source_map[vi.name] = vi

    added: List[str] = []
    for name in tensor_names:
        if name in existing_inputs or name in existing_outputs or name in existing_vi:
            continue
        src = source_map.get(name)
        if src is not None:
            vi = onnx.ValueInfoProto()
            vi.CopyFrom(src)
        else:
            vi = helper.make_tensor_value_info(name, TensorProto.UNDEFINED, None)
        extractor.graph.value_info.append(vi)
        extractor.vimap[name] = vi
        existing_vi.add(name)
        added.append(name)
    if added:
        logging.debug("Extractor 侧补充 ValueInfo: %s", added)


def trace_nodes_between(spec: SubGraphSpec, index: GraphIndex) -> Set[str]:
    boundary = set(spec.start) | index.graph_inputs | index.initializer_names
    visited_tensors: Set[str] = set()
    stack = list(spec.end)
    discovered_nodes: Set[str] = set()

    while stack:
        tensor = stack.pop()
        if tensor in visited_tensors:
            continue
        visited_tensors.add(tensor)
        if tensor in boundary:
            continue
        producer = index.tensor_to_producer.get(tensor)
        if not producer or producer in discovered_nodes:
            continue
        discovered_nodes.add(producer)
        for upstream in index.node_inputs.get(producer, []):
            if upstream and upstream not in boundary:
                stack.append(upstream)
    return discovered_nodes


def untouched_components(
    all_nodes: Sequence[str],
    covered_nodes: Set[str],
    index: GraphIndex,
) -> List[Set[str]]:
    remaining = [n for n in all_nodes if n not in covered_nodes]
    if not remaining:
        return []
    adjacency: Dict[str, Set[str]] = {name: set() for name in remaining}
    rem_set = set(remaining)

    for node in remaining:
        for out_name in index.node_outputs.get(node, []):
            for consumer in index.tensor_to_consumers.get(out_name, []):
                if consumer in rem_set:
                    adjacency[node].add(consumer)
                    adjacency[consumer].add(node)
        for inp_name in index.node_inputs.get(node, []):
            producer = index.tensor_to_producer.get(inp_name)
            if producer in rem_set:
                adjacency[node].add(producer)
                adjacency[producer].add(node)

    components: List[Set[str]] = []
    visited: Set[str] = set()
    for node in remaining:
        if node in visited:
            continue
        stack = [node]
        comp: Set[str] = set()
        while stack:
            cur = stack.pop()
            if cur in visited:
                continue
            visited.add(cur)
            comp.add(cur)
            stack.extend(adjacency[cur] - visited)
        components.append(comp)
    return components


def derive_interface(nodes: Set[str], index: GraphIndex) -> (List[str], List[str]):
    produced = set()
    for node in nodes:
        produced.update(index.node_outputs.get(node, []))

    start: Set[str] = set()
    for node in nodes:
        for inp in index.node_inputs.get(node, []):
            producer = index.tensor_to_producer.get(inp)
            if producer is None and inp not in index.initializer_names:
                start.add(inp)
            elif producer not in nodes and inp not in index.initializer_names:
                start.add(inp)

    end: Set[str] = set()
    for node in nodes:
        for out in index.node_outputs.get(node, []):
            consumers = index.tensor_to_consumers.get(out, [])
            if not consumers:
                if out in index.graph_outputs:
                    end.add(out)
                continue
            if any(consumer not in nodes for consumer in consumers):
                end.add(out)
    end.update(index.graph_outputs & produced)

    if not end and produced:
        end = produced.copy()

    return sorted(start), sorted(end)


def ordered_specs(specs: Sequence[SubGraphSpec], index: GraphIndex) -> List[SubGraphSpec]:
    available = set(index.graph_inputs) | index.initializer_names
    pending = list(specs)
    ordered: List[SubGraphSpec] = []
    while pending:
        progressed = False
        for spec in list(pending):
            if set(spec.start).issubset(available):
                ordered.append(spec)
                available.update(spec.end)
                pending.remove(spec)
                progressed = True
        if not progressed:
            missing = {spec.label: sorted(set(spec.start) - available) for spec in pending}
            raise RuntimeError(f"无法解析子图拓扑,缺少以下张量: {missing}")
    return ordered


def is_valid_onnx_model(model_path: Path) -> bool:
    """检查 ONNX 模型文件是否有效(包含必需的 opset_import)"""
    try:
        # 检查文件大小
        if model_path.stat().st_size == 0:
            logging.warning(f"模型 {model_path.name} 是空文件")
            return False

        model = onnx.load(model_path.as_posix(), load_external_data=False)

        # 检查模型是否为 None
        if model is None:
            logging.warning(f"模型 {model_path.name} 加载后为 None")
            return False

        # 检查是否有 graph
        if not hasattr(model, 'graph') or model.graph is None:
            logging.warning(f"模型 {model_path.name} 缺少 graph")
            return False

        # 检查是否有 opset_import
        if len(model.opset_import) == 0:
            logging.warning(f"模型 {model_path.name} 缺少 opset_import 信息")
            return False

        return True
    except Exception as e:
        logging.warning(f"无法加载模型 {model_path.name}: {e}")
        return False


def extract_model_file(
    source_model: onnx.ModelProto,
    spec: SubGraphSpec,
    output_dir: Path,
    suffix: str,
    run_checker: bool,
    run_shape_infer: bool,
    skip_existing: bool = True,
) -> Path:
    head = sanitize(spec.start[0]) if spec.start else "const"
    tail = sanitize(spec.end[0]) if spec.end else "out"
    filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx"
    destination = output_dir / filename
    
    # 如果文件已存在且 skip_existing=True,检查文件是否有效
    if skip_existing and destination.exists():
        if is_valid_onnx_model(destination):
            logging.info("跳过已存在的子图 %s (文件: %s)", spec.label, destination.name)
            return destination
        else:
            logging.warning("子图 %s 的文件无效,将重新生成", spec.label)

    # 手动构建子图,不使用 Extractor
    # 1. 创建新的空图
    sub_graph = helper.make_graph(
        nodes=[],
        name=f"{spec.label}_subgraph",
        inputs=[],
        outputs=[],
        initializer=[]
    )
    
    # 2. 从原始模型复制需要的节点
    #    注意:GraphIndex 为重复/空名字节点生成了唯一的候选名,这里必须使用同样的规则
    node_map = {}
    used_names: Set[str] = set()
    for idx, node in enumerate(source_model.graph.node):
        base = node.name.strip() if node.name else ""
        candidate = base or f"node_{idx}"
        while candidate in used_names:
            candidate = f"{candidate}_{idx}"
        used_names.add(candidate)
        node_map[candidate] = node

    missing_nodes: List[str] = []
    for node_name in spec.node_names:
        target = node_map.get(node_name)
        if target is None:
            missing_nodes.append(node_name)
            continue
        new_node = onnx.NodeProto()
        new_node.CopyFrom(target)
        sub_graph.node.append(new_node)

    if missing_nodes:
        logging.warning("子图 %s: 有 %d 个节点未匹配到源模型,将被跳过: %s", spec.label, len(missing_nodes), missing_nodes[:5])
    
    # 3. 收集所有需要的张量名称
    node_inputs = set()
    node_outputs = set()
    for node in sub_graph.node:
        for inp in node.input:
            if inp:
                node_inputs.add(inp)
        for out in node.output:
            if out:
                node_outputs.add(out)
    
    # 4. 从原始模型收集 value_info
    source_value_info_map = {}
    for vi in list(source_model.graph.input) + list(source_model.graph.value_info) + list(source_model.graph.output):
        source_value_info_map[vi.name] = vi
    
    # 5. 从原始模型收集 initializers
    source_init_map = {init.name: init for init in source_model.graph.initializer}

    # 6. 添加输入:从 spec.start 和需要但不是节点输出的张量
    input_tensor_names = set(spec.start)
    for tensor_name in node_inputs:
        if tensor_name not in node_outputs and tensor_name not in source_init_map:
            input_tensor_names.add(tensor_name)

    for tensor_name in sorted(input_tensor_names):
        if tensor_name in source_value_info_map:
            vi = onnx.ValueInfoProto()
            vi.CopyFrom(source_value_info_map[tensor_name])
            sub_graph.input.append(vi)
        else:
            vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None)
            sub_graph.input.append(vi)

    # 7. 添加 initializers
    for tensor_name in node_inputs:
        if tensor_name in source_init_map:
            init = onnx.TensorProto()
            init.CopyFrom(source_init_map[tensor_name])
            sub_graph.initializer.append(init)

    # 8. 添加输出:从 spec.end
    for tensor_name in spec.end:
        if tensor_name in source_value_info_map:
            vi = onnx.ValueInfoProto()
            vi.CopyFrom(source_value_info_map[tensor_name])
            sub_graph.output.append(vi)
        else:
            vi = helper.make_tensor_value_info(tensor_name, TensorProto.UNDEFINED, None)
            sub_graph.output.append(vi)

    # 9. 检查输出是否都有对应的产生节点,如果没有则添加 Identity
    for out_name in spec.end:
        if out_name not in node_outputs:
            # 这个输出没有被任何节点产生
            if out_name in input_tensor_names or out_name in source_init_map:
                # 添加 Identity 节点
                identity_node = helper.make_node(
                    'Identity',
                    inputs=[out_name],
                    outputs=[out_name],
                    name=f'passthrough_{sanitize(out_name)}'
                )
                sub_graph.node.append(identity_node)
                logging.info(f"子图 {spec.label}: 为输出 {out_name} 添加 Identity 节点")
            else:
                logging.error(f"子图 {spec.label}: 输出 {out_name} 无法产生(不在输入/initializer/节点输出中)")

    # 10. 创建模型
    sub_model = helper.make_model(sub_graph)

    # 11. 复制元数据
    sub_model.ir_version = source_model.ir_version
    sub_model.producer_name = source_model.producer_name
    sub_model.producer_version = source_model.producer_version
    sub_model.domain = source_model.domain
    sub_model.model_version = source_model.model_version
    sub_model.doc_string = source_model.doc_string

    # 12. 复制 opset imports
    while len(sub_model.opset_import) > 0:
        sub_model.opset_import.pop()

    if len(source_model.opset_import) > 0:
        for opset in source_model.opset_import:
            opset_import = sub_model.opset_import.add()
            opset_import.CopyFrom(opset)
    else:
        # 如果源模型没有 opset_import,添加默认的 opset
        logging.warning(f"源模型缺少 opset_import,为子图 {spec.label} 添加默认 opset 17")
        opset_import = sub_model.opset_import.add()
        opset_import.domain = ""
        opset_import.version = 17  # 使用 ONNX opset 17 作为默认值

    # 验证 opset_import 是否正确设置
    if len(sub_model.opset_import) == 0:
        raise RuntimeError(f"子图 {spec.label} 缺少 opset_import 信息")

    # 13. Shape inference 和 checker
    if run_shape_infer:
        try:
            sub_model = shape_inference.infer_shapes(sub_model)
        except Exception as e:
            logging.warning(f"子图 {spec.label} shape inference 失败: {e}")

    if run_checker:
        try:
            checker.check_model(sub_model)
        except Exception as e:
            logging.warning(f"子图 {spec.label} checker 验证失败: {e}")

    # 14. 保存
    onnx.save(sub_model, destination.as_posix())
    logging.info(
        "保存子图 %s (start=%s, end=%s, 节点数=%d, checker=%s, infer_shape=%s)",
        destination.name,
        spec.start,
        spec.end,
        len(sub_graph.node),
        bool(run_checker),
        bool(run_shape_infer),
    )
    return destination



def load_numpy_inputs(path: Path, expected_inputs: Iterable[str]) -> Dict[str, np.ndarray]:
    suffix = path.suffix.lower()
    expected = list(expected_inputs)
    if suffix == ".npz":
        data = np.load(path, allow_pickle=False)
        return {key: data[key] for key in data.files}
    if suffix == ".npy":
        arr = np.load(path, allow_pickle=True)
        if isinstance(arr, np.ndarray) and arr.shape == () and isinstance(arr.item(), dict):
            return {str(k): np.asarray(v) for k, v in arr.item().items()}
        if isinstance(arr, np.ndarray) and arr.dtype.names:
            return {name: arr[name] for name in arr.dtype.names}
        if len(expected) == 1:
            return {expected[0]: np.asarray(arr)}
        raise ValueError("多输入模型需要字典格式的 .npy/.npz 数据。")
    raise ValueError("仅支持 .npz 或 .npy 输入数据。")


def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]):
    if ort is None:
        raise RuntimeError("需要 onnxruntime 才能执行验证。")
    session = ort.InferenceSession(model_path.as_posix(), providers=providers)
    outputs = session.run(None, feed_dict)
    names = [meta.name for meta in session.get_outputs()]
    return dict(zip(names, outputs))


def run_split_pipeline(
    ordered_subgraphs: Sequence[SubGraphSpec],
    feed_dict: Dict[str, np.ndarray],
    providers: List[str],
) -> Dict[str, np.ndarray]:
    if ort is None:
        raise RuntimeError("需要 onnxruntime 才能执行验证。")
    tensor_store = dict(feed_dict)
    for spec in ordered_subgraphs:
        if spec.output_path is None:
            raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。")
        if not spec.output_path.exists():
            raise RuntimeError(f"子图 {spec.label} 的输出文件不存在: {spec.output_path}")

        logging.info("运行子图 %s (输入: %s, 输出: %s)", spec.label, spec.start, spec.end)
        session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers)

        # 获取实际的输入名称
        actual_inputs = [inp.name for inp in session.get_inputs()]
        logging.debug("子图 %s 实际输入: %s", spec.label, actual_inputs)

        fetch_inputs = {}
        for name in actual_inputs:
            if name not in tensor_store:
                # 尝试从 spec.start 中查找
                if name in spec.start:
                    logging.warning(f"子图 {spec.label} 缺少输入张量 {name},尝试从 feed_dict 查找")
                    if name in feed_dict:
                        tensor_store[name] = feed_dict[name]
                    else:
                        available = list(tensor_store.keys())
                        raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}")
                else:
                    available = list(tensor_store.keys())
                    raise KeyError(f"子图 {spec.label} 缺少输入张量 {name}。当前可用: {available}")
            fetch_inputs[name] = tensor_store[name]

        results = session.run(None, fetch_inputs)
        for meta, value in zip(session.get_outputs(), results):
            tensor_store[meta.name] = value
            logging.debug("子图 %s 产生输出: %s (shape=%s)", spec.label, meta.name, value.shape)
    return tensor_store


def verify(
    model_path: Path,
    ordered_subgraphs: Sequence[SubGraphSpec],
    feed_dict: Dict[str, np.ndarray],
    providers: List[str],
    rtol: float,
    atol: float,
) -> None:
    """验证切分后的子图流水线与原始模型输出是否一致。

    验证流程:
    1. 运行完整的原始模型,获得所有输出
    2. 按拓扑顺序依次运行所有子图(前一个子图的输出作为后一个子图的输入)
    3. 比较原始模型的最终输出与子图流水线的最终输出
    4. 如果所有输出在指定的误差范围内一致,则验证通过
    """
    logging.info("开始验证:运行原始模型...")
    full_outputs = run_full_model(model_path, feed_dict, providers)
    logging.info("原始模型运行完成,产生 %d 个输出", len(full_outputs))

    logging.info("开始验证:运行子图流水线...")
    split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers)
    logging.info("子图流水线运行完成")

    logging.info("比较输出...")
    for name, ref in full_outputs.items():
        cand = split_store.get(name)
        if cand is None:
            raise AssertionError(f"切分流水线未产生模型输出 {name}")
        if not np.allclose(ref, cand, rtol=rtol, atol=atol):
            diff = float(np.max(np.abs(ref - cand)))
            raise AssertionError(f"输出 {name} 不匹配,最大偏差 {diff:.3e}")
        logging.info("✓ 输出 %s 验证通过 (shape=%s)", name, ref.shape)
    logging.info("✓ 切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol)


def save_outputs(outputs: Dict[str, np.ndarray], destination: Optional[Path]) -> None:
    if destination is None:
        return
    destination.parent.mkdir(parents=True, exist_ok=True)
    np.savez(destination, **outputs)
    logging.info("流水线输出已保存至 %s", destination)


def load_sub_configs(config_path: Path) -> List[dict]:
    with config_path.open("r", encoding="utf-8") as f:
        config = json.load(f)
    sub_configs = config.get("compiler", {}).get("sub_configs")
    if not sub_configs:
        sub_configs = config.get("sub_configs")
    if not sub_configs:
        raise ValueError("配置文件中未找到 sub_configs。")
    return sub_configs


def main() -> None:
    parser = argparse.ArgumentParser(description="根据 tensor 名切分 ONNX 子图")
    parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 文件")
    parser.add_argument("--config", required=True, type=Path, help="包含 sub_configs 的 JSON")
    parser.add_argument("--output-dir", default="./split-onnx", type=Path, help="子模型输出目录")
    parser.add_argument("--providers", nargs="*", default=["CPUExecutionProvider"], help="onnxruntime providers 顺序")
    parser.add_argument("--verify", action="store_true", help="比较原始模型与流水线输出")
    parser.add_argument("--run-pipeline", action="store_true", help="只运行切分流水线并输出结果")
    parser.add_argument("--input-data", type=Path, help=".npz/.npy 格式的输入数据")
    parser.add_argument("--pipeline-output", type=Path, help="保存流水线输出为 npz")
    parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol")
    parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol")
    parser.add_argument("--skip-checker", action="store_true", help="跳过 onnx checker")
    parser.add_argument("--skip-shape-infer", action="store_true", help="跳过 shape inference")
    parser.add_argument("--skip-existing", action="store_true", help="跳过已存在的子图文件")
    parser.add_argument("--log", default="INFO", help="日志等级")
    args = parser.parse_args()

    logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO))

    # 必须加载完整数据(包括外部数据文件)
    # 虽然模型很大,但这是必须的,因为我们需要将权重完全内嵌到子模型中
    # 否则子模型会引用原始的外部数据路径,导致在新目录下无法找到数据文件
    logging.info("加载完整模型(包括外部数据文件)...这可能需要一些时间和内存")
    model = onnx.load(args.model.as_posix(), load_external_data=True)

    # 跳过对原始巨大模型的 checker,只对切分后的小模型进行 checker
    # if not args.skip_checker:
    #     checker.check_model(args.model.as_posix())
    logging.info("跳过对原始模型的 checker(模型过大),将只对切分后的子模型进行验证")

    graph_index = build_graph_index(model)

    sub_configs = load_sub_configs(args.config)
    specs: List[SubGraphSpec] = []
    covered_nodes: Set[str] = set()

    for idx, entry in enumerate(sub_configs):
        start = [name for name in entry.get("start_tensor_names", []) if name]
        end = [name for name in entry.get("end_tensor_names", []) if name]
        if not start or not end:
            raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。")
        spec = SubGraphSpec(
            label=f"cfg_{idx:02d}",
            start=start,
            end=end,
            node_names=set(),
            source="config",
        )
        nodes = trace_nodes_between(spec, graph_index)
        spec.node_names = nodes
        covered_nodes.update(nodes)
        specs.append(spec)

    leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index)
    for idx, component in enumerate(leftovers):
        start, end = derive_interface(component, graph_index)
        if not end:
            continue
        spec = SubGraphSpec(
            label=f"auto_{idx:02d}",
            start=start,
            end=end,
            node_names=component,
            source="auto",
        )
        specs.append(spec)
        logging.info(
            "自动补充子图 %s: start=%s end=%s (节点数=%d)",
            spec.label,
            spec.start,
            spec.end,
            len(component),
        )

    ordered = ordered_specs(specs, graph_index)

    required_tensors: Set[str] = set()
    for spec in ordered:
        required_tensors.update(spec.start)
        required_tensors.update(spec.end)
    ensure_value_infos(model, required_tensors)

    args.output_dir.mkdir(parents=True, exist_ok=True)

    # 首先检查所有现有的子图文件,收集需要重新生成的
    corrupted_specs = []
    for spec in ordered:
        head = sanitize(spec.start[0]) if spec.start else "const"
        tail = sanitize(spec.end[0]) if spec.end else "out"
        filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.onnx"
        potential_path = args.output_dir / filename

        if potential_path.exists() and not is_valid_onnx_model(potential_path):
            corrupted_specs.append(spec.label)
            logging.warning(f"检测到损坏的子图: {spec.label}, 将重新生成")

    if corrupted_specs:
        logging.info(f"发现 {len(corrupted_specs)} 个损坏的子图,将重新生成: {corrupted_specs}")

    # 生成或更新子图
    for spec in ordered:
        spec.output_path = extract_model_file(
            model,
            spec,
            args.output_dir,
            spec.source,
            run_checker=not args.skip_checker,
            run_shape_infer=not args.skip_shape_infer,
            skip_existing=args.skip_existing,
        )

    need_inputs = args.verify or args.run_pipeline
    if need_inputs:
        if args.input_data is None:
            raise ValueError("verify/run-pipeline 模式需要 --input-data 提供输入。")
        feed = load_numpy_inputs(args.input_data, graph_index.graph_inputs)
        missing_inputs = graph_index.graph_inputs - feed.keys()
        if missing_inputs:
            raise ValueError(f"输入数据缺少以下张量: {sorted(missing_inputs)}")
    else:
        feed = {}

    if args.verify:
        verify(args.model, ordered, feed, args.providers, args.rtol, args.atol)
    elif args.run_pipeline:
        outputs = run_split_pipeline(ordered, feed, args.providers)
        save_outputs(outputs, args.pipeline_output)
    else:
        logging.info("子模型已生成,如需验证请添加 --verify 或 --run-pipeline。")


if __name__ == "__main__":
    """
    python ./scripts/split_onnx_by_subconfig.py \
        --model ./onnx-models/z_image_transformer_body_only_simp_slim.onnx \
        --config ./pulsar2_configs/transformers_subgraph.json \
        --output-dir ./transformers_body_only_split_onnx \
        --verify \
        --input-data ./onnx-calibration-no-controlnet/transformer_inputs_prompt000_step00.npy \
        --providers CPUExecutionProvider
    """
    main()