import os import sys import tempfile import subprocess import numpy as np import cv2 import torch import torchvision import librosa import face_alignment import gradio as gr from PIL import Image import torchvision.transforms as transforms from transformers import Wav2Vec2FeatureExtractor from tqdm import tqdm import random from huggingface_hub import hf_hub_download # 引入 spaces,用于 ZeroGPU 支持 import spaces sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 尝试导入本地模块 try: from generator.FM import FMGenerator from renderer.models import IMTRenderer except ImportError as e: print(f"Import Error: {e}") print("Please ensure 'generator' and 'renderer' folders are in the same directory.") exit(1) # ========================================== # 自动下载模型权重的逻辑 # ========================================== def ensure_checkpoints(): print("Checking model checkpoints...") REPO_ID = "cbsjtu01/IMTalker" REPO_TYPE = "model" files_to_download = [ "renderer.ckpt", "generator.ckpt", "wav2vec2-base-960h/config.json", "wav2vec2-base-960h/pytorch_model.bin", "wav2vec2-base-960h/preprocessor_config.json", "wav2vec2-base-960h/feature_extractor_config.json", ] TARGET_DIR = "checkpoints" os.makedirs(TARGET_DIR, exist_ok=True) for remote_filename in files_to_download: local_file_path = os.path.join(TARGET_DIR, remote_filename) # 检查文件是否存在且大小正常 (大于 1KB) if not os.path.exists(local_file_path) or os.path.getsize(local_file_path) < 1024: print(f"Downloading {remote_filename} to {TARGET_DIR}...") try: hf_hub_download( repo_id=REPO_ID, filename=remote_filename, repo_type=REPO_TYPE, local_dir=TARGET_DIR, local_dir_use_symlinks=False ) except Exception as e: print(f"Failed to download {remote_filename}: {e}") pass else: print(f"File {local_file_path} already exists. Skipping download.") ensure_checkpoints() class AppConfig: def __init__(self): # 关键:在 ZeroGPU 环境启动时,必须先设为 CPU,不能直接占满显存,否则会被杀掉 self.device = "cpu" self.seed = 42 self.fix_noise_seed = False self.renderer_path = "./checkpoints/renderer.ckpt" self.generator_path = "./checkpoints/generator.ckpt" self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h" self.input_size = 512 self.input_nc = 3 self.fps = 25.0 self.rank = "cuda" self.sampling_rate = 16000 self.audio_marcing = 2 self.wav2vec_sec = 2.0 self.attention_window = 5 self.only_last_features = True self.audio_dropout_prob = 0.1 self.style_dim = 512 self.dim_a = 512 self.dim_h = 512 self.dim_e = 7 self.dim_motion = 32 self.dim_c = 32 self.dim_w = 32 self.fmt_depth = 8 self.num_heads = 8 self.mlp_ratio = 4.0 self.no_learned_pe = False self.num_prev_frames = 10 self.max_grad_norm = 1.0 self.ode_atol = 1e-5 self.ode_rtol = 1e-5 self.nfe = 10 self.torchdiffeq_ode_method = 'euler' self.a_cfg_scale = 3.0 self.swin_res_threshold = 128 self.window_size = 8 self.ref_path = None self.pose_path = None self.gaze_path = None self.aud_path = None self.crop = True self.source_path = None self.driving_path = None class DataProcessor: def __init__(self, opt): self.opt = opt self.fps = opt.fps self.sampling_rate = opt.sampling_rate print(f"Loading Face Alignment (CPU first)...") # 强制使用 CPU 加载 FaceAlignment,避免初始化时占用 GPU self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device='cpu', flip_input=False) print("Loading Wav2Vec2...") local_path = opt.wav2vec_model_path if os.path.exists(local_path) and os.path.exists(os.path.join(local_path, "config.json")): print(f"Loading local wav2vec from {local_path}") self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(local_path, local_files_only=True) else: print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...") self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") self.transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()]) def process_img(self, img: Image.Image) -> Image.Image: img_arr = np.array(img) # 处理灰度图和 RGBA 图 if img_arr.ndim == 2: img_arr = cv2.cvtColor(img_arr, cv2.COLOR_GRAY2RGB) elif img_arr.shape[2] == 4: img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2RGB) h, w = img_arr.shape[:2] try: print("Detecting face on original high-res image...") # 提示一下,因为大图检测会慢一点 bboxes = self.fa.face_detector.detect_from_image(img_arr) if bboxes is None or len(bboxes) == 0: # 如果原图检测失败,可以考虑保留一个回退方案,或者直接报错 print("Face detection failed on original image.") bboxes = None except Exception as e: print(f"Face detection failed: {e}") bboxes = None valid_bboxes = [] if bboxes is not None: valid_bboxes = [(int(x1), int(y1), int(x2), int(y2), score) for (x1, y1, x2, y2, score) in bboxes if score > 0.5] if not valid_bboxes: print("Warning: No face detected. Using center crop.") cx, cy = w // 2, h // 2 # 默认裁剪一个边长为短边一半的正方形 half = min(w, h) // 4 x1_new, x2_new = cx - half, cx + half y1_new, y2_new = cy - half, cy + half else: # 使用置信度最高的人脸 x1, y1, x2, y2, _ = max(valid_bboxes, key=lambda x: x[4]) cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 w_face = x2 - x1 h_face = y2 - y1 # 计算扩边后的正方形边长的一半 # 这里可以调整扩边系数 0.8 来控制裁剪区域大小 half_side = int(max(w_face, h_face) * 0.8) x1_new = cx - half_side y1_new = cy - half_side x2_new = cx + half_side y2_new = cy + half_side # 边界检查和调整,确保不超出图像范围,同时保持正方形 if x1_new < 0: x2_new += (0 - x1_new); x1_new = 0 if y1_new < 0: y2_new += (0 - y1_new); y1_new = 0 if x2_new > w: x1_new -= (x2_new - w); x2_new = w if y2_new > h: y1_new -= (y2_new - h); y2_new = h # 再次确保坐标在有效范围内 x1_new = max(0, x1_new); y1_new = max(0, y1_new); x2_new = min(w, x2_new); y2_new = min(h, y2_new) # 确保裁剪区域是正方形 (以防万一边界调整破坏了正方形) curr_w = x2_new - x1_new curr_h = y2_new - y1_new min_side = min(curr_w, curr_h) # 如果需要严格中心对齐,这里的调整可能需要更复杂一点,简单起见从左上角调整 x2_new = x1_new + min_side y2_new = y1_new + min_side # 在原图上进行裁剪 crop_img = img_arr[int(y1_new):int(y2_new), int(x1_new):int(x2_new)] crop_pil = Image.fromarray(crop_img) resized_pil = crop_pil.resize((self.opt.input_size, self.opt.input_size), Image.LANCZOS) return resized_pil def process_audio(self, path: str) -> torch.Tensor: speech_array, sampling_rate = librosa.load(path, sr=self.sampling_rate) return self.wav2vec_preprocessor(speech_array, sampling_rate=sampling_rate, return_tensors='pt').input_values[0] def crop_video_stable(self, from_mp4_file_path, to_mp4_file_path, expanded_ratio=0.6, skip_per_frame=1): if os.path.exists(to_mp4_file_path): os.remove(to_mp4_file_path) video = cv2.VideoCapture(from_mp4_file_path) index = 0 bboxes_lists = [] width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) print(f"Analyzing video for stable cropping: {from_mp4_file_path}") while video.isOpened(): success = video.grab() if not success: break if index % skip_per_frame == 0: success, frame = video.retrieve() if not success: break h, w = frame.shape[:2] mult = 360.0 / h resized_frame = cv2.resize(frame, dsize=(0, 0), fx=mult, fy=mult, interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC) try: detected_bboxes = self.fa.face_detector.detect_from_image(resized_frame) except: detected_bboxes = None current_frame_bboxes = [] if detected_bboxes is not None: for d_box in detected_bboxes: bx1, by1, bx2, by2, score = d_box if score > 0.5: current_frame_bboxes.append([int(bx1 / mult), int(by1 / mult), int(bx2 / mult), int(by2 / mult), score]) if len(current_frame_bboxes) > 0: max_bboxes = max(current_frame_bboxes, key=lambda bbox: bbox[2] - bbox[0]) bboxes_lists.append(max_bboxes) index += 1 video.release() x_center_lists, y_center_lists, width_lists, height_lists = [], [], [], [] for bbox in bboxes_lists: x1, y1, x2, y2 = bbox[:4] x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2 x_center_lists.append(x_center) y_center_lists.append(y_center) width_lists.append(x2 - x1) height_lists.append(y2 - y1) if not (x_center_lists and y_center_lists and width_lists and height_lists): import shutil shutil.copy(from_mp4_file_path, to_mp4_file_path) return x_center = sorted(x_center_lists)[len(x_center_lists) // 2] y_center = sorted(y_center_lists)[len(y_center_lists) // 2] median_width = sorted(width_lists)[len(width_lists) // 2] median_height = sorted(height_lists)[len(height_lists) // 2] expanded_width = int(median_width * (1 + expanded_ratio)) expanded_height = int(median_height * (1 + expanded_ratio)) fixed_cropped_width = min(max(expanded_width, expanded_height), width, height) x1, y1 = int(x_center - fixed_cropped_width / 2), int(y_center - fixed_cropped_width / 2) x1 = max(0, x1); y1 = max(0, y1) if x1 + fixed_cropped_width > width: x1 = width - fixed_cropped_width if y1 + fixed_cropped_width > height: y1 = height - fixed_cropped_width target_size = self.opt.input_size cmd = (f'ffmpeg -i "{from_mp4_file_path}" -filter:v "crop={fixed_cropped_width}:{fixed_cropped_width}:{x1}:{y1},scale={target_size}:{target_size}:flags=lanczos" -c:v libx264 -crf 18 -preset slow -c:a aac -b:a 128k "{to_mp4_file_path}" -y -loglevel error') if os.system(cmd) != 0: import shutil shutil.copy(from_mp4_file_path, to_mp4_file_path) class InferenceAgent: def __init__(self, opt): torch.cuda.empty_cache() self.opt = opt self.device = opt.device self.data_processor = DataProcessor(opt) print("Loading Models...") self.renderer = IMTRenderer(self.opt).to(self.device) self.generator = FMGenerator(self.opt).to(self.device) if not os.path.exists(self.opt.renderer_path) or not os.path.exists(self.opt.generator_path): raise FileNotFoundError("Checkpoints not found even after download attempt.") self._load_ckpt(self.renderer, self.opt.renderer_path, "gen.") self._load_fm_ckpt(self.generator, self.opt.generator_path) self.renderer.eval() self.generator.eval() # 关键:ZeroGPU 需要在函数内部动态将模型移动到 CUDA def to(self, device): if self.device != device: print(f"Moving models to {device}...") self.device = device self.renderer = self.renderer.to(device) self.generator = self.generator.to(device) def _load_ckpt(self, model, path, prefix="gen."): if not os.path.exists(path): print(f"Warning: Checkpoint {path} not found.") return checkpoint = torch.load(path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) clean_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.startswith(prefix)} model.load_state_dict(clean_state_dict, strict=False) def _load_fm_ckpt(self, model, path): if not os.path.exists(path): return checkpoint = torch.load(path, map_location='cpu') state_dict = checkpoint.get('state_dict', checkpoint) if 'model' in state_dict: state_dict = state_dict['model'] prefix = 'model.' clean_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} with torch.no_grad(): for name, param in model.named_parameters(): if name in clean_dict: param.copy_(clean_dict[name].to(self.device)) def save_video(self, vid_tensor, fps, audio_path=None): with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: raw_path = tmp.name if vid_tensor.dim() == 4: vid = vid_tensor.permute(0, 2, 3, 1).detach().cpu().numpy() if vid.min() < 0: vid = (vid + 1) / 2 vid = np.clip(vid, 0, 1) vid = (vid * 255).astype(np.uint8) height, width = vid.shape[1], vid.shape[2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(raw_path, fourcc, fps, (width, height)) for frame in vid: writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) writer.release() if audio_path: with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_out: final_path = tmp_out.name cmd = f"ffmpeg -y -i {raw_path} -i {audio_path} -c:v copy -c:a aac -shortest {final_path}" subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) if os.path.exists(raw_path): os.remove(raw_path) return final_path else: return raw_path @torch.no_grad() def run_audio_inference(self, img_pil, aud_path, crop, seed, nfe, cfg_scale): s_pil = self.data_processor.process_img(img_pil) if crop else img_pil.resize((self.opt.input_size, self.opt.input_size)) s_tensor = self.data_processor.transform(s_pil).unsqueeze(0).to(self.device) a_tensor = self.data_processor.process_audio(aud_path).unsqueeze(0).to(self.device) data = {'s': s_tensor, 'a': a_tensor, 'pose': None, 'cam': None, 'gaze': None, 'ref_x': None} f_r, g_r = self.renderer.dense_feature_encoder(s_tensor) t_lat = self.renderer.latent_token_encoder(s_tensor) if isinstance(t_lat, tuple): t_lat = t_lat[0] data['ref_x'] = t_lat torch.manual_seed(seed) sample = self.generator.sample(data, a_cfg_scale=cfg_scale, nfe=nfe, seed=seed) d_hat = [] T = sample.shape[1] ta_r = self.renderer.adapt(t_lat, g_r) m_r = self.renderer.latent_token_decoder(ta_r) for t in range(T): ta_c = self.renderer.adapt(sample[:, t, ...], g_r) m_c = self.renderer.latent_token_decoder(ta_c) out_frame = self.renderer.decode(m_c, m_r, f_r) d_hat.append(out_frame) vid_tensor = torch.stack(d_hat, dim=1).squeeze(0) return self.save_video(vid_tensor, self.opt.fps, aud_path) @torch.no_grad() def run_video_inference(self, source_img_pil, driving_video_path, crop): s_pil = self.data_processor.process_img(source_img_pil) if crop else source_img_pil.resize((self.opt.input_size, self.opt.input_size)) s_tensor = self.data_processor.transform(s_pil).unsqueeze(0).to(self.device) f_r, i_r = self.renderer.app_encode(s_tensor) t_r = self.renderer.mot_encode(s_tensor) ta_r = self.renderer.adapt(t_r, i_r) ma_r = self.renderer.mot_decode(ta_r) final_driving_path = driving_video_path temp_crop_video = None if crop: with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp: temp_crop_video = tmp.name self.data_processor.crop_video_stable(driving_video_path, temp_crop_video) final_driving_path = temp_crop_video cap = cv2.VideoCapture(final_driving_path) fps = cap.get(cv2.CAP_PROP_FPS) vid_results = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame).resize((self.opt.input_size, self.opt.input_size)) d_tensor = self.data_processor.transform(frame_pil).unsqueeze(0).to(self.device) t_c = self.renderer.mot_encode(d_tensor) ta_c = self.renderer.adapt(t_c, i_r) ma_c = self.renderer.mot_decode(ta_c) out = self.renderer.decode(ma_c, ma_r, f_r) vid_results.append(out.cpu()) cap.release() if temp_crop_video and os.path.exists(temp_crop_video): os.remove(temp_crop_video) if not vid_results: raise Exception("Driving video reading failed.") vid_tensor = torch.cat(vid_results, dim=0) return self.save_video(vid_tensor, fps=fps, audio_path=driving_video_path) print("Initializing Configuration...") cfg = AppConfig() agent = None try: if os.path.exists(cfg.renderer_path) and os.path.exists(cfg.generator_path): agent = InferenceAgent(cfg) else: print("Error: Checkpoints not found. Please upload 'renderer.ckpt' and 'generator.ckpt' via the Files tab.") except Exception as e: print(f"Initialization Error: {e}") import traceback traceback.print_exc() # 添加 @spaces.GPU 装饰器,必须添加! @spaces.GPU def fn_audio_driven(image, audio, crop, seed, nfe, cfg_scale, progress=gr.Progress()): if agent is None: raise gr.Error("Models not loaded properly. Check logs.") if image is None or audio is None: raise gr.Error("Missing image or audio.") # 动态移动模型到 GPU if torch.cuda.is_available(): agent.to("cuda") img_pil = Image.fromarray(image).convert('RGB') try: return agent.run_audio_inference(img_pil, audio, crop, int(seed), int(nfe), float(cfg_scale)) except Exception as e: raise gr.Error(f"Error: {e}") # 添加 @spaces.GPU 装饰器,必须添加! @spaces.GPU def fn_video_driven(source_image, driving_video, crop, progress=gr.Progress()): if agent is None: raise gr.Error("Models not loaded properly. Check logs.") if source_image is None or driving_video is None: raise gr.Error("Missing inputs.") # 动态移动模型到 GPU if torch.cuda.is_available(): agent.to("cuda") img_pil = Image.fromarray(source_image).convert('RGB') try: return agent.run_video_inference(img_pil, driving_video, crop) except Exception as e: import traceback traceback.print_exc() raise gr.Error(f"Error: {e}") with gr.Blocks(title="IMTalker Demo") as demo: gr.Markdown("# 🗣️ IMTalker: Efficient Audio-driven Talking Face Generation") # 最佳实践说明 with gr.Accordion("💡 Best Practices (Click to read)", open=False): gr.Markdown(""" To obtain the highest quality generation results, we recommend following these guidelines: 1. **Input Image Composition**: Please ensure the input image features the person's head as the primary subject. Since our model is explicitly trained on facial data, it does not support full-body video generation. * The inference pipeline automatically **crops the input image** to focus on the face by default. * **Note on Resolution**: The model generates video at a fixed resolution of **512×512**. Using extremely high-resolution inputs will result in downscaling, so prioritize facial clarity over raw image dimensions. 2. **Audio Selection**: Our model was trained primarily on **English datasets**. Consequently, we recommend using **English audio** inputs to achieve the best lip-synchronization performance and naturalness. 3. **Background Quality**: We strongly recommend using source images with **solid colored** or **blurred (bokeh)** backgrounds. Complex or highly detailed backgrounds may lead to visual artifacts or jitter in the generated video. """) with gr.Tabs(): # ========================== # Tab 1: Audio Driven # ========================== with gr.TabItem("Audio Driven"): with gr.Row(): with gr.Column(): # 1. 图片输入 a_img = gr.Image(label="Source Image", type="numpy", height=512, width=512) # --- 图片示例 (独立) --- # 请确保 examples 文件夹下有对应的 source_x.png 文件 gr.Examples( examples=[ ["examples/source_1.png"], ["examples/source_2.png"], ["examples/source_3.jpg"], ["examples/source_4.png"], ["examples/source_5.png"], ["examples/source_6.png"], ], inputs=[a_img], label="Example Images", cache_examples=False, ) # 2. 音频输入 a_aud = gr.Audio(label="Driving Audio", type="filepath") # --- 音频示例 (独立) --- # 请确保 examples 文件夹下有对应的 audio_x.wav 文件 gr.Examples( examples=[ ["examples/audio_1.wav"], ["examples/audio_2.wav"], ["examples/audio_3.wav"], ["examples/audio_4.wav"], ["examples/audio_5.wav"], ], inputs=[a_aud], label="Example Audios", cache_examples=False, ) with gr.Accordion("Settings", open=True): a_crop = gr.Checkbox(label="Auto Crop Face", value=False) a_seed = gr.Number(label="Seed", value=42) a_nfe = gr.Slider(5, 50, value=10, step=1, label="Steps (NFE)") a_cfg = gr.Slider(1.0, 5.0, value=2.0, label="CFG Scale") a_btn = gr.Button("Generate (Audio Driven)", variant="primary") with gr.Column(): a_out = gr.Video(label="Result", height=512, width=512) a_btn.click(fn_audio_driven, [a_img, a_aud, a_crop, a_seed, a_nfe, a_cfg], a_out) # ========================== # Tab 2: Video Driven # ========================== with gr.TabItem("Video Driven"): with gr.Row(): with gr.Column(): # 1. 图片输入 v_img = gr.Image(label="Source Image", type="numpy", height=512, width=512) # --- 图片示例 (独立) --- gr.Examples( examples=[ ["examples/source_7.png"], ["examples/source_8.png"], ["examples/source_9.jpg"], ["examples/source_10.png"], ["examples/source_11.png"], ], inputs=[v_img], label="Example Images", cache_examples=False, ) # 2. 视频输入 v_vid = gr.Video(label="Driving Video", sources=["upload"], height=512, width=512) # --- 视频示例 (独立) --- # 请确保 examples 文件夹下有对应的 driving_x.mp4 文件 gr.Examples( examples=[ ["examples/driving_1.mp4"], ["examples/driving_2.mp4"], ["examples/driving_3.mp4"], ["examples/driving_4.mp4"], ["examples/driving_5.mp4"], ], inputs=[v_vid], label="Example Videos", cache_examples=False, ) v_crop = gr.Checkbox(label="Auto Crop (Both Source & Driving)", value=True) v_btn = gr.Button("Generate (Video Driven)", variant="primary") with gr.Column(): v_out = gr.Video(label="Result", height=512, width=512) v_btn.click(fn_video_driven, [v_img, v_vid, v_crop], v_out) if __name__ == "__main__": demo.queue().launch()