gonefishin1 commited on
Commit
03c94e3
·
verified ·
1 Parent(s): 433b0f3

fix: resolve NameError, align model paths to models/, remove @spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +164 -175
app.py CHANGED
@@ -1,294 +1,283 @@
1
  import os
2
  import time
3
- import pdb
4
-
5
- import gradio as gr
6
- import spaces
7
- import numpy as np
8
  import sys
9
  import subprocess
 
 
 
 
10
 
11
- from huggingface_hub import snapshot_download
12
- import requests
13
-
14
- import argparse
15
- import os
16
- from omegaconf import OmegaConf
17
  import numpy as np
18
  import cv2
19
  import torch
20
- import glob
21
- import pickle
22
  from tqdm import tqdm
23
- import copy
24
  from argparse import Namespace
25
- import shutil
26
- import gdown
 
 
 
27
 
28
 
29
  def download_model():
30
- if not os.path.exists(CheckpointsDir):
31
- os.makedirs(CheckpointsDir)
32
- print("Checkpoint Not Downloaded, start downloading...")
33
- tic = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  snapshot_download(
35
  repo_id="TMElyralab/MuseTalk",
36
- local_dir=CheckpointsDir,
37
  max_workers=8,
38
  local_dir_use_symlinks=True,
 
39
  )
40
- # weight
 
 
 
41
  snapshot_download(
42
  repo_id="stabilityai/sd-vae-ft-mse",
43
- local_dir=CheckpointsDir,
44
  max_workers=8,
45
  local_dir_use_symlinks=True,
 
46
  )
47
- #dwpose
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  snapshot_download(
49
  repo_id="yzd-v/DWPose",
50
- local_dir=CheckpointsDir,
51
  max_workers=8,
52
  local_dir_use_symlinks=True,
 
53
  )
54
- #vae
55
- url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
56
- response = requests.get(url)
57
- # 确保请求成功
58
- if response.status_code == 200:
59
- # 指定文件保存的位置
60
- file_path = f"{CheckpointsDir}/whisper/tiny.pt"
61
- os.makedirs(f"{CheckpointsDir}/whisper/")
62
- # 将文件内容写入指定位置
63
- with open(file_path, "wb") as f:
64
- f.write(response.content)
65
- else:
66
- print(f"请求失败,状态码:{response.status_code}")
67
- #gdown face parse
68
- url = "https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812"
69
- os.makedirs(f"{CheckpointsDir}/face-parse-bisent/")
70
- file_path = f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth"
71
- gdown.download(url, output, quiet=False)
72
- #resnet
73
- url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
74
- response = requests.get(url)
75
- # 确保请求成功
76
- if response.status_code == 200:
77
- # 指定文件保存的位置
78
- file_path = f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
79
- # 将文件内容写入指定位置
80
- with open(file_path, "wb") as f:
81
- f.write(response.content)
82
- else:
83
- print(f"请求失败,状态码:{response.status_code}")
84
-
85
-
86
- toc = time.time()
87
-
88
- print(f"download cost {toc-tic} seconds")
89
- else:
90
- print("Already download the model.")
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
93
 
94
- download_model() # for huggingface deployment.
95
 
 
96
 
97
- from musetalk.utils.utils import get_file_type,get_video_fps,datagen
98
- from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
99
  from musetalk.utils.blending import get_image
100
  from musetalk.utils.utils import load_all_model
101
 
102
 
103
-
104
- ProjectDir = os.path.abspath(os.path.dirname(__file__))
105
- CheckpointsDir = os.path.join(ProjectDir, "checkpoints")
106
 
107
 
108
- @spaces.GPU(duration=600)
109
  @torch.no_grad()
110
- def inference(audio_path,video_path,bbox_shift,progress=gr.Progress(track_tqdm=True)):
111
- args_dict={"result_dir":'./results', "fps":25, "batch_size":8, "output_vid_name":'', "use_saved_coord":False}#same with inferenece script
 
 
 
 
 
 
112
  args = Namespace(**args_dict)
113
 
114
- input_basename = os.path.basename(video_path).split('.')[0]
115
- audio_basename = os.path.basename(audio_path).split('.')[0]
116
  output_basename = f"{input_basename}_{audio_basename}"
117
- result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
118
- crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
119
- os.makedirs(result_img_save_path,exist_ok =True)
120
 
121
- if args.output_vid_name=="":
122
- output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
123
  else:
124
  output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
125
- ############################################## extract frames from source video ##############################################
126
- if get_file_type(video_path)=="video":
127
  save_dir_full = os.path.join(args.result_dir, input_basename)
128
- os.makedirs(save_dir_full,exist_ok = True)
129
  cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
130
  os.system(cmd)
131
- input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
132
  fps = get_video_fps(video_path)
133
- else: # input img folder
134
- input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
135
  input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
136
  fps = args.fps
137
- #print(input_img_list)
138
- ############################################## extract audio feature ##############################################
139
  whisper_feature = audio_processor.audio2feat(audio_path)
140
- whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
141
- ############################################## preprocess input image ##############################################
142
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
143
- print("using extracted coordinates")
144
- with open(crop_coord_save_path,'rb') as f:
145
  coord_list = pickle.load(f)
146
  frame_list = read_imgs(input_img_list)
147
  else:
148
- print("extracting landmarks...time consuming")
149
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
150
- with open(crop_coord_save_path, 'wb') as f:
151
  pickle.dump(coord_list, f)
152
-
153
- i = 0
154
  input_latent_list = []
155
  for bbox, frame in zip(coord_list, frame_list):
156
  if bbox == coord_placeholder:
157
  continue
158
  x1, y1, x2, y2 = bbox
159
  crop_frame = frame[y1:y2, x1:x2]
160
- crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
161
  latents = vae.get_latents_for_unet(crop_frame)
162
  input_latent_list.append(latents)
163
 
164
- # to smooth the first and the last frame
165
  frame_list_cycle = frame_list + frame_list[::-1]
166
  coord_list_cycle = coord_list + coord_list[::-1]
167
  input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
168
- ############################################## inference batch by batch ##############################################
169
- print("start inference")
170
  video_num = len(whisper_chunks)
171
  batch_size = args.batch_size
172
- gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
173
  res_frame_list = []
174
- for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
175
-
 
 
176
  tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
177
- audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
178
  audio_feature_batch = pe(audio_feature_batch)
179
-
180
- pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
 
181
  recon = vae.decode_latents(pred_latents)
182
  for res_frame in recon:
183
  res_frame_list.append(res_frame)
184
-
185
- ############################################## pad to full image ##############################################
186
- print("pad talking image to original video")
187
  for i, res_frame in enumerate(tqdm(res_frame_list)):
188
- bbox = coord_list_cycle[i%(len(coord_list_cycle))]
189
- ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
190
  x1, y1, x2, y2 = bbox
191
  try:
192
- res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
193
- except:
194
- # print(bbox)
195
  continue
196
-
197
- combine_frame = get_image(ori_frame,res_frame,bbox)
198
- cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
199
-
200
- cmd_img2video = f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
201
- print(cmd_img2video)
 
 
202
  os.system(cmd_img2video)
203
 
204
  cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
205
- print(cmd_combine_audio)
206
  os.system(cmd_combine_audio)
207
 
208
- os.remove("temp.mp4")
209
- shutil.rmtree(result_img_save_path)
210
- print(f"result is save to {output_vid_name}")
211
- return output_vid_name
212
-
213
-
214
-
215
- # load model weights
216
- audio_processor,vae,unet,pe = load_all_model()
217
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
218
- timesteps = torch.tensor([0], device=device)
219
-
220
 
 
 
221
 
222
 
223
  def check_video(video):
224
- # Define the output video file name
225
  dir_path, file_name = os.path.split(video)
 
226
  if file_name.startswith("outputxxx_"):
227
  return video
228
- # Add the output prefix to the file name
229
- output_file_name = "outputxxx_" + file_name
230
 
231
- # Combine the directory path and the new file name
232
  output_video = os.path.join(dir_path, output_file_name)
233
-
234
-
235
- # Run the ffmpeg command to change the frame rate to 25fps
236
- command = f"ffmpeg -i {video} -r 25 {output_video} -y"
237
  subprocess.run(command, shell=True, check=True)
238
- return output_video
239
-
240
 
 
241
 
242
 
243
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
244
 
245
  with gr.Blocks(css=css) as demo:
246
  gr.Markdown(
247
- "<div align='center'> <h1>MuseTalk: Real-Time High Quality Lip Synchronization with Latent Space Inpainting </span> </h1> \
248
- <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
249
- </br>\
250
- Yue Zhang <sup>\*</sup>,\
251
- Minhao Liu<sup>\*</sup>,\
252
- Zhaokang Chen,\
253
- Bin Wu<sup>†</sup>,\
254
- Yingjie He,\
255
- Chao Zhan,\
256
- Wenjiang Zhou\
257
- (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
258
- Lyra Lab, Tencent Music Entertainment\
259
- </h2> \
260
- <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
261
- <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
262
- <a style='font-size:18px;color: #000000' href=''> [Technical report(Coming Soon)] </a>\
263
- <a style='font-size:18px;color: #000000' href=''> [Project Page(Coming Soon)] </a> </div>"
264
  )
265
 
266
  with gr.Row():
267
  with gr.Column():
268
- audio = gr.Audio(label="Driven Audio",type="filepath")
269
  video = gr.Video(label="Reference Video")
270
- bbox_shift = gr.Number(label="BBox_shift,[-9,9]", value=-1)
271
  btn = gr.Button("Generate")
272
  out1 = gr.Video()
273
-
274
- video.change(
275
- fn=check_video, inputs=[video], outputs=[video]
276
- )
277
  btn.click(
278
  fn=inference,
279
- inputs=[
280
- audio,
281
- video,
282
- bbox_shift,
283
- ],
284
  outputs=out1,
285
  )
286
 
287
- # Set the IP and port
288
- ip_address = "0.0.0.0" # Replace with your desired IP address
289
- port_number = 7860 # Replace with your desired port number
290
-
291
-
292
  demo.queue().launch(
293
- share=False , debug=True, server_name=ip_address, server_port=port_number
 
 
 
294
  )
 
1
  import os
2
  import time
 
 
 
 
 
3
  import sys
4
  import subprocess
5
+ import glob
6
+ import copy
7
+ import pickle
8
+ import shutil
9
 
10
+ import gradio as gr
 
 
 
 
 
11
  import numpy as np
12
  import cv2
13
  import torch
 
 
14
  from tqdm import tqdm
 
15
  from argparse import Namespace
16
+ from huggingface_hub import snapshot_download
17
+ import requests
18
+
19
+ ProjectDir = os.path.abspath(os.path.dirname(__file__))
20
+ ModelsDir = os.path.join(ProjectDir, "models")
21
 
22
 
23
  def download_model():
24
+ """Download model weights if not already present (entrypoint.sh handles this in Docker)."""
25
+ required_files = [
26
+ os.path.join(ModelsDir, "musetalkV15", "unet.pth"),
27
+ os.path.join(ModelsDir, "sd-vae", "diffusion_pytorch_model.safetensors"),
28
+ os.path.join(ModelsDir, "whisper", "config.json"),
29
+ os.path.join(ModelsDir, "dwpose", "dw-ll_ucoco_384.pth"),
30
+ ]
31
+
32
+ all_present = all(os.path.exists(f) for f in required_files)
33
+
34
+ if all_present:
35
+ print("All model files present — skipping download.")
36
+ return
37
+
38
+ print("Some model files missing, attempting download...")
39
+ tic = time.time()
40
+
41
+ os.makedirs(ModelsDir, exist_ok=True)
42
+
43
+ try:
44
  snapshot_download(
45
  repo_id="TMElyralab/MuseTalk",
46
+ local_dir=ModelsDir,
47
  max_workers=8,
48
  local_dir_use_symlinks=True,
49
+ allow_patterns=["musetalk/*", "musetalkV15/*"],
50
  )
51
+ except Exception as e:
52
+ print(f"Warning: MuseTalk model download failed: {e}")
53
+
54
+ try:
55
  snapshot_download(
56
  repo_id="stabilityai/sd-vae-ft-mse",
57
+ local_dir=os.path.join(ModelsDir, "sd-vae"),
58
  max_workers=8,
59
  local_dir_use_symlinks=True,
60
+ allow_patterns=["config.json", "diffusion_pytorch_model.*"],
61
  )
62
+ except Exception as e:
63
+ print(f"Warning: SD VAE download failed: {e}")
64
+
65
+ try:
66
+ snapshot_download(
67
+ repo_id="openai/whisper-tiny",
68
+ local_dir=os.path.join(ModelsDir, "whisper"),
69
+ max_workers=8,
70
+ local_dir_use_symlinks=True,
71
+ allow_patterns=["config.json", "pytorch_model.bin", "preprocessor_config.json"],
72
+ )
73
+ except Exception as e:
74
+ print(f"Warning: Whisper download failed: {e}")
75
+
76
+ try:
77
  snapshot_download(
78
  repo_id="yzd-v/DWPose",
79
+ local_dir=os.path.join(ModelsDir, "dwpose"),
80
  max_workers=8,
81
  local_dir_use_symlinks=True,
82
+ allow_patterns=["dw-ll_ucoco_384.pth"],
83
  )
84
+ except Exception as e:
85
+ print(f"Warning: DWPose download failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ face_parse_dir = os.path.join(ModelsDir, "face-parse-bisent")
88
+ os.makedirs(face_parse_dir, exist_ok=True)
89
+
90
+ face_parse_path = os.path.join(face_parse_dir, "79999_iter.pth")
91
+ if not os.path.exists(face_parse_path):
92
+ try:
93
+ import gdown
94
+ gdown.download(
95
+ id="154JgKpzCPW82qINcVieuPH3fZ2e0P812",
96
+ output=face_parse_path,
97
+ quiet=False,
98
+ )
99
+ except Exception as e:
100
+ print(f"Warning: Face parse download failed: {e}")
101
+
102
+ resnet_path = os.path.join(face_parse_dir, "resnet18-5c106cde.pth")
103
+ if not os.path.exists(resnet_path):
104
+ try:
105
+ response = requests.get("https://download.pytorch.org/models/resnet18-5c106cde.pth")
106
+ if response.status_code == 200:
107
+ with open(resnet_path, "wb") as f:
108
+ f.write(response.content)
109
+ except Exception as e:
110
+ print(f"Warning: ResNet download failed: {e}")
111
 
112
+ toc = time.time()
113
+ print(f"Download completed in {toc - tic:.1f}s")
114
 
 
115
 
116
+ download_model()
117
 
118
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen
119
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
120
  from musetalk.utils.blending import get_image
121
  from musetalk.utils.utils import load_all_model
122
 
123
 
124
+ audio_processor, vae, unet, pe = load_all_model()
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ timesteps = torch.tensor([0], device=device)
127
 
128
 
 
129
  @torch.no_grad()
130
+ def inference(audio_path, video_path, bbox_shift, progress=gr.Progress(track_tqdm=True)):
131
+ args_dict = {
132
+ "result_dir": "./results",
133
+ "fps": 25,
134
+ "batch_size": 8,
135
+ "output_vid_name": "",
136
+ "use_saved_coord": False,
137
+ }
138
  args = Namespace(**args_dict)
139
 
140
+ input_basename = os.path.basename(video_path).split(".")[0]
141
+ audio_basename = os.path.basename(audio_path).split(".")[0]
142
  output_basename = f"{input_basename}_{audio_basename}"
143
+ result_img_save_path = os.path.join(args.result_dir, output_basename)
144
+ crop_coord_save_path = os.path.join(result_img_save_path, input_basename + ".pkl")
145
+ os.makedirs(result_img_save_path, exist_ok=True)
146
 
147
+ if args.output_vid_name == "":
148
+ output_vid_name = os.path.join(args.result_dir, output_basename + ".mp4")
149
  else:
150
  output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
151
+
152
+ if get_file_type(video_path) == "video":
153
  save_dir_full = os.path.join(args.result_dir, input_basename)
154
+ os.makedirs(save_dir_full, exist_ok=True)
155
  cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
156
  os.system(cmd)
157
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, "*.[jpJP][pnPN]*[gG]")))
158
  fps = get_video_fps(video_path)
159
+ else:
160
+ input_img_list = glob.glob(os.path.join(video_path, "*.[jpJP][pnPN]*[gG]"))
161
  input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
162
  fps = args.fps
163
+
 
164
  whisper_feature = audio_processor.audio2feat(audio_path)
165
+ whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature, fps=fps)
166
+
167
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
168
+ print("Using extracted coordinates")
169
+ with open(crop_coord_save_path, "rb") as f:
170
  coord_list = pickle.load(f)
171
  frame_list = read_imgs(input_img_list)
172
  else:
173
+ print("Extracting landmarks...")
174
  coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
175
+ with open(crop_coord_save_path, "wb") as f:
176
  pickle.dump(coord_list, f)
177
+
 
178
  input_latent_list = []
179
  for bbox, frame in zip(coord_list, frame_list):
180
  if bbox == coord_placeholder:
181
  continue
182
  x1, y1, x2, y2 = bbox
183
  crop_frame = frame[y1:y2, x1:x2]
184
+ crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
185
  latents = vae.get_latents_for_unet(crop_frame)
186
  input_latent_list.append(latents)
187
 
 
188
  frame_list_cycle = frame_list + frame_list[::-1]
189
  coord_list_cycle = coord_list + coord_list[::-1]
190
  input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
191
+
192
+ print("Starting inference...")
193
  video_num = len(whisper_chunks)
194
  batch_size = args.batch_size
195
+ gen = datagen(whisper_chunks, input_latent_list_cycle, batch_size)
196
  res_frame_list = []
197
+
198
+ for i, (whisper_batch, latent_batch) in enumerate(
199
+ tqdm(gen, total=int(np.ceil(float(video_num) / batch_size)))
200
+ ):
201
  tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
202
+ audio_feature_batch = torch.stack(tensor_list).to(unet.device)
203
  audio_feature_batch = pe(audio_feature_batch)
204
+ pred_latents = unet.model(
205
+ latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
206
+ ).sample
207
  recon = vae.decode_latents(pred_latents)
208
  for res_frame in recon:
209
  res_frame_list.append(res_frame)
210
+
211
+ print("Compositing frames...")
 
212
  for i, res_frame in enumerate(tqdm(res_frame_list)):
213
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
214
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
215
  x1, y1, x2, y2 = bbox
216
  try:
217
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
218
+ except Exception:
 
219
  continue
220
+ combine_frame = get_image(ori_frame, res_frame, bbox)
221
+ cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
222
+
223
+ cmd_img2video = (
224
+ f"ffmpeg -y -v fatal -r {fps} -f image2 -i {result_img_save_path}/%08d.png "
225
+ f"-vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p "
226
+ f"-crf 18 temp.mp4"
227
+ )
228
  os.system(cmd_img2video)
229
 
230
  cmd_combine_audio = f"ffmpeg -y -v fatal -i {audio_path} -i temp.mp4 {output_vid_name}"
 
231
  os.system(cmd_combine_audio)
232
 
233
+ if os.path.exists("temp.mp4"):
234
+ os.remove("temp.mp4")
235
+ shutil.rmtree(result_img_save_path, ignore_errors=True)
 
 
 
 
 
 
 
 
 
236
 
237
+ print(f"Result saved to {output_vid_name}")
238
+ return output_vid_name
239
 
240
 
241
  def check_video(video):
 
242
  dir_path, file_name = os.path.split(video)
243
+
244
  if file_name.startswith("outputxxx_"):
245
  return video
 
 
246
 
247
+ output_file_name = "outputxxx_" + file_name
248
  output_video = os.path.join(dir_path, output_file_name)
249
+ command = f"ffmpeg -i {video} -r 25 {output_video} -y"
 
 
 
250
  subprocess.run(command, shell=True, check=True)
 
 
251
 
252
+ return output_video
253
 
254
 
255
  css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
256
 
257
  with gr.Blocks(css=css) as demo:
258
  gr.Markdown(
259
+ "<b>MuseTalk: Real-Time High Quality Lip Synchronization "
260
+ "with Latent Space Inpainting</b>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  )
262
 
263
  with gr.Row():
264
  with gr.Column():
265
+ audio = gr.Audio(label="Driven Audio", type="filepath")
266
  video = gr.Video(label="Reference Video")
267
+ bbox_shift = gr.Number(label="BBox shift [-9, 9]", value=-1)
268
  btn = gr.Button("Generate")
269
  out1 = gr.Video()
270
+
271
+ video.change(fn=check_video, inputs=[video], outputs=[video])
 
 
272
  btn.click(
273
  fn=inference,
274
+ inputs=[audio, video, bbox_shift],
 
 
 
 
275
  outputs=out1,
276
  )
277
 
 
 
 
 
 
278
  demo.queue().launch(
279
+ share=False,
280
+ debug=True,
281
+ server_name="0.0.0.0",
282
+ server_port=7860,
283
  )