import os import torch import torch.distributed as dist from torch.utils.data import DistributedSampler from safetensors.torch import load_file import random import glob import numpy as np from torchvision.transforms.functional import resize from torchvision.transforms import InterpolationMode from torchvision.utils import save_image def init_distributed(args): """初始化分布式训练环境""" if not dist.is_initialized(): dist.init_process_group(backend="nccl") global_rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) sp_size = 1 dp_size = world_size if args.use_usp: sp_size = args.ulysses_degree * args.ring_degree dp_size = world_size // sp_size assert sp_size <= world_size, f"sequence parallel size ({sp_size}) must be less than or equal to world size ({world_size})." assert world_size % sp_size == 0, f"world size ({world_size}) must be divisible by sequence parallel size ({sp_size})." from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment, get_data_parallel_rank init_distributed_environment(rank=global_rank, world_size=world_size) initialize_model_parallel( data_parallel_degree=dp_size, sequence_parallel_degree=sp_size, ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree, ) if dist.is_initialized(): sp_group_id = get_data_parallel_rank() if global_rank == 0: base_seed = torch.randint(0, 1000000, (1,)).item() else: base_seed = 0 base_seed = torch.tensor([base_seed], device="cuda") dist.broadcast(base_seed, src=0) base_seed = base_seed.item() seed = base_seed + sp_group_id torch.manual_seed(seed) return global_rank, world_size, local_rank, sp_size, dp_size def custom_collate_fn(batch): """自定义数据整理函数,用于处理不同长度的视频""" keys = batch[0].keys() data = {} batch_video_length = min([item["video"].shape[1] for item in batch]) for key in keys: if key=='video': # 截断视频长度 data[key] = [item[key][:,:batch_video_length] for item in batch] elif key == 'audio_emb': data[key] = [item[key][:batch_video_length] for item in batch] elif key == 'audio_embed_speaker1': data[key] = [item[key][:batch_video_length] for item in batch] elif key == 'audio_embed_speaker2': data[key] = [item[key][:batch_video_length] for item in batch] else: data[key] = [item[key] for item in batch] return data def get_distributed_sampler(dataset, args, global_rank, world_size, sp_size, dp_size): """获取分布式采样器""" if dp_size == 1: sampler = None else: if args.use_usp: from xfuser.core.distributed import get_data_parallel_rank dp_rank = get_data_parallel_rank() else: dp_rank = global_rank # 避免调用 _DP 相关内容 sampler = DistributedSampler( dataset, num_replicas=dp_size, rank=dp_rank, shuffle=False ) if global_rank == 0: print(f"Using DistributedSampler: dp_size={dp_size}, dp_rank={dp_rank}") return sampler def load_predefined_prompt_embeddings(base_dir="./multi_person/prompt_emb_concat", text_len=512): """ 加载三种预定义的prompt embeddings Args: base_dir: 预定义embedding的基础目录 text_len: 文本长度,用于padding或截断 Returns: dict: 包含三种类型embedding的字典 { 'talk_prompts': list of tensors, 'silent_prompts_left': list of tensors, 'silent_prompts_right': list of tensors } """ prompt_embeddings = { 'talk_prompts': [], 'silent_prompts_left': [], 'silent_prompts_right': [] } # 定义三个子目录 subdirs = { 'talk_prompts': 'talk_prompts', 'silent_prompts_left': 'silent_prompts_left', 'silent_prompts_right': 'silent_prompts_right' } for key, subdir in subdirs.items(): dir_path = os.path.join(base_dir, subdir) if not os.path.exists(dir_path): print(f"Warning: Directory {dir_path} does not exist") continue # 获取所有.safetensors文件 pattern = os.path.join(dir_path, "*.safetensors") files = sorted(glob.glob(pattern)) for file_path in files: try: # 加载embedding prompt_data = load_file(file_path) prompt_emb = prompt_data['context'] # 处理长度 if prompt_emb.shape[0] < text_len: padding = torch.zeros(text_len - prompt_emb.shape[0], prompt_emb.shape[1]) prompt_emb = torch.cat([prompt_emb, padding], dim=0) else: prompt_emb = prompt_emb[:text_len] prompt_embeddings[key].append(prompt_emb) except Exception as e: print(f"Error loading {file_path}: {e}") continue print(f"Loaded {len(prompt_embeddings['talk_prompts'])} talk prompts") print(f"Loaded {len(prompt_embeddings['silent_prompts_left'])} silent left prompts") print(f"Loaded {len(prompt_embeddings['silent_prompts_right'])} silent right prompts") return prompt_embeddings def get_random_prompt_embedding(prompt_embeddings, prompt_type=None, device=None, dtype=None): """ 从预定义的prompt embeddings中随机选择一个 Args: prompt_embeddings: 由load_predefined_prompt_embeddings返回的字典 prompt_type: 指定类型 ('talk_prompts', 'silent_prompts_left', 'silent_prompts_right') 如果为None,则随机选择类型 device: 目标设备 dtype: 目标数据类型 Returns: torch.Tensor: 随机选择的prompt embedding """ if prompt_type is None: # 随机选择类型 available_types = [k for k, v in prompt_embeddings.items() if len(v) > 0] if not available_types: raise ValueError("No prompt embeddings available") prompt_type = random.choice(available_types) if prompt_type not in prompt_embeddings: raise ValueError(f"Unknown prompt type: {prompt_type}") if len(prompt_embeddings[prompt_type]) == 0: raise ValueError(f"No embeddings available for type: {prompt_type}") # 随机选择一个embedding selected_embedding = random.choice(prompt_embeddings[prompt_type]) # 移动到指定设备和类型 if device is not None: selected_embedding = selected_embedding.to(device) if dtype is not None: selected_embedding = selected_embedding.to(dtype) return selected_embedding def create_silence_video(video_tensor, cycle_frames=3): """ 创建沉默视频:使用前N帧循环播放,支持倒序播放模式 Args: video_tensor: [b, 3, F, H, W] 原始视频tensor cycle_frames: int, 循环播放的帧数,默认为2 Returns: silence_video: [b, 3, F, H, W] 沉默视频tensor 播放模式: - cycle_frames=2: 12121212... (原逻辑) - cycle_frames=3: 123123123... - cycle_frames=4: 1234543212345... (倒序模式) - cycle_frames=5: 1234543212345... """ batch_size, channels, num_frames, height, width = video_tensor.shape # 确保cycle_frames不超过视频总帧数 cycle_frames = min(cycle_frames, num_frames) # 获取前cycle_frames帧 cycle_video = video_tensor[:, :, :cycle_frames, :, :] # [b, 3, cycle_frames, H, W] if cycle_frames <= 2: # 对于1-2帧,使用简单的交替模式 if cycle_frames == 1: # 单帧重复 silence_video = cycle_video.repeat(1, 1, num_frames, 1, 1) else: # 双帧交替:12121212... frame_1 = cycle_video[:, :, 0:1, :, :] # [b, 3, 1, H, W] frame_2 = cycle_video[:, :, 1:2, :, :] # [b, 3, 1, H, W] repeat_times = (num_frames + 1) // 2 frame_1_repeated = frame_1.repeat(1, 1, repeat_times, 1, 1) frame_2_repeated = frame_2.repeat(1, 1, repeat_times, 1, 1) silence_frames = torch.stack([frame_1_repeated, frame_2_repeated], dim=3) silence_frames = silence_frames.view(batch_size, channels, repeat_times * 2, height, width) silence_video = silence_frames[:, :, :num_frames, :, :] else: # 对于3帧以上,使用倒序模式:1234543212345... # 创建一个完整的循环周期:123...cycle_frames...321 forward_frames = cycle_video # [b, 3, cycle_frames, H, W] reverse_frames = torch.flip(cycle_video[:, :, 1:-1, :, :], dims=[2]) # [b, 3, cycle_frames-2, H, W] # 拼接一个完整周期:123...cycle_frames...321 one_cycle = torch.cat([forward_frames, reverse_frames], dim=2) # [b, 3, 2*cycle_frames-2, H, W] cycle_length = 2 * cycle_frames - 2 # 计算需要多少个完整周期 num_cycles = (num_frames + cycle_length - 1) // cycle_length # 重复完整周期 repeated_cycles = one_cycle.repeat(1, 1, num_cycles, 1, 1) # [b, 3, num_cycles*cycle_length, H, W] # 截断到所需帧数 silence_video = repeated_cycles[:, :, :num_frames, :, :] # [b, 3, F, H, W] return silence_video def extract_square_faces_from_ref_images(face_parser, ref_images, video_paths, first_frame_faces, crop_size=224, device=None, torch_dtype=None, global_rank=0, current_global_step=0): """ 使用人脸解析器从参考图像中提取方形人脸,如果检测不到人脸则回退到dataset提供的crop人脸 Args: face_parser: FaceInference实例 ref_images: list of tensors, 每个tensor shape: [3, H, W] video_paths: list of str, 对应的视频路径 first_frame_faces: tensor, dataset提供的crop人脸 [b, 3, face_H, face_W] crop_size: int, 裁剪后的人脸尺寸 device: torch设备 torch_dtype: torch数据类型 global_rank: 全局rank,用于控制日志输出 current_global_step: 当前训练步数 Returns: list of tensors: 每个tensor shape: [3, crop_size, crop_size] """ square_faces = [] for i, (ref_image, video_path) in enumerate(zip(ref_images, video_paths)): try: # 将tensor转换为numpy数组用于人脸检测 # ref_image shape: [3, H, W], 需要转换为 [H, W, 3] # 确保转换为float32类型,避免BFloat16问题 ref_image_np = ref_image.permute(1, 2, 0).cpu().float().numpy() # 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用 if ref_image_np.min() < 0: # 从[-1,1]转换到[0,1] ref_image_np = (ref_image_np + 1) / 2 ref_image_np = np.clip(ref_image_np, 0, 1) # 转换为uint8格式 ref_image_np = (ref_image_np * 255).astype(np.uint8) # # 添加调试信息 # if global_rank == 0 and current_global_step % 100 == 0: # print(f"[Face Parser] 原始ref_image值域: [{ref_image.min():.4f}, {ref_image.max():.4f}]") # print(f"[Face Parser] 转换后numpy数组值域: [{ref_image_np.min()}, {ref_image_np.max()}]") # print(f"[Face Parser] 转换后numpy数组形状: {ref_image_np.shape}") # 使用人脸解析器检测人脸 face_result = face_parser.infer_from_array(ref_image_np, n=1) # 只取最大的人脸 if face_result and 'bboxes' in face_result and len(face_result['bboxes']) > 0: # 获取第一个(最大的)人脸的bbox bbox = face_result['bboxes'][0] # [x, y, width, height] x, y, w, h = bbox # 从原图中裁剪人脸区域 face_crop = ref_image_np[int(y):int(y+h), int(x):int(x+w)] # # 添加调试信息 # if global_rank == 0 and current_global_step % 100 == 0: # print(f"[Face Parser] 成功提取人脸 {i}: bbox=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]") # print(f"[Face Parser] 裁剪后的人脸形状: {face_crop.shape}") # print(f"[Face Parser] 裁剪后的人脸值域: [{face_crop.min()}, {face_crop.max()}]") # 将numpy数组转换回tensor并调整尺寸 face_tensor = torch.from_numpy(face_crop).permute(2, 0, 1).float() / 255.0 face_tensor = resize(face_tensor, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR) # # 添加更多调试信息 # if global_rank == 0 and current_global_step % 100 == 0: # print(f"[Face Parser] 调整尺寸后的tensor形状: {face_tensor.shape}") # print(f"[Face Parser] 调整尺寸后的tensor值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]") square_faces.append(face_tensor.to(device, dtype=torch_dtype)) else: # 如果没有检测到人脸,回退到使用dataset提供的crop人脸 dataset_face = first_frame_faces[i] # [3, face_H, face_W] # 添加调试信息 if global_rank == 0 : print(f"[Face Parser] 未检测到人脸 {i},回退到dataset提供的crop人脸") print(f"[Face Parser] Dataset人脸形状: {dataset_face.shape}") print(f"[Face Parser] Dataset人脸值域: [{dataset_face.min():.4f}, {dataset_face.max():.4f}]") face_tensor = resize(dataset_face, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR) # 添加更多调试信息 if global_rank == 0: print(f"[Face Parser] 调整尺寸后的dataset人脸形状: {face_tensor.shape}") print(f"[Face Parser] 调整尺寸后的dataset人脸值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]") square_faces.append(face_tensor.to(device, dtype=torch_dtype)) except Exception as e: # 异常情况下回退到使用dataset提供的crop人脸 dataset_face = first_frame_faces[i] # [3, face_H, face_W] # 添加调试信息 if global_rank == 0: print(f"[Face Parser] 处理图像 {i} 时出错: {str(e)},回退到dataset提供的crop人脸") print(f"[Face Parser] 异常回退 - Dataset人脸形状: {dataset_face.shape}") print(f"[Face Parser] 异常回退 - Dataset人脸值域: [{dataset_face.min():.4f}, {dataset_face.max():.4f}]") face_tensor = resize(dataset_face, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR) # 添加更多调试信息 if global_rank == 0: print(f"[Face Parser] 异常回退 - 调整尺寸后的tensor形状: {face_tensor.shape}") print(f"[Face Parser] 异常回退 - 调整尺寸后的tensor值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]") square_faces.append(face_tensor.to(device, dtype=torch_dtype)) return square_faces def save_face_parser_debug_images(ref_cropped_list, ref_image, video_paths, first_frame_faces, current_step, crop_image_size, debug_dir="./logs/debug/ref", global_rank=0): """ 保存人脸解析器的调试图像到debug目录 Args: ref_cropped_list: list of tensors, 人脸解析器生成的方形人脸 ref_image: tensor, 原始参考图像 video_paths: list of str, 视频路径 first_frame_faces: tensor, dataset提供的crop人脸 current_step: int, 当前训练步数 crop_image_size: int, 裁剪图像尺寸 debug_dir: str, 调试目录路径 """ try: # 只在rank 0时创建debug目录 if global_rank == 0: os.makedirs(debug_dir, exist_ok=True) os.chmod(debug_dir, 0o777) # 保存每个batch的调试图像 for batch_idx, (ref_cropped, video_path) in enumerate(zip(ref_cropped_list, video_paths)): # 只在rank 0时创建子目录 if global_rank == 0: batch_dir = os.path.join(debug_dir, f"step_{current_step}_batch_{batch_idx}") os.makedirs(batch_dir, exist_ok=True) os.chmod(batch_dir, 0o777) # 保存原始参考图像 if ref_image is not None: ref_img_single = ref_image[batch_idx] # [3, H, W] # 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用 if ref_img_single.min() < 0: # 从[-1,1]转换到[0,1] ref_img_single = (ref_img_single + 1) / 2 ref_img_single = torch.clamp(ref_img_single, 0, 1) ref_img_path = os.path.join(batch_dir, "original_ref_image.png") save_image(ref_img_single, ref_img_path) # 保存dataset提供的原始crop人脸 if first_frame_faces is not None: dataset_face_single = first_frame_faces[batch_idx] # [3, face_H, face_W] # 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用 if dataset_face_single.min() < 0: # 从[-1,1]转换到[0,1] dataset_face_single = (dataset_face_single + 1) / 2 dataset_face_single = torch.clamp(dataset_face_single, 0, 1) dataset_face_path = os.path.join(batch_dir, "dataset_crop_face.png") save_image(dataset_face_single, dataset_face_path) # 保存人脸解析器生成的方形人脸 if ref_cropped is not None: ref_cropped_single = ref_cropped[batch_idx] # [3, crop_size, crop_size] # 添加调试信息 print(f"[Debug] Parsed face tensor shape: {ref_cropped_single.shape}") print(f"[Debug] Parsed face tensor dtype: {ref_cropped_single.dtype}") print(f"[Debug] Parsed face tensor device: {ref_cropped_single.device}") print(f"[Debug] Parsed face tensor stats: min={ref_cropped_single.min():.4f}, max={ref_cropped_single.max():.4f}, mean={ref_cropped_single.mean():.4f}, std={ref_cropped_single.std():.4f}") # 检查是否有异常值 if torch.isnan(ref_cropped_single).any(): print(f"[Debug] Warning: Parsed face contains NaN values!") if torch.isinf(ref_cropped_single).any(): print(f"[Debug] Warning: Parsed face contains Inf values!") # 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用 if ref_cropped_single.min() < 0: # 从[-1,1]转换到[0,1] ref_cropped_single = (ref_cropped_single + 1) / 2 ref_cropped_single = torch.clamp(ref_cropped_single, 0, 1) # 保存处理后的图像 face_path = os.path.join(batch_dir, "parsed_square_face.png") save_image(ref_cropped_single, face_path) # 额外保存一个未处理的版本用于对比 raw_face_path = os.path.join(batch_dir, "parsed_square_face_raw.png") save_image(ref_cropped[batch_idx], raw_face_path) # 保存视频路径信息 info_path = os.path.join(batch_dir, "info.txt") with open(info_path, 'w') as f: f.write(f"Video Path: {video_path}\n") f.write(f"Step: {current_step}\n") f.write(f"Batch Index: {batch_idx}\n") f.write(f"Face Size: {crop_image_size}x{crop_image_size}\n") if first_frame_faces is not None: f.write(f"Dataset Face Size: {first_frame_faces[batch_idx].shape[1]}x{first_frame_faces[batch_idx].shape[2]}\n") # 添加tensor值域信息 if ref_image is not None: ref_img_single = ref_image[batch_idx] f.write(f"Original Ref Image Range: [{ref_img_single.min():.4f}, {ref_img_single.max():.4f}]\n") if first_frame_faces is not None: dataset_face_single = first_frame_faces[batch_idx] f.write(f"Dataset Face Range: [{dataset_face_single.min():.4f}, {dataset_face_single.max():.4f}]\n") if ref_cropped is not None: ref_cropped_single = ref_cropped[batch_idx] f.write(f"Parsed Face Range: [{ref_cropped_single.min():.4f}, {ref_cropped_single.max():.4f}]\n") print(f"[Debug] 保存人脸解析器调试图像到: {batch_dir}") except Exception as e: print(f"[Debug] 保存人脸解析器调试图像时出错: {str(e)}")