Ruining Li commited on
Commit
07ed33a
·
1 Parent(s): 656937d

Remove lambdas in model imple to enable pickling for Zero GPU

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. particulate/models.py +18 -12
app.py CHANGED
@@ -138,7 +138,7 @@ class ParticulateApp:
138
  *[None] * 9
139
  )
140
 
141
- @spaces.GPU
142
  def _predict_impl(
143
  self,
144
  mesh,
 
138
  *[None] * 9
139
  )
140
 
141
+ @spaces.GPU(duration=10)
142
  def _predict_impl(
143
  self,
144
  mesh,
particulate/models.py CHANGED
@@ -187,12 +187,7 @@ class Articulate3D(nn.Module):
187
  nn.SiLU(),
188
  nn.Linear(hidden_size * 4, 1)
189
  )
190
- self.point_mask_decoding_func = lambda p, q: [(
191
- self.point_mask_decoder(torch.cat([
192
- p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
193
- q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
194
- ], dim=-1)).squeeze(-1)
195
- )]
196
  else:
197
  self.point_mask_decoder = nn.ModuleList([
198
  nn.Sequential(
@@ -202,12 +197,7 @@ class Articulate3D(nn.Module):
202
  )
203
  for _ in range(self.num_mask_hypotheses)
204
  ])
205
- self.point_mask_decoding_func = lambda p, q: [(
206
- self.point_mask_decoder[i](torch.cat([
207
- p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
208
- q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
209
- ], dim=-1)).squeeze(-1)
210
- ) for i in range(self.num_mask_hypotheses)]
211
 
212
  self.use_point_features_for_motion_decoding = use_point_features_for_motion_decoding
213
  self.point_feature_random_ratio = point_feature_random_ratio
@@ -261,6 +251,22 @@ class Articulate3D(nn.Module):
261
 
262
  self.matcher = HungarianMatcher()
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def forward_attn(
265
  self,
266
  xyz: torch.FloatTensor,
 
187
  nn.SiLU(),
188
  nn.Linear(hidden_size * 4, 1)
189
  )
190
+ self.point_mask_decoding_func = self._point_mask_decoding_func_single
 
 
 
 
 
191
  else:
192
  self.point_mask_decoder = nn.ModuleList([
193
  nn.Sequential(
 
197
  )
198
  for _ in range(self.num_mask_hypotheses)
199
  ])
200
+ self.point_mask_decoding_func = self._point_mask_decoding_func_multi
 
 
 
 
 
201
 
202
  self.use_point_features_for_motion_decoding = use_point_features_for_motion_decoding
203
  self.point_feature_random_ratio = point_feature_random_ratio
 
251
 
252
  self.matcher = HungarianMatcher()
253
 
254
+ def _point_mask_decoding_func_single(self, p, q):
255
+ return [(
256
+ self.point_mask_decoder(torch.cat([
257
+ p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
258
+ q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
259
+ ], dim=-1)).squeeze(-1)
260
+ )]
261
+
262
+ def _point_mask_decoding_func_multi(self, p, q):
263
+ return [(
264
+ self.point_mask_decoder[i](torch.cat([
265
+ p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D)
266
+ q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D)
267
+ ], dim=-1)).squeeze(-1)
268
+ ) for i in range(self.num_mask_hypotheses)]
269
+
270
  def forward_attn(
271
  self,
272
  xyz: torch.FloatTensor,