0xZohar commited on
Commit
d1768a6
·
verified ·
1 Parent(s): 9fb53ab

Add code/cube3d/model/autoencoder/one_d_autoencoder.py

Browse files
code/cube3d/model/autoencoder/one_d_autoencoder.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from functools import partial
5
+ from typing import List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from skimage import measure
11
+ from torch.nn import functional as F
12
+ from tqdm import tqdm
13
+
14
+ from cube3d.model.autoencoder.embedder import PhaseModulatedFourierEmbedder
15
+ from cube3d.model.autoencoder.grid import (
16
+ generate_dense_grid_points,
17
+ marching_cubes_with_warp,
18
+ )
19
+ from cube3d.model.autoencoder.spherical_vq import SphericalVectorQuantizer
20
+ from cube3d.model.transformers.attention import (
21
+ EncoderCrossAttentionLayer,
22
+ EncoderLayer,
23
+ init_linear,
24
+ init_tfixup,
25
+ )
26
+ from cube3d.model.transformers.norm import LayerNorm
27
+
28
+
29
+ def init_sort(x):
30
+ """
31
+ Sorts the input tensor `x` based on its pairwise distances to the first element.
32
+ This function computes the pairwise distances between all elements in `x` and the
33
+ first element of `x`. It then sorts the elements of `x` in ascending order of
34
+ their distances to the first element.
35
+ Args:
36
+ x (torch.Tensor): A 2D tensor where each row represents a data point.
37
+ Returns:
38
+ torch.Tensor: A tensor containing the rows of `x` sorted by their distances
39
+ to the first row of `x`.
40
+ """
41
+
42
+ distances = torch.cdist(x, x[:1])
43
+ _, indices = torch.sort(distances.squeeze(), dim=0)
44
+ x = x[indices]
45
+ return x
46
+
47
+
48
+ class MLPEmbedder(nn.Module):
49
+ def __init__(self, in_dim: int, embed_dim: int, bias: bool = True):
50
+ super().__init__()
51
+ self.in_layer = nn.Linear(in_dim, embed_dim, bias=bias)
52
+ self.silu = nn.SiLU()
53
+ self.out_layer = nn.Linear(embed_dim, embed_dim, bias=bias)
54
+
55
+ self.apply(partial(init_linear, embed_dim=embed_dim))
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ return self.out_layer(self.silu(self.in_layer(x)))
59
+
60
+
61
+ class OneDEncoder(nn.Module):
62
+ def __init__(
63
+ self,
64
+ embedder,
65
+ num_latents: int,
66
+ point_feats: int,
67
+ embed_point_feats: bool,
68
+ width: int,
69
+ num_heads: int,
70
+ num_layers: int,
71
+ with_cls_token: bool = False,
72
+ cross_attention_levels: Optional[List[int]] = None,
73
+ eps: float = 1e-6,
74
+ ) -> None:
75
+ """
76
+ Initializes the OneDEncoder model.
77
+ Args:
78
+ embedder: An embedding module that provides the input embedding functionality.
79
+ num_latents (int): The number of latent variables.
80
+ point_feats (int): The number of point features.
81
+ embed_point_feats (bool): Whether to embed point features or not.
82
+ width (int): The width of the embedding dimension.
83
+ num_heads (int): The number of attention heads.
84
+ num_layers (int): The number of encoder layers.
85
+ with_cls_token (bool, optional): Whether to include a classification token like in Vision Transformers (ViT). Defaults to False.
86
+ cross_attention_levels (Optional[List[int]], optional): The indices of layers where cross-attention is applied. Defaults to None.
87
+ eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
88
+ Returns:
89
+ None
90
+ """
91
+ super().__init__()
92
+
93
+ self.embedder = embedder
94
+
95
+ # add cls token like ViT
96
+ self.with_cls_token = with_cls_token
97
+ if self.with_cls_token:
98
+ query = torch.empty((1 + num_latents, width))
99
+ else:
100
+ query = torch.empty((num_latents, width))
101
+
102
+ # initialize then sort query to potentially get better ordering
103
+ query.uniform_(-1.0, 1.0)
104
+ query = init_sort(query)
105
+
106
+ # set parameter
107
+ self.query = nn.Parameter(query)
108
+
109
+ self.embed_point_feats = embed_point_feats
110
+ in_dim = (
111
+ self.embedder.out_dim * 2
112
+ if self.embed_point_feats
113
+ else self.embedder.out_dim + point_feats
114
+ )
115
+ self.feat_in = MLPEmbedder(in_dim, embed_dim=width)
116
+
117
+ if cross_attention_levels is None:
118
+ cross_attention_levels = [0]
119
+
120
+ self.blocks = nn.ModuleList()
121
+ for i in range(num_layers):
122
+ if i in cross_attention_levels:
123
+ self.blocks.append(
124
+ EncoderCrossAttentionLayer(
125
+ embed_dim=width,
126
+ num_heads=num_heads,
127
+ eps=eps,
128
+ )
129
+ )
130
+ else:
131
+ self.blocks.append(
132
+ EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
133
+ )
134
+ self.ln_f = LayerNorm(width, eps=eps)
135
+
136
+ init_tfixup(self, num_layers)
137
+
138
+ def _forward(self, h, data, attn_mask=None):
139
+ """
140
+ Forward pass for the autoencoder model.
141
+
142
+ Args:
143
+ h (torch.Tensor): The input tensor to be processed, typically representing
144
+ the hidden state or intermediate representation.
145
+ data (torch.Tensor): The input data tensor to be transformed by the feature
146
+ extraction layer and used in cross-attention layers.
147
+ attn_mask (torch.Tensor, optional): An optional attention mask tensor to be
148
+ used in attention layers for masking specific positions. Defaults to None.
149
+ Returns:
150
+ torch.Tensor: The output tensor after processing through the layers and
151
+ applying final normalization.
152
+ """
153
+
154
+ data = self.feat_in(data)
155
+
156
+ for block in self.blocks:
157
+ if isinstance(block, EncoderCrossAttentionLayer):
158
+ h = block(h, data)
159
+ else:
160
+ h = block(h, attn_mask=attn_mask)
161
+
162
+ h = self.ln_f(h)
163
+ return h
164
+
165
+ def forward(
166
+ self, pts: torch.Tensor, feats: torch.Tensor
167
+ ) -> Tuple[torch.Tensor, list[torch.Tensor]]:
168
+ """
169
+ Forward pass of the 1D autoencoder model.
170
+ Args:
171
+ pts (torch.Tensor): Input tensor representing points with shape (batch_size, num_points, point_dim).
172
+ feats (torch.Tensor): Input tensor representing features with shape (batch_size, num_points, feature_dim).
173
+ Can be None if no features are provided.
174
+ Returns:
175
+ Tuple[torch.Tensor, list[torch.Tensor]]:
176
+ - The output tensor after processing the input data.
177
+ - A list of intermediate tensors (if applicable) generated during the forward pass.
178
+ """
179
+
180
+ b = pts.shape[0]
181
+ data = self.embedder(pts)
182
+
183
+ if feats is not None:
184
+ if self.embed_point_feats:
185
+ feats = self.embedder(feats)
186
+ data = torch.cat([data, feats], dim=-1)
187
+
188
+ # prepare query and data
189
+ h = self.query.unsqueeze(0).expand(b, -1, -1)
190
+ return self._forward(h, data, attn_mask=None)
191
+
192
+
193
+ class OneDBottleNeck(nn.Module):
194
+ def __init__(
195
+ self,
196
+ block,
197
+ ) -> None:
198
+ """
199
+ Initializes the OneDBottleNeck class.
200
+ Args:
201
+ block: The building block or module used within the autoencoder.
202
+ """
203
+ super().__init__()
204
+
205
+ self.block = block
206
+
207
+ def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, dict]:
208
+ """
209
+ Forward pass of the OneDBottleNeck function.
210
+ Args:
211
+ h (torch.Tensor): Input tensor to the model.
212
+ Returns:
213
+ Tuple[torch.Tensor, dict]: A tuple containing:
214
+ - The transformed tensor `z` after passing through the block (if applicable).
215
+ - A dictionary `ret_dict` containing additional information:
216
+ - "indices": Indices from the block output (if present).
217
+ - "z_q": Quantized tensor from the block output (if present).
218
+
219
+ """
220
+
221
+ z = h
222
+ ret_dict = {}
223
+ if self.block is not None:
224
+ z, d = self.block(z)
225
+
226
+ key_mappings = {
227
+ "q": "indices",
228
+ "z_q": "z_q",
229
+ }
230
+ for in_key, out_key in key_mappings.items():
231
+ if in_key in d:
232
+ ret_dict[out_key] = d[in_key]
233
+
234
+ return z, ret_dict
235
+
236
+
237
+ class OneDDecoder(nn.Module):
238
+ def __init__(
239
+ self,
240
+ num_latents: int,
241
+ width: int,
242
+ num_heads: int,
243
+ num_layers: int,
244
+ eps: float = 1e-6,
245
+ ) -> None:
246
+ """
247
+ Initializes the OneDDecoder class.
248
+ Args:
249
+ num_latents (int): The number of latent variables.
250
+ width (int): The width of the embedding dimension.
251
+ num_heads (int): The number of attention heads in each encoder layer.
252
+ num_layers (int): The number of encoder layers.
253
+ eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
254
+ """
255
+ super().__init__()
256
+
257
+ self.register_buffer("query", torch.empty([0, width]), persistent=False)
258
+ self.positional_encodings = nn.Parameter(
259
+ init_sort(F.normalize(torch.empty(num_latents, width).normal_()))
260
+ )
261
+ self.blocks = nn.ModuleList(
262
+ [
263
+ EncoderLayer(embed_dim=width, num_heads=num_heads, eps=eps)
264
+ for _ in range(num_layers)
265
+ ]
266
+ )
267
+
268
+ init_tfixup(self, num_layers)
269
+
270
+ def _forward(self, h):
271
+ """
272
+ Applies a sequence of operations to the input tensor `h` using the blocks
273
+ defined in the model.
274
+ Args:
275
+ h (torch.Tensor): The input tensor to be processed by the blocks.
276
+ Returns:
277
+ torch.Tensor: The output tensor after applying all blocks sequentially.
278
+ """
279
+
280
+ for block in self.blocks:
281
+ h = block(h)
282
+ return h
283
+
284
+ def forward(self, z):
285
+ """
286
+ This method processes the input tensor `z` by padding it to a fixed length,
287
+ adding positional encodings, and then passing it through the `_forward` method.
288
+
289
+ Args:
290
+ z (torch.Tensor): Input tensor.
291
+ Returns:
292
+ torch.Tensor: Output tensor after processing through the autoencoder.
293
+ Notes:
294
+ - If the `query` attribute has a non-zero shape, the input tensor `z` is padded
295
+ to match the required length using slices of `query`.
296
+ - Positional encodings are added to the padded input tensor before passing it
297
+ to the `_forward` method.
298
+ """
299
+
300
+ # pad input to fixed length
301
+ if self.query.shape[0] > 0:
302
+ pad_len = self.query.shape[0] + 1 - z.shape[1]
303
+ paddings = self.query[:pad_len, ...].unsqueeze(0).expand(z.shape[0], -1, -1)
304
+ z = torch.cat([paddings, z], dim=1)
305
+ h = z + self.positional_encodings[: z.shape[1], :].unsqueeze(0).expand(
306
+ z.shape[0], -1, -1
307
+ )
308
+
309
+ return self._forward(h)
310
+
311
+
312
+ class OneDOccupancyDecoder(nn.Module):
313
+ def __init__(
314
+ self, embedder, out_features: int, width: int, num_heads: int, eps=1e-6
315
+ ) -> None:
316
+ """
317
+ Initializes the OneDOccupancyDecoder module.
318
+ Args:
319
+ embedder: An embedding module that provides input embeddings.
320
+ out_features (int): The number of output features for the final linear layer.
321
+ width (int): The width of the intermediate layers.
322
+ num_heads (int): The number of attention heads for the cross-attention layer.
323
+ eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1e-6.
324
+ """
325
+ super().__init__()
326
+
327
+ self.embedder = embedder
328
+ self.query_in = MLPEmbedder(self.embedder.out_dim, width)
329
+
330
+ self.attn_out = EncoderCrossAttentionLayer(embed_dim=width, num_heads=num_heads)
331
+ self.ln_f = LayerNorm(width, eps=eps)
332
+ self.c_head = nn.Linear(width, out_features)
333
+
334
+ def query(self, queries: torch.Tensor):
335
+ """
336
+ Processes the input tensor through the embedder and query_in layers.
337
+ Args:
338
+ queries (torch.Tensor): A tensor containing the input data to be processed.
339
+ Returns:
340
+ torch.Tensor: The output tensor after being processed by the embedder and query_in layers.
341
+ """
342
+
343
+ return self.query_in(self.embedder(queries))
344
+
345
+ def forward(self, queries: torch.Tensor, latents: torch.Tensor):
346
+ """
347
+ Defines the forward pass of the model.
348
+ Args:
349
+ queries (torch.Tensor): Input tensor representing the queries.
350
+ latents (torch.Tensor): Input tensor representing the latent representations.
351
+ Returns:
352
+ torch.Tensor: Output tensor after applying the query transformation,
353
+ attention mechanism, and final processing layers.
354
+ """
355
+ queries = self.query(queries)
356
+ x = self.attn_out(queries, latents)
357
+ x = self.c_head(self.ln_f(x))
358
+ return x
359
+
360
+
361
+ class OneDAutoEncoder(nn.Module):
362
+ @dataclass
363
+ class Config:
364
+ checkpoint_path: str = ""
365
+
366
+ # network params
367
+ num_encoder_latents: int = 256
368
+ num_decoder_latents: int = 256
369
+ embed_dim: int = 12
370
+ width: int = 768
371
+ num_heads: int = 12
372
+ out_dim: int = 1
373
+ eps: float = 1e-6
374
+
375
+ # grid features embedding
376
+ num_freqs: int = 128
377
+ point_feats: int = 0
378
+ embed_point_feats: bool = False
379
+
380
+ num_encoder_layers: int = 1
381
+ encoder_cross_attention_levels: list[int] = field(default_factory=list)
382
+ num_decoder_layers: int = 23
383
+
384
+ encoder_with_cls_token: bool = True
385
+ num_codes: int = 16384
386
+
387
+ def __init__(self, cfg: Config) -> None:
388
+ """
389
+ Initializes the OneDAutoencoder model.
390
+ Args:
391
+ cfg (Config): Configuration object containing the parameters for the model.
392
+ Attributes:
393
+ cfg (Config): Stores the configuration object.
394
+ embedder (PhaseModulatedFourierEmbedder): Embeds input data using phase-modulated Fourier features.
395
+ encoder (OneDEncoder): Encodes the input data into latent representations.
396
+ bottleneck (OneDBottleNeck): Bottleneck layer containing a spherical vector quantizer for dimensionality reduction.
397
+ decoder (OneDDecoder): Decodes latent representations back into the original data space.
398
+ occupancy_decoder (OneDOccupancyDecoder): Decodes occupancy information from latent representations.
399
+ """
400
+
401
+ super().__init__()
402
+
403
+ self.cfg = cfg
404
+
405
+ self.embedder = PhaseModulatedFourierEmbedder(
406
+ num_freqs=self.cfg.num_freqs, input_dim=3
407
+ )
408
+
409
+ self.encoder = OneDEncoder(
410
+ embedder=self.embedder,
411
+ num_latents=self.cfg.num_encoder_latents,
412
+ with_cls_token=self.cfg.encoder_with_cls_token,
413
+ point_feats=self.cfg.point_feats,
414
+ embed_point_feats=self.cfg.embed_point_feats,
415
+ width=self.cfg.width,
416
+ num_heads=self.cfg.num_heads,
417
+ num_layers=self.cfg.num_encoder_layers,
418
+ cross_attention_levels=self.cfg.encoder_cross_attention_levels,
419
+ eps=self.cfg.eps,
420
+ )
421
+
422
+ block = SphericalVectorQuantizer(
423
+ self.cfg.embed_dim,
424
+ self.cfg.num_codes,
425
+ self.cfg.width,
426
+ codebook_regularization="kl",
427
+ )
428
+ self.bottleneck = OneDBottleNeck(block=block)
429
+
430
+ self.decoder = OneDDecoder(
431
+ num_latents=self.cfg.num_encoder_latents,
432
+ width=self.cfg.width,
433
+ num_heads=self.cfg.num_heads,
434
+ num_layers=self.cfg.num_decoder_layers,
435
+ eps=self.cfg.eps,
436
+ )
437
+
438
+ self.occupancy_decoder = OneDOccupancyDecoder(
439
+ embedder=self.embedder,
440
+ out_features=self.cfg.out_dim,
441
+ width=self.cfg.width,
442
+ num_heads=self.cfg.num_heads,
443
+ eps=self.cfg.eps,
444
+ )
445
+
446
+ @torch.no_grad()
447
+ def decode_indices(self, shape_ids: torch.Tensor):
448
+ """
449
+ Decodes the given shape indices into latent representations.
450
+ Args:
451
+ shape_ids (torch.Tensor): A tensor containing the shape indices to be decoded.
452
+ Returns:
453
+ torch.Tensor: The decoded latent representations corresponding to the input shape indices.
454
+ """
455
+
456
+ z_q = self.bottleneck.block.lookup_codebook(shape_ids)
457
+ latents = self.decode(z_q)
458
+ return latents
459
+
460
+ @torch.no_grad()
461
+ def query_embeds(self, shape_ids: torch.Tensor):
462
+ """
463
+ Retrieves the latent embeddings corresponding to the given shape IDs.
464
+ Args:
465
+ shape_ids (torch.Tensor): A tensor containing the IDs of the shapes
466
+ for which the latent embeddings are to be queried.
467
+ Returns:
468
+ torch.Tensor: A tensor containing the latent embeddings retrieved
469
+ from the codebook for the provided shape IDs.
470
+ """
471
+
472
+ z_q = self.bottleneck.block.lookup_codebook_latents(shape_ids)
473
+ return z_q
474
+
475
+ @torch.no_grad()
476
+ def query_indices(self, shape_embs: torch.Tensor):
477
+ """
478
+ Queries the indices of the quantized embeddings from the bottleneck layer.
479
+ Args:
480
+ shape_embs (torch.Tensor): The input tensor containing shape embeddings
481
+ to be quantized.
482
+ Returns:
483
+ torch.Tensor: A tensor containing the quantized indices.
484
+ """
485
+
486
+ _, ret_dict = self.bottleneck.block.quantize(shape_embs)
487
+ return ret_dict["q"]
488
+
489
+ def encode(self, x: torch.Tensor, **kwargs):
490
+ """
491
+ Encodes the input tensor using the encoder and bottleneck layers.
492
+ Args:
493
+ x (torch.Tensor): Input tensor with shape (..., N), where the first 3
494
+ dimensions represent points (pts) and the remaining dimensions
495
+ represent features (feats).
496
+ **kwargs: Additional keyword arguments.
497
+ Returns:
498
+ Tuple[torch.Tensor, torch.Tensor, None, dict]: A tuple containing:
499
+ - z_e (torch.Tensor): Encoded tensor before bottleneck processing.
500
+ - z (torch.Tensor): Encoded tensor after bottleneck processing.
501
+ - None: Placeholder for compatibility with other methods.
502
+ - d (dict): Dictionary containing additional information, including:
503
+ - "z_cls" (torch.Tensor, optional): Class token if
504
+ `self.cfg.encoder_with_cls_token` is True.
505
+ """
506
+
507
+ pts, feats = x[..., :3], x[..., 3:]
508
+ z_e = self.encoder(pts, feats)
509
+
510
+ # split class token
511
+ if self.cfg.encoder_with_cls_token:
512
+ z_cls = z_e[:, 0, ...]
513
+ z_e = z_e[:, 1:, ...]
514
+
515
+ # quantize or kl
516
+ z, d = self.bottleneck(z_e)
517
+
518
+ if self.cfg.encoder_with_cls_token:
519
+ d["z_cls"] = z_cls
520
+ return z_e, z, None, d
521
+
522
+ def decode(self, z: torch.Tensor):
523
+ """
524
+ Decodes the latent representation `z` using the decoder network.
525
+ Args:
526
+ z (torch.Tensor): The latent representation tensor to be decoded.
527
+ Returns:
528
+ torch.Tensor: The decoded output tensor.
529
+ """
530
+
531
+ h = self.decoder(z)
532
+ return h
533
+
534
+ def query(self, queries: torch.Tensor, latents: torch.Tensor):
535
+ """
536
+ Computes the logits by decoding the given queries and latent representations.
537
+ Args:
538
+ queries (torch.Tensor): A tensor containing the query points to be decoded.
539
+ latents (torch.Tensor): A tensor containing the latent representations corresponding to the queries.
540
+ Returns:
541
+ torch.Tensor: A tensor containing the decoded logits for the given queries and latents.
542
+ """
543
+
544
+ logits = self.occupancy_decoder(queries, latents).squeeze(-1)
545
+ return logits
546
+
547
+ def forward(self, surface, queries, **kwargs):
548
+ """
549
+ Perform a forward pass through the autoencoder model.
550
+ Args:
551
+ surface (torch.Tensor): The input surface tensor to be encoded.
552
+ queries (torch.Tensor): The query tensor used for generating logits.
553
+ **kwargs: Additional keyword arguments.
554
+ Returns:
555
+ tuple: A tuple containing:
556
+ - z (torch.Tensor): The latent representation of the input surface.
557
+ - latents (torch.Tensor): The decoded output from the latent representation.
558
+ - None: Placeholder for a potential future return value.
559
+ - logits (torch.Tensor): The logits generated from the queries and latents.
560
+ - d (torch.Tensor): Additional output from the encoding process.
561
+ """
562
+
563
+ _, z, _, d = self.encode(surface)
564
+
565
+ latents = self.decode(z)
566
+ logits = self.query(queries, latents)
567
+
568
+ return z, latents, None, logits, d
569
+
570
+ @torch.no_grad()
571
+ def extract_geometry(
572
+ self,
573
+ latents: torch.FloatTensor,
574
+ bounds: list[float] = [
575
+ -1.05,
576
+ -1.05,
577
+ -1.05,
578
+ 1.05,
579
+ 1.05,
580
+ 1.05,
581
+ ],
582
+ resolution_base: float = 9.0,
583
+ chunk_size: int = 2_000_000,
584
+ use_warp: bool = False,
585
+ ):
586
+ """
587
+ Extracts 3D geometry from latent representations using a dense grid sampling
588
+ and marching cubes algorithm.
589
+ Args:
590
+ latents (torch.FloatTensor): A tensor of latent representations with shape
591
+ (batch_size, latent_dim).
592
+ bounds (list[float], optional): A list of six floats defining the bounding box
593
+ for the 3D grid in the format [xmin, ymin, zmin, xmax, ymax, zmax].
594
+ Defaults to [-1.05, -1.05, -1.05, 1.05, 1.05, 1.05].
595
+ resolution_base (float, optional): The base resolution for the grid. Higher
596
+ values result in finer grids. Defaults to 9.0.
597
+ chunk_size (int, optional): The number of grid points to process in a single
598
+ chunk. Defaults to 2,000,000.
599
+ use_warp (bool, optional): Whether to use a GPU-accelerated marching cubes
600
+ implementation. If False, falls back to a CPU implementation. Defaults to False.
601
+ Returns:
602
+ tuple:
603
+ - mesh_v_f (list[tuple]): A list of tuples containing vertices and faces
604
+ for each batch element. Each tuple is of the form
605
+ (vertices, faces), where:
606
+ - vertices (np.ndarray): Array of vertex coordinates with shape
607
+ (num_vertices, 3).
608
+ - faces (np.ndarray): Array of face indices with shape
609
+ (num_faces, 3).
610
+ If geometry extraction fails for a batch element, the tuple will be
611
+ (None, None).
612
+ - has_surface (np.ndarray): A boolean array indicating whether a surface
613
+ was successfully extracted for each batch element.
614
+ Raises:
615
+ Exception: Logs warnings or errors if geometry extraction fails for any
616
+ batch element or if the marching cubes algorithm encounters issues.
617
+ """
618
+ bbox_min = np.array(bounds[0:3])
619
+ bbox_max = np.array(bounds[3:6])
620
+ bbox_size = bbox_max - bbox_min
621
+
622
+ xyz_samples, grid_size, length = generate_dense_grid_points(
623
+ bbox_min=bbox_min,
624
+ bbox_max=bbox_max,
625
+ resolution_base=resolution_base,
626
+ indexing="ij",
627
+ )
628
+ xyz_samples = torch.FloatTensor(xyz_samples)
629
+ batch_size = latents.shape[0]
630
+
631
+ batch_logits = []
632
+
633
+ progress_bar = tqdm(
634
+ range(0, xyz_samples.shape[0], chunk_size),
635
+ desc=f"extracting geometry",
636
+ unit="chunk",
637
+ )
638
+ for start in progress_bar:
639
+ queries = xyz_samples[start : start + chunk_size, :]
640
+
641
+ num_queries = queries.shape[0]
642
+ if start > 0 and num_queries < chunk_size:
643
+ queries = F.pad(queries, [0, 0, 0, chunk_size - num_queries])
644
+ batch_queries = queries.unsqueeze(0).expand(batch_size, -1, -1).to(latents)
645
+
646
+ logits = self.query(batch_queries, latents)[:, :num_queries]
647
+ batch_logits.append(logits)
648
+
649
+ grid_logits = (
650
+ torch.cat(batch_logits, dim=1)
651
+ .detach()
652
+ .view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
653
+ .float()
654
+ )
655
+
656
+ mesh_v_f = []
657
+ has_surface = np.zeros((batch_size,), dtype=np.bool_)
658
+ for i in range(batch_size):
659
+ try:
660
+ warp_success = False
661
+ if use_warp:
662
+ try:
663
+ vertices, faces = marching_cubes_with_warp(
664
+ grid_logits[i],
665
+ level=0.0,
666
+ device=grid_logits.device,
667
+ )
668
+ warp_success = True
669
+ except Exception as e:
670
+ logging.warning(
671
+ f"Warning: error in marching cubes with warp: {e}"
672
+ )
673
+ warp_success = False # Fall back to CPU version
674
+
675
+ if not warp_success:
676
+ logging.warning(
677
+ "Warning: falling back to CPU version of marching cubes using skimage measure"
678
+ )
679
+ vertices, faces, _, _ = measure.marching_cubes(
680
+ grid_logits[i].cpu().numpy(), 0, method="lewiner"
681
+ )
682
+
683
+ #import ipdb; ipdb.set_trace()
684
+ vertices = vertices / grid_size * bbox_size + bbox_min
685
+ faces = faces[:, [2, 1, 0]]
686
+ mesh_v_f.append(
687
+ (vertices.astype(np.float32), np.ascontiguousarray(faces))
688
+ )
689
+ has_surface[i] = True
690
+ except Exception as e:
691
+ logging.error(f"Error: error in extract_geometry: {e}")
692
+ mesh_v_f.append((None, None))
693
+ has_surface[i] = False
694
+
695
+ return mesh_v_f, has_surface