import argparse import logging import os from pathlib import Path import warnings import torch from sharp.models import PredictorParams, create_predictor LOGGER = logging.getLogger(__name__) # 默认权重文件名 DEFAULT_CHECKPOINT_NAME = "sharp_2572gikvuh.pt" MODEL_URL = f"https://ml-site.cdn-apple.com/models/sharp/{DEFAULT_CHECKPOINT_NAME}" def verify_model_precision(model_path: Path): """ 静态检测 ONNX 模型的权重数据类型分布,用于验证量化是否生效。 """ try: import onnx from collections import Counter # 尝试加载模型,注意大模型可能带 .data 后缀 model = onnx.load(str(model_path)) # 统计所有初始化器(权重)的数据类型 # 在 ONNX 中,FLOAT=1, UINT8=2, INT8=3, FLOAT16=10, UINT4=27, INT4=28, FLOAT4=29 dtypes = [onnx.TensorProto.DataType.Name(i.data_type) for i in model.graph.initializer] counter = Counter(dtypes) LOGGER.info(f"--- 权重类型分布报告 [{model_path.name}] ---") for dtype, count in counter.items(): LOGGER.info(f" - {dtype}: {count} 个张量") # 计算文件大小并打印 file_size = os.path.getsize(model_path) LOGGER.info(f" [文件信息] 路径: {model_path}") LOGGER.info(f" [文件信息] 体积: {file_size / (1024*1024):.2f} MB") LOGGER.info("------------------------------------------------") except ImportError: LOGGER.warning("未安装 'onnx' 模块,跳过权重精度验证。") except Exception as e: LOGGER.warning(f"验证模型精度时发生异常: {e}") def export_onnx(output_path: Path, checkpoint_path: str = None, internal_shape: int = 1536, is_fp16: bool = False): device = torch.device("cpu") LOGGER.info("正在创建预测器模型...") predictor = create_predictor(PredictorParams()) # FP16 处理:将网络参数转为半精度,体积减半 if is_fp16: LOGGER.info("正在将模型转换为 FP16 (半精度) 以直接将其缩减进单一 < 2GB 的文件中...") predictor = predictor.half() if checkpoint_path is None: torch_hub_checkpoints = Path(torch.hub.get_dir()) / "checkpoints" search_paths = [ Path(DEFAULT_CHECKPOINT_NAME), Path("data") / DEFAULT_CHECKPOINT_NAME, torch_hub_checkpoints / DEFAULT_CHECKPOINT_NAME, ] found_path = None for p in search_paths: if p.exists(): found_path = p break if found_path: LOGGER.info(f"在本地找到模型权重: {found_path},加载本地权重...") state_dict = torch.load(found_path, map_location="cpu", weights_only=True) else: state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, progress=True, map_location="cpu") else: state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) predictor.load_state_dict(state_dict) predictor.eval() predictor.to(device) dummy_image = torch.randn(1, 3, internal_shape, internal_shape, device=device) dummy_disparity = torch.tensor([1.0], device=device) if is_fp16: dummy_image = dummy_image.half() dummy_disparity = dummy_disparity.half() dummy_inputs = (dummy_image, dummy_disparity) LOGGER.info("正在导出为 ONNX 格式...") torch.onnx.export( predictor, dummy_inputs, str(output_path), export_params=True, opset_version=19, do_constant_folding=True, input_names=['image', 'disparity_factor'], output_names=['mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities'], keep_initializers_as_inputs=False, dynamic_axes={k: {0: 'batch_size'} for k in ['image', 'disparity_factor', 'mean_vectors', 'singular_values', 'quaternions', 'colors', 'opacities']} ) LOGGER.info(f"成功将全精度 ONNX 模型导出至: {output_path}") def main(): logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") parser = argparse.ArgumentParser(description="将 ML-SHARP 模型导出为多种精度的 ONNX 格式。") parser.add_argument("--output", type=str, default="sharp.onnx", help="导出的基础 ONNX 模型保存路径。") parser.add_argument("--checkpoint", type=str, default=None, help="本地 .pt 权重文件路径。") parser.add_argument("--fp16", action="store_true", help="启用 FP16 半精度导出。") parser.add_argument("--int8", action="store_true", help="应用动态 INT8 量化。") parser.add_argument("--int4", action="store_true", help="应用真正的 INT4 重量级量化 (WOQ)。") args = parser.parse_args() # 如果启用了 FP16 模式且用户未指定 output,则重定向默认输出名 if args.fp16 and args.output == "sharp.onnx": args.output = "sharp_fp16.onnx" output_path = Path(args.output) is_quant_requested = args.int8 or args.int4 if is_quant_requested and output_path.exists(): LOGGER.info(f"检测到基础模型 '{output_path}' 已存在,跳过导出并开始验证。") verify_model_precision(output_path) else: export_onnx(output_path, args.checkpoint, is_fp16=args.fp16) verify_model_precision(output_path) if is_quant_requested: try: from onnxruntime.quantization import quantize_dynamic, QuantType import onnx from onnx import shape_inference except ImportError: LOGGER.error("缺失依赖项:请安装 onnx 和 onnxruntime。") return preprocessed_path = output_path.with_name(f"{output_path.stem}_pre.onnx") if not preprocessed_path.exists(): LOGGER.info("正在将基础模型转换为外部数据格式 (已跳过导致崩溃的形状推导模式)...") try: model_raw = onnx.load(str(output_path)) onnx.save( model_raw, str(preprocessed_path), save_as_external_data=True, all_tensors_to_one_file=True, location=preprocessed_path.name + ".data", size_threshold=1024, convert_attribute=True ) model_source = preprocessed_path except Exception as e: LOGGER.warning(f"预分流模型保存失败: {e}") model_source = output_path else: model_source = preprocessed_path tasks = [] if args.int8: tasks.append("int8") if args.int4: tasks.append("int4") # 1. INT8 量化 int8_path = output_path.with_name(f"{output_path.stem}_int8.onnx") if "int8" in tasks: LOGGER.info("正在应用动态 INT8 量化...") quantize_dynamic(str(model_source), str(int8_path), weight_type=QuantType.QUInt8) verify_model_precision(int8_path) # 2. INT4 重量级量化 (WOQ) if "int4" in tasks: LOGGER.info("正在准备进行真正的 INT4 权重压缩 (WOQ)...") quantized_output = output_path.with_name(f"{output_path.stem}_int4.onnx") import onnxruntime.quantization as oq import importlib import pkgutil # 优先尝试新版 Unified API quantize_func = getattr(oq, "quantize", None) WeightOnlyConfig = getattr(oq, "MatMulWeightOnlyQuantConfig", None) if quantize_func and WeightOnlyConfig: LOGGER.info("[INT4] 使用 MatMulWeightOnlyQuantConfig 执行量化...") try: model = onnx.load(str(model_source)) config = WeightOnlyConfig(weight_type=getattr(QuantType, 'QInt4', QuantType.QUInt8)) quantize_func(model, str(quantized_output), config) LOGGER.info(f"[INT4] Unified API 导出成功!") verify_model_precision(quantized_output) return except Exception as e: LOGGER.warning(f"[INT4] Unified API 失败: {e}") # Fallback 1: 尝试 Bnb4Quantizer Bnb4Quantizer = None try: from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer Bnb4Quantizer = MatMulBnb4Quantizer except: pass if Bnb4Quantizer: LOGGER.info("[INT4] 使用 MatMulBnb4Quantizer (BnB 4-bit) 执行量化...") try: import inspect model = onnx.load(str(model_source)) sig = inspect.signature(Bnb4Quantizer.__init__) params = sig.parameters kwargs = {"model": model, "block_size": 32} if "quant_type" in params: kwargs["quant_type"] = 1 q = Bnb4Quantizer(**kwargs) q.process() m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q) LOGGER.info(f"正在保存单体 INT4 模型...") onnx.save(m, str(quantized_output)) LOGGER.info(f"[INT4] BnB 导出成功!") verify_model_precision(quantized_output) return except Exception as e: LOGGER.warning(f"[INT4] BnB 失败: {e}") # Fallback 2: 暴力搜索兼容旧版 Class API def search_legacy(pkg): for attr in dir(pkg): low = attr.lower() if ("quantiz" in low or "quantis" in low) and ("4bit" in low or "nbit" in low or "weight" in low): return getattr(pkg, attr), pkg.__name__ if hasattr(pkg, "__path__"): for _, modname, _ in pkgutil.walk_packages(pkg.__path__, pkg.__name__ + "."): try: mod = importlib.import_module(modname) res, src = search_legacy(mod) if res: return res, src except: continue return None, None LegacyQuantizer, src = search_legacy(oq) if LegacyQuantizer and (isinstance(LegacyQuantizer, type) or callable(LegacyQuantizer)): LOGGER.info(f"[INT4] 找到量化类/函数: {src}.{LegacyQuantizer.__name__}") try: model = onnx.load(str(model_source)) if "Quantizer" in LegacyQuantizer.__name__: q = LegacyQuantizer(model, block_size=32, is_symmetric=True) q.process() m = q.model.model if hasattr(q, 'model') and hasattr(q.model, 'model') else (q.model if hasattr(q, 'model') else q) onnx.save(m, str(quantized_output)) else: LegacyQuantizer(model, str(quantized_output), weight_type=QuantType.QInt4) LOGGER.info(f"[INT4] 导出成功!") verify_model_precision(quantized_output) except Exception as e: LOGGER.error(f"[INT4] 执行失败: {e}") else: LOGGER.error(f"[INT4] 未找到有效量化器。所有属性: {dir(oq)}") if __name__ == "__main__": main()