Spaces:
Running on Zero
Running on Zero
Ruining Li commited on
Commit ·
07ed33a
1
Parent(s): 656937d
Remove lambdas in model imple to enable pickling for Zero GPU
Browse files- app.py +1 -1
- particulate/models.py +18 -12
app.py
CHANGED
|
@@ -138,7 +138,7 @@ class ParticulateApp:
|
|
| 138 |
*[None] * 9
|
| 139 |
)
|
| 140 |
|
| 141 |
-
@spaces.GPU
|
| 142 |
def _predict_impl(
|
| 143 |
self,
|
| 144 |
mesh,
|
|
|
|
| 138 |
*[None] * 9
|
| 139 |
)
|
| 140 |
|
| 141 |
+
@spaces.GPU(duration=10)
|
| 142 |
def _predict_impl(
|
| 143 |
self,
|
| 144 |
mesh,
|
particulate/models.py
CHANGED
|
@@ -187,12 +187,7 @@ class Articulate3D(nn.Module):
|
|
| 187 |
nn.SiLU(),
|
| 188 |
nn.Linear(hidden_size * 4, 1)
|
| 189 |
)
|
| 190 |
-
self.point_mask_decoding_func =
|
| 191 |
-
self.point_mask_decoder(torch.cat([
|
| 192 |
-
p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
|
| 193 |
-
q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
|
| 194 |
-
], dim=-1)).squeeze(-1)
|
| 195 |
-
)]
|
| 196 |
else:
|
| 197 |
self.point_mask_decoder = nn.ModuleList([
|
| 198 |
nn.Sequential(
|
|
@@ -202,12 +197,7 @@ class Articulate3D(nn.Module):
|
|
| 202 |
)
|
| 203 |
for _ in range(self.num_mask_hypotheses)
|
| 204 |
])
|
| 205 |
-
self.point_mask_decoding_func =
|
| 206 |
-
self.point_mask_decoder[i](torch.cat([
|
| 207 |
-
p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
|
| 208 |
-
q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
|
| 209 |
-
], dim=-1)).squeeze(-1)
|
| 210 |
-
) for i in range(self.num_mask_hypotheses)]
|
| 211 |
|
| 212 |
self.use_point_features_for_motion_decoding = use_point_features_for_motion_decoding
|
| 213 |
self.point_feature_random_ratio = point_feature_random_ratio
|
|
@@ -261,6 +251,22 @@ class Articulate3D(nn.Module):
|
|
| 261 |
|
| 262 |
self.matcher = HungarianMatcher()
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def forward_attn(
|
| 265 |
self,
|
| 266 |
xyz: torch.FloatTensor,
|
|
|
|
| 187 |
nn.SiLU(),
|
| 188 |
nn.Linear(hidden_size * 4, 1)
|
| 189 |
)
|
| 190 |
+
self.point_mask_decoding_func = self._point_mask_decoding_func_single
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
else:
|
| 192 |
self.point_mask_decoder = nn.ModuleList([
|
| 193 |
nn.Sequential(
|
|
|
|
| 197 |
)
|
| 198 |
for _ in range(self.num_mask_hypotheses)
|
| 199 |
])
|
| 200 |
+
self.point_mask_decoding_func = self._point_mask_decoding_func_multi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
self.use_point_features_for_motion_decoding = use_point_features_for_motion_decoding
|
| 203 |
self.point_feature_random_ratio = point_feature_random_ratio
|
|
|
|
| 251 |
|
| 252 |
self.matcher = HungarianMatcher()
|
| 253 |
|
| 254 |
+
def _point_mask_decoding_func_single(self, p, q):
|
| 255 |
+
return [(
|
| 256 |
+
self.point_mask_decoder(torch.cat([
|
| 257 |
+
p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
|
| 258 |
+
q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
|
| 259 |
+
], dim=-1)).squeeze(-1)
|
| 260 |
+
)]
|
| 261 |
+
|
| 262 |
+
def _point_mask_decoding_func_multi(self, p, q):
|
| 263 |
+
return [(
|
| 264 |
+
self.point_mask_decoder[i](torch.cat([
|
| 265 |
+
p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
|
| 266 |
+
q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
|
| 267 |
+
], dim=-1)).squeeze(-1)
|
| 268 |
+
) for i in range(self.num_mask_hypotheses)]
|
| 269 |
+
|
| 270 |
def forward_attn(
|
| 271 |
self,
|
| 272 |
xyz: torch.FloatTensor,
|