Delete triposplat.py

#4
Files changed (1) hide show
  1. triposplat.py +0 -598
triposplat.py DELETED
@@ -1,598 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- import safetensors.torch
5
- from PIL import Image, ImageFilter
6
- from torchvision import transforms
7
- from tqdm.auto import tqdm
8
-
9
- from model import (
10
- DinoV3ViT, Flux2VAEEncoder, BiRefNet,
11
- OctreeProbabilityFixedlenDecoder, ElasticGaussianFixedlenDecoder,
12
- LatentSeqMMFlowModel, OctreeGaussianDecoder,
13
- )
14
-
15
-
16
- # ---------------------------------------------------------------------------
17
- # Gaussian
18
- # ---------------------------------------------------------------------------
19
-
20
- class Gaussian:
21
- def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
22
- scaling_bias: float = 0.01, opacity_bias: float = 0.1,
23
- scaling_activation: str = "exp", device='cuda'):
24
- self.sh_degree = sh_degree
25
- self.mininum_kernel_size = mininum_kernel_size
26
- self.scaling_bias = scaling_bias
27
- self.opacity_bias = opacity_bias
28
- self.device = device
29
- self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
30
-
31
- if scaling_activation == "exp":
32
- self._scaling_activation = torch.exp
33
- self._inverse_scaling_activation = torch.log
34
- elif scaling_activation == "softplus":
35
- self._scaling_activation = F.softplus
36
- self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
37
-
38
- self._opacity_activation = torch.sigmoid
39
- self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
40
-
41
- self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
42
- self.rots_bias = torch.zeros(4, device=self.device)
43
- self.rots_bias[0] = 1
44
- self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
45
-
46
- self._storage = {}
47
-
48
- def _get_store(self, name):
49
- return self._storage.get(name)
50
-
51
- def _set_store(self, name, value):
52
- self._storage[name] = value
53
-
54
- @property
55
- def _xyz(self):
56
- return self._get_store("_xyz")
57
- @_xyz.setter
58
- def _xyz(self, value):
59
- if value is None:
60
- self._set_store("_xyz", None); self._set_store("xyz", None); return
61
- self._set_store("_xyz", value)
62
- self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
63
-
64
- @property
65
- def get_xyz(self):
66
- return self._get_store("xyz")
67
-
68
- @property
69
- def _features_dc(self):
70
- return self._get_store("_features_dc")
71
- @_features_dc.setter
72
- def _features_dc(self, value):
73
- self._set_store("_features_dc", value)
74
-
75
- @property
76
- def _opacity(self):
77
- return self._get_store("_opacity")
78
- @_opacity.setter
79
- def _opacity(self, value):
80
- if value is None:
81
- self._set_store("_opacity", None); self._set_store("opacity", None); return
82
- self._set_store("_opacity", value)
83
- self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
84
-
85
- @property
86
- def get_opacity(self):
87
- return self._get_store("opacity")
88
-
89
- @property
90
- def _scaling(self):
91
- return self._get_store("_scaling")
92
- @_scaling.setter
93
- def _scaling(self, value):
94
- if value is None:
95
- self._set_store("_scaling", None); self._set_store("scaling", None); return
96
- self._set_store("_scaling", value)
97
- s = self._scaling_activation(value + self.scale_bias)
98
- s = torch.square(s) + self.mininum_kernel_size ** 2
99
- self._set_store("scaling", torch.sqrt(s))
100
-
101
- @property
102
- def get_scaling(self):
103
- return self._get_store("scaling")
104
-
105
- @property
106
- def _rotation(self):
107
- return self._get_store("_rotation")
108
- @_rotation.setter
109
- def _rotation(self, value):
110
- self._set_store("_rotation", value)
111
-
112
- def construct_list_of_attributes(self):
113
- l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
114
- dc = self._features_dc
115
- for i in range(dc.shape[1] * dc.shape[2]):
116
- l.append(f'f_dc_{i}')
117
- l.append('opacity')
118
- for i in range(self._scaling.shape[1]):
119
- l.append(f'scale_{i}')
120
- for i in range(self._rotation.shape[1]):
121
- l.append(f'rot_{i}')
122
- return l
123
-
124
- _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
125
-
126
- def _get_ply_data(self, transform=None):
127
- xyz = self.get_xyz.detach().cpu().numpy()
128
- normals = np.zeros_like(xyz)
129
- f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
130
- opacities = self._inverse_opacity_activation(self.get_opacity).detach().cpu().numpy()
131
- scale = torch.log(self.get_scaling).detach().cpu().numpy()
132
- rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
133
- if transform is not None:
134
- transform = np.array(transform)
135
- xyz = np.matmul(xyz, transform.T)
136
- R_mat = _quat_to_matrix(rotation)
137
- R_mat = np.matmul(transform, R_mat)
138
- rotation = _matrix_to_quat(R_mat)
139
- return xyz, normals, f_dc, opacities, scale, rotation
140
-
141
- def _transformed_xyz_rot(self, transform=None):
142
- if transform is None:
143
- transform = self._DEFAULT_TRANSFORM
144
- transform = np.array(transform, dtype=np.float32)
145
- xyz = self.get_xyz.detach().cpu().numpy().astype(np.float32)
146
- rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
147
- xyz = np.matmul(xyz, transform.T)
148
- R_mat = _quat_to_matrix(rotation)
149
- R_mat = np.matmul(transform, R_mat)
150
- rotation = _matrix_to_quat(R_mat)
151
- return xyz, rotation
152
-
153
- def to_ply_bytes(self, transform=None) -> bytes:
154
- if transform is None:
155
- transform = self._DEFAULT_TRANSFORM
156
- xyz, normals, f_dc, opacities, scale, rotation = self._get_ply_data(transform=transform)
157
- dtype_full = [(attr, 'f4') for attr in self.construct_list_of_attributes()]
158
- elements = np.empty(xyz.shape[0], dtype=dtype_full)
159
- elements[:] = list(map(tuple, np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)))
160
- return _binary_ply_bytes(elements, dtype_full)
161
-
162
- def to_splat_bytes(self, transform=None) -> bytes:
163
- if transform is None:
164
- transform = self._DEFAULT_TRANSFORM
165
- xyz, rotation = self._transformed_xyz_rot(transform=transform)
166
- scale = self.get_scaling.detach().cpu().numpy().astype(np.float32)
167
- opacity = self.get_opacity.detach().cpu().numpy()
168
- f_dc = self._features_dc.detach().cpu().numpy()
169
- C0 = 0.28209479177387814
170
- # .splat packs color as 4 bytes RGBA: RGB from the SH DC term, A from opacity.
171
- rgb = np.clip((f_dc[:, 0, :] * C0 + 0.5) * 255, 0, 255).astype(np.uint8)
172
- alpha = np.clip(opacity[:, 0:1] * 255, 0, 255).astype(np.uint8)
173
- rgba = np.concatenate([rgb, alpha], axis=1)
174
- rot = rotation / np.linalg.norm(rotation, axis=-1, keepdims=True)
175
- rot_u8 = np.clip(rot * 128 + 128, 0, 255).astype(np.uint8)
176
- order = np.argsort(-opacity[:, 0] * np.prod(scale, axis=-1))
177
- xyz, scale, rgba, rot_u8 = xyz[order], scale[order], rgba[order], rot_u8[order]
178
- # Per-splat record is exactly 32 bytes: xyz(12) + scale(12) + rgba(4) + rot(4).
179
- data = np.concatenate([
180
- xyz.astype(np.float32).view(np.uint8).reshape(-1, 12),
181
- scale.astype(np.float32).view(np.uint8).reshape(-1, 12),
182
- rgba.reshape(-1, 4),
183
- rot_u8.reshape(-1, 4),
184
- ], axis=1).reshape(-1)
185
- return data.tobytes()
186
-
187
- def save_ply(self, path, transform=None):
188
- with open(path, 'wb') as f:
189
- f.write(self.to_ply_bytes(transform=transform))
190
-
191
- def save_splat(self, path, transform=None):
192
- with open(path, 'wb') as f:
193
- f.write(self.to_splat_bytes(transform=transform))
194
-
195
-
196
- def _binary_ply_bytes(elements, dtype_full) -> bytes:
197
- num_vertices = len(elements)
198
- header = "ply\nformat binary_little_endian 1.0\n"
199
- header += f"element vertex {num_vertices}\n"
200
- type_map = {'f4': 'float', 'u1': 'uchar', 'i4': 'int'}
201
- for name, t in dtype_full:
202
- header += f"property {type_map.get(t, t)} {name}\n"
203
- header += "end_header\n"
204
- return header.encode('ascii') + elements.tobytes()
205
-
206
-
207
- def _quat_to_matrix(q):
208
- q = q / np.linalg.norm(q, axis=-1, keepdims=True)
209
- w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
210
- R = np.stack([
211
- 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
212
- 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
213
- 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
214
- ], axis=-1).reshape(-1, 3, 3)
215
- return R
216
-
217
-
218
- def _matrix_to_quat(R):
219
- trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
220
- q = np.zeros((R.shape[0], 4), dtype=R.dtype)
221
- s = np.sqrt(np.maximum(trace + 1, 0)) * 2
222
- q[:, 0] = 0.25 * s
223
- q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / np.where(s != 0, s, 1)
224
- q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / np.where(s != 0, s, 1)
225
- q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / np.where(s != 0, s, 1)
226
- m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
227
- s1 = np.sqrt(np.maximum(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], 0)) * 2
228
- q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
229
- q[m01, 1] = 0.25 * s1[m01]
230
- q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
231
- q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
232
- m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
233
- s2 = np.sqrt(np.maximum(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], 0)) * 2
234
- q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
235
- q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
236
- q[m11, 2] = 0.25 * s2[m11]
237
- q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
238
- m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
239
- s3 = np.sqrt(np.maximum(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], 0)) * 2
240
- q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
241
- q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
242
- q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
243
- q[m21, 3] = 0.25 * s3[m21]
244
- return q / np.linalg.norm(q, axis=-1, keepdims=True)
245
-
246
-
247
- def _build_gaussians(decoder: ElasticGaussianFixedlenDecoder, points_pred: dict, pred: dict):
248
- x = points_pred
249
- offset = decoder._get_offset(pred['features'])
250
- h = pred["features"]
251
- ret = []
252
- for i in range(h.shape[0]):
253
- g = Gaussian(
254
- sh_degree=0,
255
- aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
256
- mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
257
- scaling_bias=decoder.rep_config['scaling_bias'],
258
- opacity_bias=decoder.rep_config['opacity_bias'],
259
- scaling_activation=decoder.rep_config['scaling_activation'],
260
- )
261
- _x = x["points"][i, :, None, :]
262
- for k, v in decoder.layout.items():
263
- if k == '_xyz':
264
- setattr(g, k, (offset[i] + _x).flatten(0, 1))
265
- elif k in ('_xyz_center', '_offset_scale'):
266
- continue
267
- else:
268
- feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
269
- setattr(g, k, feats * decoder.rep_config['lr'][k])
270
- ret.append(g)
271
- return ret
272
-
273
-
274
- # ---------------------------------------------------------------------------
275
- # Euler flow sampler
276
- # ---------------------------------------------------------------------------
277
-
278
- class FlowEulerCfgSampler:
279
- def __init__(self, sigma_min: float = 1e-5):
280
- self.sigma_min = sigma_min
281
-
282
- def _get_batch_size(self, x_t):
283
- return next(iter(x_t.values())).shape[0] if isinstance(x_t, dict) else x_t.shape[0]
284
-
285
- def _get_device(self, x_t):
286
- return next(iter(x_t.values())).device if isinstance(x_t, dict) else x_t.device
287
-
288
- def _inference_model(self, model, x_t, t, cond=None):
289
- batch = self._get_batch_size(x_t)
290
- device = self._get_device(x_t)
291
- t_scaled = torch.tensor([1000 * t] * batch, device=device, dtype=torch.float32)
292
- if isinstance(cond, dict):
293
- for k, v in cond.items():
294
- if isinstance(v, torch.Tensor) and v.shape[0] == 1 and batch > 1:
295
- cond[k] = v.repeat(batch, *([1] * (len(v.shape) - 1)))
296
- elif cond is not None and cond.shape[0] == 1 and batch > 1:
297
- cond = cond.repeat(batch, *([1] * (len(cond.shape) - 1)))
298
- return model(x_t, t_scaled, cond)
299
-
300
- def _cfg_prediction(self, model, x_t, t, cond, neg_cond, guidance_scale):
301
- # Diffusers-style convention: guidance_scale == 1 (or <= 1, or None) means no CFG —
302
- # only the conditional pass runs, halving the per-step cost. > 1 enables CFG and
303
- # blends as `pred = s * cond + (1 - s) * uncond = s * cond - (s - 1) * uncond`.
304
- pred_v = self._inference_model(model, x_t, t, cond)
305
- if isinstance(guidance_scale, dict):
306
- if not any(s > 1 for s in guidance_scale.values()):
307
- return pred_v
308
- neg_pred_v = self._inference_model(model, x_t, t, neg_cond)
309
- for key in pred_v:
310
- s = guidance_scale.get(key, 1.0)
311
- if s > 1:
312
- pred_v[key] = s * pred_v[key] - (s - 1) * neg_pred_v[key]
313
- return pred_v
314
- if guidance_scale is None or guidance_scale <= 1:
315
- return pred_v
316
- neg_pred_v = self._inference_model(model, x_t, t, neg_cond)
317
- for key in pred_v:
318
- pred_v[key] = guidance_scale * pred_v[key] - (guidance_scale - 1) * neg_pred_v[key]
319
- return pred_v
320
-
321
- @torch.no_grad()
322
- def sample(self, model, noise, cond, neg_cond, steps=50, shift=1.0,
323
- guidance_scale=None, show_progress=False, callback=None):
324
- sample = noise
325
- t_seq = shift * np.linspace(1, 0, steps + 1) / (1 + (shift - 1) * np.linspace(1, 0, steps + 1))
326
- t_pairs = list(zip(t_seq[:-1], t_seq[1:]))
327
- iterator = tqdm(t_pairs, desc="Sampling", total=steps) if show_progress else t_pairs
328
- for i, (t, t_prev) in enumerate(iterator):
329
- x_t = {k: v.clone() for k, v in sample.items()} if isinstance(sample, dict) else sample.clone()
330
- pred_v = self._cfg_prediction(model, x_t, t, cond, neg_cond, guidance_scale)
331
- dt = t - t_prev
332
- if isinstance(sample, dict):
333
- for key in sample:
334
- sample[key] = sample[key] - pred_v[key] * dt
335
- else:
336
- sample = sample - pred_v * dt
337
- if callback is not None:
338
- callback(i + 1, steps)
339
- return sample
340
-
341
-
342
- # ---------------------------------------------------------------------------
343
- # Component loaders
344
- # ---------------------------------------------------------------------------
345
-
346
- def _place(m, device, dtype):
347
- if device is not None or dtype is not None:
348
- m = m.to(device=device, dtype=dtype)
349
- return m.eval()
350
-
351
-
352
- def load_dinov3(path: str, device=None, dtype=None) -> DinoV3ViT:
353
- m = DinoV3ViT()
354
- m.load_safetensors(path)
355
- return _place(m, device, dtype)
356
-
357
-
358
- def load_vae_encoder(path: str, device=None, dtype=None) -> Flux2VAEEncoder:
359
- m = Flux2VAEEncoder()
360
- m.load_safetensors(path)
361
- return _place(m, device, dtype)
362
-
363
-
364
- def load_rmbg(path: str, device=None, dtype=None) -> BiRefNet:
365
- m = BiRefNet()
366
- m.load_safetensors(path)
367
- return _place(m, device, dtype)
368
-
369
-
370
- FLOW_MODEL_ARGS = dict(
371
- q_token_length=8192, in_channels=16, cam_channels=5, out_channels=16,
372
- model_channels=1024, cond_channels=1280, cond2_channels=128,
373
- num_refiner_blocks=2, num_blocks=24, num_heads=16, mlp_ratio=4,
374
- qk_rms_norm=True, share_mod=True, use_shift_table=True,
375
- )
376
-
377
-
378
- def load_flow_model(path: str, device=None, dtype=None) -> LatentSeqMMFlowModel:
379
- m = LatentSeqMMFlowModel(**FLOW_MODEL_ARGS)
380
- m.load_safetensors(path)
381
- return _place(m, device, dtype)
382
-
383
-
384
- OCTREE_DECODER_ARGS = dict(
385
- model_channels=1024, cond_channels=16,
386
- num_blocks=4, num_heads=16, mlp_ratio=4, share_mod=True,
387
- )
388
-
389
- GS_DECODER_ARGS = dict(
390
- in_channels=3, model_channels=1024, cond_channels=16,
391
- attn_mode="full", num_blocks=16, num_heads=16, mlp_ratio=4,
392
- use_learned_offset_scale=True, use_per_offset=True,
393
- representation_config=dict(
394
- lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
395
- perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
396
- filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
397
- scaling_activation="softplus",
398
- ),
399
- )
400
-
401
-
402
- def load_decoder(path: str, device=None, dtype=None) -> OctreeGaussianDecoder:
403
- m = OctreeGaussianDecoder(OCTREE_DECODER_ARGS, GS_DECODER_ARGS)
404
- m.load_safetensors(path)
405
- return _place(m, device, dtype)
406
-
407
-
408
- # ---------------------------------------------------------------------------
409
- # Pipeline stages
410
- # ---------------------------------------------------------------------------
411
-
412
- _CANVAS_SIZE = 1024
413
-
414
-
415
- def _image_to_pil(image) -> Image.Image:
416
- if isinstance(image, Image.Image):
417
- return image
418
- if isinstance(image, (str, bytes)) or hasattr(image, "__fspath__"):
419
- return Image.open(image)
420
- if isinstance(image, torch.Tensor):
421
- t = image.detach().cpu()
422
- if t.ndim == 4:
423
- assert t.shape[0] == 1, (
424
- f"batched image input is not supported (got B={t.shape[0]}); "
425
- "pass one image at a time"
426
- )
427
- t = t[0]
428
- arr = (t.clamp(0, 1) * 255).to(torch.uint8).numpy()
429
- mode = "RGBA" if arr.shape[-1] == 4 else "RGB"
430
- return Image.fromarray(arr, mode=mode)
431
- raise TypeError(f"unsupported image type: {type(image)}")
432
-
433
-
434
- def preprocess_image(image, rmbg: BiRefNet, erode_radius: int = 1) -> Image.Image:
435
- image = _image_to_pil(image)
436
- size = _CANVAS_SIZE
437
- w, h = image.size
438
- s = size / min(w, h)
439
- image = image.resize((max(1, int(round(w * s))), max(1, int(round(h * s)))), Image.LANCZOS)
440
- has_real_alpha = (image.mode == "RGBA"
441
- and np.array(image.getchannel(3), dtype=np.int32).min() < 255)
442
- if not has_real_alpha:
443
- image = rmbg.remove_background(image.convert("RGB"))
444
- if erode_radius > 0:
445
- image.putalpha(image.getchannel(3).filter(ImageFilter.MinFilter(2 * erode_radius + 1)))
446
- alpha = np.array(image.getchannel(3))
447
- ys, xs = np.nonzero(alpha)
448
- bbox = [xs.min(), ys.min(), xs.max(), ys.max()]
449
- cx, cy = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
450
- half = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 * 1.2
451
- image = image.crop([int(cx - half), int(cy - half), int(cx + half), int(cy + half)])
452
- image = image.resize((size, size), Image.LANCZOS)
453
- bg = Image.new("RGB", (size, size), (0, 0, 0))
454
- bg.paste(image, mask=image.split()[3])
455
- return bg
456
-
457
-
458
- _DINOV3_NORMALIZE = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
459
-
460
-
461
- @torch.no_grad()
462
- def encode_image(image: Image.Image, dinov3: DinoV3ViT, vae_encoder: Flux2VAEEncoder,
463
- generator: torch.Generator = None) -> dict:
464
- device = next(dinov3.parameters()).device
465
- img_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device=device, dtype=torch.float32)
466
- img_normed = _DINOV3_NORMALIZE(img_tensor)
467
- dinov3_dtype = next(dinov3.parameters()).dtype
468
- vae_dtype = next(vae_encoder.parameters()).dtype
469
- dinov3_feat = dinov3(pixel_values=img_normed.to(dinov3_dtype))
470
- dinov3_feat = F.layer_norm(dinov3_feat.float(), dinov3_feat.shape[-1:])
471
- vae_feat = vae_encoder.encode(img_tensor.to(vae_dtype) * 2 - 1,
472
- deterministic=False, generator=generator)
473
- # pad 5 zero tokens so feature2's token length matches feature1's (cls + 4 registers + patches)
474
- zero_reg = torch.zeros(vae_feat.shape[0], 5, vae_feat.shape[2],
475
- dtype=vae_feat.dtype, device=vae_feat.device)
476
- vae_feat = torch.cat([zero_reg, vae_feat], dim=1)
477
- return {'feature1': dinov3_feat, 'feature2': vae_feat}
478
-
479
-
480
- @torch.no_grad()
481
- def sample_latent(flow_model: LatentSeqMMFlowModel, cond: dict,
482
- steps: int = 50, guidance_scale: float = 7.0, shift: float = 3.0,
483
- generator: torch.Generator = None,
484
- show_progress: bool = False, callback=None) -> dict:
485
- device = flow_model.device
486
- neg_cond = {k: torch.zeros_like(v) for k, v in cond.items()}
487
- noise = {'latent': torch.randn(1, flow_model.q_token_length, flow_model.in_channels,
488
- device=device, generator=generator)}
489
- if flow_model.cam_channels is not None:
490
- noise['camera'] = torch.randn(1, 1, flow_model.cam_channels,
491
- device=device, generator=generator)
492
- sampler = FlowEulerCfgSampler()
493
- return sampler.sample(flow_model, noise, cond=cond, neg_cond=neg_cond,
494
- steps=steps, guidance_scale=guidance_scale, shift=shift,
495
- show_progress=show_progress, callback=callback)
496
-
497
-
498
- # ---------------------------------------------------------------------------
499
- # Pipeline
500
- # ---------------------------------------------------------------------------
501
-
502
- class TripoSplatPipeline:
503
- def __init__(self, ckpt_path: str, decoder_path: str, dinov3_path: str,
504
- flux2_vae_encoder_path: str, rmbg_path: str, device: str = "cuda"):
505
- self._device = torch.device(device)
506
- self.dinov3 = load_dinov3 (dinov3_path, device=self._device, dtype=torch.bfloat16)
507
- self.vae_encoder = load_vae_encoder (flux2_vae_encoder_path, device=self._device, dtype=torch.bfloat16)
508
- self.rmbg = load_rmbg (rmbg_path, device=self._device, dtype=torch.float16)
509
- self.flow_model = load_flow_model (ckpt_path, device=self._device, dtype=torch.float16)
510
- self.decoder = load_decoder (decoder_path, device=self._device, dtype=torch.float16)
511
-
512
- def preprocess_image(self, image, erode_radius: int = 1) -> Image.Image:
513
- return preprocess_image(image, self.rmbg, erode_radius=erode_radius)
514
-
515
- def encode_image(self, image: Image.Image, generator: torch.Generator = None) -> dict:
516
- return encode_image(image, self.dinov3, self.vae_encoder, generator=generator)
517
-
518
- def sample_latent(self, cond: dict, steps: int = 50, guidance_scale: float = 7.0,
519
- shift: float = 3.0, generator: torch.Generator = None,
520
- show_progress: bool = False, callback=None) -> dict:
521
- return sample_latent(self.flow_model, cond, steps=steps, guidance_scale=guidance_scale,
522
- shift=shift, generator=generator,
523
- show_progress=show_progress, callback=callback)
524
-
525
- def decode_latent(self, latent: torch.Tensor, num_gaussians: int = 262144):
526
- return self.decoder.decode(latent, num_gaussians=num_gaussians)
527
-
528
- _NUM_GAUSSIANS_MIN = 32768
529
- _NUM_GAUSSIANS_MAX = 262144
530
-
531
- def _validate_num_gaussians(self, n: int) -> int:
532
- assert self._NUM_GAUSSIANS_MIN <= n <= self._NUM_GAUSSIANS_MAX, (
533
- f"num_gaussians must be in [{self._NUM_GAUSSIANS_MIN}, {self._NUM_GAUSSIANS_MAX}], got {n}"
534
- )
535
- gpp = self.decoder.gaussians_per_point
536
- if n % gpp == 0:
537
- return n
538
- rounded = round(n / gpp) * gpp
539
- print(f"[TripoSplatPipeline] num_gaussians={n} is not a multiple of {gpp}; rounding to {rounded}")
540
- return rounded
541
-
542
- @torch.no_grad()
543
- def run(self, image, seed: int = 42, steps: int = 20, guidance_scale: float = 3.0,
544
- shift: float = 3.0, num_gaussians=262144, erode_radius: int = 1,
545
- show_progress: bool = False, callback=None):
546
- """
547
- Args:
548
- image: Input image. Accepts a file path / PIL.Image / torch.Tensor
549
- (`[1,H,W,C]` or `[H,W,C]`, float in `[0, 1]`, optional alpha
550
- channel as the 4th channel).
551
- seed: RNG seed for the VAE encoder's stochastic latent sampling and
552
- the initial flow-matching noise. Same seed → same output.
553
- steps: Number of Euler integrator steps in the flow-matching sampler.
554
- More steps → better fidelity, linear runtime cost.
555
- Recommend: 10~20.
556
- guidance_scale: Classifier-free-guidance strength (diffusers
557
- convention). `≤ 1.0` disables CFG. Higher → more detail,
558
- stronger adherence to the input image; too high can cause color
559
- oversaturation.
560
- Recommend: 3.0.
561
- shift: Flow-matching timestep schedule shift. `1.0` gives a uniform
562
- schedule; `>1.0` allocates more steps to the early/high-noise end.
563
- Recommend: 3.0.
564
- num_gaussians: Target Gaussian-splat count. An `int` returns a
565
- single `Gaussian`. A `list` / `tuple` of ints returns a
566
- `list[Gaussian]`. Each count is rounded to the nearest multiple
567
- of 32. More gaussians → more detail but higher rendering and
568
- storage cost.
569
- Recommend: 32768~262144.
570
- erode_radius: Pixel radius used to erode the alpha matte after
571
- background removal, to avoid segmentation-border bleed before
572
- compositing on black. `0` disables; `1` is a 3×3 minimum filter.
573
- Recommend: 1.
574
- show_progress: Print a `tqdm` progress bar over sampler steps.
575
- callback: Optional `fn(step, total)` invoked after each sampler step.
576
- Useful for external progress UIs (e.g. ComfyUI's
577
- `ProgressBar.update`).
578
-
579
- Returns:
580
- `(gaussian, prepared_image)` for an `int` `num_gaussians`, or
581
- `(list_of_gaussians, prepared_image)` for a `list` / `tuple`. The
582
- second element is the RGB composite the encoders actually saw —
583
- useful for display / debugging.
584
- """
585
- if isinstance(num_gaussians, (list, tuple)):
586
- counts = [self._validate_num_gaussians(n) for n in num_gaussians]
587
- else:
588
- counts = [self._validate_num_gaussians(num_gaussians)]
589
-
590
- gen = torch.Generator(device=self._device).manual_seed(seed)
591
- prepared = self.preprocess_image(image, erode_radius=erode_radius)
592
- cond = self.encode_image(prepared, generator=gen)
593
- out = self.sample_latent(cond, steps=steps, guidance_scale=guidance_scale, shift=shift,
594
- generator=gen, show_progress=show_progress, callback=callback)
595
- gaussians = [self.decode_latent(out['latent'], num_gaussians=n) for n in counts]
596
- if isinstance(num_gaussians, (list, tuple)):
597
- return gaussians, prepared
598
- return gaussians[0], prepared