wlyu-adobe commited on
Commit
540ef6e
·
1 Parent(s): 5ca8dc1

Add CUDA memory management and reduce resolution for ZeroGPU compatibility

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -93,7 +93,7 @@ class FaceLiftPipeline:
93
 
94
  # Parameters
95
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
96
- self.image_size = 512
97
  self.camera_indices = [2, 1, 0, 5, 4, 3]
98
 
99
  # Load models (keep on CPU for ZeroGPU compatibility)
@@ -147,8 +147,10 @@ class FaceLiftPipeline:
147
  self.mvdiffusion_pipeline.to(self.device)
148
  self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
149
  self.gs_lrm_model.to(self.device)
 
150
  self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
151
  self._models_on_gpu = True
 
152
  print("Models on GPU, xformers enabled!")
153
 
154
  @spaces.GPU(duration=120)
@@ -198,6 +200,9 @@ class FaceLiftPipeline:
198
  multiview_path = output_dir / "multiview.png"
199
  multiview_image.save(multiview_path)
200
 
 
 
 
201
  # Prepare 3D reconstruction input
202
  view_arrays = [np.array(view) for view in selected_views]
203
  lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
@@ -227,12 +232,15 @@ class FaceLiftPipeline:
227
  })
228
 
229
  # Run 3D reconstruction
230
- with torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
231
  result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
232
 
233
  comp_image = result.render[0].unsqueeze(0).detach()
234
  gaussians = result.gaussians[0]
235
 
 
 
 
236
  # Save filtered gaussians
237
  filtered_gaussians = gaussians.apply_all_filters(
238
  cam_origins=None,
@@ -252,15 +260,20 @@ class FaceLiftPipeline:
252
  output_path = output_dir / "output.png"
253
  Image.fromarray(comp_image).save(output_path)
254
 
255
- # Generate turntable video
256
- turntable_frames = render_turntable(gaussians, rendering_resolution=self.image_size,
257
- num_views=180)
258
- turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=180)
 
 
259
  turntable_frames = np.ascontiguousarray(turntable_frames)
260
 
261
  turntable_path = output_dir / "turntable.mp4"
262
  imageseq2video(turntable_frames, str(turntable_path), fps=30)
263
 
 
 
 
264
  return str(input_path), str(multiview_path), str(output_path), \
265
  str(turntable_path), str(ply_path)
266
 
 
93
 
94
  # Parameters
95
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
96
+ self.image_size = 384 # Reduced from 512 for ZeroGPU memory constraints
97
  self.camera_indices = [2, 1, 0, 5, 4, 3]
98
 
99
  # Load models (keep on CPU for ZeroGPU compatibility)
 
147
  self.mvdiffusion_pipeline.to(self.device)
148
  self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
149
  self.gs_lrm_model.to(self.device)
150
+ self.gs_lrm_model.eval() # Set to eval mode
151
  self.color_prompt_embedding = self.color_prompt_embedding.to(self.device)
152
  self._models_on_gpu = True
153
+ torch.cuda.empty_cache() # Clear cache after moving models
154
  print("Models on GPU, xformers enabled!")
155
 
156
  @spaces.GPU(duration=120)
 
200
  multiview_path = output_dir / "multiview.png"
201
  multiview_image.save(multiview_path)
202
 
203
+ # Clear CUDA cache after diffusion to free memory
204
+ torch.cuda.empty_cache()
205
+
206
  # Prepare 3D reconstruction input
207
  view_arrays = [np.array(view) for view in selected_views]
208
  lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
 
232
  })
233
 
234
  # Run 3D reconstruction
235
+ with torch.no_grad(), torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
236
  result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
237
 
238
  comp_image = result.render[0].unsqueeze(0).detach()
239
  gaussians = result.gaussians[0]
240
 
241
+ # Clear CUDA cache after reconstruction
242
+ torch.cuda.empty_cache()
243
+
244
  # Save filtered gaussians
245
  filtered_gaussians = gaussians.apply_all_filters(
246
  cam_origins=None,
 
260
  output_path = output_dir / "output.png"
261
  Image.fromarray(comp_image).save(output_path)
262
 
263
+ # Generate turntable video (reduced resolution and frames for ZeroGPU memory limits)
264
+ turntable_resolution = 256 # Lower resolution for turntable to save memory
265
+ num_turntable_views = 120 # Reduced from 180
266
+ turntable_frames = render_turntable(gaussians, rendering_resolution=turntable_resolution,
267
+ num_views=num_turntable_views)
268
+ turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=num_turntable_views)
269
  turntable_frames = np.ascontiguousarray(turntable_frames)
270
 
271
  turntable_path = output_dir / "turntable.mp4"
272
  imageseq2video(turntable_frames, str(turntable_path), fps=30)
273
 
274
+ # Final CUDA cache clear
275
+ torch.cuda.empty_cache()
276
+
277
  return str(input_path), str(multiview_path), str(output_path), \
278
  str(turntable_path), str(ply_path)
279