Kyle Pearson commited on
Commit
2cda5f8
·
1 Parent(s): 9bef2af

inference script

Browse files
Files changed (1) hide show
  1. inference_onnx.py +302 -0
inference_onnx.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """ONNX Inference Script for SHARP Model.
3
+
4
+ Loads an ONNX model (fp32 or fp16), runs inference on an input image,
5
+ and exports the result as a PLY file.
6
+
7
+ Usage:
8
+ python inference_onnx.py -m sharp.onnx -i test.png -o output.ply
9
+ python inference_onnx.py -m sharp_inline_fp16.onnx -i test.png -o output.ply -d 0.5
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import logging
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import onnxruntime as ort
20
+ from PIL import Image
21
+ from plyfile import PlyData, PlyElement
22
+
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
+ LOGGER = logging.getLogger(__name__)
25
+
26
+ DEFAULT_HEIGHT = 1536
27
+ DEFAULT_WIDTH = 1536
28
+
29
+
30
+ def linear_to_srgb(linear: float) -> float:
31
+ if linear <= 0.0031308:
32
+ return linear * 12.92
33
+ return 1.055 * pow(linear, 1.0 / 2.4) - 0.055
34
+
35
+
36
+ def rgb_to_sh(rgb: float) -> float:
37
+ coeff_degree0 = 1.0 / np.sqrt(4.0 * np.pi)
38
+ return (rgb - 0.5) / coeff_degree0
39
+
40
+
41
+ def inverse_sigmoid(x: float) -> float:
42
+ x = np.clip(x, 1e-6, 1.0 - 1e-6)
43
+ return np.log(x / (1.0 - x))
44
+
45
+
46
+ def preprocess_image(image_path: str | Path, target_size: tuple[int, int] = (DEFAULT_HEIGHT, DEFAULT_WIDTH)):
47
+ """Load and preprocess an image for ONNX inference."""
48
+ image_path = Path(image_path)
49
+ target_h, target_w = target_size
50
+
51
+ img = Image.open(image_path)
52
+ original_size = img.size
53
+ focal_length_px = original_size[0]
54
+
55
+ if img.size != (target_w, target_h):
56
+ img = img.resize((target_w, target_h), Image.BILINEAR)
57
+
58
+ img_np = np.array(img, dtype=np.float32) / 255.0
59
+
60
+ if img_np.shape[2] == 4:
61
+ img_np = img_np[:, :, :3]
62
+
63
+ img_np = np.transpose(img_np, (2, 0, 1))
64
+ img_np = np.expand_dims(img_np, axis=0)
65
+
66
+ LOGGER.info(f"Loaded image: {image_path}, original size: {original_size}")
67
+ LOGGER.info(f"Preprocessed shape: {img_np.shape}, range: [{img_np.min():.4f}, {img_np.max():.4f}]")
68
+
69
+ return img_np, float(focal_length_px), original_size
70
+
71
+
72
+ def run_inference(onnx_path: str | Path, image: np.ndarray, disparity_factor: float = 1.0) -> dict[str, np.ndarray]:
73
+ """Run ONNX inference on the preprocessed image."""
74
+ onnx_path = Path(onnx_path)
75
+
76
+ LOGGER.info(f"Loading ONNX model: {onnx_path}")
77
+
78
+ # Try with default providers first, then fallback to CPU only
79
+ try:
80
+ session = ort.InferenceSession(str(onnx_path))
81
+ except Exception as e:
82
+ error_msg = str(e)
83
+ if "tensor(float16)" in error_msg and "tensor(float)" in error_msg:
84
+ LOGGER.error("FP16 model has mixed float16/float32 types. This model was converted incorrectly.")
85
+ LOGGER.error("For FP16 inference on Apple Silicon, use the Core ML model (sharp.mlpackage) instead.")
86
+ LOGGER.error("Or regenerate the ONNX model with proper FP16 conversion.")
87
+ raise RuntimeError(f"Invalid FP16 model: {error_msg}")
88
+ # Try CPU fallback
89
+ try:
90
+ session = ort.InferenceSession(str(onnx_path), providers=['CPUExecutionProvider'])
91
+ except Exception as cpu_e:
92
+ raise RuntimeError(f"Failed to load ONNX model: {cpu_e}")
93
+
94
+ input_names = [inp.name for inp in session.get_inputs()]
95
+ output_names = [out.name for out in session.get_outputs()]
96
+
97
+ LOGGER.info(f"Input names: {input_names}")
98
+ LOGGER.info(f"Output names: {output_names}")
99
+
100
+ inputs = {
101
+ "image": image.astype(np.float32),
102
+ "disparity_factor": np.array([disparity_factor], dtype=np.float32)
103
+ }
104
+
105
+ LOGGER.info("Running inference...")
106
+ raw_outputs = session.run(None, inputs)
107
+
108
+ outputs = {}
109
+
110
+ if len(raw_outputs) == 1:
111
+ concat = raw_outputs[0]
112
+ sizes = [3, 3, 4, 3, 1]
113
+ names = [
114
+ "mean_vectors_3d_positions",
115
+ "singular_values_scales",
116
+ "quaternions_rotations",
117
+ "colors_rgb_linear",
118
+ "opacities_alpha_channel"
119
+ ]
120
+ start = 0
121
+ for name, size in zip(names, sizes):
122
+ outputs[name] = concat[:, :, start:start + size]
123
+ start += size
124
+ elif len(raw_outputs) == 5:
125
+ names = [
126
+ "mean_vectors_3d_positions",
127
+ "singular_values_scales",
128
+ "quaternions_rotations",
129
+ "colors_rgb_linear",
130
+ "opacities_alpha_channel"
131
+ ]
132
+ for name, out in zip(names, raw_outputs):
133
+ outputs[name] = out
134
+ else:
135
+ for name, out in zip(output_names, raw_outputs):
136
+ outputs[name] = out
137
+
138
+ for name, arr in outputs.items():
139
+ LOGGER.info(f" {name}: shape {arr.shape}")
140
+
141
+ return outputs
142
+
143
+
144
+ def export_ply(outputs: dict[str, np.ndarray], output_path: str | Path,
145
+ focal_length_px: float, image_shape: tuple[int, int],
146
+ decimation: float = 1.0) -> None:
147
+ """Export Gaussians to PLY file format."""
148
+ output_path = Path(output_path)
149
+
150
+ mean_vectors = outputs["mean_vectors_3d_positions"]
151
+ singular_values = outputs["singular_values_scales"]
152
+ quaternions = outputs["quaternions_rotations"]
153
+ colors = outputs["colors_rgb_linear"]
154
+ opacities = outputs["opacities_alpha_channel"]
155
+
156
+ mean_vectors = mean_vectors[0]
157
+ singular_values = singular_values[0]
158
+ quaternions = quaternions[0]
159
+ colors = colors[0]
160
+ opacities = opacities[0]
161
+
162
+ num_gaussians = mean_vectors.shape[0]
163
+ LOGGER.info(f"Exporting {num_gaussians} Gaussians to PLY")
164
+
165
+ if decimation < 1.0:
166
+ log_scales = np.log(np.maximum(singular_values, 1e-10))
167
+ scale_product = np.exp(np.sum(log_scales, axis=1))
168
+ importance = scale_product * opacities
169
+
170
+ indices = np.argsort(-importance)
171
+ keep_count = max(1, int(num_gaussians * decimation))
172
+ keep_indices = indices[:keep_count]
173
+ keep_indices.sort()
174
+
175
+ LOGGER.info(f"Decimating: keeping {keep_count} of {num_gaussians} ({decimation * 100:.1f}%)")
176
+
177
+ mean_vectors = mean_vectors[keep_indices]
178
+ singular_values = singular_values[keep_indices]
179
+ quaternions = quaternions[keep_indices]
180
+ colors = colors[keep_indices]
181
+ opacities = opacities[keep_indices]
182
+ num_gaussians = keep_count
183
+
184
+ vertex_data = np.zeros(num_gaussians, dtype=[
185
+ ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
186
+ ('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
187
+ ('opacity', 'f4'),
188
+ ('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
189
+ ('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4')
190
+ ])
191
+
192
+ vertex_data['x'] = mean_vectors[:, 0]
193
+ vertex_data['y'] = mean_vectors[:, 1]
194
+ vertex_data['z'] = mean_vectors[:, 2]
195
+
196
+ for i in range(num_gaussians):
197
+ r, g, b = colors[i]
198
+ srgb_r = linear_to_srgb(float(r))
199
+ srgb_g = linear_to_srgb(float(g))
200
+ srgb_b = linear_to_srgb(float(b))
201
+
202
+ vertex_data['f_dc_0'][i] = rgb_to_sh(srgb_r)
203
+ vertex_data['f_dc_1'][i] = rgb_to_sh(srgb_g)
204
+ vertex_data['f_dc_2'][i] = rgb_to_sh(srgb_b)
205
+
206
+ vertex_data['opacity'] = inverse_sigmoid(opacities)
207
+
208
+ vertex_data['scale_0'] = np.log(np.maximum(singular_values[:, 0], 1e-10))
209
+ vertex_data['scale_1'] = np.log(np.maximum(singular_values[:, 1], 1e-10))
210
+ vertex_data['scale_2'] = np.log(np.maximum(singular_values[:, 2], 1e-10))
211
+
212
+ vertex_data['rot_0'] = quaternions[:, 0]
213
+ vertex_data['rot_1'] = quaternions[:, 1]
214
+ vertex_data['rot_2'] = quaternions[:, 2]
215
+ vertex_data['rot_3'] = quaternions[:, 3]
216
+
217
+ vertex_element = PlyElement.describe(vertex_data, 'vertex')
218
+
219
+ # Extrinsic: 4x4 identity matrix as 16 separate properties
220
+ extrinsic_data = np.zeros(1, dtype=[('extrinsic', 'f4', (16,))])
221
+ extrinsic_data['extrinsic'][0] = [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]
222
+ extrinsic_element = PlyElement.describe(extrinsic_data, 'extrinsic')
223
+
224
+ img_h, img_w = image_shape
225
+ # Intrinsic: 3x3 matrix as 9 separate properties
226
+ intrinsic_data = np.zeros(1, dtype=[('intrinsic', 'f4', (9,))])
227
+ intrinsic_data['intrinsic'][0] = [focal_length_px, 0, img_w / 2, 0, focal_length_px, img_h / 2, 0, 0, 1]
228
+ intrinsic_element = PlyElement.describe(intrinsic_data, 'intrinsic')
229
+
230
+ # Image size: 2 separate uint32 properties
231
+ image_size_data = np.zeros(1, dtype=[('image_size', 'u4', (2,))])
232
+ image_size_data['image_size'][0] = [img_w, img_h]
233
+ image_size_element = PlyElement.describe(image_size_data, 'image_size')
234
+
235
+ # Frame: 2 separate int32 properties
236
+ frame_data = np.zeros(1, dtype=[('frame', 'i4', (2,))])
237
+ frame_data['frame'][0] = [1, num_gaussians]
238
+ frame_element = PlyElement.describe(frame_data, 'frame')
239
+
240
+ z_values = mean_vectors[:, 2]
241
+ z_safe = np.maximum(z_values, 1e-6)
242
+ disparities = 1.0 / z_safe
243
+ disparities.sort()
244
+ disparity_10 = disparities[int(len(disparities) * 0.1)] if len(disparities) > 0 else 0.0
245
+ disparity_90 = disparities[int(len(disparities) * 0.9)] if len(disparities) > 0 else 1.0
246
+ disparity_data = np.zeros(1, dtype=[('disparity', 'f4', (2,))])
247
+ disparity_data['disparity'][0] = [disparity_10, disparity_90]
248
+ disparity_element = PlyElement.describe(disparity_data, 'disparity')
249
+
250
+ # Color space: single uchar property
251
+ color_space_data = np.zeros(1, dtype=[('color_space', 'u1')])
252
+ color_space_data['color_space'][0] = 1
253
+ color_space_element = PlyElement.describe(color_space_data, 'color_space')
254
+
255
+ # Version: 3 uchar properties
256
+ version_data = np.zeros(1, dtype=[('version', 'u1', (3,))])
257
+ version_data['version'][0] = [1, 5, 0]
258
+ version_element = PlyElement.describe(version_data, 'version')
259
+
260
+ PlyData([
261
+ vertex_element,
262
+ extrinsic_element,
263
+ intrinsic_element,
264
+ image_size_element,
265
+ frame_element,
266
+ disparity_element,
267
+ color_space_element,
268
+ version_element
269
+ ], text=False).write(str(output_path))
270
+
271
+ LOGGER.info(f"Saved PLY with {num_gaussians} Gaussians to {output_path}")
272
+
273
+
274
+ def main():
275
+ parser = argparse.ArgumentParser(
276
+ description="ONNX Inference for SHARP - Generate 3D Gaussians from an image"
277
+ )
278
+ parser.add_argument("-m", "--model", type=str, required=True,
279
+ help="Path to ONNX model file")
280
+ parser.add_argument("-i", "--input", type=str, required=True,
281
+ help="Path to input image")
282
+ parser.add_argument("-o", "--output", type=str, required=True,
283
+ help="Path to output file (.ply)")
284
+ parser.add_argument("-d", "--decimate", type=float, default=1.0,
285
+ help="Decimation ratio 0.0-1.0 (default: 1.0 = keep all)")
286
+ parser.add_argument("--disparity-factor", type=float, default=1.0,
287
+ help="Disparity factor for depth conversion (default: 1.0)")
288
+
289
+ args = parser.parse_args()
290
+
291
+ # Preprocess image
292
+ image, focal_length_px, image_shape = preprocess_image(args.input)
293
+
294
+ # Run inference
295
+ outputs = run_inference(args.model, image, args.disparity_factor)
296
+
297
+ # Export to PLY
298
+ export_ply(outputs, args.output, focal_length_px, image_shape, args.decimate)
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()