|
|
|
|
|
""" |
|
|
验证并修复损坏的子图文件 |
|
|
""" |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
import onnx |
|
|
|
|
|
|
|
|
def is_valid_onnx_model(model_path: Path) -> tuple[bool, str]: |
|
|
"""检查 ONNX 模型文件是否有效,返回 (是否有效, 错误信息)""" |
|
|
try: |
|
|
|
|
|
if model_path.stat().st_size == 0: |
|
|
return False, "文件为空" |
|
|
|
|
|
model = onnx.load(model_path.as_posix(), load_external_data=False) |
|
|
|
|
|
|
|
|
if model is None: |
|
|
return False, "模型加载后为 None" |
|
|
|
|
|
|
|
|
if not hasattr(model, 'graph') or model.graph is None: |
|
|
return False, "缺少 graph" |
|
|
|
|
|
|
|
|
if len(model.opset_import) == 0: |
|
|
return False, "缺少 opset_import 信息" |
|
|
|
|
|
return True, "OK" |
|
|
except Exception as e: |
|
|
return False, str(e) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="验证子图文件的有效性") |
|
|
parser.add_argument("--dir", type=Path, required=True, help="包含子图的目录") |
|
|
parser.add_argument("--remove-corrupted", action="store_true", help="删除损坏的文件") |
|
|
parser.add_argument("--log", default="INFO", help="日志等级") |
|
|
args = parser.parse_args() |
|
|
|
|
|
logging.basicConfig( |
|
|
level=getattr(logging, args.log.upper(), logging.INFO), |
|
|
format='%(levelname)s: %(message)s' |
|
|
) |
|
|
|
|
|
if not args.dir.exists(): |
|
|
logging.error(f"目录不存在: {args.dir}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
onnx_files = sorted(args.dir.glob("*.onnx")) |
|
|
|
|
|
if not onnx_files: |
|
|
logging.warning(f"目录 {args.dir} 中没有找到 ONNX 文件") |
|
|
return 0 |
|
|
|
|
|
logging.info(f"找到 {len(onnx_files)} 个 ONNX 文件") |
|
|
logging.info("=" * 80) |
|
|
|
|
|
valid_count = 0 |
|
|
corrupted_files = [] |
|
|
|
|
|
for idx, onnx_file in enumerate(onnx_files, 1): |
|
|
is_valid, error_msg = is_valid_onnx_model(onnx_file) |
|
|
|
|
|
if is_valid: |
|
|
logging.info(f"[{idx}/{len(onnx_files)}] ✓ {onnx_file.name}") |
|
|
valid_count += 1 |
|
|
else: |
|
|
logging.error(f"[{idx}/{len(onnx_files)}] ✗ {onnx_file.name}: {error_msg}") |
|
|
corrupted_files.append(onnx_file) |
|
|
|
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info(f"总计: {len(onnx_files)} 个文件") |
|
|
logging.info(f"有效: {valid_count} 个文件") |
|
|
logging.info(f"损坏: {len(corrupted_files)} 个文件") |
|
|
|
|
|
if corrupted_files: |
|
|
logging.info("\n损坏的文件列表:") |
|
|
for f in corrupted_files: |
|
|
logging.info(f" - {f.name}") |
|
|
|
|
|
if args.remove_corrupted: |
|
|
logging.info("\n删除损坏的文件...") |
|
|
for f in corrupted_files: |
|
|
f.unlink() |
|
|
logging.info(f" 已删除: {f.name}") |
|
|
logging.info(f"已删除 {len(corrupted_files)} 个损坏的文件") |
|
|
else: |
|
|
logging.info("\n使用 --remove-corrupted 选项来删除这些损坏的文件") |
|
|
logging.info("然后重新运行 split_onnx_by_subconfig.py 来重新生成") |
|
|
|
|
|
return 0 if len(corrupted_files) == 0 else 1 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
python ./scripts/verify_and_fix_subgraphs.py \ |
|
|
--dir ./transformers_body_only_1728_992_split_onnx |
|
|
""" |
|
|
exit(main()) |
|
|
|