faizhashmi614 commited on
Commit
5f9062d
·
1 Parent(s): 013fa46

puxhing docker files

Browse files
Files changed (4) hide show
  1. Dockerfile +52 -0
  2. README.md +0 -11
  3. app.py +570 -0
  4. requirements.txt +37 -0
Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Set working directory early
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies with cleanup
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ ffmpeg \
9
+ git \
10
+ libgl1-mesa-glx \
11
+ libglib2.0-0 \
12
+ build-essential \
13
+ python3-dev \
14
+ libjpeg-dev \
15
+ libpng-dev \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # Copy only what’s needed early for caching
19
+ COPY requirements.txt .
20
+ COPY scripts ./scripts
21
+ COPY configs ./configs
22
+
23
+ # Upgrade pip + install Python deps
24
+ RUN pip install --upgrade pip && \
25
+ pip install --no-cache-dir -r requirements.txt
26
+ # Install LiveKit SDKs
27
+ RUN pip install --no-cache-dir \
28
+ livekit==1.0.7 \
29
+ livekit-api==1.0.2 \
30
+ omegaconf \
31
+ transformers==4.39.3 \
32
+ && pip uninstall -y protobuf && pip install --no-cache-dir protobuf==3.20.3
33
+
34
+ # Install pose dependencies (with caching minimized)
35
+ RUN pip install --no-cache-dir cython && \
36
+ pip install --no-cache-dir git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
37
+
38
+ RUN pip install --no-cache-dir mmengine==0.10.7 mmcv==2.0.0rc4 && \
39
+ pip install --no-cache-dir openmim && \
40
+ mim install mmpose && \
41
+ mim install mmdet
42
+
43
+ # Copy rest of the code
44
+ COPY . .
45
+
46
+ # Final cleanup (in case anything big remains)
47
+ RUN apt-get clean && \
48
+ find /root/.cache -type f -delete && \
49
+ rm -rf /root/.cache/pip
50
+
51
+ # Set entrypoint
52
+ # CMD ["python3", "-m", "scripts.realtime_inference", "--version", "v15", "--inference_config", "configs/inference/realtime.yaml"]
README.md DELETED
@@ -1,11 +0,0 @@
1
- ---
2
- title: MusetalkLivekitSetup
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- short_description: Testing purpose
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pdb
4
+ import re
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import sys
9
+ import subprocess
10
+
11
+ from huggingface_hub import snapshot_download
12
+ import requests
13
+
14
+ import argparse
15
+ import os
16
+ from omegaconf import OmegaConf
17
+ import numpy as np
18
+ import cv2
19
+ import torch
20
+ import glob
21
+ import pickle
22
+ from tqdm import tqdm
23
+ import copy
24
+ from argparse import Namespace
25
+ import shutil
26
+ import gdown
27
+ import imageio
28
+ import ffmpeg
29
+ from moviepy.editor import *
30
+ from transformers import WhisperModel
31
+
32
+ ProjectDir = os.path.abspath(os.path.dirname(__file__))
33
+ CheckpointsDir = os.path.join(ProjectDir, "models")
34
+
35
+ @torch.no_grad()
36
+ def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
37
+ left_cheek_width=90, right_cheek_width=90):
38
+ """Debug inpainting parameters, only process the first frame"""
39
+ # Set default parameters
40
+ args_dict = {
41
+ "result_dir": './results/debug',
42
+ "fps": 25,
43
+ "batch_size": 1,
44
+ "output_vid_name": '',
45
+ "use_saved_coord": False,
46
+ "audio_padding_length_left": 2,
47
+ "audio_padding_length_right": 2,
48
+ "version": "v15",
49
+ "extra_margin": extra_margin,
50
+ "parsing_mode": parsing_mode,
51
+ "left_cheek_width": left_cheek_width,
52
+ "right_cheek_width": right_cheek_width
53
+ }
54
+ args = Namespace(**args_dict)
55
+
56
+ # Create debug directory
57
+ os.makedirs(args.result_dir, exist_ok=True)
58
+
59
+ # Read first frame
60
+ if get_file_type(video_path) == "video":
61
+ reader = imageio.get_reader(video_path)
62
+ first_frame = reader.get_data(0)
63
+ reader.close()
64
+ else:
65
+ first_frame = cv2.imread(video_path)
66
+ first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
67
+
68
+ # Save first frame
69
+ debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
70
+ cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
71
+
72
+ # Get face coordinates
73
+ coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
74
+ bbox = coord_list[0]
75
+ frame = frame_list[0]
76
+
77
+ if bbox == coord_placeholder:
78
+ return None, "No face detected, please adjust bbox_shift parameter"
79
+
80
+ # Initialize face parser
81
+ fp = FaceParsing(
82
+ left_cheek_width=args.left_cheek_width,
83
+ right_cheek_width=args.right_cheek_width
84
+ )
85
+
86
+ # Process first frame
87
+ x1, y1, x2, y2 = bbox
88
+ y2 = y2 + args.extra_margin
89
+ y2 = min(y2, frame.shape[0])
90
+ crop_frame = frame[y1:y2, x1:x2]
91
+ crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
92
+
93
+ # Generate random audio features
94
+ random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
95
+ audio_feature = pe(random_audio)
96
+
97
+ # Get latents
98
+ latents = vae.get_latents_for_unet(crop_frame)
99
+ latents = latents.to(dtype=weight_dtype)
100
+
101
+ # Generate prediction results
102
+ pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
103
+ recon = vae.decode_latents(pred_latents)
104
+
105
+ # Inpaint back to original image
106
+ res_frame = recon[0]
107
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
108
+ combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
109
+
110
+ # Save results (no need to convert color space again since get_image already returns RGB format)
111
+ debug_result_path = os.path.join(args.result_dir, "debug_result.png")
112
+ cv2.imwrite(debug_result_path, combine_frame)
113
+
114
+ # Create information text
115
+ info_text = f"Parameter information:\n" + \
116
+ f"bbox_shift: {bbox_shift}\n" + \
117
+ f"extra_margin: {extra_margin}\n" + \
118
+ f"parsing_mode: {parsing_mode}\n" + \
119
+ f"left_cheek_width: {left_cheek_width}\n" + \
120
+ f"right_cheek_width: {right_cheek_width}\n" + \
121
+ f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
122
+
123
+ return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
124
+
125
+ def print_directory_contents(path):
126
+ for child in os.listdir(path):
127
+ child_path = os.path.join(path, child)
128
+ if os.path.isdir(child_path):
129
+ print(child_path)
130
+
131
+ def download_model():
132
+ # 检查必需的模型文件是否存在
133
+ required_models = {
134
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
135
+ "MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
136
+ "SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
137
+ "Whisper": f"{CheckpointsDir}/whisper/config.json",
138
+ "DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
139
+ "SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
140
+ "Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
141
+ "ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
142
+ }
143
+
144
+ missing_models = []
145
+ for model_name, model_path in required_models.items():
146
+ if not os.path.exists(model_path):
147
+ missing_models.append(model_name)
148
+
149
+ if missing_models:
150
+ # 全用英文
151
+ print("The following required model files are missing:")
152
+ for model in missing_models:
153
+ print(f"- {model}")
154
+ print("\nPlease run the download script to download the missing models:")
155
+ if sys.platform == "win32":
156
+ print("Windows: Run download_weights.bat")
157
+ else:
158
+ print("Linux/Mac: Run ./download_weights.sh")
159
+ sys.exit(1)
160
+ else:
161
+ print("All required model files exist.")
162
+
163
+
164
+
165
+
166
+ download_model() # for huggingface deployment.
167
+
168
+ from musetalk.utils.blending import get_image
169
+ from musetalk.utils.face_parsing import FaceParsing
170
+ from musetalk.utils.audio_processor import AudioProcessor
171
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
172
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
173
+
174
+
175
+ def fast_check_ffmpeg():
176
+ try:
177
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
178
+ return True
179
+ except:
180
+ return False
181
+
182
+
183
+ @torch.no_grad()
184
+ def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
185
+ left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
186
+ # Set default parameters, aligned with inference.py
187
+ args_dict = {
188
+ "result_dir": './results/output',
189
+ "fps": 25,
190
+ "batch_size": 8,
191
+ "output_vid_name": '',
192
+ "use_saved_coord": False,
193
+ "audio_padding_length_left": 2,
194
+ "audio_padding_length_right": 2,
195
+ "version": "v15", # Fixed use v15 version
196
+ "extra_margin": extra_margin,
197
+ "parsing_mode": parsing_mode,
198
+ "left_cheek_width": left_cheek_width,
199
+ "right_cheek_width": right_cheek_width
200
+ }
201
+ args = Namespace(**args_dict)
202
+
203
+ # Check ffmpeg
204
+ if not fast_check_ffmpeg():
205
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
206
+
207
+ input_basename = os.path.basename(video_path).split('.')[0]
208
+ audio_basename = os.path.basename(audio_path).split('.')[0]
209
+ output_basename = f"{input_basename}_{audio_basename}"
210
+
211
+ # Create temporary directory
212
+ temp_dir = os.path.join(args.result_dir, f"{args.version}")
213
+ os.makedirs(temp_dir, exist_ok=True)
214
+
215
+ # Set result save path
216
+ result_img_save_path = os.path.join(temp_dir, output_basename)
217
+ crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
218
+ os.makedirs(result_img_save_path, exist_ok=True)
219
+
220
+ if args.output_vid_name == "":
221
+ output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
222
+ else:
223
+ output_vid_name = os.path.join(temp_dir, args.output_vid_name)
224
+
225
+ ############################################## extract frames from source video ##############################################
226
+ if get_file_type(video_path) == "video":
227
+ save_dir_full = os.path.join(temp_dir, input_basename)
228
+ os.makedirs(save_dir_full, exist_ok=True)
229
+ # Read video
230
+ reader = imageio.get_reader(video_path)
231
+
232
+ # Save images
233
+ for i, im in enumerate(reader):
234
+ imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
235
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
236
+ fps = get_video_fps(video_path)
237
+ else: # input img folder
238
+ input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
239
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
240
+ fps = args.fps
241
+
242
+ ############################################## extract audio feature ##############################################
243
+ # Extract audio features
244
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
245
+ whisper_chunks = audio_processor.get_whisper_chunk(
246
+ whisper_input_features,
247
+ device,
248
+ weight_dtype,
249
+ whisper,
250
+ librosa_length,
251
+ fps=fps,
252
+ audio_padding_length_left=args.audio_padding_length_left,
253
+ audio_padding_length_right=args.audio_padding_length_right,
254
+ )
255
+
256
+ ############################################## preprocess input image ##############################################
257
+ if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
258
+ print("using extracted coordinates")
259
+ with open(crop_coord_save_path,'rb') as f:
260
+ coord_list = pickle.load(f)
261
+ frame_list = read_imgs(input_img_list)
262
+ else:
263
+ print("extracting landmarks...time consuming")
264
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
265
+ with open(crop_coord_save_path, 'wb') as f:
266
+ pickle.dump(coord_list, f)
267
+ bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
268
+
269
+ # Initialize face parser
270
+ fp = FaceParsing(
271
+ left_cheek_width=args.left_cheek_width,
272
+ right_cheek_width=args.right_cheek_width
273
+ )
274
+
275
+ i = 0
276
+ input_latent_list = []
277
+ for bbox, frame in zip(coord_list, frame_list):
278
+ if bbox == coord_placeholder:
279
+ continue
280
+ x1, y1, x2, y2 = bbox
281
+ y2 = y2 + args.extra_margin
282
+ y2 = min(y2, frame.shape[0])
283
+ crop_frame = frame[y1:y2, x1:x2]
284
+ crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
285
+ latents = vae.get_latents_for_unet(crop_frame)
286
+ input_latent_list.append(latents)
287
+
288
+ # to smooth the first and the last frame
289
+ frame_list_cycle = frame_list + frame_list[::-1]
290
+ coord_list_cycle = coord_list + coord_list[::-1]
291
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
292
+
293
+ ############################################## inference batch by batch ##############################################
294
+ print("start inference")
295
+ video_num = len(whisper_chunks)
296
+ batch_size = args.batch_size
297
+ gen = datagen(
298
+ whisper_chunks=whisper_chunks,
299
+ vae_encode_latents=input_latent_list_cycle,
300
+ batch_size=batch_size,
301
+ delay_frame=0,
302
+ device=device,
303
+ )
304
+ res_frame_list = []
305
+ for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
306
+ audio_feature_batch = pe(whisper_batch)
307
+ # Ensure latent_batch is consistent with model weight type
308
+ latent_batch = latent_batch.to(dtype=weight_dtype)
309
+
310
+ pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
311
+ recon = vae.decode_latents(pred_latents)
312
+ for res_frame in recon:
313
+ res_frame_list.append(res_frame)
314
+
315
+ ############################################## pad to full image ##############################################
316
+ print("pad talking image to original video")
317
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
318
+ bbox = coord_list_cycle[i%(len(coord_list_cycle))]
319
+ ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
320
+ x1, y1, x2, y2 = bbox
321
+ y2 = y2 + args.extra_margin
322
+ y2 = min(y2, frame.shape[0])
323
+ try:
324
+ res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
325
+ except:
326
+ continue
327
+
328
+ # Use v15 version blending
329
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
330
+
331
+ cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
332
+
333
+ # Frame rate
334
+ fps = 25
335
+ # Output video path
336
+ output_video = 'temp.mp4'
337
+
338
+ # Read images
339
+ def is_valid_image(file):
340
+ pattern = re.compile(r'\d{8}\.png')
341
+ return pattern.match(file)
342
+
343
+ images = []
344
+ files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
345
+ files.sort(key=lambda x: int(x.split('.')[0]))
346
+
347
+ for file in files:
348
+ filename = os.path.join(result_img_save_path, file)
349
+ images.append(imageio.imread(filename))
350
+
351
+
352
+ # Save video
353
+ imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
354
+
355
+ input_video = './temp.mp4'
356
+ # Check if the input_video and audio_path exist
357
+ if not os.path.exists(input_video):
358
+ raise FileNotFoundError(f"Input video file not found: {input_video}")
359
+ if not os.path.exists(audio_path):
360
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
361
+
362
+ # Read video
363
+ reader = imageio.get_reader(input_video)
364
+ fps = reader.get_meta_data()['fps'] # Get original video frame rate
365
+ reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
366
+ # Store frames in list
367
+ frames = images
368
+
369
+ print(len(frames))
370
+
371
+ # Load the video
372
+ video_clip = VideoFileClip(input_video)
373
+
374
+ # Load the audio
375
+ audio_clip = AudioFileClip(audio_path)
376
+
377
+ # Set the audio to the video
378
+ video_clip = video_clip.set_audio(audio_clip)
379
+
380
+ # Write the output video
381
+ video_clip.write_videofile(output_vid_name, codec='libx264', audio_codec='aac',fps=25)
382
+
383
+ os.remove("temp.mp4")
384
+ #shutil.rmtree(result_img_save_path)
385
+ print(f"result is save to {output_vid_name}")
386
+ return output_vid_name,bbox_shift_text
387
+
388
+
389
+
390
+ # load model weights
391
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
392
+ vae, unet, pe = load_all_model(
393
+ unet_model_path="./models/musetalkV15/unet.pth",
394
+ vae_type="sd-vae",
395
+ unet_config="./models/musetalkV15/musetalk.json",
396
+ device=device
397
+ )
398
+
399
+ # Parse command line arguments
400
+ parser = argparse.ArgumentParser()
401
+ parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
402
+ parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
403
+ parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
404
+ parser.add_argument("--share", action="store_true", help="Create a public link")
405
+ parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
406
+ args = parser.parse_args()
407
+
408
+ # Set data type
409
+ if args.use_float16:
410
+ # Convert models to half precision for better performance
411
+ pe = pe.half()
412
+ vae.vae = vae.vae.half()
413
+ unet.model = unet.model.half()
414
+ weight_dtype = torch.float16
415
+ else:
416
+ weight_dtype = torch.float32
417
+
418
+ # Move models to specified device
419
+ pe = pe.to(device)
420
+ vae.vae = vae.vae.to(device)
421
+ unet.model = unet.model.to(device)
422
+
423
+ timesteps = torch.tensor([0], device=device)
424
+
425
+ # Initialize audio processor and Whisper model
426
+ audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
427
+ whisper = WhisperModel.from_pretrained("./models/whisper")
428
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
429
+ whisper.requires_grad_(False)
430
+
431
+
432
+ def check_video(video):
433
+ if not isinstance(video, str):
434
+ return video # in case of none type
435
+ # Define the output video file name
436
+ dir_path, file_name = os.path.split(video)
437
+ if file_name.startswith("outputxxx_"):
438
+ return video
439
+ # Add the output prefix to the file name
440
+ output_file_name = "outputxxx_" + file_name
441
+
442
+ os.makedirs('./results',exist_ok=True)
443
+ os.makedirs('./results/output',exist_ok=True)
444
+ os.makedirs('./results/input',exist_ok=True)
445
+
446
+ # Combine the directory path and the new file name
447
+ output_video = os.path.join('./results/input', output_file_name)
448
+
449
+
450
+ # read video
451
+ reader = imageio.get_reader(video)
452
+ fps = reader.get_meta_data()['fps'] # get fps from original video
453
+
454
+ # conver fps to 25
455
+ frames = [im for im in reader]
456
+ target_fps = 25
457
+
458
+ L = len(frames)
459
+ L_target = int(L / fps * target_fps)
460
+ original_t = [x / fps for x in range(1, L+1)]
461
+ t_idx = 0
462
+ target_frames = []
463
+ for target_t in range(1, L_target+1):
464
+ while target_t / target_fps > original_t[t_idx]:
465
+ t_idx += 1 # find the first t_idx so that target_t / target_fps <= original_t[t_idx]
466
+ if t_idx >= L:
467
+ break
468
+ target_frames.append(frames[t_idx])
469
+
470
+ # save video
471
+ imageio.mimwrite(output_video, target_frames, 'FFMPEG', fps=25, codec='libx264', quality=9, pixelformat='yuv420p')
472
+ return output_video
473
+
474
+
475
+
476
+
477
+ css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
478
+
479
+ with gr.Blocks(css=css) as demo:
480
+ gr.Markdown(
481
+ """<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
482
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
483
+ </br>\
484
+ Yue Zhang <sup>*</sup>,\
485
+ Zhizhou Zhong <sup>*</sup>,\
486
+ Minhao Liu<sup>*</sup>,\
487
+ Zhaokang Chen,\
488
+ Bin Wu<sup>†</sup>,\
489
+ Yubin Zeng,\
490
+ Chao Zhang,\
491
+ Yingjie He,\
492
+ Junxin Huang,\
493
+ Wenjiang Zhou <br>\
494
+ (<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
495
+ Lyra Lab, Tencent Music Entertainment\
496
+ </h2> \
497
+ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
498
+ <a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
499
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
500
+ )
501
+
502
+ with gr.Row():
503
+ with gr.Column():
504
+ audio = gr.Audio(label="Drving Audio",type="filepath")
505
+ video = gr.Video(label="Reference Video",sources=['upload'])
506
+ bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
507
+ extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
508
+ parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
509
+ left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
510
+ right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
511
+ bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
512
+
513
+ with gr.Row():
514
+ debug_btn = gr.Button("1. Test Inpainting ")
515
+ btn = gr.Button("2. Generate")
516
+ with gr.Column():
517
+ debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
518
+ debug_info = gr.Textbox(label="Parameter Information", lines=5)
519
+ out1 = gr.Video()
520
+
521
+ video.change(
522
+ fn=check_video, inputs=[video], outputs=[video]
523
+ )
524
+ btn.click(
525
+ fn=inference,
526
+ inputs=[
527
+ audio,
528
+ video,
529
+ bbox_shift,
530
+ extra_margin,
531
+ parsing_mode,
532
+ left_cheek_width,
533
+ right_cheek_width
534
+ ],
535
+ outputs=[out1,bbox_shift_scale]
536
+ )
537
+ debug_btn.click(
538
+ fn=debug_inpainting,
539
+ inputs=[
540
+ video,
541
+ bbox_shift,
542
+ extra_margin,
543
+ parsing_mode,
544
+ left_cheek_width,
545
+ right_cheek_width
546
+ ],
547
+ outputs=[debug_image, debug_info]
548
+ )
549
+
550
+ # Check ffmpeg and add to PATH
551
+ if not fast_check_ffmpeg():
552
+ print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
553
+ # According to operating system, choose path separator
554
+ path_separator = ';' if sys.platform == 'win32' else ':'
555
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
556
+ if not fast_check_ffmpeg():
557
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
558
+
559
+ # Solve asynchronous IO issues on Windows
560
+ if sys.platform == 'win32':
561
+ import asyncio
562
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
563
+
564
+ # Start Gradio application
565
+ demo.queue().launch(
566
+ share=args.share,
567
+ debug=True,
568
+ server_name=args.ip,
569
+ server_port=args.port
570
+ )
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core AI packages (compatible with PyTorch 2.0.1 CPU)
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ torchaudio==2.0.2
5
+
6
+ # For MuseTalk
7
+ diffusers==0.30.2
8
+ accelerate==0.28.0
9
+ transformers==4.39.2
10
+ huggingface_hub==0.30.2
11
+ einops==0.8.1
12
+ omegaconf==2.3.0
13
+
14
+ # For audio processing
15
+ librosa==0.11.0
16
+ soundfile==0.12.1
17
+
18
+ # For video/image processing
19
+ opencv-python==4.9.0.80
20
+ ffmpeg-python==0.2.0
21
+ moviepy==1.0.3
22
+ imageio[ffmpeg]
23
+
24
+ # For gradio demo
25
+ gradio==3.41.2
26
+
27
+ # TensorFlow for Whisper support (CPU only, avoid M1 GPU issues)
28
+ tensorflow==2.11.0
29
+ tensorboard==2.11.0
30
+
31
+ # Pose-related tools (mmpose and dependencies installed manually)
32
+ # DO NOT install mmcv/mmdet here; install them manually with the right version!
33
+ # mmcv and mmdet should match build and platform
34
+
35
+ # Utilities
36
+ gdown
37
+ requests