Z-Image-Turbo / VideoX-Fun /scripts /verify_and_fix_subgraphs.py
yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""
验证并修复损坏的子图文件
"""
import argparse
import logging
from pathlib import Path
import onnx
def is_valid_onnx_model(model_path: Path) -> tuple[bool, str]:
"""检查 ONNX 模型文件是否有效,返回 (是否有效, 错误信息)"""
try:
# 检查文件大小
if model_path.stat().st_size == 0:
return False, "文件为空"
model = onnx.load(model_path.as_posix(), load_external_data=False)
# 检查模型是否为 None
if model is None:
return False, "模型加载后为 None"
# 检查是否有 graph
if not hasattr(model, 'graph') or model.graph is None:
return False, "缺少 graph"
# 检查是否有 opset_import
if len(model.opset_import) == 0:
return False, "缺少 opset_import 信息"
return True, "OK"
except Exception as e:
return False, str(e)
def main():
parser = argparse.ArgumentParser(description="验证子图文件的有效性")
parser.add_argument("--dir", type=Path, required=True, help="包含子图的目录")
parser.add_argument("--remove-corrupted", action="store_true", help="删除损坏的文件")
parser.add_argument("--log", default="INFO", help="日志等级")
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log.upper(), logging.INFO),
format='%(levelname)s: %(message)s'
)
if not args.dir.exists():
logging.error(f"目录不存在: {args.dir}")
return 1
# 获取所有 .onnx 文件
onnx_files = sorted(args.dir.glob("*.onnx"))
if not onnx_files:
logging.warning(f"目录 {args.dir} 中没有找到 ONNX 文件")
return 0
logging.info(f"找到 {len(onnx_files)} 个 ONNX 文件")
logging.info("=" * 80)
valid_count = 0
corrupted_files = []
for idx, onnx_file in enumerate(onnx_files, 1):
is_valid, error_msg = is_valid_onnx_model(onnx_file)
if is_valid:
logging.info(f"[{idx}/{len(onnx_files)}] ✓ {onnx_file.name}")
valid_count += 1
else:
logging.error(f"[{idx}/{len(onnx_files)}] ✗ {onnx_file.name}: {error_msg}")
corrupted_files.append(onnx_file)
# 总结
logging.info("=" * 80)
logging.info(f"总计: {len(onnx_files)} 个文件")
logging.info(f"有效: {valid_count} 个文件")
logging.info(f"损坏: {len(corrupted_files)} 个文件")
if corrupted_files:
logging.info("\n损坏的文件列表:")
for f in corrupted_files:
logging.info(f" - {f.name}")
if args.remove_corrupted:
logging.info("\n删除损坏的文件...")
for f in corrupted_files:
f.unlink()
logging.info(f" 已删除: {f.name}")
logging.info(f"已删除 {len(corrupted_files)} 个损坏的文件")
else:
logging.info("\n使用 --remove-corrupted 选项来删除这些损坏的文件")
logging.info("然后重新运行 split_onnx_by_subconfig.py 来重新生成")
return 0 if len(corrupted_files) == 0 else 1
if __name__ == "__main__":
"""
python ./scripts/verify_and_fix_subgraphs.py \
--dir ./transformers_body_only_1728_992_split_onnx
"""
exit(main())