app.py CHANGED
@@ -26,11 +26,19 @@ sys.modules["torchvision.transforms.functional_tensor"] = _F
26
  os.environ["PROCESSED_RESULTS"] = os.path.join(os.getcwd(), "processed_results")
27
  os.makedirs(os.environ["PROCESSED_RESULTS"], exist_ok=True)
28
 
29
- src = "checkpoints"
30
  dst = os.path.expanduser("~/.cache/torch/hub/checkpoints")
31
 
32
  os.makedirs(dst, exist_ok=True)
33
 
 
 
 
 
 
 
 
 
34
  print("Done copying checkpoints!")
35
 
36
  print("Loading LatentSync models...")
 
26
  os.environ["PROCESSED_RESULTS"] = os.path.join(os.getcwd(), "processed_results")
27
  os.makedirs(os.environ["PROCESSED_RESULTS"], exist_ok=True)
28
 
29
+ src = "/models"
30
  dst = os.path.expanduser("~/.cache/torch/hub/checkpoints")
31
 
32
  os.makedirs(dst, exist_ok=True)
33
 
34
+ if os.path.exists(src):
35
+ for item in os.listdir(src):
36
+ src_path = os.path.join(src, item)
37
+ dst_path = os.path.join(dst, item)
38
+ if os.path.isfile(src_path) and not os.path.exists(dst_path):
39
+ shutil.copy2(src_path, dst_path)
40
+ print(f"Copied {item} to {dst}")
41
+
42
  print("Done copying checkpoints!")
43
 
44
  print("Loading LatentSync models...")
config.py CHANGED
@@ -1,5 +1,10 @@
1
  """Configuration constants and global settings for OutofLipSync"""
2
 
 
 
 
 
 
3
  # Video settings
4
  DEFAULT_DURATION = 10
5
  MIN_DURATION = 5
 
1
  """Configuration constants and global settings for OutofLipSync"""
2
 
3
+ import os
4
+
5
+ # Models directory - from environment variable or default to /models
6
+ MODELS_DIR = os.getenv("MODELS_DIR", "/models")
7
+
8
  # Video settings
9
  DEFAULT_DURATION = 10
10
  MIN_DURATION = 5
eval/eval_syncnet_acc.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
 
15
  import argparse
 
 
16
  from tqdm.auto import tqdm
17
  import torch
18
  import torch.nn as nn
@@ -23,6 +25,9 @@ from diffusers import AutoencoderKL
23
  from omegaconf import OmegaConf
24
  from accelerate.utils import set_seed
25
 
 
 
 
26
 
27
  def main(config):
28
  set_seed(config.run.seed)
@@ -31,13 +36,19 @@ def main(config):
31
 
32
  if config.data.latent_space:
33
  vae = AutoencoderKL.from_pretrained(
34
- "runwayml/stable-diffusion-inpainting", subfolder="vae", revision="fp16", torch_dtype=torch.float16
 
 
 
 
35
  )
36
  vae.requires_grad_(False)
37
  vae.to(device)
38
 
39
  # Dataset and Dataloader setup
40
- dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
 
 
41
 
42
  test_dataloader = torch.utils.data.DataLoader(
43
  dataset,
@@ -52,7 +63,9 @@ def main(config):
52
  syncnet = StableSyncNet(OmegaConf.to_container(config.model)).to(device)
53
 
54
  print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}")
55
- checkpoint = torch.load(config.ckpt.inference_ckpt_path, map_location=device, weights_only=True)
 
 
56
 
57
  syncnet.load_state_dict(checkpoint["state_dict"])
58
  syncnet.to(dtype=torch.float16)
@@ -80,7 +93,9 @@ def main(config):
80
  with torch.no_grad():
81
  frames = vae.encode(frames).latent_dist.sample() * 0.18215
82
 
83
- frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
 
 
84
  else:
85
  frames = rearrange(frames, "b f c h w -> b (f c) h w")
86
 
@@ -102,14 +117,18 @@ def main(config):
102
 
103
  if global_step >= num_val_batches:
104
  progress_bar.close()
105
- print(f"SyncNet Accuracy: {num_correct_preds / num_total_preds*100:.2f}%")
 
 
106
  return
107
 
108
 
109
  if __name__ == "__main__":
110
  parser = argparse.ArgumentParser(description="Code to test the accuracy of SyncNet")
111
 
112
- parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml")
 
 
113
  args = parser.parse_args()
114
 
115
  # Load a configuration file
 
13
  # limitations under the License.
14
 
15
  import argparse
16
+ import os
17
+ import sys
18
  from tqdm.auto import tqdm
19
  import torch
20
  import torch.nn as nn
 
25
  from omegaconf import OmegaConf
26
  from accelerate.utils import set_seed
27
 
28
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
29
+ from config import MODELS_DIR
30
+
31
 
32
  def main(config):
33
  set_seed(config.run.seed)
 
36
 
37
  if config.data.latent_space:
38
  vae = AutoencoderKL.from_pretrained(
39
+ "runwayml/stable-diffusion-inpainting",
40
+ subfolder="vae",
41
+ revision="fp16",
42
+ torch_dtype=torch.float16,
43
+ cache_dir=MODELS_DIR,
44
  )
45
  vae.requires_grad_(False)
46
  vae.to(device)
47
 
48
  # Dataset and Dataloader setup
49
+ dataset = SyncNetDataset(
50
+ config.data.val_data_dir, config.data.val_fileslist, config
51
+ )
52
 
53
  test_dataloader = torch.utils.data.DataLoader(
54
  dataset,
 
63
  syncnet = StableSyncNet(OmegaConf.to_container(config.model)).to(device)
64
 
65
  print(f"Load checkpoint from: {config.ckpt.inference_ckpt_path}")
66
+ checkpoint = torch.load(
67
+ config.ckpt.inference_ckpt_path, map_location=device, weights_only=True
68
+ )
69
 
70
  syncnet.load_state_dict(checkpoint["state_dict"])
71
  syncnet.to(dtype=torch.float16)
 
93
  with torch.no_grad():
94
  frames = vae.encode(frames).latent_dist.sample() * 0.18215
95
 
96
+ frames = rearrange(
97
+ frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames
98
+ )
99
  else:
100
  frames = rearrange(frames, "b f c h w -> b (f c) h w")
101
 
 
117
 
118
  if global_step >= num_val_batches:
119
  progress_bar.close()
120
+ print(
121
+ f"SyncNet Accuracy: {num_correct_preds / num_total_preds * 100:.2f}%"
122
+ )
123
  return
124
 
125
 
126
  if __name__ == "__main__":
127
  parser = argparse.ArgumentParser(description="Code to test the accuracy of SyncNet")
128
 
129
+ parser.add_argument(
130
+ "--config_path", type=str, default="configs/syncnet/syncnet_16_latent.yaml"
131
+ )
132
  args = parser.parse_args()
133
 
134
  # Load a configuration file
latentsync/utils/util.py CHANGED
@@ -49,9 +49,7 @@ def read_video(video_path: str, change_fps=True, use_decord=True):
49
  if os.path.exists(temp_dir):
50
  shutil.rmtree(temp_dir)
51
  os.makedirs(temp_dir, exist_ok=True)
52
- command = (
53
- f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
54
- )
55
  subprocess.run(command, shell=True)
56
  target_video_path = os.path.join(temp_dir, "video.mp4")
57
  else:
@@ -127,7 +125,9 @@ def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
127
 
128
  def write_video_cv2(video_output_path: str, video_frames: np.ndarray, fps: int):
129
  height, width = video_frames[0].shape[:2]
130
- out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
 
 
131
  # out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
132
  for frame in video_frames:
133
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
@@ -162,7 +162,9 @@ def check_video_fps(video_path: str):
162
  cam = cv2.VideoCapture(video_path)
163
  fps = cam.get(cv2.CAP_PROP_FPS)
164
  if fps != 25:
165
- raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
 
 
166
 
167
 
168
  def one_step_sampling(ddim_scheduler, pred_noise, timesteps, x_t):
@@ -175,7 +177,9 @@ def one_step_sampling(ddim_scheduler, pred_noise, timesteps, x_t):
175
  if ddim_scheduler.config.prediction_type == "epsilon":
176
  beta_prod_t = beta_prod_t[:, None, None, None, None]
177
  alpha_prod_t = alpha_prod_t[:, None, None, None, None]
178
- pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
 
 
179
  else:
180
  raise NotImplementedError("This prediction type is not implemented yet")
181
 
@@ -269,16 +273,25 @@ def count_video_time(video_path):
269
 
270
  def check_ffmpeg_installed():
271
  # Run the ffmpeg command with the -version argument to check if it's installed
272
- result = subprocess.run("ffmpeg -version", stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
 
 
273
  if not result.returncode == 0:
274
- raise FileNotFoundError("ffmpeg not found, please install it by:\n $ conda install -c conda-forge ffmpeg")
 
 
275
 
276
 
277
- def check_model_and_download(ckpt_path: str, huggingface_model_id: str = "ByteDance/LatentSync-1.5"):
 
 
278
  if not os.path.exists(ckpt_path):
279
  ckpt_path_obj = Path(ckpt_path)
280
  download_cmd = f"huggingface-cli download {huggingface_model_id} {Path(*ckpt_path_obj.parts[1:])} --local-dir {Path(ckpt_path_obj.parts[0])}"
281
  subprocess.run(download_cmd, shell=True)
 
 
 
282
 
283
 
284
  class dummy_context:
 
49
  if os.path.exists(temp_dir):
50
  shutil.rmtree(temp_dir)
51
  os.makedirs(temp_dir, exist_ok=True)
52
+ command = f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
 
 
53
  subprocess.run(command, shell=True)
54
  target_video_path = os.path.join(temp_dir, "video.mp4")
55
  else:
 
125
 
126
  def write_video_cv2(video_output_path: str, video_frames: np.ndarray, fps: int):
127
  height, width = video_frames[0].shape[:2]
128
+ out = cv2.VideoWriter(
129
+ video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
130
+ )
131
  # out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
132
  for frame in video_frames:
133
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
 
162
  cam = cv2.VideoCapture(video_path)
163
  fps = cam.get(cv2.CAP_PROP_FPS)
164
  if fps != 25:
165
+ raise ValueError(
166
+ f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS."
167
+ )
168
 
169
 
170
  def one_step_sampling(ddim_scheduler, pred_noise, timesteps, x_t):
 
177
  if ddim_scheduler.config.prediction_type == "epsilon":
178
  beta_prod_t = beta_prod_t[:, None, None, None, None]
179
  alpha_prod_t = alpha_prod_t[:, None, None, None, None]
180
+ pred_original_sample = (
181
+ x_t - beta_prod_t ** (0.5) * pred_noise
182
+ ) / alpha_prod_t ** (0.5)
183
  else:
184
  raise NotImplementedError("This prediction type is not implemented yet")
185
 
 
273
 
274
  def check_ffmpeg_installed():
275
  # Run the ffmpeg command with the -version argument to check if it's installed
276
+ result = subprocess.run(
277
+ "ffmpeg -version", stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
278
+ )
279
  if not result.returncode == 0:
280
+ raise FileNotFoundError(
281
+ "ffmpeg not found, please install it by:\n $ conda install -c conda-forge ffmpeg"
282
+ )
283
 
284
 
285
+ def check_model_and_download(
286
+ ckpt_path: str, huggingface_model_id: str = "ByteDance/LatentSync-1.5"
287
+ ):
288
  if not os.path.exists(ckpt_path):
289
  ckpt_path_obj = Path(ckpt_path)
290
  download_cmd = f"huggingface-cli download {huggingface_model_id} {Path(*ckpt_path_obj.parts[1:])} --local-dir {Path(ckpt_path_obj.parts[0])}"
291
  subprocess.run(download_cmd, shell=True)
292
+ print(f"Downloaded model to {ckpt_path}")
293
+ else:
294
+ print(f"Model already exists: {ckpt_path}")
295
 
296
 
297
  class dummy_context:
latentsync/whisper/audio2feature.py CHANGED
@@ -15,8 +15,9 @@ class Audio2Feature:
15
  audio_embeds_cache_dir=None,
16
  num_frames=16,
17
  audio_feat_length=[2, 2],
 
18
  ):
19
- self.model = load_model(model_path, device)
20
  self.audio_embeds_cache_dir = audio_embeds_cache_dir
21
  if audio_embeds_cache_dir is not None and audio_embeds_cache_dir != "":
22
  Path(audio_embeds_cache_dir).mkdir(parents=True, exist_ok=True)
 
15
  audio_embeds_cache_dir=None,
16
  num_frames=16,
17
  audio_feat_length=[2, 2],
18
+ download_root=None,
19
  ):
20
+ self.model = load_model(model_path, device, download_root=download_root)
21
  self.audio_embeds_cache_dir = audio_embeds_cache_dir
22
  if audio_embeds_cache_dir is not None and audio_embeds_cache_dir != "":
23
  Path(audio_embeds_cache_dir).mkdir(parents=True, exist_ok=True)
lipsync.py CHANGED
@@ -8,11 +8,12 @@ import torch
8
  from DeepCache import DeepCacheSDHelper
9
  from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
10
  from shared.model_manager import ModelManager
 
11
 
12
  torch.backends.cudnn.benchmark = True
13
  torch.backends.cudnn.deterministic = False
14
 
15
- os.makedirs("checkpoints", exist_ok=True)
16
 
17
 
18
  def get_quality_params(level: str) -> tuple:
 
8
  from DeepCache import DeepCacheSDHelper
9
  from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
10
  from shared.model_manager import ModelManager
11
+ from config import MODELS_DIR
12
 
13
  torch.backends.cudnn.benchmark = True
14
  torch.backends.cudnn.deterministic = False
15
 
16
+ os.makedirs(MODELS_DIR, exist_ok=True)
17
 
18
 
19
  def get_quality_params(level: str) -> tuple:
preprocess/data_processing_pipeline.py CHANGED
@@ -14,6 +14,7 @@
14
 
15
  import argparse
16
  import os
 
17
  from preprocess.affine_transform import affine_transform_multi_gpus
18
  from preprocess.remove_broken_videos import remove_broken_videos_multiprocessing
19
  from preprocess.detect_shot import detect_shot_multiprocessing
@@ -25,14 +26,22 @@ from preprocess.filter_visual_quality import filter_visual_quality_multi_gpus
25
  from preprocess.remove_incorrect_affined import remove_incorrect_affined_multiprocessing
26
  from latentsync.utils.util import check_model_and_download
27
 
 
 
 
28
 
29
  def data_processing_pipeline(
30
- total_num_workers, per_gpu_num_workers, resolution, sync_conf_threshold, temp_dir, input_dir
 
 
 
 
 
31
  ):
32
  print("Checking models are downloaded...")
33
- check_model_and_download("checkpoints/auxiliary/syncnet_v2.model")
34
- check_model_and_download("checkpoints/auxiliary/sfd_face.pth")
35
- check_model_and_download("checkpoints/auxiliary/koniq_pretrained.pkl")
36
 
37
  print("Removing broken videos...")
38
  remove_broken_videos_multiprocessing(input_dir, total_num_workers)
@@ -55,19 +64,39 @@ def data_processing_pipeline(
55
  # filter_high_resolution_multiprocessing(segmented_dir, high_resolution_dir, resolution, total_num_workers)
56
 
57
  print("Affine transforming videos...")
58
- affine_transformed_dir = os.path.join(os.path.dirname(input_dir), "affine_transformed")
59
- affine_transform_multi_gpus(segmented_dir, affine_transformed_dir, temp_dir, resolution, per_gpu_num_workers // 2)
 
 
 
 
 
 
 
 
60
 
61
  # print("Removing incorrect affined videos...")
62
  # remove_incorrect_affined_multiprocessing(affine_transformed_dir, total_num_workers)
63
 
64
  print("Syncing audio and video...")
65
- av_synced_dir = os.path.join(os.path.dirname(input_dir), f"av_synced_{sync_conf_threshold}")
66
- sync_av_multi_gpus(affine_transformed_dir, av_synced_dir, temp_dir, per_gpu_num_workers, sync_conf_threshold)
 
 
 
 
 
 
 
 
67
 
68
  print("Filtering visual quality...")
69
- high_visual_quality_dir = os.path.join(os.path.dirname(input_dir), "high_visual_quality")
70
- filter_visual_quality_multi_gpus(av_synced_dir, high_visual_quality_dir, per_gpu_num_workers)
 
 
 
 
71
 
72
 
73
  if __name__ == "__main__":
 
14
 
15
  import argparse
16
  import os
17
+ import sys
18
  from preprocess.affine_transform import affine_transform_multi_gpus
19
  from preprocess.remove_broken_videos import remove_broken_videos_multiprocessing
20
  from preprocess.detect_shot import detect_shot_multiprocessing
 
26
  from preprocess.remove_incorrect_affined import remove_incorrect_affined_multiprocessing
27
  from latentsync.utils.util import check_model_and_download
28
 
29
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
30
+ from config import MODELS_DIR
31
+
32
 
33
  def data_processing_pipeline(
34
+ total_num_workers,
35
+ per_gpu_num_workers,
36
+ resolution,
37
+ sync_conf_threshold,
38
+ temp_dir,
39
+ input_dir,
40
  ):
41
  print("Checking models are downloaded...")
42
+ check_model_and_download(f"{MODELS_DIR}/auxiliary/syncnet_v2.model")
43
+ check_model_and_download(f"{MODELS_DIR}/auxiliary/sfd_face.pth")
44
+ check_model_and_download(f"{MODELS_DIR}/auxiliary/koniq_pretrained.pkl")
45
 
46
  print("Removing broken videos...")
47
  remove_broken_videos_multiprocessing(input_dir, total_num_workers)
 
64
  # filter_high_resolution_multiprocessing(segmented_dir, high_resolution_dir, resolution, total_num_workers)
65
 
66
  print("Affine transforming videos...")
67
+ affine_transformed_dir = os.path.join(
68
+ os.path.dirname(input_dir), "affine_transformed"
69
+ )
70
+ affine_transform_multi_gpus(
71
+ segmented_dir,
72
+ affine_transformed_dir,
73
+ temp_dir,
74
+ resolution,
75
+ per_gpu_num_workers // 2,
76
+ )
77
 
78
  # print("Removing incorrect affined videos...")
79
  # remove_incorrect_affined_multiprocessing(affine_transformed_dir, total_num_workers)
80
 
81
  print("Syncing audio and video...")
82
+ av_synced_dir = os.path.join(
83
+ os.path.dirname(input_dir), f"av_synced_{sync_conf_threshold}"
84
+ )
85
+ sync_av_multi_gpus(
86
+ affine_transformed_dir,
87
+ av_synced_dir,
88
+ temp_dir,
89
+ per_gpu_num_workers,
90
+ sync_conf_threshold,
91
+ )
92
 
93
  print("Filtering visual quality...")
94
+ high_visual_quality_dir = os.path.join(
95
+ os.path.dirname(input_dir), "high_visual_quality"
96
+ )
97
+ filter_visual_quality_multi_gpus(
98
+ av_synced_dir, high_visual_quality_dir, per_gpu_num_workers
99
+ )
100
 
101
 
102
  if __name__ == "__main__":
preprocess/filter_visual_quality.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
 
15
  import os
 
16
  import tqdm
17
  import torch
18
  import torchvision
@@ -23,6 +24,9 @@ from decord import VideoReader
23
  from einops import rearrange
24
  from eval.hyper_iqa import HyperNet, TargetNet
25
 
 
 
 
26
 
27
  paths = []
28
 
@@ -38,7 +42,9 @@ def gather_paths(input_dir, output_dir):
38
  continue
39
  paths.append((video_input, video_output))
40
  elif os.path.isdir(os.path.join(input_dir, video)):
41
- gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
 
 
42
 
43
 
44
  def read_video(video_path: str):
@@ -61,13 +67,17 @@ def func(paths, device_id):
61
 
62
  # load the pre-trained model on the koniq-10k dataset
63
  model_hyper.load_state_dict(
64
- (torch.load("checkpoints/auxiliary/koniq_pretrained.pkl", map_location=device, weights_only=True))
 
 
65
  )
66
 
67
  transforms = torchvision.transforms.Compose(
68
  [
69
  torchvision.transforms.CenterCrop(size=224),
70
- torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
 
 
71
  ]
72
  )
73
 
@@ -76,7 +86,9 @@ def func(paths, device_id):
76
  video_frames = read_video(video_input)
77
  video_frames = transforms(video_frames)
78
  video_frames = video_frames.clone().detach().to(device)
79
- paras = model_hyper(video_frames) # 'paras' contains the network weights conveyed to target network
 
 
80
 
81
  # Building target network
82
  model_target = TargetNet(paras).to(device)
@@ -84,11 +96,15 @@ def func(paths, device_id):
84
  param.requires_grad = False
85
 
86
  # Quality prediction
87
- pred = model_target(paras["target_in_vec"]) # 'paras['target_in_vec']' is the input to target net
 
 
88
 
89
  # quality score ranges from 0-100, a higher score indicates a better quality
90
  quality_score = pred.mean().item()
91
- print(f"Input video: {video_input}\nVisual quality score: {quality_score:.2f}")
 
 
92
 
93
  if quality_score >= 40:
94
  os.makedirs(os.path.dirname(video_output), exist_ok=True)
 
13
  # limitations under the License.
14
 
15
  import os
16
+ import sys
17
  import tqdm
18
  import torch
19
  import torchvision
 
24
  from einops import rearrange
25
  from eval.hyper_iqa import HyperNet, TargetNet
26
 
27
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
28
+ from config import MODELS_DIR
29
+
30
 
31
  paths = []
32
 
 
42
  continue
43
  paths.append((video_input, video_output))
44
  elif os.path.isdir(os.path.join(input_dir, video)):
45
+ gather_paths(
46
+ os.path.join(input_dir, video), os.path.join(output_dir, video)
47
+ )
48
 
49
 
50
  def read_video(video_path: str):
 
67
 
68
  # load the pre-trained model on the koniq-10k dataset
69
  model_hyper.load_state_dict(
70
+ (torch.load(f"{MODELS_DIR}/auxiliary/koniq_pretrained.pkl", map_location=device, weights_only=True))
71
+ )
72
+ )
73
  )
74
 
75
  transforms = torchvision.transforms.Compose(
76
  [
77
  torchvision.transforms.CenterCrop(size=224),
78
+ torchvision.transforms.Normalize(
79
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
80
+ ),
81
  ]
82
  )
83
 
 
86
  video_frames = read_video(video_input)
87
  video_frames = transforms(video_frames)
88
  video_frames = video_frames.clone().detach().to(device)
89
+ paras = model_hyper(
90
+ video_frames
91
+ ) # 'paras' contains the network weights conveyed to target network
92
 
93
  # Building target network
94
  model_target = TargetNet(paras).to(device)
 
96
  param.requires_grad = False
97
 
98
  # Quality prediction
99
+ pred = model_target(
100
+ paras["target_in_vec"]
101
+ ) # 'paras['target_in_vec']' is the input to target net
102
 
103
  # quality score ranges from 0-100, a higher score indicates a better quality
104
  quality_score = pred.mean().item()
105
+ print(
106
+ f"Input video: {video_input}\nVisual quality score: {quality_score:.2f}"
107
+ )
108
 
109
  if quality_score >= 40:
110
  os.makedirs(os.path.dirname(video_output), exist_ok=True)
preprocess/sync_av.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
 
15
  import os
 
16
  import tqdm
17
  from eval.syncnet import SyncNetEval
18
  from eval.syncnet_detect import SyncNetDetector
@@ -22,6 +23,10 @@ import subprocess
22
  import shutil
23
  from multiprocessing import Process
24
 
 
 
 
 
25
  paths = []
26
 
27
 
@@ -36,11 +41,13 @@ def gather_paths(input_dir, output_dir):
36
  continue
37
  paths.append((video_input, video_output))
38
  elif os.path.isdir(os.path.join(input_dir, video)):
39
- gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
 
 
40
 
41
 
42
  def adjust_offset(video_input: str, video_output: str, av_offset: int, fps: int = 25):
43
- command = f"ffmpeg -loglevel error -y -i {video_input} -itsoffset {av_offset/fps} -i {video_input} -map 0:v -map 1:a -c copy -q:v 0 -q:a 0 {video_output}"
44
  subprocess.run(command, shell=True)
45
 
46
 
@@ -49,17 +56,23 @@ def func(sync_conf_threshold, paths, device_id, process_temp_dir):
49
  device = f"cuda:{device_id}"
50
 
51
  syncnet = SyncNetEval(device=device)
52
- syncnet.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
53
 
54
  detect_results_dir = os.path.join(process_temp_dir, "detect_results")
55
  syncnet_eval_results_dir = os.path.join(process_temp_dir, "syncnet_eval_results")
56
 
57
- syncnet_detector = SyncNetDetector(device=device, detect_results_dir=detect_results_dir)
 
 
58
 
59
  for video_input, video_output in paths:
60
  try:
61
  av_offset, conf = syncnet_eval(
62
- syncnet, syncnet_detector, video_input, syncnet_eval_results_dir, detect_results_dir
 
 
 
 
63
  )
64
 
65
  if conf >= sync_conf_threshold and abs(av_offset) <= 6:
@@ -77,7 +90,9 @@ def split(a, n):
77
  return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
78
 
79
 
80
- def sync_av_multi_gpus(input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold):
 
 
81
  gather_paths(input_dir, output_dir)
82
  num_devices = torch.cuda.device_count()
83
  if num_devices == 0:
@@ -111,4 +126,6 @@ if __name__ == "__main__":
111
  num_workers = 20 # How many processes per device
112
  sync_conf_threshold = 3
113
 
114
- sync_av_multi_gpus(input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold)
 
 
 
13
  # limitations under the License.
14
 
15
  import os
16
+ import sys
17
  import tqdm
18
  from eval.syncnet import SyncNetEval
19
  from eval.syncnet_detect import SyncNetDetector
 
23
  import shutil
24
  from multiprocessing import Process
25
 
26
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
+ from config import MODELS_DIR
28
+
29
+
30
  paths = []
31
 
32
 
 
41
  continue
42
  paths.append((video_input, video_output))
43
  elif os.path.isdir(os.path.join(input_dir, video)):
44
+ gather_paths(
45
+ os.path.join(input_dir, video), os.path.join(output_dir, video)
46
+ )
47
 
48
 
49
  def adjust_offset(video_input: str, video_output: str, av_offset: int, fps: int = 25):
50
+ command = f"ffmpeg -loglevel error -y -i {video_input} -itsoffset {av_offset / fps} -i {video_input} -map 0:v -map 1:a -c copy -q:v 0 -q:a 0 {video_output}"
51
  subprocess.run(command, shell=True)
52
 
53
 
 
56
  device = f"cuda:{device_id}"
57
 
58
  syncnet = SyncNetEval(device=device)
59
+ syncnet.loadParameters(f"{MODELS_DIR}/auxiliary/syncnet_v2.model")
60
 
61
  detect_results_dir = os.path.join(process_temp_dir, "detect_results")
62
  syncnet_eval_results_dir = os.path.join(process_temp_dir, "syncnet_eval_results")
63
 
64
+ syncnet_detector = SyncNetDetector(
65
+ device=device, detect_results_dir=detect_results_dir
66
+ )
67
 
68
  for video_input, video_output in paths:
69
  try:
70
  av_offset, conf = syncnet_eval(
71
+ syncnet,
72
+ syncnet_detector,
73
+ video_input,
74
+ syncnet_eval_results_dir,
75
+ detect_results_dir,
76
  )
77
 
78
  if conf >= sync_conf_threshold and abs(av_offset) <= 6:
 
90
  return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
91
 
92
 
93
+ def sync_av_multi_gpus(
94
+ input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold
95
+ ):
96
  gather_paths(input_dir, output_dir)
97
  num_devices = torch.cuda.device_count()
98
  if num_devices == 0:
 
126
  num_workers = 20 # How many processes per device
127
  sync_conf_threshold = 3
128
 
129
+ sync_av_multi_gpus(
130
+ input_dir, output_dir, temp_dir, num_workers, sync_conf_threshold
131
+ )
scripts/inference.py CHANGED
@@ -14,10 +14,14 @@
14
 
15
  import argparse
16
  import os
 
17
  from omegaconf import OmegaConf
18
  import torch
19
  from diffusers import AutoencoderKL, DDIMScheduler
20
  from latentsync.models.unet import UNet3DConditionModel
 
 
 
21
  from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
22
  from accelerate.utils import set_seed
23
  from latentsync.whisper.audio2feature import Audio2Feature
@@ -56,7 +60,9 @@ def main(config, args):
56
  audio_feat_length=config.data.audio_feat_length,
57
  )
58
 
59
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
 
 
60
  vae.config.scaling_factor = 0.18215
61
  vae.config.shift_factor = 0
62
 
 
14
 
15
  import argparse
16
  import os
17
+ import sys
18
  from omegaconf import OmegaConf
19
  import torch
20
  from diffusers import AutoencoderKL, DDIMScheduler
21
  from latentsync.models.unet import UNet3DConditionModel
22
+
23
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+ from config import MODELS_DIR
25
  from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
26
  from accelerate.utils import set_seed
27
  from latentsync.whisper.audio2feature import Audio2Feature
 
60
  audio_feat_length=config.data.audio_feat_length,
61
  )
62
 
63
+ vae = AutoencoderKL.from_pretrained(
64
+ "stabilityai/sd-vae-ft-mse", torch_dtype=dtype, cache_dir=MODELS_DIR
65
+ )
66
  vae.config.scaling_factor = 0.18215
67
  vae.config.shift_factor = 0
68
 
scripts/train_syncnet.py CHANGED
@@ -13,11 +13,18 @@
13
  # limitations under the License.
14
 
15
  from tqdm.auto import tqdm
16
- import os, argparse, datetime, math
 
 
 
 
17
  import logging
18
  from omegaconf import OmegaConf
19
  import shutil
20
 
 
 
 
21
  from latentsync.data.syncnet_dataset import SyncNetDataset
22
  from latentsync.models.stable_syncnet import StableSyncNet
23
  from latentsync.models.wav2lip_syncnet import Wav2LipSyncNet
@@ -67,15 +74,21 @@ def main(config):
67
  device = torch.device(local_rank)
68
 
69
  if config.data.latent_space:
70
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
 
 
71
  vae.requires_grad_(False)
72
  vae.to(device)
73
  else:
74
  vae = None
75
 
76
  # Dataset and Dataloader setup
77
- train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
78
- val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
 
 
 
 
79
 
80
  train_distributed_sampler = DistributedSampler(
81
  train_dataset,
@@ -118,7 +131,8 @@ def main(config):
118
  # syncnet = Wav2LipSyncNet().to(device)
119
 
120
  optimizer = torch.optim.AdamW(
121
- list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
 
122
  )
123
 
124
  global_step = 0
@@ -130,7 +144,9 @@ def main(config):
130
  if config.ckpt.resume_ckpt_path != "":
131
  if is_main_process:
132
  logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
133
- ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device, weights_only=True)
 
 
134
 
135
  syncnet.load_state_dict(ckpt["state_dict"])
136
 
@@ -145,7 +161,9 @@ def main(config):
145
  syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
146
 
147
  num_update_steps_per_epoch = math.ceil(len(train_dataloader))
148
- num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
 
 
149
 
150
  if is_main_process:
151
  logger.info("***** Running training *****")
@@ -158,15 +176,22 @@ def main(config):
158
  logger.info(f" Total optimization steps = {config.run.max_train_steps}")
159
 
160
  first_epoch = global_step // num_update_steps_per_epoch
161
- num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
 
 
162
 
163
  # Only show the progress bar once on each machine.
164
  progress_bar = tqdm(
165
- range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
 
 
 
166
  )
167
 
168
  # Support mixed-precision training
169
- scaler = torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
 
 
170
 
171
  for epoch in range(first_epoch, num_train_epochs):
172
  train_dataloader.sampler.set_epoch(epoch)
@@ -186,15 +211,17 @@ def main(config):
186
  num_samples_limit // config.data.num_frames
187
  ) # due to the limited cuda memory, we split the input frames into parts
188
  if frames.shape[0] > max_batch_size:
189
- assert (
190
- frames.shape[0] % max_batch_size == 0
191
- ), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
192
  frames_part_results = []
193
  for i in range(0, frames.shape[0], max_batch_size):
194
  frames_part = frames[i : i + max_batch_size]
195
  frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
196
  with torch.no_grad():
197
- frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
 
 
198
  frames_part_results.append(frames_part)
199
  frames = torch.cat(frames_part_results, dim=0)
200
  else:
@@ -202,7 +229,9 @@ def main(config):
202
  with torch.no_grad():
203
  frames = vae.encode(frames).latent_dist.sample() * 0.18215
204
 
205
- frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
 
 
206
  else:
207
  frames = rearrange(frames, "b f c h w -> b (f c) h w")
208
 
@@ -211,14 +240,22 @@ def main(config):
211
  frames = frames[:, :, height // 2 :, :]
212
 
213
  # Disable gradient sync for the first N-1 steps, enable sync on the final step
214
- with syncnet.no_sync() if (index + 1) % config.data.gradient_accumulation_steps != 0 else dummy_context():
 
 
 
 
215
  # Mixed-precision training
216
  with torch.autocast(
217
- device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training
 
 
218
  ):
219
  vision_embeds, audio_embeds = syncnet(frames, audio_samples)
220
 
221
- loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
 
 
222
  loss = loss / config.data.gradient_accumulation_steps
223
 
224
  # Backpropagate
@@ -230,7 +267,9 @@ def main(config):
230
  if (index + 1) % config.data.gradient_accumulation_steps == 0:
231
  """>>> gradient clipping >>>"""
232
  scaler.unscale_(optimizer)
233
- torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
 
 
234
  """ <<< gradient clipping <<< """
235
  scaler.step(optimizer)
236
  scaler.update()
@@ -255,15 +294,21 @@ def main(config):
255
  )
256
  val_step_list.append(global_step)
257
  val_loss_list.append(val_loss)
258
- logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
 
 
259
  plot_loss_chart(
260
- os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
 
 
261
  ("Train loss", train_step_list, train_loss_list),
262
  ("Val loss", val_step_list, val_loss_list),
263
  )
264
 
265
  if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
266
- checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
 
 
267
  torch.save(
268
  {
269
  "state_dict": syncnet.module.state_dict(), # to unwrap DDP
@@ -288,7 +333,9 @@ def main(config):
288
 
289
 
290
  @torch.no_grad()
291
- def validation(val_dataloader, device, syncnet, latent_space, lower_half, vae, num_val_batches):
 
 
292
  syncnet.eval()
293
 
294
  losses = []
@@ -330,7 +377,9 @@ def validation(val_dataloader, device, syncnet, latent_space, lower_half, vae, n
330
 
331
  if __name__ == "__main__":
332
  parser = argparse.ArgumentParser(description="Code to train the SyncNet")
333
- parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_pixel.yaml")
 
 
334
  args = parser.parse_args()
335
 
336
  # Load a configuration file
 
13
  # limitations under the License.
14
 
15
  from tqdm.auto import tqdm
16
+ import os
17
+ import sys
18
+ import argparse
19
+ import datetime
20
+ import math
21
  import logging
22
  from omegaconf import OmegaConf
23
  import shutil
24
 
25
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
26
+ from config import MODELS_DIR
27
+
28
  from latentsync.data.syncnet_dataset import SyncNetDataset
29
  from latentsync.models.stable_syncnet import StableSyncNet
30
  from latentsync.models.wav2lip_syncnet import Wav2LipSyncNet
 
74
  device = torch.device(local_rank)
75
 
76
  if config.data.latent_space:
77
+ vae = AutoencoderKL.from_pretrained(
78
+ "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR
79
+ )
80
  vae.requires_grad_(False)
81
  vae.to(device)
82
  else:
83
  vae = None
84
 
85
  # Dataset and Dataloader setup
86
+ train_dataset = SyncNetDataset(
87
+ config.data.train_data_dir, config.data.train_fileslist, config
88
+ )
89
+ val_dataset = SyncNetDataset(
90
+ config.data.val_data_dir, config.data.val_fileslist, config
91
+ )
92
 
93
  train_distributed_sampler = DistributedSampler(
94
  train_dataset,
 
131
  # syncnet = Wav2LipSyncNet().to(device)
132
 
133
  optimizer = torch.optim.AdamW(
134
+ list(filter(lambda p: p.requires_grad, syncnet.parameters())),
135
+ lr=config.optimizer.lr,
136
  )
137
 
138
  global_step = 0
 
144
  if config.ckpt.resume_ckpt_path != "":
145
  if is_main_process:
146
  logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
147
+ ckpt = torch.load(
148
+ config.ckpt.resume_ckpt_path, map_location=device, weights_only=True
149
+ )
150
 
151
  syncnet.load_state_dict(ckpt["state_dict"])
152
 
 
161
  syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
162
 
163
  num_update_steps_per_epoch = math.ceil(len(train_dataloader))
164
+ num_train_epochs = math.ceil(
165
+ config.run.max_train_steps / num_update_steps_per_epoch
166
+ )
167
 
168
  if is_main_process:
169
  logger.info("***** Running training *****")
 
176
  logger.info(f" Total optimization steps = {config.run.max_train_steps}")
177
 
178
  first_epoch = global_step // num_update_steps_per_epoch
179
+ num_val_batches = config.data.num_val_samples // (
180
+ num_processes * config.data.batch_size
181
+ )
182
 
183
  # Only show the progress bar once on each machine.
184
  progress_bar = tqdm(
185
+ range(0, config.run.max_train_steps),
186
+ initial=global_step,
187
+ desc="Steps",
188
+ disable=not is_main_process,
189
  )
190
 
191
  # Support mixed-precision training
192
+ scaler = (
193
+ torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
194
+ )
195
 
196
  for epoch in range(first_epoch, num_train_epochs):
197
  train_dataloader.sampler.set_epoch(epoch)
 
211
  num_samples_limit // config.data.num_frames
212
  ) # due to the limited cuda memory, we split the input frames into parts
213
  if frames.shape[0] > max_batch_size:
214
+ assert frames.shape[0] % max_batch_size == 0, (
215
+ f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
216
+ )
217
  frames_part_results = []
218
  for i in range(0, frames.shape[0], max_batch_size):
219
  frames_part = frames[i : i + max_batch_size]
220
  frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
221
  with torch.no_grad():
222
+ frames_part = (
223
+ vae.encode(frames_part).latent_dist.sample() * 0.18215
224
+ )
225
  frames_part_results.append(frames_part)
226
  frames = torch.cat(frames_part_results, dim=0)
227
  else:
 
229
  with torch.no_grad():
230
  frames = vae.encode(frames).latent_dist.sample() * 0.18215
231
 
232
+ frames = rearrange(
233
+ frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames
234
+ )
235
  else:
236
  frames = rearrange(frames, "b f c h w -> b (f c) h w")
237
 
 
240
  frames = frames[:, :, height // 2 :, :]
241
 
242
  # Disable gradient sync for the first N-1 steps, enable sync on the final step
243
+ with (
244
+ syncnet.no_sync()
245
+ if (index + 1) % config.data.gradient_accumulation_steps != 0
246
+ else dummy_context()
247
+ ):
248
  # Mixed-precision training
249
  with torch.autocast(
250
+ device_type="cuda",
251
+ dtype=torch.float16,
252
+ enabled=config.run.mixed_precision_training,
253
  ):
254
  vision_embeds, audio_embeds = syncnet(frames, audio_samples)
255
 
256
+ loss = cosine_loss(
257
+ vision_embeds.float(), audio_embeds.float(), y
258
+ ).mean()
259
  loss = loss / config.data.gradient_accumulation_steps
260
 
261
  # Backpropagate
 
267
  if (index + 1) % config.data.gradient_accumulation_steps == 0:
268
  """>>> gradient clipping >>>"""
269
  scaler.unscale_(optimizer)
270
+ torch.nn.utils.clip_grad_norm_(
271
+ syncnet.parameters(), config.optimizer.max_grad_norm
272
+ )
273
  """ <<< gradient clipping <<< """
274
  scaler.step(optimizer)
275
  scaler.update()
 
294
  )
295
  val_step_list.append(global_step)
296
  val_loss_list.append(val_loss)
297
+ logger.info(
298
+ f"Validation loss at step {global_step} is {val_loss:0.3f}"
299
+ )
300
  plot_loss_chart(
301
+ os.path.join(
302
+ output_dir, f"loss_charts/loss_chart-{global_step}.png"
303
+ ),
304
  ("Train loss", train_step_list, train_loss_list),
305
  ("Val loss", val_step_list, val_loss_list),
306
  )
307
 
308
  if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
309
+ checkpoint_save_path = os.path.join(
310
+ output_dir, f"checkpoints/checkpoint-{global_step}.pt"
311
+ )
312
  torch.save(
313
  {
314
  "state_dict": syncnet.module.state_dict(), # to unwrap DDP
 
333
 
334
 
335
  @torch.no_grad()
336
+ def validation(
337
+ val_dataloader, device, syncnet, latent_space, lower_half, vae, num_val_batches
338
+ ):
339
  syncnet.eval()
340
 
341
  losses = []
 
377
 
378
  if __name__ == "__main__":
379
  parser = argparse.ArgumentParser(description="Code to train the SyncNet")
380
+ parser.add_argument(
381
+ "--config_path", type=str, default="configs/syncnet/syncnet_16_pixel.yaml"
382
+ )
383
  args = parser.parse_args()
384
 
385
  # Load a configuration file
scripts/train_unet.py CHANGED
@@ -13,11 +13,15 @@
13
  # limitations under the License.
14
 
15
  import os
 
16
  import math
17
  import argparse
18
  import shutil
19
  import datetime
20
  import logging
 
 
 
21
  from omegaconf import OmegaConf
22
 
23
  from tqdm.auto import tqdm
@@ -93,7 +97,7 @@ def main(config):
93
  noise_scheduler = DDIMScheduler.from_pretrained("configs")
94
 
95
  vae = AutoencoderKL.from_pretrained(
96
- "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
97
  )
98
  vae.config.scaling_factor = 0.18215
99
  vae.config.shift_factor = 0
 
13
  # limitations under the License.
14
 
15
  import os
16
+ import sys
17
  import math
18
  import argparse
19
  import shutil
20
  import datetime
21
  import logging
22
+
23
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+ from config import MODELS_DIR
25
  from omegaconf import OmegaConf
26
 
27
  from tqdm.auto import tqdm
 
97
  noise_scheduler = DDIMScheduler.from_pretrained("configs")
98
 
99
  vae = AutoencoderKL.from_pretrained(
100
+ "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR
101
  )
102
  vae.config.scaling_factor = 0.18215
103
  vae.config.shift_factor = 0
shared/face_detection/detector.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import numpy as np
2
  import torch
3
 
 
 
 
4
  INSIGHTFACE_DETECT_SIZE = 512
5
 
6
 
@@ -56,7 +61,7 @@ class FaceDetector:
56
 
57
  self.app = FaceAnalysis(
58
  allowed_modules=["detection", "landmark_2d_106"],
59
- root="checkpoints/auxiliary",
60
  providers=["CUDAExecutionProvider"],
61
  )
62
  self.app.prepare(
 
1
+ import os
2
+ import sys
3
  import numpy as np
4
  import torch
5
 
6
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ from config import MODELS_DIR
8
+
9
  INSIGHTFACE_DETECT_SIZE = 512
10
 
11
 
 
61
 
62
  self.app = FaceAnalysis(
63
  allowed_modules=["detection", "landmark_2d_106"],
64
+ root=f"{MODELS_DIR}/auxiliary",
65
  providers=["CUDAExecutionProvider"],
66
  )
67
  self.app.prepare(
shared/model_manager.py CHANGED
@@ -39,10 +39,14 @@ class ModelManager:
39
  """Load Whisper audio encoder (lazy loaded)"""
40
  if self._whisper_encoder is None:
41
  from latentsync.whisper.audio2feature import Audio2Feature
 
42
 
43
  logger.info(f"Loading Whisper encoder from {model_path}...")
44
  self._whisper_encoder = Audio2Feature(
45
- model_path=model_path, device=device, num_frames=num_frames
 
 
 
46
  )
47
  logger.info("Whisper encoder loaded")
48
  return self._whisper_encoder
@@ -53,8 +57,12 @@ class ModelManager:
53
  from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
54
 
55
  logger.info("Loading VAE...")
 
 
56
  vae = AutoencoderKL.from_pretrained(
57
- "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
 
 
58
  )
59
  vae.config.scaling_factor = 0.18215
60
  vae.config.shift_factor = 0
@@ -74,20 +82,23 @@ class ModelManager:
74
  """Load LatentSync UNet (lazy loaded)"""
75
  if self._latentsync_unet is None:
76
  from latentsync.models.unet import UNet3DConditionModel
 
77
 
78
- logger.info("Downloading LatentSync-1.6 models...")
79
- os.makedirs("checkpoints", exist_ok=True)
80
- snapshot_download(
81
- repo_id="ByteDance/LatentSync-1.6", local_dir="./checkpoints"
82
- )
 
 
 
83
 
84
  logger.info("Loading LatentSync UNet...")
85
  config = self.get_latentsync_config()
86
 
87
- inference_ckpt_path = "checkpoints/latentsync_unet.pt"
88
  unet, _ = UNet3DConditionModel.from_pretrained(
89
  OmegaConf.to_container(config.model),
90
- inference_ckpt_path,
91
  device="cpu",
92
  )
93
  unet = unet.to(dtype=torch.float16).to(device)
 
39
  """Load Whisper audio encoder (lazy loaded)"""
40
  if self._whisper_encoder is None:
41
  from latentsync.whisper.audio2feature import Audio2Feature
42
+ from config import MODELS_DIR
43
 
44
  logger.info(f"Loading Whisper encoder from {model_path}...")
45
  self._whisper_encoder = Audio2Feature(
46
+ model_path=model_path,
47
+ device=device,
48
+ num_frames=num_frames,
49
+ download_root=f"{MODELS_DIR}/whisper",
50
  )
51
  logger.info("Whisper encoder loaded")
52
  return self._whisper_encoder
 
57
  from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
58
 
59
  logger.info("Loading VAE...")
60
+ from config import MODELS_DIR
61
+
62
  vae = AutoencoderKL.from_pretrained(
63
+ "stabilityai/sd-vae-ft-mse",
64
+ torch_dtype=torch.float16,
65
+ cache_dir=MODELS_DIR,
66
  )
67
  vae.config.scaling_factor = 0.18215
68
  vae.config.shift_factor = 0
 
82
  """Load LatentSync UNet (lazy loaded)"""
83
  if self._latentsync_unet is None:
84
  from latentsync.models.unet import UNet3DConditionModel
85
+ from config import MODELS_DIR
86
 
87
+ unet_path = f"{MODELS_DIR}/latentsync_unet.pt"
88
+
89
+ if not os.path.exists(unet_path):
90
+ logger.info("Downloading LatentSync-1.6 models...")
91
+ os.makedirs(MODELS_DIR, exist_ok=True)
92
+ snapshot_download(
93
+ repo_id="ByteDance/LatentSync-1.6", local_dir=MODELS_DIR
94
+ )
95
 
96
  logger.info("Loading LatentSync UNet...")
97
  config = self.get_latentsync_config()
98
 
 
99
  unet, _ = UNet3DConditionModel.from_pretrained(
100
  OmegaConf.to_container(config.model),
101
+ unet_path,
102
  device="cpu",
103
  )
104
  unet = unet.to(dtype=torch.float16).to(device)
shared/vae/loader.py CHANGED
@@ -1,8 +1,13 @@
1
  """VAE loader (placeholder - actual loading handled by ModelManager)"""
2
 
 
 
3
  import torch
4
  from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
5
 
 
 
 
6
 
7
  def load_vae(device: str = "cuda"):
8
  """Load VAE from HuggingFace
@@ -14,7 +19,7 @@ def load_vae(device: str = "cuda"):
14
  VAE model
15
  """
16
  vae = AutoencoderKL.from_pretrained(
17
- "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
18
  )
19
  vae.config.scaling_factor = 0.18215
20
  vae.config.shift_factor = 0
 
1
  """VAE loader (placeholder - actual loading handled by ModelManager)"""
2
 
3
+ import os
4
+ import sys
5
  import torch
6
  from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
7
 
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ from config import MODELS_DIR
10
+
11
 
12
  def load_vae(device: str = "cuda"):
13
  """Load VAE from HuggingFace
 
19
  VAE model
20
  """
21
  vae = AutoencoderKL.from_pretrained(
22
+ "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR
23
  )
24
  vae.config.scaling_factor = 0.18215
25
  vae.config.shift_factor = 0