MultiPerson / utils /training_utils.py
zzz66's picture
Initial commit with LFS
400a879
import os
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from safetensors.torch import load_file
import random
import glob
import numpy as np
from torchvision.transforms.functional import resize
from torchvision.transforms import InterpolationMode
from torchvision.utils import save_image
def init_distributed(args):
"""初始化分布式训练环境"""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
global_rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
sp_size = 1
dp_size = world_size
if args.use_usp:
sp_size = args.ulysses_degree * args.ring_degree
dp_size = world_size // sp_size
assert sp_size <= world_size, f"sequence parallel size ({sp_size}) must be less than or equal to world size ({world_size})."
assert world_size % sp_size == 0, f"world size ({world_size}) must be divisible by sequence parallel size ({sp_size})."
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment, get_data_parallel_rank
init_distributed_environment(rank=global_rank, world_size=world_size)
initialize_model_parallel(
data_parallel_degree=dp_size,
sequence_parallel_degree=sp_size,
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
if dist.is_initialized():
sp_group_id = get_data_parallel_rank()
if global_rank == 0:
base_seed = torch.randint(0, 1000000, (1,)).item()
else:
base_seed = 0
base_seed = torch.tensor([base_seed], device="cuda")
dist.broadcast(base_seed, src=0)
base_seed = base_seed.item()
seed = base_seed + sp_group_id
torch.manual_seed(seed)
return global_rank, world_size, local_rank, sp_size, dp_size
def custom_collate_fn(batch):
"""自定义数据整理函数,用于处理不同长度的视频"""
keys = batch[0].keys()
data = {}
batch_video_length = min([item["video"].shape[1] for item in batch])
for key in keys:
if key=='video': # 截断视频长度
data[key] = [item[key][:,:batch_video_length] for item in batch]
elif key == 'audio_emb':
data[key] = [item[key][:batch_video_length] for item in batch]
elif key == 'audio_embed_speaker1':
data[key] = [item[key][:batch_video_length] for item in batch]
elif key == 'audio_embed_speaker2':
data[key] = [item[key][:batch_video_length] for item in batch]
else:
data[key] = [item[key] for item in batch]
return data
def get_distributed_sampler(dataset, args, global_rank, world_size, sp_size, dp_size):
"""获取分布式采样器"""
if dp_size == 1:
sampler = None
else:
if args.use_usp:
from xfuser.core.distributed import get_data_parallel_rank
dp_rank = get_data_parallel_rank()
else:
dp_rank = global_rank # 避免调用 _DP 相关内容
sampler = DistributedSampler(
dataset,
num_replicas=dp_size,
rank=dp_rank,
shuffle=False
)
if global_rank == 0:
print(f"Using DistributedSampler: dp_size={dp_size}, dp_rank={dp_rank}")
return sampler
def load_predefined_prompt_embeddings(base_dir="./multi_person/prompt_emb_concat", text_len=512):
"""
加载三种预定义的prompt embeddings
Args:
base_dir: 预定义embedding的基础目录
text_len: 文本长度,用于padding或截断
Returns:
dict: 包含三种类型embedding的字典
{
'talk_prompts': list of tensors,
'silent_prompts_left': list of tensors,
'silent_prompts_right': list of tensors
}
"""
prompt_embeddings = {
'talk_prompts': [],
'silent_prompts_left': [],
'silent_prompts_right': []
}
# 定义三个子目录
subdirs = {
'talk_prompts': 'talk_prompts',
'silent_prompts_left': 'silent_prompts_left',
'silent_prompts_right': 'silent_prompts_right'
}
for key, subdir in subdirs.items():
dir_path = os.path.join(base_dir, subdir)
if not os.path.exists(dir_path):
print(f"Warning: Directory {dir_path} does not exist")
continue
# 获取所有.safetensors文件
pattern = os.path.join(dir_path, "*.safetensors")
files = sorted(glob.glob(pattern))
for file_path in files:
try:
# 加载embedding
prompt_data = load_file(file_path)
prompt_emb = prompt_data['context']
# 处理长度
if prompt_emb.shape[0] < text_len:
padding = torch.zeros(text_len - prompt_emb.shape[0], prompt_emb.shape[1])
prompt_emb = torch.cat([prompt_emb, padding], dim=0)
else:
prompt_emb = prompt_emb[:text_len]
prompt_embeddings[key].append(prompt_emb)
except Exception as e:
print(f"Error loading {file_path}: {e}")
continue
print(f"Loaded {len(prompt_embeddings['talk_prompts'])} talk prompts")
print(f"Loaded {len(prompt_embeddings['silent_prompts_left'])} silent left prompts")
print(f"Loaded {len(prompt_embeddings['silent_prompts_right'])} silent right prompts")
return prompt_embeddings
def get_random_prompt_embedding(prompt_embeddings, prompt_type=None, device=None, dtype=None):
"""
从预定义的prompt embeddings中随机选择一个
Args:
prompt_embeddings: 由load_predefined_prompt_embeddings返回的字典
prompt_type: 指定类型 ('talk_prompts', 'silent_prompts_left', 'silent_prompts_right')
如果为None,则随机选择类型
device: 目标设备
dtype: 目标数据类型
Returns:
torch.Tensor: 随机选择的prompt embedding
"""
if prompt_type is None:
# 随机选择类型
available_types = [k for k, v in prompt_embeddings.items() if len(v) > 0]
if not available_types:
raise ValueError("No prompt embeddings available")
prompt_type = random.choice(available_types)
if prompt_type not in prompt_embeddings:
raise ValueError(f"Unknown prompt type: {prompt_type}")
if len(prompt_embeddings[prompt_type]) == 0:
raise ValueError(f"No embeddings available for type: {prompt_type}")
# 随机选择一个embedding
selected_embedding = random.choice(prompt_embeddings[prompt_type])
# 移动到指定设备和类型
if device is not None:
selected_embedding = selected_embedding.to(device)
if dtype is not None:
selected_embedding = selected_embedding.to(dtype)
return selected_embedding
def create_silence_video(video_tensor, cycle_frames=3):
"""
创建沉默视频:使用前N帧循环播放,支持倒序播放模式
Args:
video_tensor: [b, 3, F, H, W] 原始视频tensor
cycle_frames: int, 循环播放的帧数,默认为2
Returns:
silence_video: [b, 3, F, H, W] 沉默视频tensor
播放模式:
- cycle_frames=2: 12121212... (原逻辑)
- cycle_frames=3: 123123123...
- cycle_frames=4: 1234543212345... (倒序模式)
- cycle_frames=5: 1234543212345...
"""
batch_size, channels, num_frames, height, width = video_tensor.shape
# 确保cycle_frames不超过视频总帧数
cycle_frames = min(cycle_frames, num_frames)
# 获取前cycle_frames帧
cycle_video = video_tensor[:, :, :cycle_frames, :, :] # [b, 3, cycle_frames, H, W]
if cycle_frames <= 2:
# 对于1-2帧,使用简单的交替模式
if cycle_frames == 1:
# 单帧重复
silence_video = cycle_video.repeat(1, 1, num_frames, 1, 1)
else:
# 双帧交替:12121212...
frame_1 = cycle_video[:, :, 0:1, :, :] # [b, 3, 1, H, W]
frame_2 = cycle_video[:, :, 1:2, :, :] # [b, 3, 1, H, W]
repeat_times = (num_frames + 1) // 2
frame_1_repeated = frame_1.repeat(1, 1, repeat_times, 1, 1)
frame_2_repeated = frame_2.repeat(1, 1, repeat_times, 1, 1)
silence_frames = torch.stack([frame_1_repeated, frame_2_repeated], dim=3)
silence_frames = silence_frames.view(batch_size, channels, repeat_times * 2, height, width)
silence_video = silence_frames[:, :, :num_frames, :, :]
else:
# 对于3帧以上,使用倒序模式:1234543212345...
# 创建一个完整的循环周期:123...cycle_frames...321
forward_frames = cycle_video # [b, 3, cycle_frames, H, W]
reverse_frames = torch.flip(cycle_video[:, :, 1:-1, :, :], dims=[2]) # [b, 3, cycle_frames-2, H, W]
# 拼接一个完整周期:123...cycle_frames...321
one_cycle = torch.cat([forward_frames, reverse_frames], dim=2) # [b, 3, 2*cycle_frames-2, H, W]
cycle_length = 2 * cycle_frames - 2
# 计算需要多少个完整周期
num_cycles = (num_frames + cycle_length - 1) // cycle_length
# 重复完整周期
repeated_cycles = one_cycle.repeat(1, 1, num_cycles, 1, 1) # [b, 3, num_cycles*cycle_length, H, W]
# 截断到所需帧数
silence_video = repeated_cycles[:, :, :num_frames, :, :] # [b, 3, F, H, W]
return silence_video
def extract_square_faces_from_ref_images(face_parser, ref_images, video_paths, first_frame_faces, crop_size=224, device=None, torch_dtype=None, global_rank=0, current_global_step=0):
"""
使用人脸解析器从参考图像中提取方形人脸,如果检测不到人脸则回退到dataset提供的crop人脸
Args:
face_parser: FaceInference实例
ref_images: list of tensors, 每个tensor shape: [3, H, W]
video_paths: list of str, 对应的视频路径
first_frame_faces: tensor, dataset提供的crop人脸 [b, 3, face_H, face_W]
crop_size: int, 裁剪后的人脸尺寸
device: torch设备
torch_dtype: torch数据类型
global_rank: 全局rank,用于控制日志输出
current_global_step: 当前训练步数
Returns:
list of tensors: 每个tensor shape: [3, crop_size, crop_size]
"""
square_faces = []
for i, (ref_image, video_path) in enumerate(zip(ref_images, video_paths)):
try:
# 将tensor转换为numpy数组用于人脸检测
# ref_image shape: [3, H, W], 需要转换为 [H, W, 3]
# 确保转换为float32类型,避免BFloat16问题
ref_image_np = ref_image.permute(1, 2, 0).cpu().float().numpy()
# 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用
if ref_image_np.min() < 0:
# 从[-1,1]转换到[0,1]
ref_image_np = (ref_image_np + 1) / 2
ref_image_np = np.clip(ref_image_np, 0, 1)
# 转换为uint8格式
ref_image_np = (ref_image_np * 255).astype(np.uint8)
# # 添加调试信息
# if global_rank == 0 and current_global_step % 100 == 0:
# print(f"[Face Parser] 原始ref_image值域: [{ref_image.min():.4f}, {ref_image.max():.4f}]")
# print(f"[Face Parser] 转换后numpy数组值域: [{ref_image_np.min()}, {ref_image_np.max()}]")
# print(f"[Face Parser] 转换后numpy数组形状: {ref_image_np.shape}")
# 使用人脸解析器检测人脸
face_result = face_parser.infer_from_array(ref_image_np, n=1) # 只取最大的人脸
if face_result and 'bboxes' in face_result and len(face_result['bboxes']) > 0:
# 获取第一个(最大的)人脸的bbox
bbox = face_result['bboxes'][0] # [x, y, width, height]
x, y, w, h = bbox
# 从原图中裁剪人脸区域
face_crop = ref_image_np[int(y):int(y+h), int(x):int(x+w)]
# # 添加调试信息
# if global_rank == 0 and current_global_step % 100 == 0:
# print(f"[Face Parser] 成功提取人脸 {i}: bbox=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]")
# print(f"[Face Parser] 裁剪后的人脸形状: {face_crop.shape}")
# print(f"[Face Parser] 裁剪后的人脸值域: [{face_crop.min()}, {face_crop.max()}]")
# 将numpy数组转换回tensor并调整尺寸
face_tensor = torch.from_numpy(face_crop).permute(2, 0, 1).float() / 255.0
face_tensor = resize(face_tensor, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR)
# # 添加更多调试信息
# if global_rank == 0 and current_global_step % 100 == 0:
# print(f"[Face Parser] 调整尺寸后的tensor形状: {face_tensor.shape}")
# print(f"[Face Parser] 调整尺寸后的tensor值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]")
square_faces.append(face_tensor.to(device, dtype=torch_dtype))
else:
# 如果没有检测到人脸,回退到使用dataset提供的crop人脸
dataset_face = first_frame_faces[i] # [3, face_H, face_W]
# 添加调试信息
if global_rank == 0 :
print(f"[Face Parser] 未检测到人脸 {i},回退到dataset提供的crop人脸")
print(f"[Face Parser] Dataset人脸形状: {dataset_face.shape}")
print(f"[Face Parser] Dataset人脸值域: [{dataset_face.min():.4f}, {dataset_face.max():.4f}]")
face_tensor = resize(dataset_face, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR)
# 添加更多调试信息
if global_rank == 0:
print(f"[Face Parser] 调整尺寸后的dataset人脸形状: {face_tensor.shape}")
print(f"[Face Parser] 调整尺寸后的dataset人脸值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]")
square_faces.append(face_tensor.to(device, dtype=torch_dtype))
except Exception as e:
# 异常情况下回退到使用dataset提供的crop人脸
dataset_face = first_frame_faces[i] # [3, face_H, face_W]
# 添加调试信息
if global_rank == 0:
print(f"[Face Parser] 处理图像 {i} 时出错: {str(e)},回退到dataset提供的crop人脸")
print(f"[Face Parser] 异常回退 - Dataset人脸形状: {dataset_face.shape}")
print(f"[Face Parser] 异常回退 - Dataset人脸值域: [{dataset_face.min():.4f}, {dataset_face.max():.4f}]")
face_tensor = resize(dataset_face, size=(crop_size, crop_size), interpolation=InterpolationMode.BILINEAR)
# 添加更多调试信息
if global_rank == 0:
print(f"[Face Parser] 异常回退 - 调整尺寸后的tensor形状: {face_tensor.shape}")
print(f"[Face Parser] 异常回退 - 调整尺寸后的tensor值域: [{face_tensor.min():.4f}, {face_tensor.max():.4f}]")
square_faces.append(face_tensor.to(device, dtype=torch_dtype))
return square_faces
def save_face_parser_debug_images(ref_cropped_list, ref_image, video_paths, first_frame_faces, current_step, crop_image_size, debug_dir="./logs/debug/ref", global_rank=0):
"""
保存人脸解析器的调试图像到debug目录
Args:
ref_cropped_list: list of tensors, 人脸解析器生成的方形人脸
ref_image: tensor, 原始参考图像
video_paths: list of str, 视频路径
first_frame_faces: tensor, dataset提供的crop人脸
current_step: int, 当前训练步数
crop_image_size: int, 裁剪图像尺寸
debug_dir: str, 调试目录路径
"""
try:
# 只在rank 0时创建debug目录
if global_rank == 0:
os.makedirs(debug_dir, exist_ok=True)
os.chmod(debug_dir, 0o777)
# 保存每个batch的调试图像
for batch_idx, (ref_cropped, video_path) in enumerate(zip(ref_cropped_list, video_paths)):
# 只在rank 0时创建子目录
if global_rank == 0:
batch_dir = os.path.join(debug_dir, f"step_{current_step}_batch_{batch_idx}")
os.makedirs(batch_dir, exist_ok=True)
os.chmod(batch_dir, 0o777)
# 保存原始参考图像
if ref_image is not None:
ref_img_single = ref_image[batch_idx] # [3, H, W]
# 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用
if ref_img_single.min() < 0:
# 从[-1,1]转换到[0,1]
ref_img_single = (ref_img_single + 1) / 2
ref_img_single = torch.clamp(ref_img_single, 0, 1)
ref_img_path = os.path.join(batch_dir, "original_ref_image.png")
save_image(ref_img_single, ref_img_path)
# 保存dataset提供的原始crop人脸
if first_frame_faces is not None:
dataset_face_single = first_frame_faces[batch_idx] # [3, face_H, face_W]
# 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用
if dataset_face_single.min() < 0:
# 从[-1,1]转换到[0,1]
dataset_face_single = (dataset_face_single + 1) / 2
dataset_face_single = torch.clamp(dataset_face_single, 0, 1)
dataset_face_path = os.path.join(batch_dir, "dataset_crop_face.png")
save_image(dataset_face_single, dataset_face_path)
# 保存人脸解析器生成的方形人脸
if ref_cropped is not None:
ref_cropped_single = ref_cropped[batch_idx] # [3, crop_size, crop_size]
# 添加调试信息
print(f"[Debug] Parsed face tensor shape: {ref_cropped_single.shape}")
print(f"[Debug] Parsed face tensor dtype: {ref_cropped_single.dtype}")
print(f"[Debug] Parsed face tensor device: {ref_cropped_single.device}")
print(f"[Debug] Parsed face tensor stats: min={ref_cropped_single.min():.4f}, max={ref_cropped_single.max():.4f}, mean={ref_cropped_single.mean():.4f}, std={ref_cropped_single.std():.4f}")
# 检查是否有异常值
if torch.isnan(ref_cropped_single).any():
print(f"[Debug] Warning: Parsed face contains NaN values!")
if torch.isinf(ref_cropped_single).any():
print(f"[Debug] Warning: Parsed face contains Inf values!")
# 处理值域:如果值域是[-1,1],则转换到[0,1];如果已经是[0,1],则直接使用
if ref_cropped_single.min() < 0:
# 从[-1,1]转换到[0,1]
ref_cropped_single = (ref_cropped_single + 1) / 2
ref_cropped_single = torch.clamp(ref_cropped_single, 0, 1)
# 保存处理后的图像
face_path = os.path.join(batch_dir, "parsed_square_face.png")
save_image(ref_cropped_single, face_path)
# 额外保存一个未处理的版本用于对比
raw_face_path = os.path.join(batch_dir, "parsed_square_face_raw.png")
save_image(ref_cropped[batch_idx], raw_face_path)
# 保存视频路径信息
info_path = os.path.join(batch_dir, "info.txt")
with open(info_path, 'w') as f:
f.write(f"Video Path: {video_path}\n")
f.write(f"Step: {current_step}\n")
f.write(f"Batch Index: {batch_idx}\n")
f.write(f"Face Size: {crop_image_size}x{crop_image_size}\n")
if first_frame_faces is not None:
f.write(f"Dataset Face Size: {first_frame_faces[batch_idx].shape[1]}x{first_frame_faces[batch_idx].shape[2]}\n")
# 添加tensor值域信息
if ref_image is not None:
ref_img_single = ref_image[batch_idx]
f.write(f"Original Ref Image Range: [{ref_img_single.min():.4f}, {ref_img_single.max():.4f}]\n")
if first_frame_faces is not None:
dataset_face_single = first_frame_faces[batch_idx]
f.write(f"Dataset Face Range: [{dataset_face_single.min():.4f}, {dataset_face_single.max():.4f}]\n")
if ref_cropped is not None:
ref_cropped_single = ref_cropped[batch_idx]
f.write(f"Parsed Face Range: [{ref_cropped_single.min():.4f}, {ref_cropped_single.max():.4f}]\n")
print(f"[Debug] 保存人脸解析器调试图像到: {batch_dir}")
except Exception as e:
print(f"[Debug] 保存人脸解析器调试图像时出错: {str(e)}")