cbsjtu01 commited on
Commit
795f990
·
1 Parent(s): 461edea
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +62 -612
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📉
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 6.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -1,463 +1,3 @@
1
- <<<<<<< HEAD
2
- import os
3
- import sys
4
- import tempfile
5
- import subprocess
6
- import numpy as np
7
- import cv2
8
- import torch
9
- import torchvision
10
- import librosa
11
- import face_alignment
12
- import gradio as gr
13
- from PIL import Image
14
- import torchvision.transforms as transforms
15
- from transformers import Wav2Vec2FeatureExtractor
16
- from tqdm import tqdm
17
- import random
18
- from huggingface_hub import hf_hub_download
19
-
20
- # 引入 spaces,用于 ZeroGPU 支持
21
- import spaces
22
-
23
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
24
-
25
- # 尝试导入本地模块
26
- try:
27
- from generator.FM import FMGenerator
28
- from renderer.models import IMTRenderer
29
- except ImportError as e:
30
- print(f"Import Error: {e}")
31
- print("Please ensure 'generator' and 'renderer' folders are in the same directory.")
32
- exit(1)
33
-
34
- # ==========================================
35
- # 自动下载模型权重的逻辑
36
- # ==========================================
37
- def ensure_checkpoints():
38
- print("Checking model checkpoints...")
39
-
40
- REPO_ID = "cbsjtu01/IMTalker"
41
- REPO_TYPE = "model"
42
-
43
- files_to_download = [
44
- "renderer.ckpt",
45
- "generator.ckpt",
46
- "wav2vec2-base-960h/config.json",
47
- "wav2vec2-base-960h/pytorch_model.bin",
48
- "wav2vec2-base-960h/preprocessor_config.json",
49
- "wav2vec2-base-960h/feature_extractor_config.json",
50
- ]
51
-
52
- TARGET_DIR = "checkpoints"
53
- os.makedirs(TARGET_DIR, exist_ok=True)
54
-
55
- for remote_filename in files_to_download:
56
- local_file_path = os.path.join(TARGET_DIR, remote_filename)
57
-
58
- # 检查文件是否存在且大小正常 (大于 1KB)
59
- if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024:
60
- print(f"Downloading {remote_filename} to {TARGET_DIR}...")
61
- try:
62
- hf_hub_download(
63
- repo_id=REPO_ID,
64
- filename=remote_filename,
65
- repo_type=REPO_TYPE,
66
- local_dir=TARGET_DIR,
67
- local_dir_use_symlinks=False
68
- )
69
- except Exception as e:
70
- print(f"Failed to download {remote_filename}: {e}")
71
- pass
72
- else:
73
- print(f"File {local_file_path} already exists. Skipping download.")
74
-
75
- ensure_checkpoints()
76
-
77
- class AppConfig:
78
- def __init__(self):
79
- # 关键:在 ZeroGPU 环境启动时,必须先设为 CPU,不能直接占满显存,否则会被杀掉
80
- self.device = "cpu"
81
- self.seed = 42
82
- self.fix_noise_seed = False
83
- self.renderer_path = "./checkpoints/renderer.ckpt"
84
- self.generator_path = "./checkpoints/generator.ckpt"
85
- self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h"
86
- self.input_size = 256
87
- self.input_nc = 3
88
- self.fps = 25.0
89
- self.rank = "cuda"
90
- self.sampling_rate = 16000
91
- self.audio_marcing = 2
92
- self.wav2vec_sec = 2.0
93
- self.attention_window = 5
94
- self.only_last_features = True
95
- self.audio_dropout_prob = 0.1
96
- self.style_dim = 512
97
- self.dim_a = 512
98
- self.dim_h = 512
99
- self.dim_e = 7
100
- self.dim_motion = 32
101
- self.dim_c = 32
102
- self.dim_w = 32
103
- self.fmt_depth = 8
104
- self.num_heads = 8
105
- self.mlp_ratio = 4.0
106
- self.no_learned_pe = False
107
- self.num_prev_frames = 10
108
- self.max_grad_norm = 1.0
109
- self.ode_atol = 1e-5
110
- self.ode_rtol = 1e-5
111
- self.nfe = 10
112
- self.torchdiffeq_ode_method = 'euler'
113
- self.a_cfg_scale = 3.0
114
- self.swin_res_threshold = 128
115
- self.window_size = 8
116
- self.ref_path = None
117
- self.pose_path = None
118
- self.gaze_path = None
119
- self.aud_path = None
120
- self.crop = True
121
- self.source_path = None
122
- self.driving_path = None
123
-
124
- class DataProcessor:
125
- def __init__(self, opt):
126
- self.opt = opt
127
- self.fps = opt.fps
128
- self.sampling_rate = opt.sampling_rate
129
- print(f"Loading Face Alignment (CPU first)...")
130
- # 强制使用 CPU 加载 FaceAlignment,避免初始化时占用 GPU
131
- self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu', flip_input=False)
132
- print("Loading Wav2Vec2...")
133
- local_path = opt.wav2vec_model_path
134
- if os.path.exists(local_path) and os.path.exists(os.path.join(local_path, "config.json")):
135
- print(f"Loading local wav2vec from {local_path}")
136
- self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(local_path, local_files_only=True)
137
- else:
138
- print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...")
139
- self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
140
- self.transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
141
-
142
- def process_img(self, img: Image.Image) -> Image.Image:
143
- img_arr = np.array(img)
144
- if img_arr.ndim == 2:
145
- img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB)
146
- elif img_arr.shape[2] == 4:
147
- img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2RGB)
148
- h, w = img_arr.shape[:2]
149
- mult = 360.0 / h
150
- resized_img = cv2.resize(img_arr, dsize=(0, 0), fx=mult, fy=mult, interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC)
151
- try:
152
- bboxes = self.fa.face_detector.detect_from_image(resized_img)
153
- if bboxes is None or len(bboxes) == 0:
154
- bboxes = self.fa.face_detector.detect_from_image(img_arr)
155
- except Exception as e:
156
- print(f"Face detection failed: {e}")
157
- bboxes = None
158
- valid_bboxes = []
159
- if bboxes is not None:
160
- valid_bboxes = [(int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score) for (x1, y1, x2, y2, score) in bboxes if score > 0.5]
161
- if not valid_bboxes:
162
- print("Warning: No face detected. Using center crop.")
163
- cx, cy = w // 2, h // 2
164
- half = min(w, h) // 2
165
- x1_new, x2_new = cx - half, cx + half
166
- y1_new, y2_new = cy - half, cy + half
167
- else:
168
- x1, y1, x2, y2, _ = valid_bboxes[0]
169
- cx = (x1 + x2) // 2
170
- cy = (y1 + y2) // 2
171
- w_face = x2 - x1
172
- h_face = y2 - y1
173
- half_side = int(max(w_face, h_face) * 0.8)
174
- x1_new = cx - half_side
175
- y1_new = cy - half_side
176
- x2_new = cx + half_side
177
- y2_new = cy + half_side
178
- if x1_new < 0: x2_new += (0 - x1_new); x1_new = 0
179
- if y1_new < 0: y2_new += (0 - y1_new); y1_new = 0
180
- if x2_new > w: x1_new -= (x2_new - w); x2_new = w
181
- if y2_new > h: y1_new -= (y2_new - h); y2_new = h
182
- x1_new = max(0, x1_new); y1_new = max(0, y1_new); x2_new = min(w, x2_new); y2_new = min(h, y2_new)
183
- curr_w = x2_new - x1_new; curr_h = y2_new - y1_new
184
- min_side = min(curr_w, curr_h)
185
- x2_new = x1_new + min_side; y2_new = y1_new + min_side
186
- crop_img = img_arr[int(y1_new):int(y2_new), int(x1_new):int(x2_new)]
187
- crop_pil = Image.fromarray(crop_img)
188
- return crop_pil.resize((self.opt.input_size, self.opt.input_size))
189
-
190
- def process_audio(self, path: str) -> torch.Tensor:
191
- speech_array, sampling_rate = librosa.load(path, sr=self.sampling_rate)
192
- return self.wav2vec_preprocessor(speech_array, sampling_rate=sampling_rate, return_tensors='pt').input_values[0]
193
-
194
- def crop_video_stable(self, from_mp4_file_path, to_mp4_file_path, expanded_ratio=0.6, skip_per_frame=1):
195
- if os.path.exists(to_mp4_file_path): os.remove(to_mp4_file_path)
196
- video = cv2.VideoCapture(from_mp4_file_path)
197
- index = 0
198
- bboxes_lists = []
199
- width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
200
- height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
201
- print(f"Analyzing video for stable cropping: {from_mp4_file_path}")
202
- while video.isOpened():
203
- success = video.grab()
204
- if not success: break
205
- if index % skip_per_frame == 0:
206
- success, frame = video.retrieve()
207
- if not success: break
208
- h, w = frame.shape[:2]
209
- mult = 360.0 / h
210
- resized_frame = cv2.resize(frame, dsize=(0, 0), fx=mult, fy=mult, interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC)
211
- try: detected_bboxes = self.fa.face_detector.detect_from_image(resized_frame)
212
- except: detected_bboxes = None
213
- current_frame_bboxes = []
214
- if detected_bboxes is not None:
215
- for d_box in detected_bboxes:
216
- bx1, by1, bx2, by2, score = d_box
217
- if score > 0.5: current_frame_bboxes.append([int(bx1 / mult), int(by1 / mult), int(bx2 / mult), int(by2 / mult), score])
218
- if len(current_frame_bboxes) > 0:
219
- max_bboxes = max(current_frame_bboxes, key=lambda bbox: bbox[2] - bbox[0])
220
- bboxes_lists.append(max_bboxes)
221
- index += 1
222
- video.release()
223
- x_center_lists, y_center_lists, width_lists, height_lists = [], [], [], []
224
- for bbox in bboxes_lists:
225
- x1, y1, x2, y2 = bbox[:4]
226
- x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
227
- x_center_lists.append(x_center)
228
- y_center_lists.append(y_center)
229
- width_lists.append(x2 - x1)
230
- height_lists.append(y2 - y1)
231
- if not (x_center_lists and y_center_lists and width_lists and height_lists):
232
- import shutil
233
- shutil.copy(from_mp4_file_path, to_mp4_file_path)
234
- return
235
- x_center = sorted(x_center_lists)[len(x_center_lists) // 2]
236
- y_center = sorted(y_center_lists)[len(y_center_lists) // 2]
237
- median_width = sorted(width_lists)[len(width_lists) // 2]
238
- median_height = sorted(height_lists)[len(height_lists) // 2]
239
- expanded_width = int(median_width * (1 + expanded_ratio))
240
- expanded_height = int(median_height * (1 + expanded_ratio))
241
- fixed_cropped_width = min(max(expanded_width, expanded_height), width, height)
242
- x1, y1 = int(x_center - fixed_cropped_width / 2), int(y_center - fixed_cropped_width / 2)
243
- x1 = max(0, x1); y1 = max(0, y1)
244
- if x1 + fixed_cropped_width > width: x1 = width - fixed_cropped_width
245
- if y1 + fixed_cropped_width > height: y1 = height - fixed_cropped_width
246
- target_size = self.opt.input_size
247
- cmd = (f'ffmpeg -i "{from_mp4_file_path}" -filter:v "crop={fixed_cropped_width}:{fixed_cropped_width}:{x1}:{y1},scale={target_size}:{target_size}:flags=lanczos" -c:v libx264 -crf 18 -preset slow -c:a aac -b:a 128k "{to_mp4_file_path}" -y -loglevel error')
248
- if os.system(cmd) != 0:
249
- import shutil
250
- shutil.copy(from_mp4_file_path, to_mp4_file_path)
251
-
252
- class InferenceAgent:
253
- def __init__(self, opt):
254
- torch.cuda.empty_cache()
255
- self.opt = opt
256
- self.device = opt.device # 默认为 cpu,防止启动时崩溃
257
- self.data_processor = DataProcessor(opt)
258
- print("Loading Models...")
259
- self.renderer = IMTRenderer(self.opt).to(self.device)
260
- self.generator = FMGenerator(self.opt).to(self.device)
261
- if not os.path.exists(self.opt.renderer_path) or not os.path.exists(self.opt.generator_path):
262
- raise FileNotFoundError("Checkpoints not found even after download attempt.")
263
- self._load_ckpt(self.renderer, self.opt.renderer_path, "gen.")
264
- self._load_fm_ckpt(self.generator, self.opt.generator_path)
265
- self.renderer.eval()
266
- self.generator.eval()
267
-
268
- # 关键:ZeroGPU 需要在函数内部动态将模型移动到 CUDA
269
- def to(self, device):
270
- if self.device != device:
271
- print(f"Moving models to {device}...")
272
- self.device = device
273
- self.renderer = self.renderer.to(device)
274
- self.generator = self.generator.to(device)
275
-
276
- def _load_ckpt(self, model, path, prefix="gen."):
277
- if not os.path.exists(path):
278
- print(f"Warning: Checkpoint {path} not found.")
279
- return
280
- checkpoint = torch.load(path, map_location="cpu")
281
- state_dict = checkpoint.get("state_dict", checkpoint)
282
- clean_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.startswith(prefix)}
283
- model.load_state_dict(clean_state_dict, strict=False)
284
-
285
- def _load_fm_ckpt(self, model, path):
286
- if not os.path.exists(path): return
287
- checkpoint = torch.load(path, map_location='cpu')
288
- state_dict = checkpoint.get('state_dict', checkpoint)
289
- if 'model' in state_dict: state_dict = state_dict['model']
290
- prefix = 'model.'
291
- clean_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
292
- with torch.no_grad():
293
- for name, param in model.named_parameters():
294
- if name in clean_dict:
295
- param.copy_(clean_dict[name].to(self.device))
296
-
297
- def save_video(self, vid_tensor, fps, audio_path=None):
298
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
299
- raw_path = tmp.name
300
- if vid_tensor.dim() == 4:
301
- vid = vid_tensor.permute(0, 2, 3, 1).detach().cpu().numpy()
302
- if vid.min() < 0:
303
- vid = (vid + 1) / 2
304
- vid = np.clip(vid, 0, 1)
305
- vid = (vid * 255).astype(np.uint8)
306
- height, width = vid.shape[1], vid.shape[2]
307
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
308
- writer = cv2.VideoWriter(raw_path, fourcc, fps, (width, height))
309
- for frame in vid:
310
- writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
311
- writer.release()
312
- if audio_path:
313
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_out:
314
- final_path = tmp_out.name
315
- cmd = f"ffmpeg -y -i {raw_path} -i {audio_path} -c:v copy -c:a aac -shortest {final_path}"
316
- subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
317
- if os.path.exists(raw_path): os.remove(raw_path)
318
- return final_path
319
- else:
320
- return raw_path
321
-
322
- @torch.no_grad()
323
- def run_audio_inference(self, img_pil, aud_path, crop, seed, nfe, cfg_scale):
324
- s_pil = self.data_processor.process_img(img_pil) if crop else img_pil.resize((self.opt.input_size, self.opt.input_size))
325
- s_tensor = self.data_processor.transform(s_pil).unsqueeze(0).to(self.device)
326
- a_tensor = self.data_processor.process_audio(aud_path).unsqueeze(0).to(self.device)
327
- data = {'s': s_tensor, 'a': a_tensor, 'pose': None, 'cam': None, 'gaze': None, 'ref_x': None}
328
- f_r, g_r = self.renderer.dense_feature_encoder(s_tensor)
329
- t_lat = self.renderer.latent_token_encoder(s_tensor)
330
- if isinstance(t_lat, tuple): t_lat = t_lat[0]
331
- data['ref_x'] = t_lat
332
- torch.manual_seed(seed)
333
- sample = self.generator.sample(data, a_cfg_scale=cfg_scale, nfe=nfe, seed=seed)
334
- d_hat = []
335
- T = sample.shape[1]
336
- ta_r = self.renderer.adapt(t_lat, g_r)
337
- m_r = self.renderer.latent_token_decoder(ta_r)
338
- for t in range(T):
339
- ta_c = self.renderer.adapt(sample[:, t, ...], g_r)
340
- m_c = self.renderer.latent_token_decoder(ta_c)
341
- out_frame = self.renderer.decode(m_c, m_r, f_r)
342
- d_hat.append(out_frame)
343
- vid_tensor = torch.stack(d_hat, dim=1).squeeze(0)
344
- return self.save_video(vid_tensor, self.opt.fps, aud_path)
345
-
346
- @torch.no_grad()
347
- def run_video_inference(self, source_img_pil, driving_video_path, crop):
348
- s_pil = self.data_processor.process_img(source_img_pil) if crop else source_img_pil.resize((self.opt.input_size, self.opt.input_size))
349
- s_tensor = self.data_processor.transform(s_pil).unsqueeze(0).to(self.device)
350
- f_r, i_r = self.renderer.app_encode(s_tensor)
351
- t_r = self.renderer.mot_encode(s_tensor)
352
- ta_r = self.renderer.adapt(t_r, i_r)
353
- ma_r = self.renderer.mot_decode(ta_r)
354
- final_driving_path = driving_video_path
355
- temp_crop_video = None
356
- if crop:
357
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: temp_crop_video = tmp.name
358
- self.data_processor.crop_video_stable(driving_video_path, temp_crop_video)
359
- final_driving_path = temp_crop_video
360
- cap = cv2.VideoCapture(final_driving_path)
361
- fps = cap.get(cv2.CAP_PROP_FPS)
362
- vid_results = []
363
- while True:
364
- ret, frame = cap.read()
365
- if not ret: break
366
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
367
- frame_pil = Image.fromarray(frame).resize((self.opt.input_size, self.opt.input_size))
368
- d_tensor = self.data_processor.transform(frame_pil).unsqueeze(0).to(self.device)
369
- t_c = self.renderer.mot_encode(d_tensor)
370
- ta_c = self.renderer.adapt(t_c, i_r)
371
- ma_c = self.renderer.mot_decode(ta_c)
372
- out = self.renderer.decode(ma_c, ma_r, f_r)
373
- vid_results.append(out.cpu())
374
- cap.release()
375
- if temp_crop_video and os.path.exists(temp_crop_video): os.remove(temp_crop_video)
376
- if not vid_results: raise Exception("Driving video reading failed.")
377
- vid_tensor = torch.cat(vid_results, dim=0)
378
- return self.save_video(vid_tensor, fps=fps, audio_path=None)
379
-
380
- print("Initializing Configuration...")
381
- cfg = AppConfig()
382
- agent = None
383
-
384
- try:
385
- if os.path.exists(cfg.renderer_path) and os.path.exists(cfg.generator_path):
386
- agent = InferenceAgent(cfg)
387
- else:
388
- print("Error: Checkpoints not found. Please upload 'renderer.ckpt' and 'generator.ckpt' via the Files tab.")
389
- except Exception as e:
390
- print(f"Initialization Error: {e}")
391
- import traceback
392
- traceback.print_exc()
393
-
394
- # 添加 @spaces.GPU 装饰器,必须添加!
395
- @spaces.GPU
396
- def fn_audio_driven(image, audio, crop, seed, nfe, cfg_scale, progress=gr.Progress()):
397
- if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
398
- if image is None or audio is None: raise gr.Error("Missing image or audio.")
399
-
400
- # 动态移动模型到 GPU
401
- if torch.cuda.is_available():
402
- agent.to("cuda")
403
-
404
- img_pil = Image.fromarray(image).convert('RGB')
405
- try:
406
- return agent.run_audio_inference(img_pil, audio, crop, int(seed), int(nfe), float(cfg_scale))
407
- except Exception as e:
408
- raise gr.Error(f"Error: {e}")
409
-
410
- # 添加 @spaces.GPU 装饰器,必须添加!
411
- @spaces.GPU
412
- def fn_video_driven(source_image, driving_video, crop, progress=gr.Progress()):
413
- if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
414
- if source_image is None or driving_video is None: raise gr.Error("Missing inputs.")
415
-
416
- # 动态移动模型到 GPU
417
- if torch.cuda.is_available():
418
- agent.to("cuda")
419
-
420
- img_pil = Image.fromarray(source_image).convert('RGB')
421
- try:
422
- return agent.run_video_inference(img_pil, driving_video, crop)
423
- except Exception as e:
424
- import traceback
425
- traceback.print_exc()
426
- raise gr.Error(f"Error: {e}")
427
-
428
- # Gradio 4.x 语法:去除了 css,使用 sources=["upload"]
429
- with gr.Blocks(title="IMTalker Demo") as demo:
430
- gr.Markdown("# 🗣️ IMTalker: Efficient Audio-driven Talking Face Generation")
431
- with gr.Tabs():
432
- with gr.TabItem("Audio Driven"):
433
- with gr.Row():
434
- with gr.Column():
435
- a_img = gr.Image(label="Source Image", type="numpy")
436
- a_aud = gr.Audio(label="Driving Audio", type="filepath")
437
- with gr.Accordion("Settings", open=True):
438
- a_crop = gr.Checkbox(label="Auto Crop Face", value=True)
439
- a_seed = gr.Number(label="Seed", value=42)
440
- a_nfe = gr.Slider(5, 50, value=10, step=1, label="Steps (NFE)")
441
- a_cfg = gr.Slider(1.0, 5.0, value=3.0, label="CFG Scale")
442
- a_btn = gr.Button("Generate (Audio Driven)", variant="primary")
443
- with gr.Column():
444
- a_out = gr.Video(label="Result")
445
- a_btn.click(fn_audio_driven, [a_img, a_aud, a_crop, a_seed, a_nfe, a_cfg], a_out)
446
-
447
- with gr.TabItem("Video Driven"):
448
- with gr.Row():
449
- with gr.Column():
450
- v_img = gr.Image(label="Source Image", type="numpy")
451
- # Gradio 4.x 语法
452
- v_vid = gr.Video(label="Driving Video", sources=["upload"])
453
- v_crop = gr.Checkbox(label="Auto Crop (Both Source & Driving)", value=True)
454
- v_btn = gr.Button("Generate (Video Driven)", variant="primary")
455
- with gr.Column():
456
- v_out = gr.Video(label="Result")
457
- v_btn.click(fn_video_driven, [v_img, v_vid, v_crop], v_out)
458
-
459
- if __name__ == "__main__":
460
- =======
461
  import os
462
  import sys
463
  import tempfile
@@ -476,9 +16,12 @@ from tqdm import tqdm
476
  import random
477
  from huggingface_hub import hf_hub_download
478
 
 
 
 
479
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
480
 
481
- # 尝试导入本地模块,如果失败则提示
482
  try:
483
  from generator.FM import FMGenerator
484
  from renderer.models import IMTRenderer
@@ -491,18 +34,11 @@ except ImportError as e:
491
  # 自动下载模型权重的逻辑
492
  # ==========================================
493
  def ensure_checkpoints():
494
- """
495
- 从指定仓库下载模型文件。
496
- """
497
  print("Checking model checkpoints...")
498
 
499
- # 修改为用户提供的新仓库 ID
500
  REPO_ID = "cbsjtu01/IMTalker"
501
- # 这是一个模型仓库(URL没有 /spaces/),所以类型是 model
502
  REPO_TYPE = "model"
503
 
504
- # 只需要列出远程文件名即可,我们会统一下载到 checkpoints 文件夹
505
- # 并且会自动保持目录结构(比如 wav2vec2/config.json 会自动建文件夹)
506
  files_to_download = [
507
  "renderer.ckpt",
508
  "generator.ckpt",
@@ -512,61 +48,50 @@ def ensure_checkpoints():
512
  "wav2vec2-base-960h/feature_extractor_config.json",
513
  ]
514
 
515
- # 目标根目录
516
  TARGET_DIR = "checkpoints"
517
  os.makedirs(TARGET_DIR, exist_ok=True)
518
 
519
  for remote_filename in files_to_download:
520
- # 计算预期的本地完整路径
521
  local_file_path = os.path.join(TARGET_DIR, remote_filename)
522
 
523
- # 检查文件是否存在且大小正常 (大于 1KB,防止是空文件)
524
  if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024:
525
  print(f"Downloading {remote_filename} to {TARGET_DIR}...")
526
  try:
527
- # 关键修改:直接指定 local_dir 为 checkpoints
528
- # hf_hub_download 会自动处理 remote_filename 中的子目录结构
529
  hf_hub_download(
530
  repo_id=REPO_ID,
531
  filename=remote_filename,
532
  repo_type=REPO_TYPE,
533
- local_dir=TARGET_DIR, # 所有文件都以 checkpoints 为根目录
534
  local_dir_use_symlinks=False
535
  )
536
  except Exception as e:
537
  print(f"Failed to download {remote_filename}: {e}")
538
- # wav2vec2 下载失败可以忽略,后面有 fallback
539
  pass
540
  else:
541
  print(f"File {local_file_path} already exists. Skipping download.")
542
 
543
- # 在配置初始化前执行检查
544
  ensure_checkpoints()
545
 
546
  class AppConfig:
547
  def __init__(self):
548
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
549
  self.seed = 42
550
  self.fix_noise_seed = False
551
-
552
  self.renderer_path = "./checkpoints/renderer.ckpt"
553
  self.generator_path = "./checkpoints/generator.ckpt"
554
-
555
- # 这里的路径改为 None,让 DataProcessor 决定是加载本地还是远程
556
  self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h"
557
-
558
  self.input_size = 256
559
  self.input_nc = 3
560
  self.fps = 25.0
561
  self.rank = "cuda"
562
-
563
  self.sampling_rate = 16000
564
  self.audio_marcing = 2
565
  self.wav2vec_sec = 2.0
566
  self.attention_window = 5
567
  self.only_last_features = True
568
  self.audio_dropout_prob = 0.1
569
-
570
  self.style_dim = 512
571
  self.dim_a = 512
572
  self.dim_h = 512
@@ -574,23 +99,19 @@ class AppConfig:
574
  self.dim_motion = 32
575
  self.dim_c = 32
576
  self.dim_w = 32
577
-
578
  self.fmt_depth = 8
579
  self.num_heads = 8
580
  self.mlp_ratio = 4.0
581
  self.no_learned_pe = False
582
  self.num_prev_frames = 10
583
  self.max_grad_norm = 1.0
584
-
585
  self.ode_atol = 1e-5
586
  self.ode_rtol = 1e-5
587
  self.nfe = 10
588
  self.torchdiffeq_ode_method = 'euler'
589
  self.a_cfg_scale = 3.0
590
-
591
  self.swin_res_threshold = 128
592
  self.window_size = 8
593
-
594
  self.ref_path = None
595
  self.pose_path = None
596
  self.gaze_path = None
@@ -604,44 +125,28 @@ class DataProcessor:
604
  self.opt = opt
605
  self.fps = opt.fps
606
  self.sampling_rate = opt.sampling_rate
607
-
608
- print(f"Loading Face Alignment to {opt.device}...")
609
- # 增加容错:如果 device 是 cuda 但显存不足,可能需要回退,但这里保持原样
610
- self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device=opt.device, flip_input=False)
611
-
612
- # 优化后的 wav2vec 加载逻辑
613
  print("Loading Wav2Vec2...")
614
  local_path = opt.wav2vec_model_path
615
  if os.path.exists(local_path) and os.path.exists(os.path.join(local_path, "config.json")):
616
  print(f"Loading local wav2vec from {local_path}")
617
- self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
618
- local_path, local_files_only=True
619
- )
620
  else:
621
  print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...")
622
  self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
623
-
624
- self.transform = transforms.Compose([
625
- transforms.Resize((256, 256)),
626
- transforms.ToTensor(),
627
- ])
628
 
629
  def process_img(self, img: Image.Image) -> Image.Image:
630
  img_arr = np.array(img)
631
- # 处理灰度图和透明通道
632
  if img_arr.ndim == 2:
633
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB)
634
  elif img_arr.shape[2] == 4:
635
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2RGB)
636
-
637
  h, w = img_arr.shape[:2]
638
  mult = 360.0 / h
639
- resized_img = cv2.resize(
640
- img_arr, dsize=(0, 0), fx=mult, fy=mult,
641
- interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC
642
- )
643
-
644
- # 尝试检测人脸
645
  try:
646
  bboxes = self.fa.face_detector.detect_from_image(resized_img)
647
  if bboxes is None or len(bboxes) == 0:
@@ -649,15 +154,9 @@ class DataProcessor:
649
  except Exception as e:
650
  print(f"Face detection failed: {e}")
651
  bboxes = None
652
-
653
  valid_bboxes = []
654
  if bboxes is not None:
655
- valid_bboxes = [
656
- (int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score)
657
- for (x1, y1, x2, y2, score) in bboxes if score > 0.5
658
- ]
659
-
660
- # 如果没检测到人脸,使用中心裁剪
661
  if not valid_bboxes:
662
  print("Warning: No face detected. Using center crop.")
663
  cx, cy = w // 2, h // 2
@@ -675,92 +174,51 @@ class DataProcessor:
675
  y1_new = cy - half_side
676
  x2_new = cx + half_side
677
  y2_new = cy + half_side
678
-
679
- # 边界处理
680
- if x1_new < 0:
681
- x2_new += (0 - x1_new)
682
- x1_new = 0
683
- if y1_new < 0:
684
- y2_new += (0 - y1_new)
685
- y1_new = 0
686
- if x2_new > w:
687
- x1_new -= (x2_new - w)
688
- x2_new = w
689
- if y2_new > h:
690
- y1_new -= (y2_new - h)
691
- y2_new = h
692
-
693
- x1_new = max(0, x1_new)
694
- y1_new = max(0, y1_new)
695
- x2_new = min(w, x2_new)
696
- y2_new = min(h, y2_new)
697
-
698
- # 保证正方形
699
- curr_w = x2_new - x1_new
700
- curr_h = y2_new - y1_new
701
  min_side = min(curr_w, curr_h)
702
- x2_new = x1_new + min_side
703
- y2_new = y1_new + min_side
704
-
705
  crop_img = img_arr[int(y1_new):int(y2_new), int(x1_new):int(x2_new)]
706
  crop_pil = Image.fromarray(crop_img)
707
  return crop_pil.resize((self.opt.input_size, self.opt.input_size))
708
 
709
  def process_audio(self, path: str) -> torch.Tensor:
710
  speech_array, sampling_rate = librosa.load(path, sr=self.sampling_rate)
711
- return self.wav2vec_preprocessor(
712
- speech_array,
713
- sampling_rate=sampling_rate,
714
- return_tensors='pt'
715
- ).input_values[0]
716
 
717
  def crop_video_stable(self, from_mp4_file_path, to_mp4_file_path, expanded_ratio=0.6, skip_per_frame=1):
718
- if os.path.exists(to_mp4_file_path):
719
- os.remove(to_mp4_file_path)
720
-
721
  video = cv2.VideoCapture(from_mp4_file_path)
722
  index = 0
723
  bboxes_lists = []
724
  width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
725
  height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
726
-
727
  print(f"Analyzing video for stable cropping: {from_mp4_file_path}")
728
-
729
  while video.isOpened():
730
  success = video.grab()
731
- if not success:
732
- break
733
  if index % skip_per_frame == 0:
734
  success, frame = video.retrieve()
735
- if not success:
736
- break
737
  h, w = frame.shape[:2]
738
  mult = 360.0 / h
739
- resized_frame = cv2.resize(
740
- frame, dsize=(0, 0), fx=mult, fy=mult,
741
- interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC
742
- )
743
- try:
744
- detected_bboxes = self.fa.face_detector.detect_from_image(resized_frame)
745
- except:
746
- detected_bboxes = None
747
-
748
  current_frame_bboxes = []
749
  if detected_bboxes is not None:
750
  for d_box in detected_bboxes:
751
  bx1, by1, bx2, by2, score = d_box
752
- if score > 0.5:
753
- current_frame_bboxes.append([
754
- int(bx1 / mult), int(by1 / mult),
755
- int(bx2 / mult), int(by2 / mult),
756
- score
757
- ])
758
  if len(current_frame_bboxes) > 0:
759
  max_bboxes = max(current_frame_bboxes, key=lambda bbox: bbox[2] - bbox[0])
760
  bboxes_lists.append(max_bboxes)
761
  index += 1
762
  video.release()
763
-
764
  x_center_lists, y_center_lists, width_lists, height_lists = [], [], [], []
765
  for bbox in bboxes_lists:
766
  x1, y1, x2, y2 = bbox[:4]
@@ -769,37 +227,23 @@ class DataProcessor:
769
  y_center_lists.append(y_center)
770
  width_lists.append(x2 - x1)
771
  height_lists.append(y2 - y1)
772
-
773
  if not (x_center_lists and y_center_lists and width_lists and height_lists):
774
  import shutil
775
  shutil.copy(from_mp4_file_path, to_mp4_file_path)
776
  return
777
-
778
  x_center = sorted(x_center_lists)[len(x_center_lists) // 2]
779
  y_center = sorted(y_center_lists)[len(y_center_lists) // 2]
780
  median_width = sorted(width_lists)[len(width_lists) // 2]
781
  median_height = sorted(height_lists)[len(height_lists) // 2]
782
-
783
  expanded_width = int(median_width * (1 + expanded_ratio))
784
  expanded_height = int(median_height * (1 + expanded_ratio))
785
  fixed_cropped_width = min(max(expanded_width, expanded_height), width, height)
786
-
787
  x1, y1 = int(x_center - fixed_cropped_width / 2), int(y_center - fixed_cropped_width / 2)
788
- x1 = max(0, x1)
789
- y1 = max(0, y1)
790
  if x1 + fixed_cropped_width > width: x1 = width - fixed_cropped_width
791
  if y1 + fixed_cropped_width > height: y1 = height - fixed_cropped_width
792
-
793
  target_size = self.opt.input_size
794
-
795
- cmd = (
796
- f'ffmpeg -i "{from_mp4_file_path}" '
797
- f'-filter:v "crop={fixed_cropped_width}:{fixed_cropped_width}:{x1}:{y1},'
798
- f'scale={target_size}:{target_size}:flags=lanczos" '
799
- f'-c:v libx264 -crf 18 -preset slow '
800
- f'-c:a aac -b:a 128k "{to_mp4_file_path}" -y -loglevel error'
801
- )
802
-
803
  if os.system(cmd) != 0:
804
  import shutil
805
  shutil.copy(from_mp4_file_path, to_mp4_file_path)
@@ -808,23 +252,26 @@ class InferenceAgent:
808
  def __init__(self, opt):
809
  torch.cuda.empty_cache()
810
  self.opt = opt
811
- self.device = opt.device
812
  self.data_processor = DataProcessor(opt)
813
-
814
  print("Loading Models...")
815
  self.renderer = IMTRenderer(self.opt).to(self.device)
816
  self.generator = FMGenerator(self.opt).to(self.device)
817
-
818
- # 增加路径检查,防止崩溃
819
  if not os.path.exists(self.opt.renderer_path) or not os.path.exists(self.opt.generator_path):
820
  raise FileNotFoundError("Checkpoints not found even after download attempt.")
821
-
822
  self._load_ckpt(self.renderer, self.opt.renderer_path, "gen.")
823
  self._load_fm_ckpt(self.generator, self.opt.generator_path)
824
-
825
  self.renderer.eval()
826
  self.generator.eval()
827
 
 
 
 
 
 
 
 
 
828
  def _load_ckpt(self, model, path, prefix="gen."):
829
  if not os.path.exists(path):
830
  print(f"Warning: Checkpoint {path} not found.")
@@ -855,22 +302,18 @@ class InferenceAgent:
855
  vid = (vid + 1) / 2
856
  vid = np.clip(vid, 0, 1)
857
  vid = (vid * 255).astype(np.uint8)
858
-
859
  height, width = vid.shape[1], vid.shape[2]
860
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
861
  writer = cv2.VideoWriter(raw_path, fourcc, fps, (width, height))
862
  for frame in vid:
863
  writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
864
  writer.release()
865
-
866
  if audio_path:
867
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_out:
868
  final_path = tmp_out.name
869
- # 使用 ffmpeg 合成音频,增加 -shortest 防止长度不一致
870
  cmd = f"ffmpeg -y -i {raw_path} -i {audio_path} -c:v copy -c:a aac -shortest {final_path}"
871
  subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
872
- if os.path.exists(raw_path):
873
- os.remove(raw_path)
874
  return final_path
875
  else:
876
  return raw_path
@@ -907,15 +350,12 @@ class InferenceAgent:
907
  t_r = self.renderer.mot_encode(s_tensor)
908
  ta_r = self.renderer.adapt(t_r, i_r)
909
  ma_r = self.renderer.mot_decode(ta_r)
910
-
911
  final_driving_path = driving_video_path
912
  temp_crop_video = None
913
  if crop:
914
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
915
- temp_crop_video = tmp.name
916
  self.data_processor.crop_video_stable(driving_video_path, temp_crop_video)
917
  final_driving_path = temp_crop_video
918
-
919
  cap = cv2.VideoCapture(final_driving_path)
920
  fps = cap.get(cv2.CAP_PROP_FPS)
921
  vid_results = []
@@ -931,10 +371,8 @@ class InferenceAgent:
931
  out = self.renderer.decode(ma_c, ma_r, f_r)
932
  vid_results.append(out.cpu())
933
  cap.release()
934
- if temp_crop_video and os.path.exists(temp_crop_video):
935
- os.remove(temp_crop_video)
936
- if not vid_results:
937
- raise Exception("Driving video reading failed.")
938
  vid_tensor = torch.cat(vid_results, dim=0)
939
  return self.save_video(vid_tensor, fps=fps, audio_path=None)
940
 
@@ -943,7 +381,6 @@ cfg = AppConfig()
943
  agent = None
944
 
945
  try:
946
- # 再次检查文件是否存在,如果不存在则不实例化 agent
947
  if os.path.exists(cfg.renderer_path) and os.path.exists(cfg.generator_path):
948
  agent = InferenceAgent(cfg)
949
  else:
@@ -953,18 +390,32 @@ except Exception as e:
953
  import traceback
954
  traceback.print_exc()
955
 
 
 
956
  def fn_audio_driven(image, audio, crop, seed, nfe, cfg_scale, progress=gr.Progress()):
957
  if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
958
  if image is None or audio is None: raise gr.Error("Missing image or audio.")
 
 
 
 
 
959
  img_pil = Image.fromarray(image).convert('RGB')
960
  try:
961
  return agent.run_audio_inference(img_pil, audio, crop, int(seed), int(nfe), float(cfg_scale))
962
  except Exception as e:
963
  raise gr.Error(f"Error: {e}")
964
 
 
 
965
  def fn_video_driven(source_image, driving_video, crop, progress=gr.Progress()):
966
  if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
967
  if source_image is None or driving_video is None: raise gr.Error("Missing inputs.")
 
 
 
 
 
968
  img_pil = Image.fromarray(source_image).convert('RGB')
969
  try:
970
  return agent.run_video_inference(img_pil, driving_video, crop)
@@ -973,10 +424,9 @@ def fn_video_driven(source_image, driving_video, crop, progress=gr.Progress()):
973
  traceback.print_exc()
974
  raise gr.Error(f"Error: {e}")
975
 
976
- # Removed 'css' argument to prevent TypeError in certain Gradio versions
977
  with gr.Blocks(title="IMTalker Demo") as demo:
978
  gr.Markdown("# 🗣️ IMTalker: Efficient Audio-driven Talking Face Generation")
979
-
980
  with gr.Tabs():
981
  with gr.TabItem("Audio Driven"):
982
  with gr.Row():
@@ -997,7 +447,8 @@ with gr.Blocks(title="IMTalker Demo") as demo:
997
  with gr.Row():
998
  with gr.Column():
999
  v_img = gr.Image(label="Source Image", type="numpy")
1000
- v_vid = gr.Video(label="Driving Video", source="upload")
 
1001
  v_crop = gr.Checkbox(label="Auto Crop (Both Source & Driving)", value=True)
1002
  v_btn = gr.Button("Generate (Video Driven)", variant="primary")
1003
  with gr.Column():
@@ -1005,5 +456,4 @@ with gr.Blocks(title="IMTalker Demo") as demo:
1005
  v_btn.click(fn_video_driven, [v_img, v_vid, v_crop], v_out)
1006
 
1007
  if __name__ == "__main__":
1008
- >>>>>>> 1a44ff1967f2f89ddf2b5accfcc8a1d4119aa529
1009
  demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import tempfile
 
16
  import random
17
  from huggingface_hub import hf_hub_download
18
 
19
+ # 引入 spaces,用于 ZeroGPU 支持
20
+ import spaces
21
+
22
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
23
 
24
+ # 尝试导入本地模块
25
  try:
26
  from generator.FM import FMGenerator
27
  from renderer.models import IMTRenderer
 
34
  # 自动下载模型权重的逻辑
35
  # ==========================================
36
  def ensure_checkpoints():
 
 
 
37
  print("Checking model checkpoints...")
38
 
 
39
  REPO_ID = "cbsjtu01/IMTalker"
 
40
  REPO_TYPE = "model"
41
 
 
 
42
  files_to_download = [
43
  "renderer.ckpt",
44
  "generator.ckpt",
 
48
  "wav2vec2-base-960h/feature_extractor_config.json",
49
  ]
50
 
 
51
  TARGET_DIR = "checkpoints"
52
  os.makedirs(TARGET_DIR, exist_ok=True)
53
 
54
  for remote_filename in files_to_download:
 
55
  local_file_path = os.path.join(TARGET_DIR, remote_filename)
56
 
57
+ # 检查文件是否存在且大小正常 (大于 1KB)
58
  if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024:
59
  print(f"Downloading {remote_filename} to {TARGET_DIR}...")
60
  try:
 
 
61
  hf_hub_download(
62
  repo_id=REPO_ID,
63
  filename=remote_filename,
64
  repo_type=REPO_TYPE,
65
+ local_dir=TARGET_DIR,
66
  local_dir_use_symlinks=False
67
  )
68
  except Exception as e:
69
  print(f"Failed to download {remote_filename}: {e}")
 
70
  pass
71
  else:
72
  print(f"File {local_file_path} already exists. Skipping download.")
73
 
 
74
  ensure_checkpoints()
75
 
76
  class AppConfig:
77
  def __init__(self):
78
+ # 关键:在 ZeroGPU 环境启动时,必须先设为 CPU,不能直接占满显存,否则会被杀掉
79
+ self.device = "cpu"
80
  self.seed = 42
81
  self.fix_noise_seed = False
 
82
  self.renderer_path = "./checkpoints/renderer.ckpt"
83
  self.generator_path = "./checkpoints/generator.ckpt"
 
 
84
  self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h"
 
85
  self.input_size = 256
86
  self.input_nc = 3
87
  self.fps = 25.0
88
  self.rank = "cuda"
 
89
  self.sampling_rate = 16000
90
  self.audio_marcing = 2
91
  self.wav2vec_sec = 2.0
92
  self.attention_window = 5
93
  self.only_last_features = True
94
  self.audio_dropout_prob = 0.1
 
95
  self.style_dim = 512
96
  self.dim_a = 512
97
  self.dim_h = 512
 
99
  self.dim_motion = 32
100
  self.dim_c = 32
101
  self.dim_w = 32
 
102
  self.fmt_depth = 8
103
  self.num_heads = 8
104
  self.mlp_ratio = 4.0
105
  self.no_learned_pe = False
106
  self.num_prev_frames = 10
107
  self.max_grad_norm = 1.0
 
108
  self.ode_atol = 1e-5
109
  self.ode_rtol = 1e-5
110
  self.nfe = 10
111
  self.torchdiffeq_ode_method = 'euler'
112
  self.a_cfg_scale = 3.0
 
113
  self.swin_res_threshold = 128
114
  self.window_size = 8
 
115
  self.ref_path = None
116
  self.pose_path = None
117
  self.gaze_path = None
 
125
  self.opt = opt
126
  self.fps = opt.fps
127
  self.sampling_rate = opt.sampling_rate
128
+ print(f"Loading Face Alignment (CPU first)...")
129
+ # 强制使用 CPU 加载 FaceAlignment,避免初始化时占用 GPU
130
+ self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu', flip_input=False)
 
 
 
131
  print("Loading Wav2Vec2...")
132
  local_path = opt.wav2vec_model_path
133
  if os.path.exists(local_path) and os.path.exists(os.path.join(local_path, "config.json")):
134
  print(f"Loading local wav2vec from {local_path}")
135
+ self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(local_path, local_files_only=True)
 
 
136
  else:
137
  print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...")
138
  self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
139
+ self.transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
 
 
 
 
140
 
141
  def process_img(self, img: Image.Image) -> Image.Image:
142
  img_arr = np.array(img)
 
143
  if img_arr.ndim == 2:
144
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB)
145
  elif img_arr.shape[2] == 4:
146
  img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2RGB)
 
147
  h, w = img_arr.shape[:2]
148
  mult = 360.0 / h
149
+ resized_img = cv2.resize(img_arr, dsize=(0, 0), fx=mult, fy=mult, interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC)
 
 
 
 
 
150
  try:
151
  bboxes = self.fa.face_detector.detect_from_image(resized_img)
152
  if bboxes is None or len(bboxes) == 0:
 
154
  except Exception as e:
155
  print(f"Face detection failed: {e}")
156
  bboxes = None
 
157
  valid_bboxes = []
158
  if bboxes is not None:
159
+ valid_bboxes = [(int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score) for (x1, y1, x2, y2, score) in bboxes if score > 0.5]
 
 
 
 
 
160
  if not valid_bboxes:
161
  print("Warning: No face detected. Using center crop.")
162
  cx, cy = w // 2, h // 2
 
174
  y1_new = cy - half_side
175
  x2_new = cx + half_side
176
  y2_new = cy + half_side
177
+ if x1_new < 0: x2_new += (0 - x1_new); x1_new = 0
178
+ if y1_new < 0: y2_new += (0 - y1_new); y1_new = 0
179
+ if x2_new > w: x1_new -= (x2_new - w); x2_new = w
180
+ if y2_new > h: y1_new -= (y2_new - h); y2_new = h
181
+ x1_new = max(0, x1_new); y1_new = max(0, y1_new); x2_new = min(w, x2_new); y2_new = min(h, y2_new)
182
+ curr_w = x2_new - x1_new; curr_h = y2_new - y1_new
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  min_side = min(curr_w, curr_h)
184
+ x2_new = x1_new + min_side; y2_new = y1_new + min_side
 
 
185
  crop_img = img_arr[int(y1_new):int(y2_new), int(x1_new):int(x2_new)]
186
  crop_pil = Image.fromarray(crop_img)
187
  return crop_pil.resize((self.opt.input_size, self.opt.input_size))
188
 
189
  def process_audio(self, path: str) -> torch.Tensor:
190
  speech_array, sampling_rate = librosa.load(path, sr=self.sampling_rate)
191
+ return self.wav2vec_preprocessor(speech_array, sampling_rate=sampling_rate, return_tensors='pt').input_values[0]
 
 
 
 
192
 
193
  def crop_video_stable(self, from_mp4_file_path, to_mp4_file_path, expanded_ratio=0.6, skip_per_frame=1):
194
+ if os.path.exists(to_mp4_file_path): os.remove(to_mp4_file_path)
 
 
195
  video = cv2.VideoCapture(from_mp4_file_path)
196
  index = 0
197
  bboxes_lists = []
198
  width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
199
  height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
200
  print(f"Analyzing video for stable cropping: {from_mp4_file_path}")
 
201
  while video.isOpened():
202
  success = video.grab()
203
+ if not success: break
 
204
  if index % skip_per_frame == 0:
205
  success, frame = video.retrieve()
206
+ if not success: break
 
207
  h, w = frame.shape[:2]
208
  mult = 360.0 / h
209
+ resized_frame = cv2.resize(frame, dsize=(0, 0), fx=mult, fy=mult, interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC)
210
+ try: detected_bboxes = self.fa.face_detector.detect_from_image(resized_frame)
211
+ except: detected_bboxes = None
 
 
 
 
 
 
212
  current_frame_bboxes = []
213
  if detected_bboxes is not None:
214
  for d_box in detected_bboxes:
215
  bx1, by1, bx2, by2, score = d_box
216
+ if score > 0.5: current_frame_bboxes.append([int(bx1 / mult), int(by1 / mult), int(bx2 / mult), int(by2 / mult), score])
 
 
 
 
 
217
  if len(current_frame_bboxes) > 0:
218
  max_bboxes = max(current_frame_bboxes, key=lambda bbox: bbox[2] - bbox[0])
219
  bboxes_lists.append(max_bboxes)
220
  index += 1
221
  video.release()
 
222
  x_center_lists, y_center_lists, width_lists, height_lists = [], [], [], []
223
  for bbox in bboxes_lists:
224
  x1, y1, x2, y2 = bbox[:4]
 
227
  y_center_lists.append(y_center)
228
  width_lists.append(x2 - x1)
229
  height_lists.append(y2 - y1)
 
230
  if not (x_center_lists and y_center_lists and width_lists and height_lists):
231
  import shutil
232
  shutil.copy(from_mp4_file_path, to_mp4_file_path)
233
  return
 
234
  x_center = sorted(x_center_lists)[len(x_center_lists) // 2]
235
  y_center = sorted(y_center_lists)[len(y_center_lists) // 2]
236
  median_width = sorted(width_lists)[len(width_lists) // 2]
237
  median_height = sorted(height_lists)[len(height_lists) // 2]
 
238
  expanded_width = int(median_width * (1 + expanded_ratio))
239
  expanded_height = int(median_height * (1 + expanded_ratio))
240
  fixed_cropped_width = min(max(expanded_width, expanded_height), width, height)
 
241
  x1, y1 = int(x_center - fixed_cropped_width / 2), int(y_center - fixed_cropped_width / 2)
242
+ x1 = max(0, x1); y1 = max(0, y1)
 
243
  if x1 + fixed_cropped_width > width: x1 = width - fixed_cropped_width
244
  if y1 + fixed_cropped_width > height: y1 = height - fixed_cropped_width
 
245
  target_size = self.opt.input_size
246
+ cmd = (f'ffmpeg -i "{from_mp4_file_path}" -filter:v "crop={fixed_cropped_width}:{fixed_cropped_width}:{x1}:{y1},scale={target_size}:{target_size}:flags=lanczos" -c:v libx264 -crf 18 -preset slow -c:a aac -b:a 128k "{to_mp4_file_path}" -y -loglevel error')
 
 
 
 
 
 
 
 
247
  if os.system(cmd) != 0:
248
  import shutil
249
  shutil.copy(from_mp4_file_path, to_mp4_file_path)
 
252
  def __init__(self, opt):
253
  torch.cuda.empty_cache()
254
  self.opt = opt
255
+ self.device = opt.device # 默认为 cpu,防止启动时崩溃
256
  self.data_processor = DataProcessor(opt)
 
257
  print("Loading Models...")
258
  self.renderer = IMTRenderer(self.opt).to(self.device)
259
  self.generator = FMGenerator(self.opt).to(self.device)
 
 
260
  if not os.path.exists(self.opt.renderer_path) or not os.path.exists(self.opt.generator_path):
261
  raise FileNotFoundError("Checkpoints not found even after download attempt.")
 
262
  self._load_ckpt(self.renderer, self.opt.renderer_path, "gen.")
263
  self._load_fm_ckpt(self.generator, self.opt.generator_path)
 
264
  self.renderer.eval()
265
  self.generator.eval()
266
 
267
+ # 关键:ZeroGPU 需要在函数内部动态将模型移动到 CUDA
268
+ def to(self, device):
269
+ if self.device != device:
270
+ print(f"Moving models to {device}...")
271
+ self.device = device
272
+ self.renderer = self.renderer.to(device)
273
+ self.generator = self.generator.to(device)
274
+
275
  def _load_ckpt(self, model, path, prefix="gen."):
276
  if not os.path.exists(path):
277
  print(f"Warning: Checkpoint {path} not found.")
 
302
  vid = (vid + 1) / 2
303
  vid = np.clip(vid, 0, 1)
304
  vid = (vid * 255).astype(np.uint8)
 
305
  height, width = vid.shape[1], vid.shape[2]
306
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
307
  writer = cv2.VideoWriter(raw_path, fourcc, fps, (width, height))
308
  for frame in vid:
309
  writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
310
  writer.release()
 
311
  if audio_path:
312
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_out:
313
  final_path = tmp_out.name
 
314
  cmd = f"ffmpeg -y -i {raw_path} -i {audio_path} -c:v copy -c:a aac -shortest {final_path}"
315
  subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
316
+ if os.path.exists(raw_path): os.remove(raw_path)
 
317
  return final_path
318
  else:
319
  return raw_path
 
350
  t_r = self.renderer.mot_encode(s_tensor)
351
  ta_r = self.renderer.adapt(t_r, i_r)
352
  ma_r = self.renderer.mot_decode(ta_r)
 
353
  final_driving_path = driving_video_path
354
  temp_crop_video = None
355
  if crop:
356
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: temp_crop_video = tmp.name
 
357
  self.data_processor.crop_video_stable(driving_video_path, temp_crop_video)
358
  final_driving_path = temp_crop_video
 
359
  cap = cv2.VideoCapture(final_driving_path)
360
  fps = cap.get(cv2.CAP_PROP_FPS)
361
  vid_results = []
 
371
  out = self.renderer.decode(ma_c, ma_r, f_r)
372
  vid_results.append(out.cpu())
373
  cap.release()
374
+ if temp_crop_video and os.path.exists(temp_crop_video): os.remove(temp_crop_video)
375
+ if not vid_results: raise Exception("Driving video reading failed.")
 
 
376
  vid_tensor = torch.cat(vid_results, dim=0)
377
  return self.save_video(vid_tensor, fps=fps, audio_path=None)
378
 
 
381
  agent = None
382
 
383
  try:
 
384
  if os.path.exists(cfg.renderer_path) and os.path.exists(cfg.generator_path):
385
  agent = InferenceAgent(cfg)
386
  else:
 
390
  import traceback
391
  traceback.print_exc()
392
 
393
+ # 添加 @spaces.GPU 装饰器,必须添加!
394
+ @spaces.GPU
395
  def fn_audio_driven(image, audio, crop, seed, nfe, cfg_scale, progress=gr.Progress()):
396
  if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
397
  if image is None or audio is None: raise gr.Error("Missing image or audio.")
398
+
399
+ # 动态移动模型到 GPU
400
+ if torch.cuda.is_available():
401
+ agent.to("cuda")
402
+
403
  img_pil = Image.fromarray(image).convert('RGB')
404
  try:
405
  return agent.run_audio_inference(img_pil, audio, crop, int(seed), int(nfe), float(cfg_scale))
406
  except Exception as e:
407
  raise gr.Error(f"Error: {e}")
408
 
409
+ # 添加 @spaces.GPU 装饰器,必须添加!
410
+ @spaces.GPU
411
  def fn_video_driven(source_image, driving_video, crop, progress=gr.Progress()):
412
  if agent is None: raise gr.Error("Models not loaded properly. Check logs.")
413
  if source_image is None or driving_video is None: raise gr.Error("Missing inputs.")
414
+
415
+ # 动态移动模型到 GPU
416
+ if torch.cuda.is_available():
417
+ agent.to("cuda")
418
+
419
  img_pil = Image.fromarray(source_image).convert('RGB')
420
  try:
421
  return agent.run_video_inference(img_pil, driving_video, crop)
 
424
  traceback.print_exc()
425
  raise gr.Error(f"Error: {e}")
426
 
427
+ # Gradio 4.x 语法:去除了 css,使用 sources=["upload"]
428
  with gr.Blocks(title="IMTalker Demo") as demo:
429
  gr.Markdown("# 🗣️ IMTalker: Efficient Audio-driven Talking Face Generation")
 
430
  with gr.Tabs():
431
  with gr.TabItem("Audio Driven"):
432
  with gr.Row():
 
447
  with gr.Row():
448
  with gr.Column():
449
  v_img = gr.Image(label="Source Image", type="numpy")
450
+ # Gradio 4.x 语法
451
+ v_vid = gr.Video(label="Driving Video", sources=["upload"])
452
  v_crop = gr.Checkbox(label="Auto Crop (Both Source & Driving)", value=True)
453
  v_btn = gr.Button("Generate (Video Driven)", variant="primary")
454
  with gr.Column():
 
456
  v_btn.click(fn_video_driven, [v_img, v_vid, v_crop], v_out)
457
 
458
  if __name__ == "__main__":
 
459
  demo.queue().launch()