| | import sys |
| | import os |
| | import torch |
| | import json |
| | from safetensors.torch import load_file |
| |
|
| | |
| | sys.path.append(os.path.join(os.getcwd(), 'DA-2-repo/src')) |
| |
|
| | try: |
| | from da2.model.spherevit import SphereViT |
| | except ImportError as e: |
| | print(f"Error importing SphereViT: {e}") |
| | sys.exit(1) |
| |
|
| | |
| | config_path = 'DA-2-repo/configs/infer.json' |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| |
|
| | |
| | |
| | |
| | H, W = 546, 1092 |
| | config['inference']['min_pixels'] = H * W |
| | config['inference']['max_pixels'] = H * W |
| |
|
| | print(f"Initializing model with input size {W}x{H}...") |
| | |
| | model = SphereViT(config) |
| |
|
| | |
| | print("Loading weights from model.safetensors...") |
| | try: |
| | weights = load_file('model.safetensors') |
| | missing, unexpected = model.load_state_dict(weights, strict=False) |
| | if missing: |
| | print(f"Missing keys: {len(missing)}") |
| | |
| | if unexpected: |
| | print(f"Unexpected keys: {len(unexpected)}") |
| | |
| | except Exception as e: |
| | print(f"Error loading weights: {e}") |
| | sys.exit(1) |
| |
|
| | print("Exporting model in FP32 (full precision)...") |
| | model.eval() |
| |
|
| | |
| | dummy_input = torch.randn(1, 3, H, W) |
| |
|
| | |
| | output_file = "onnx/model.onnx" |
| | print(f"Exporting to {output_file}...") |
| | try: |
| | torch.onnx.export( |
| | model, |
| | dummy_input, |
| | output_file, |
| | opset_version=17, |
| | input_names=["pixel_values"], |
| | output_names=["predicted_depth"], |
| | dynamic_axes={ |
| | "pixel_values": {0: "batch_size"}, |
| | "predicted_depth": {0: "batch_size"} |
| | }, |
| | export_params=True, |
| | do_constant_folding=True, |
| | verbose=False |
| | ) |
| |
|
| | print(f"Successfully exported to {output_file}") |
| | |
| | try: |
| | from onnxruntime.quantization import quantize_dynamic, QuantType |
| | quantized_output_file = "onnx/model_quantized.onnx" |
| | print(f"Quantizing model to {quantized_output_file}...") |
| | quantize_dynamic( |
| | output_file, |
| | quantized_output_file, |
| | weight_type=QuantType.QInt8 |
| | ) |
| | print(f"Successfully quantized to {quantized_output_file}") |
| | except Exception as qe: |
| | print(f"Error during quantization: {qe}") |
| | import traceback |
| | traceback.print_exc() |
| | except Exception as e: |
| | print(f"Error exporting to ONNX: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|