naicoi commited on
Commit
a353e2d
Β·
1 Parent(s): f5651ba

Squash merge opencode/gentle-cactus into pr/3

Browse files
Files changed (2) hide show
  1. app.py +5 -1
  2. latentsync/pipelines/lipsync_pipeline.py +199 -44
app.py CHANGED
@@ -3,9 +3,13 @@ OutofLipSync - Lipsync Only Application
3
  Main Gradio UI module
4
  """
5
 
 
 
 
 
 
6
  import logging
7
  import sys
8
- import os
9
  import shutil
10
 
11
  import gradio as gr
 
3
  Main Gradio UI module
4
  """
5
 
6
+ import os
7
+
8
+ # Optimize PyTorch memory allocation to reduce fragmentation
9
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
10
+
11
  import logging
12
  import sys
 
13
  import shutil
14
 
15
  import gradio as gr
latentsync/pipelines/lipsync_pipeline.py CHANGED
@@ -59,7 +59,10 @@ class LipsyncPipeline(DiffusionPipeline):
59
  ):
60
  super().__init__()
61
 
62
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
 
 
 
63
  deprecation_message = (
64
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
65
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -68,12 +71,17 @@ class LipsyncPipeline(DiffusionPipeline):
68
  " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
69
  " file"
70
  )
71
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
 
 
72
  new_config = dict(scheduler.config)
73
  new_config["steps_offset"] = 1
74
  scheduler._internal_dict = FrozenDict(new_config)
75
 
76
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
 
 
 
77
  deprecation_message = (
78
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
79
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -81,15 +89,21 @@ class LipsyncPipeline(DiffusionPipeline):
81
  " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
82
  " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
83
  )
84
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
 
 
85
  new_config = dict(scheduler.config)
86
  new_config["clip_sample"] = False
87
  scheduler._internal_dict = FrozenDict(new_config)
88
 
89
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
 
 
90
  version.parse(unet.config._diffusers_version).base_version
91
  ) < version.parse("0.9.0.dev0")
92
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
 
 
93
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
94
  deprecation_message = (
95
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -97,12 +111,14 @@ class LipsyncPipeline(DiffusionPipeline):
97
  " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
98
  " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
99
  " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
100
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
101
  " in the config might lead to incorrect results in future versions. If you have downloaded this"
102
  " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
103
  " the `unet/config.json` file"
104
  )
105
- deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
 
 
106
  new_config = dict(unet.config)
107
  new_config["sample_size"] = 64
108
  unet._internal_dict = FrozenDict(new_config)
@@ -138,7 +154,9 @@ class LipsyncPipeline(DiffusionPipeline):
138
  return self.device
139
 
140
  def decode_latents(self, latents):
141
- latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
 
 
142
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
143
  decoded_latents = self.vae.decode(latents).sample
144
  return decoded_latents
@@ -149,13 +167,17 @@ class LipsyncPipeline(DiffusionPipeline):
149
  # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
150
  # and should be between [0, 1]
151
 
152
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
153
  extra_step_kwargs = {}
154
  if accepts_eta:
155
  extra_step_kwargs["eta"] = eta
156
 
157
  # check if the scheduler accepts generator
158
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
159
  if accepts_generator:
160
  extra_step_kwargs["generator"] = generator
161
  return extra_step_kwargs
@@ -164,17 +186,22 @@ class LipsyncPipeline(DiffusionPipeline):
164
  assert height == width, "Height and width must be equal"
165
 
166
  if height % 8 != 0 or width % 8 != 0:
167
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
 
168
 
169
  if (callback_steps is None) or (
170
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
 
171
  ):
172
  raise ValueError(
173
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
174
  f" {type(callback_steps)}."
175
  )
176
 
177
- def prepare_latents(self, num_frames, num_channels_latents, height, width, dtype, device, generator):
 
 
178
  shape = (
179
  1,
180
  num_channels_latents,
@@ -183,7 +210,9 @@ class LipsyncPipeline(DiffusionPipeline):
183
  width // self.vae_scale_factor,
184
  ) # (b, c, f, h, w)
185
  rand_device = "cpu" if device.type == "mps" else device
186
- latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
 
 
187
  latents = latents.repeat(1, 1, num_frames, 1, 1)
188
 
189
  # scale the initial noise by the standard deviation required by the scheduler
@@ -191,7 +220,15 @@ class LipsyncPipeline(DiffusionPipeline):
191
  return latents
192
 
193
  def prepare_mask_latents(
194
- self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
 
 
 
 
 
 
 
 
195
  ):
196
  # resize the mask to latents shape as we concatenate the mask to the latents
197
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
@@ -202,8 +239,12 @@ class LipsyncPipeline(DiffusionPipeline):
202
  masked_image = masked_image.to(device=device, dtype=dtype)
203
 
204
  # encode the mask image into latents space so we can concatenate it to the latents
205
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
206
- masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
 
 
 
 
207
 
208
  # aligning device to prevent device errors when concating it with the latent model input
209
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
@@ -215,16 +256,26 @@ class LipsyncPipeline(DiffusionPipeline):
215
 
216
  mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
217
  masked_image_latents = (
218
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
 
 
219
  )
220
  return mask, masked_image_latents
221
 
222
- def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
 
 
223
  images = images.to(device=device, dtype=dtype)
224
  image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
225
- image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
 
 
226
  image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
227
- image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
 
 
 
 
228
 
229
  return image_latents
230
 
@@ -234,7 +285,9 @@ class LipsyncPipeline(DiffusionPipeline):
234
  self._progress_bar_config.update(kwargs)
235
 
236
  @staticmethod
237
- def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
 
 
238
  # Paste the surrounding pixels back, because we only want to change the mouth region
239
  pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
240
  masks = masks.to(device=device, dtype=weight_dtype)
@@ -263,7 +316,13 @@ class LipsyncPipeline(DiffusionPipeline):
263
  faces = torch.stack(faces)
264
  return faces, boxes, affine_matrices
265
 
266
- def restore_video(self, faces: torch.Tensor, video_frames: np.ndarray, boxes: list, affine_matrices: list):
 
 
 
 
 
 
267
  video_frames = video_frames[: len(faces)]
268
  out_frames = []
269
  print(f"Restoring {len(faces)} faces...")
@@ -272,12 +331,54 @@ class LipsyncPipeline(DiffusionPipeline):
272
  height = int(y2 - y1)
273
  width = int(x2 - x1)
274
  face = torchvision.transforms.functional.resize(
275
- face, size=(height, width), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
 
 
 
 
 
 
276
  )
277
- out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
278
  out_frames.append(out_frame)
279
  return np.stack(out_frames, axis=0)
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def loop_video(self, whisper_chunks: list, video_frames: np.ndarray):
282
  # If the audio is longer than the video, we need to loop the video
283
  if len(whisper_chunks) > len(video_frames):
@@ -299,7 +400,9 @@ class LipsyncPipeline(DiffusionPipeline):
299
  loop_boxes += boxes[::-1]
300
  loop_affine_matrices += affine_matrices[::-1]
301
 
302
- video_frames = np.concatenate(loop_video_frames, axis=0)[: len(whisper_chunks)]
 
 
303
  faces = torch.cat(loop_faces, dim=0)[: len(whisper_chunks)]
304
  boxes = loop_boxes[: len(whisper_chunks)]
305
  affine_matrices = loop_affine_matrices[: len(whisper_chunks)]
@@ -339,7 +442,9 @@ class LipsyncPipeline(DiffusionPipeline):
339
  # 0. Define call parameters
340
  device = self._execution_device
341
  mask_image = load_fixed_mask(height, mask_image_path)
342
- self.image_processor = ImageProcessor(height, device="cuda", mask_image=mask_image)
 
 
343
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
344
 
345
  # 1. Default height and width to unet
@@ -362,12 +467,16 @@ class LipsyncPipeline(DiffusionPipeline):
362
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
363
 
364
  whisper_feature = self.audio_encoder.audio2feat(audio_path)
365
- whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
 
 
366
 
367
  audio_samples = read_audio(audio_path)
368
  video_frames = read_video(video_path, use_decord=False)
369
 
370
- video_frames, faces, boxes, affine_matrices = self.loop_video(whisper_chunks, video_frames)
 
 
371
 
372
  synced_video_frames = []
373
 
@@ -387,7 +496,9 @@ class LipsyncPipeline(DiffusionPipeline):
387
  num_inferences = math.ceil(len(whisper_chunks) / num_frames)
388
  for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
389
  if self.unet.add_audio_layer:
390
- audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
 
 
391
  audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
392
  if do_classifier_free_guidance:
393
  null_audio_embeds = torch.zeros_like(audio_embeds)
@@ -396,8 +507,10 @@ class LipsyncPipeline(DiffusionPipeline):
396
  audio_embeds = None
397
  inference_faces = faces[i * num_frames : (i + 1) * num_frames]
398
  latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
399
- ref_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
400
- inference_faces, affine_transform=False
 
 
401
  )
402
 
403
  # 7. Prepare mask latent variables
@@ -422,30 +535,48 @@ class LipsyncPipeline(DiffusionPipeline):
422
  )
423
 
424
  # 9. Denoising loop
425
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
 
426
  with self.progress_bar(total=num_inference_steps) as progress_bar:
427
  for j, t in enumerate(timesteps):
428
  # expand the latents if we are doing classifier free guidance
429
- unet_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
 
 
 
430
 
431
  unet_input = self.scheduler.scale_model_input(unet_input, t)
432
 
433
  # concat latents, mask, masked_image_latents in the channel dimension
434
- unet_input = torch.cat([unet_input, mask_latents, masked_image_latents, ref_latents], dim=1)
 
 
 
435
 
436
  # predict the noise residual
437
- noise_pred = self.unet(unet_input, t, encoder_hidden_states=audio_embeds).sample
 
 
438
 
439
  # perform guidance
440
  if do_classifier_free_guidance:
441
  noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
442
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
 
 
443
 
444
  # compute the previous noisy sample x_t -> x_t-1
445
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
446
 
447
  # call the callback, if provided
448
- if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
 
 
 
449
  progress_bar.update()
450
  if callback is not None and j % callback_steps == 0:
451
  callback(j, t, latents)
@@ -455,11 +586,33 @@ class LipsyncPipeline(DiffusionPipeline):
455
  decoded_latents = self.paste_surrounding_pixels_back(
456
  decoded_latents, ref_pixel_values, 1 - masks, device, weight_dtype
457
  )
458
- synced_video_frames.append(decoded_latents)
459
 
460
- synced_video_frames = self.restore_video(torch.cat(synced_video_frames), video_frames, boxes, affine_matrices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
 
 
463
  audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
464
 
465
  if is_train:
@@ -469,7 +622,9 @@ class LipsyncPipeline(DiffusionPipeline):
469
  shutil.rmtree(temp_dir)
470
  os.makedirs(temp_dir, exist_ok=True)
471
 
472
- write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=video_fps)
 
 
473
 
474
  sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
475
 
 
59
  ):
60
  super().__init__()
61
 
62
+ if (
63
+ hasattr(scheduler.config, "steps_offset")
64
+ and scheduler.config.steps_offset != 1
65
+ ):
66
  deprecation_message = (
67
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
68
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
 
71
  " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
72
  " file"
73
  )
74
+ deprecate(
75
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
76
+ )
77
  new_config = dict(scheduler.config)
78
  new_config["steps_offset"] = 1
79
  scheduler._internal_dict = FrozenDict(new_config)
80
 
81
+ if (
82
+ hasattr(scheduler.config, "clip_sample")
83
+ and scheduler.config.clip_sample is True
84
+ ):
85
  deprecation_message = (
86
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
87
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
 
89
  " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
90
  " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
91
  )
92
+ deprecate(
93
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
94
+ )
95
  new_config = dict(scheduler.config)
96
  new_config["clip_sample"] = False
97
  scheduler._internal_dict = FrozenDict(new_config)
98
 
99
+ is_unet_version_less_0_9_0 = hasattr(
100
+ unet.config, "_diffusers_version"
101
+ ) and version.parse(
102
  version.parse(unet.config._diffusers_version).base_version
103
  ) < version.parse("0.9.0.dev0")
104
+ is_unet_sample_size_less_64 = (
105
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
106
+ )
107
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
108
  deprecation_message = (
109
  "The configuration file of the unet has set the default `sample_size` to smaller than"
 
111
  " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
112
  " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
113
  " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
114
+ " configuration file. Please make sure to update the config accordingly as leaving 'sample_size=32'"
115
  " in the config might lead to incorrect results in future versions. If you have downloaded this"
116
  " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
117
  " the `unet/config.json` file"
118
  )
119
+ deprecate(
120
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
121
+ )
122
  new_config = dict(unet.config)
123
  new_config["sample_size"] = 64
124
  unet._internal_dict = FrozenDict(new_config)
 
154
  return self.device
155
 
156
  def decode_latents(self, latents):
157
+ latents = (
158
+ latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
159
+ )
160
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
161
  decoded_latents = self.vae.decode(latents).sample
162
  return decoded_latents
 
167
  # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
168
  # and should be between [0, 1]
169
 
170
+ accepts_eta = "eta" in set(
171
+ inspect.signature(self.scheduler.step).parameters.keys()
172
+ )
173
  extra_step_kwargs = {}
174
  if accepts_eta:
175
  extra_step_kwargs["eta"] = eta
176
 
177
  # check if the scheduler accepts generator
178
+ accepts_generator = "generator" in set(
179
+ inspect.signature(self.scheduler.step).parameters.keys()
180
+ )
181
  if accepts_generator:
182
  extra_step_kwargs["generator"] = generator
183
  return extra_step_kwargs
 
186
  assert height == width, "Height and width must be equal"
187
 
188
  if height % 8 != 0 or width % 8 != 0:
189
+ raise ValueError(
190
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
191
+ )
192
 
193
  if (callback_steps is None) or (
194
+ callback_steps is not None
195
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
196
  ):
197
  raise ValueError(
198
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
199
  f" {type(callback_steps)}."
200
  )
201
 
202
+ def prepare_latents(
203
+ self, num_frames, num_channels_latents, height, width, dtype, device, generator
204
+ ):
205
  shape = (
206
  1,
207
  num_channels_latents,
 
210
  width // self.vae_scale_factor,
211
  ) # (b, c, f, h, w)
212
  rand_device = "cpu" if device.type == "mps" else device
213
+ latents = torch.randn(
214
+ shape, generator=generator, device=rand_device, dtype=dtype
215
+ ).to(device)
216
  latents = latents.repeat(1, 1, num_frames, 1, 1)
217
 
218
  # scale the initial noise by the standard deviation required by the scheduler
 
220
  return latents
221
 
222
  def prepare_mask_latents(
223
+ self,
224
+ mask,
225
+ masked_image,
226
+ height,
227
+ width,
228
+ dtype,
229
+ device,
230
+ generator,
231
+ do_classifier_free_guidance,
232
  ):
233
  # resize the mask to latents shape as we concatenate the mask to the latents
234
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
 
239
  masked_image = masked_image.to(device=device, dtype=dtype)
240
 
241
  # encode the mask image into latents space so we can concatenate it to the latents
242
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
243
+ generator=generator
244
+ )
245
+ masked_image_latents = (
246
+ masked_image_latents - self.vae.config.shift_factor
247
+ ) * self.vae.config.scaling_factor
248
 
249
  # aligning device to prevent device errors when concating it with the latent model input
250
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
 
256
 
257
  mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
258
  masked_image_latents = (
259
+ torch.cat([masked_image_latents] * 2)
260
+ if do_classifier_free_guidance
261
+ else masked_image_latents
262
  )
263
  return mask, masked_image_latents
264
 
265
+ def prepare_image_latents(
266
+ self, images, device, dtype, generator, do_classifier_free_guidance
267
+ ):
268
  images = images.to(device=device, dtype=dtype)
269
  image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
270
+ image_latents = (
271
+ image_latents - self.vae.config.shift_factor
272
+ ) * self.vae.config.scaling_factor
273
  image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
274
+ image_latents = (
275
+ torch.cat([image_latents] * 2)
276
+ if do_classifier_free_guidance
277
+ else image_latents
278
+ )
279
 
280
  return image_latents
281
 
 
285
  self._progress_bar_config.update(kwargs)
286
 
287
  @staticmethod
288
+ def paste_surrounding_pixels_back(
289
+ decoded_latents, pixel_values, masks, device, weight_dtype
290
+ ):
291
  # Paste the surrounding pixels back, because we only want to change the mouth region
292
  pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
293
  masks = masks.to(device=device, dtype=weight_dtype)
 
316
  faces = torch.stack(faces)
317
  return faces, boxes, affine_matrices
318
 
319
+ def restore_video(
320
+ self,
321
+ faces: torch.Tensor,
322
+ video_frames: np.ndarray,
323
+ boxes: list,
324
+ affine_matrices: list,
325
+ ):
326
  video_frames = video_frames[: len(faces)]
327
  out_frames = []
328
  print(f"Restoring {len(faces)} faces...")
 
331
  height = int(y2 - y1)
332
  width = int(x2 - x1)
333
  face = torchvision.transforms.functional.resize(
334
+ face,
335
+ size=(height, width),
336
+ interpolation=transforms.InterpolationMode.BICUBIC,
337
+ antialias=True,
338
+ )
339
+ out_frame = self.image_processor.restorer.restore_img(
340
+ video_frames[index], face, affine_matrices[index]
341
  )
 
342
  out_frames.append(out_frame)
343
  return np.stack(out_frames, axis=0)
344
 
345
+ def restore_video_from_cpu(
346
+ self,
347
+ faces_list: List[torch.Tensor],
348
+ video_frames: np.ndarray,
349
+ boxes: list,
350
+ affine_matrices: list,
351
+ ):
352
+ """Restore video when faces are stored on CPU to save GPU memory"""
353
+ video_frames = video_frames[: len(faces_list)]
354
+ out_frames = []
355
+ device = self._execution_device
356
+ print(f"Restoring {len(faces_list)} faces from CPU to GPU {device}...")
357
+
358
+ for index, face_cpu in enumerate(tqdm.tqdm(faces_list)):
359
+ # Move frame to GPU only when needed for restoration
360
+ face = face_cpu.to(device=device, dtype=torch.float16)
361
+
362
+ x1, y1, x2, y2 = boxes[index]
363
+ height = int(y2 - y1)
364
+ width = int(x2 - x1)
365
+ face = torchvision.transforms.functional.resize(
366
+ face,
367
+ size=(height, width),
368
+ interpolation=transforms.InterpolationMode.BICUBIC,
369
+ antialias=True,
370
+ )
371
+ out_frame = self.image_processor.restorer.restore_img(
372
+ video_frames[index], face, affine_matrices[index]
373
+ )
374
+ out_frames.append(out_frame)
375
+
376
+ # Explicitly free GPU memory for this frame
377
+ del face
378
+ torch.cuda.empty_cache()
379
+
380
+ return np.stack(out_frames, axis=0)
381
+
382
  def loop_video(self, whisper_chunks: list, video_frames: np.ndarray):
383
  # If the audio is longer than the video, we need to loop the video
384
  if len(whisper_chunks) > len(video_frames):
 
400
  loop_boxes += boxes[::-1]
401
  loop_affine_matrices += affine_matrices[::-1]
402
 
403
+ video_frames = np.concatenate(loop_video_frames, axis=0)[
404
+ : len(whisper_chunks)
405
+ ]
406
  faces = torch.cat(loop_faces, dim=0)[: len(whisper_chunks)]
407
  boxes = loop_boxes[: len(whisper_chunks)]
408
  affine_matrices = loop_affine_matrices[: len(whisper_chunks)]
 
442
  # 0. Define call parameters
443
  device = self._execution_device
444
  mask_image = load_fixed_mask(height, mask_image_path)
445
+ self.image_processor = ImageProcessor(
446
+ height, device="cuda", mask_image=mask_image
447
+ )
448
  self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
449
 
450
  # 1. Default height and width to unet
 
467
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
468
 
469
  whisper_feature = self.audio_encoder.audio2feat(audio_path)
470
+ whisper_chunks = self.audio_encoder.feature2chunks(
471
+ feature_array=whisper_feature, fps=video_fps
472
+ )
473
 
474
  audio_samples = read_audio(audio_path)
475
  video_frames = read_video(video_path, use_decord=False)
476
 
477
+ video_frames, faces, boxes, affine_matrices = self.loop_video(
478
+ whisper_chunks, video_frames
479
+ )
480
 
481
  synced_video_frames = []
482
 
 
496
  num_inferences = math.ceil(len(whisper_chunks) / num_frames)
497
  for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
498
  if self.unet.add_audio_layer:
499
+ audio_embeds = torch.stack(
500
+ whisper_chunks[i * num_frames : (i + 1) * num_frames]
501
+ )
502
  audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
503
  if do_classifier_free_guidance:
504
  null_audio_embeds = torch.zeros_like(audio_embeds)
 
507
  audio_embeds = None
508
  inference_faces = faces[i * num_frames : (i + 1) * num_frames]
509
  latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
510
+ ref_pixel_values, masked_pixel_values, masks = (
511
+ self.image_processor.prepare_masks_and_masked_images(
512
+ inference_faces, affine_transform=False
513
+ )
514
  )
515
 
516
  # 7. Prepare mask latent variables
 
535
  )
536
 
537
  # 9. Denoising loop
538
+ num_warmup_steps = (
539
+ len(timesteps) - num_inference_steps * self.scheduler.order
540
+ )
541
  with self.progress_bar(total=num_inference_steps) as progress_bar:
542
  for j, t in enumerate(timesteps):
543
  # expand the latents if we are doing classifier free guidance
544
+ unet_input = (
545
+ torch.cat([latents] * 2)
546
+ if do_classifier_free_guidance
547
+ else latents
548
+ )
549
 
550
  unet_input = self.scheduler.scale_model_input(unet_input, t)
551
 
552
  # concat latents, mask, masked_image_latents in the channel dimension
553
+ unet_input = torch.cat(
554
+ [unet_input, mask_latents, masked_image_latents, ref_latents],
555
+ dim=1,
556
+ )
557
 
558
  # predict the noise residual
559
+ noise_pred = self.unet(
560
+ unet_input, t, encoder_hidden_states=audio_embeds
561
+ ).sample
562
 
563
  # perform guidance
564
  if do_classifier_free_guidance:
565
  noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
566
+ noise_pred = noise_pred_uncond + guidance_scale * (
567
+ noise_pred_audio - noise_pred_uncond
568
+ )
569
 
570
  # compute the previous noisy sample x_t -> x_t-1
571
+ latents = self.scheduler.step(
572
+ noise_pred, t, latents, **extra_step_kwargs
573
+ ).prev_sample
574
 
575
  # call the callback, if provided
576
+ if j == len(timesteps) - 1 or (
577
+ (j + 1) > num_warmup_steps
578
+ and (j + 1) % self.scheduler.order == 0
579
+ ):
580
  progress_bar.update()
581
  if callback is not None and j % callback_steps == 0:
582
  callback(j, t, latents)
 
586
  decoded_latents = self.paste_surrounding_pixels_back(
587
  decoded_latents, ref_pixel_values, 1 - masks, device, weight_dtype
588
  )
 
589
 
590
+ # Move decoded latents to CPU to save GPU memory
591
+ decoded_latents_cpu = decoded_latents.cpu()
592
+ synced_video_frames.append(decoded_latents_cpu)
593
+
594
+ # Explicitly clear GPU memory
595
+ del decoded_latents
596
+ del ref_pixel_values
597
+ del masked_pixel_values
598
+ del masks
599
+ del mask_latents
600
+ del masked_image_latents
601
+ del ref_latents
602
+ del latents
603
+ if do_classifier_free_guidance:
604
+ del noise_pred_uncond, noise_pred_audio
605
+ del noise_pred
606
+ torch.cuda.empty_cache()
607
+
608
+ # Restore video from CPU tensors to save GPU memory
609
+ synced_video_frames = self.restore_video_from_cpu(
610
+ synced_video_frames, video_frames, boxes, affine_matrices
611
+ )
612
 
613
+ audio_samples_remain_length = int(
614
+ synced_video_frames.shape[0] / video_fps * audio_sample_rate
615
+ )
616
  audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
617
 
618
  if is_train:
 
622
  shutil.rmtree(temp_dir)
623
  os.makedirs(temp_dir, exist_ok=True)
624
 
625
+ write_video(
626
+ os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=video_fps
627
+ )
628
 
629
  sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
630