Yihua7 commited on
Commit
c008121
·
1 Parent(s): 0350eb5

Remove PyTorch3D runtime dependency

Browse files
anigen/models/structured_latent_vae/anigen_decoder.py CHANGED
@@ -8,7 +8,7 @@ from ...modules import sparse as sp
8
  from ...representations import MeshExtractResult
9
  from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton
10
  from ..sparse_elastic_mixin import SparseTransformerElasticMixin
11
- from pytorch3d.ops import knn_points
12
  from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
13
  from .skin_models import SKIN_MODEL_DICT
14
  import torch.nn.functional as F
 
8
  from ...representations import MeshExtractResult
9
  from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton
10
  from ..sparse_elastic_mixin import SparseTransformerElasticMixin
11
+ from ...utils.pytorch3d_compat import knn_points
12
  from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
13
  from .skin_models import SKIN_MODEL_DICT
14
  import torch.nn.functional as F
anigen/models/structured_latent_vae/anigen_encoder.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn.functional as F
5
  from ...modules import sparse as sp
6
  from ..sparse_elastic_mixin import SparseTransformerElasticMixin
7
  from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
8
- from pytorch3d.ops import knn_points
9
  from .skin_models import SkinEncoder
10
 
11
 
 
5
  from ...modules import sparse as sp
6
  from ..sparse_elastic_mixin import SparseTransformerElasticMixin
7
  from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
8
+ from ...utils.pytorch3d_compat import knn_points
9
  from .skin_models import SkinEncoder
10
 
11
 
anigen/representations/mesh/cube2mesh_skeleton.py CHANGED
@@ -5,7 +5,7 @@ from ...modules.sparse import SparseTensor
5
  from easydict import EasyDict as edict
6
  from .utils_cube import *
7
  from .flexicubes.flexicubes import FlexiCubes
8
- from pytorch3d.ops import knn_points
9
 
10
 
11
  class AniGenMeshExtractResult:
 
5
  from easydict import EasyDict as edict
6
  from .utils_cube import *
7
  from .flexicubes.flexicubes import FlexiCubes
8
+ from ...utils.pytorch3d_compat import knn_points
9
 
10
 
11
  class AniGenMeshExtractResult:
anigen/representations/skeleton/grouping.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import numpy as np
3
- from pytorch3d.ops import ball_query, knn_points
4
 
5
 
6
  def disjoint_set_unioin_find(N, pairs):
@@ -178,4 +178,3 @@ GROUPING_STRATEGIES = {
178
  "threshold": threshold_grouping,
179
  "mean_shift": mean_shift_grouping,
180
  }
181
-
 
1
  import torch
2
  import numpy as np
3
+ from ...utils.pytorch3d_compat import ball_query, knn_points
4
 
5
 
6
  def disjoint_set_unioin_find(N, pairs):
 
178
  "threshold": threshold_grouping,
179
  "mean_shift": mean_shift_grouping,
180
  }
 
anigen/trainers/flow_matching/anigen_sparse_flow_matching.py CHANGED
@@ -19,7 +19,7 @@ from .mixins.text_conditioned import TextConditionedMixin
19
  from .mixins.image_conditioned import ImageConditionedMixin
20
  from ...representations import MeshExtractResult
21
  from ...utils.skin_utils import get_transform
22
- from pytorch3d.ops import knn_points
23
  from ...renderers import MeshRenderer
24
  from ...utils.data_utils import recursive_to_device
25
  import copy
 
19
  from .mixins.image_conditioned import ImageConditionedMixin
20
  from ...representations import MeshExtractResult
21
  from ...utils.skin_utils import get_transform
22
+ from ...utils.pytorch3d_compat import knn_points
23
  from ...renderers import MeshRenderer
24
  from ...utils.data_utils import recursive_to_device
25
  import copy
anigen/trainers/vae/anigen_skin_ae.py CHANGED
@@ -19,7 +19,7 @@ from ...renderers import OctreeRenderer
19
 
20
  from ...modules.sparse import SparseTensor
21
  from ...utils.loss_utils import l1_loss, smooth_l1_loss, l2_loss, ssim, lpips
22
- from pytorch3d.ops import knn_points
23
  import torch.nn.functional as F
24
 
25
 
 
19
 
20
  from ...modules.sparse import SparseTensor
21
  from ...utils.loss_utils import l1_loss, smooth_l1_loss, l2_loss, ssim, lpips
22
+ from ...utils.pytorch3d_compat import knn_points
23
  import torch.nn.functional as F
24
 
25
 
anigen/trainers/vae/anigen_slat_mesh_vae.py CHANGED
@@ -19,7 +19,7 @@ from ...renderers import OctreeRenderer
19
 
20
  from ...modules.sparse import SparseTensor
21
  from ...utils.loss_utils import l1_loss, smooth_l1_loss, l2_loss, ssim, lpips
22
- from pytorch3d.ops import knn_points
23
  import torch.nn.functional as F
24
 
25
 
 
19
 
20
  from ...modules.sparse import SparseTensor
21
  from ...utils.loss_utils import l1_loss, smooth_l1_loss, l2_loss, ssim, lpips
22
+ from ...utils.pytorch3d_compat import knn_points
23
  import torch.nn.functional as F
24
 
25
 
anigen/utils/export_utils.py CHANGED
@@ -472,10 +472,10 @@ def transfer_vertex_colors_nearest(src_vertices, src_colors, dst_vertices):
472
  dst_v = torch.from_numpy(dst_vertices).to(device=device, dtype=torch.float32)
473
  src_c = torch.from_numpy(src_colors).to(device=device, dtype=torch.float32)
474
 
475
- # Prefer pytorch3d for KNN if available.
476
  idx = None
477
  try:
478
- from pytorch3d.ops import knn_points
479
  _, nn_idx, _ = knn_points(dst_v[None], src_v[None], K=1, return_nn=False)
480
  idx = nn_idx[0, :, 0]
481
  except Exception:
 
472
  dst_v = torch.from_numpy(dst_vertices).to(device=device, dtype=torch.float32)
473
  src_c = torch.from_numpy(src_colors).to(device=device, dtype=torch.float32)
474
 
475
+ # Prefer the local PyTorch3D-compatible helper for KNN.
476
  idx = None
477
  try:
478
+ from .pytorch3d_compat import knn_points
479
  _, nn_idx, _ = knn_points(dst_v[None], src_v[None], K=1, return_nn=False)
480
  idx = nn_idx[0, :, 0]
481
  except Exception:
anigen/utils/pytorch3d_compat.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+
5
+
6
+ KNNResult = namedtuple("KNNResult", ["dists", "idx", "knn"])
7
+ BallQueryResult = namedtuple("BallQueryResult", ["dists", "idx", "knn"])
8
+
9
+
10
+ def knn_points(p1, p2, lengths1=None, lengths2=None, K=1, norm=2, return_nn=False, return_sorted=True):
11
+ if p1.dim() != 3 or p2.dim() != 3:
12
+ raise ValueError("p1 and p2 must have shape (N, P, D)")
13
+ if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
14
+ raise ValueError("p1 and p2 must have matching batch and point dimensions")
15
+
16
+ batch, p1_count, _ = p1.shape
17
+ p2_count = p2.shape[1]
18
+ k = min(K, p2_count)
19
+ if k <= 0:
20
+ empty_dists = p1.new_empty((batch, p1_count, 0))
21
+ empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
22
+ empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
23
+ return KNNResult(empty_dists, empty_idx, empty_nn)
24
+
25
+ dists = torch.cdist(p1.float(), p2.float(), p=norm)
26
+ if norm == 2:
27
+ dists = dists.square()
28
+
29
+ if lengths2 is not None:
30
+ arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
31
+ valid = arange < lengths2.to(device=p2.device).view(batch, 1, 1)
32
+ dists = dists.masked_fill(~valid, torch.inf)
33
+
34
+ dists_k, idx = torch.topk(dists, k=k, dim=-1, largest=False, sorted=return_sorted)
35
+
36
+ if K > k:
37
+ pad = K - k
38
+ dists_k = torch.cat([dists_k, dists_k.new_full((batch, p1_count, pad), torch.inf)], dim=-1)
39
+ idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)
40
+
41
+ if lengths1 is not None:
42
+ arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
43
+ valid = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
44
+ dists_k = dists_k.masked_fill(~valid, torch.inf)
45
+ idx = idx.masked_fill(~valid, -1)
46
+
47
+ knn = None
48
+ if return_nn:
49
+ safe_idx = idx.clamp_min(0)
50
+ gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
51
+ points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
52
+ knn = torch.gather(points, 2, gather_idx)
53
+ knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)
54
+
55
+ return KNNResult(dists_k.to(dtype=p1.dtype), idx, knn)
56
+
57
+
58
+ def ball_query(p1, p2, lengths1=None, lengths2=None, K=1, radius=0.2, return_nn=False):
59
+ if p1.dim() != 3 or p2.dim() != 3:
60
+ raise ValueError("p1 and p2 must have shape (N, P, D)")
61
+ if p1.shape[0] != p2.shape[0] or p1.shape[2] != p2.shape[2]:
62
+ raise ValueError("p1 and p2 must have matching batch and point dimensions")
63
+
64
+ batch, p1_count, _ = p1.shape
65
+ p2_count = p2.shape[1]
66
+ k = max(int(K), 0)
67
+ if k == 0:
68
+ empty_dists = p1.new_empty((batch, p1_count, 0))
69
+ empty_idx = torch.empty((batch, p1_count, 0), dtype=torch.long, device=p1.device)
70
+ empty_nn = p2.new_empty((batch, p1_count, 0, p2.shape[2])) if return_nn else None
71
+ return BallQueryResult(empty_dists, empty_idx, empty_nn)
72
+
73
+ dists = torch.cdist(p1.float(), p2.float(), p=2).square()
74
+ max_dist = float(radius) * float(radius)
75
+ valid = dists <= max_dist
76
+
77
+ if lengths2 is not None:
78
+ arange = torch.arange(p2_count, device=p2.device).view(1, 1, p2_count)
79
+ valid = valid & (arange < lengths2.to(device=p2.device).view(batch, 1, 1))
80
+
81
+ masked = dists.masked_fill(~valid, torch.inf)
82
+ take = min(k, p2_count)
83
+ dists_k, idx = torch.topk(masked, k=take, dim=-1, largest=False, sorted=True)
84
+ invalid = torch.isinf(dists_k)
85
+ idx = idx.masked_fill(invalid, -1)
86
+ dists_k = dists_k.masked_fill(invalid, 0)
87
+
88
+ if k > take:
89
+ pad = k - take
90
+ dists_k = torch.cat([dists_k, dists_k.new_zeros((batch, p1_count, pad))], dim=-1)
91
+ idx = torch.cat([idx, idx.new_full((batch, p1_count, pad), -1)], dim=-1)
92
+
93
+ if lengths1 is not None:
94
+ arange = torch.arange(p1_count, device=p1.device).view(1, p1_count, 1)
95
+ valid_p1 = arange < lengths1.to(device=p1.device).view(batch, 1, 1)
96
+ dists_k = dists_k.masked_fill(~valid_p1, 0)
97
+ idx = idx.masked_fill(~valid_p1, -1)
98
+
99
+ knn = None
100
+ if return_nn:
101
+ safe_idx = idx.clamp_min(0)
102
+ gather_idx = safe_idx.unsqueeze(-1).expand(-1, -1, -1, p2.shape[2])
103
+ points = p2.unsqueeze(1).expand(-1, p1_count, -1, -1)
104
+ knn = torch.gather(points, 2, gather_idx)
105
+ knn = knn.masked_fill(idx.unsqueeze(-1) < 0, 0)
106
+
107
+ return BallQueryResult(dists_k.to(dtype=p1.dtype), idx, knn)