File size: 3,491 Bytes
ba96580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
"""
验证并修复损坏的子图文件
"""
import argparse
import logging
from pathlib import Path
import onnx


def is_valid_onnx_model(model_path: Path) -> tuple[bool, str]:
    """检查 ONNX 模型文件是否有效,返回 (是否有效, 错误信息)"""
    try:
        # 检查文件大小
        if model_path.stat().st_size == 0:
            return False, "文件为空"
            
        model = onnx.load(model_path.as_posix(), load_external_data=False)
        
        # 检查模型是否为 None
        if model is None:
            return False, "模型加载后为 None"
            
        # 检查是否有 graph
        if not hasattr(model, 'graph') or model.graph is None:
            return False, "缺少 graph"
            
        # 检查是否有 opset_import
        if len(model.opset_import) == 0:
            return False, "缺少 opset_import 信息"
            
        return True, "OK"
    except Exception as e:
        return False, str(e)


def main():
    parser = argparse.ArgumentParser(description="验证子图文件的有效性")
    parser.add_argument("--dir", type=Path, required=True, help="包含子图的目录")
    parser.add_argument("--remove-corrupted", action="store_true", help="删除损坏的文件")
    parser.add_argument("--log", default="INFO", help="日志等级")
    args = parser.parse_args()
    
    logging.basicConfig(
        level=getattr(logging, args.log.upper(), logging.INFO),
        format='%(levelname)s: %(message)s'
    )
    
    if not args.dir.exists():
        logging.error(f"目录不存在: {args.dir}")
        return 1
    
    # 获取所有 .onnx 文件
    onnx_files = sorted(args.dir.glob("*.onnx"))
    
    if not onnx_files:
        logging.warning(f"目录 {args.dir} 中没有找到 ONNX 文件")
        return 0
    
    logging.info(f"找到 {len(onnx_files)} 个 ONNX 文件")
    logging.info("=" * 80)
    
    valid_count = 0
    corrupted_files = []
    
    for idx, onnx_file in enumerate(onnx_files, 1):
        is_valid, error_msg = is_valid_onnx_model(onnx_file)
        
        if is_valid:
            logging.info(f"[{idx}/{len(onnx_files)}] ✓ {onnx_file.name}")
            valid_count += 1
        else:
            logging.error(f"[{idx}/{len(onnx_files)}] ✗ {onnx_file.name}: {error_msg}")
            corrupted_files.append(onnx_file)
    
    # 总结
    logging.info("=" * 80)
    logging.info(f"总计: {len(onnx_files)} 个文件")
    logging.info(f"有效: {valid_count} 个文件")
    logging.info(f"损坏: {len(corrupted_files)} 个文件")
    
    if corrupted_files:
        logging.info("\n损坏的文件列表:")
        for f in corrupted_files:
            logging.info(f"  - {f.name}")
        
        if args.remove_corrupted:
            logging.info("\n删除损坏的文件...")
            for f in corrupted_files:
                f.unlink()
                logging.info(f"  已删除: {f.name}")
            logging.info(f"已删除 {len(corrupted_files)} 个损坏的文件")
        else:
            logging.info("\n使用 --remove-corrupted 选项来删除这些损坏的文件")
            logging.info("然后重新运行 split_onnx_by_subconfig.py 来重新生成")
    
    return 0 if len(corrupted_files) == 0 else 1


if __name__ == "__main__":
    """
    python ./scripts/verify_and_fix_subgraphs.py \
        --dir ./transformers_body_only_1728_992_split_onnx
    """
    exit(main())