#!/usr/bin/env python3 """ 验证并修复损坏的子图文件 """ import argparse import logging from pathlib import Path import onnx def is_valid_onnx_model(model_path: Path) -> tuple[bool, str]: """检查 ONNX 模型文件是否有效,返回 (是否有效, 错误信息)""" try: # 检查文件大小 if model_path.stat().st_size == 0: return False, "文件为空" model = onnx.load(model_path.as_posix(), load_external_data=False) # 检查模型是否为 None if model is None: return False, "模型加载后为 None" # 检查是否有 graph if not hasattr(model, 'graph') or model.graph is None: return False, "缺少 graph" # 检查是否有 opset_import if len(model.opset_import) == 0: return False, "缺少 opset_import 信息" return True, "OK" except Exception as e: return False, str(e) def main(): parser = argparse.ArgumentParser(description="验证子图文件的有效性") parser.add_argument("--dir", type=Path, required=True, help="包含子图的目录") parser.add_argument("--remove-corrupted", action="store_true", help="删除损坏的文件") parser.add_argument("--log", default="INFO", help="日志等级") args = parser.parse_args() logging.basicConfig( level=getattr(logging, args.log.upper(), logging.INFO), format='%(levelname)s: %(message)s' ) if not args.dir.exists(): logging.error(f"目录不存在: {args.dir}") return 1 # 获取所有 .onnx 文件 onnx_files = sorted(args.dir.glob("*.onnx")) if not onnx_files: logging.warning(f"目录 {args.dir} 中没有找到 ONNX 文件") return 0 logging.info(f"找到 {len(onnx_files)} 个 ONNX 文件") logging.info("=" * 80) valid_count = 0 corrupted_files = [] for idx, onnx_file in enumerate(onnx_files, 1): is_valid, error_msg = is_valid_onnx_model(onnx_file) if is_valid: logging.info(f"[{idx}/{len(onnx_files)}] ✓ {onnx_file.name}") valid_count += 1 else: logging.error(f"[{idx}/{len(onnx_files)}] ✗ {onnx_file.name}: {error_msg}") corrupted_files.append(onnx_file) # 总结 logging.info("=" * 80) logging.info(f"总计: {len(onnx_files)} 个文件") logging.info(f"有效: {valid_count} 个文件") logging.info(f"损坏: {len(corrupted_files)} 个文件") if corrupted_files: logging.info("\n损坏的文件列表:") for f in corrupted_files: logging.info(f" - {f.name}") if args.remove_corrupted: logging.info("\n删除损坏的文件...") for f in corrupted_files: f.unlink() logging.info(f" 已删除: {f.name}") logging.info(f"已删除 {len(corrupted_files)} 个损坏的文件") else: logging.info("\n使用 --remove-corrupted 选项来删除这些损坏的文件") logging.info("然后重新运行 split_onnx_by_subconfig.py 来重新生成") return 0 if len(corrupted_files) == 0 else 1 if __name__ == "__main__": """ python ./scripts/verify_and_fix_subgraphs.py \ --dir ./transformers_body_only_1728_992_split_onnx """ exit(main())