MultiPerson / wan /utils /infer_utils.py
C4G-HKUST's picture
feat: trim 4s
6c41e4a
import torch
import numpy as np
import cv2
import os
import librosa
import math
def calculate_frame_num_from_audio(audio_paths, fps=24, mode="pad"):
"""
Calculate corresponding frame number based on audio file length
Args:
audio_paths (list): List of audio file paths
fps (int): Video frame rate, default 24fps
mode (str): Audio processing mode, "pad" or "concat".
In "pad" mode, returns max duration.
In "concat" mode, returns sum of all durations.
Returns:
int: Calculated frame number, returns default 81 if audio file does not exist
"""
if not audio_paths:
raise ValueError("No audio files, cannot determine frame number")
if mode == "concat":
# Concat mode: sum all audio durations
total_duration = 0
for audio_path in audio_paths:
if audio_path and os.path.exists(audio_path):
try:
# Use librosa to get audio duration
duration = librosa.get_duration(filename=audio_path)
total_duration += duration
print(f"audio file {audio_path} duration: {duration:.2f} seconds")
except Exception as e:
raise ValueError(f"Failed to read audio file {audio_path}: {e}")
if total_duration > 0:
# Calculate frame number, round up
frame_num = int(math.ceil(total_duration * fps))
# Ensure frame number is in 4n+1 format (model requirement)
frame_num = ((frame_num - 1) // 4) * 4 + 1
print(f"Calculated frame number (concat mode): {frame_num} based on total audio duration {total_duration:.2f}s and frame rate {fps}fps")
return frame_num
else:
raise ValueError("No audio files, cannot determine frame number")
else:
# Pad mode: use max duration (original behavior)
max_duration = 0
for audio_path in audio_paths:
if audio_path and os.path.exists(audio_path):
try:
# Use librosa to get audio duration
duration = librosa.get_duration(filename=audio_path)
max_duration = max(max_duration, duration)
print(f"audio file {audio_path} duration: {duration:.2f} seconds")
except Exception as e:
raise ValueError(f"Failed to read audio file {audio_path}: {e}")
if max_duration > 0:
# Calculate frame number, round up
frame_num = int(math.ceil(max_duration * fps))
# Ensure frame number is in 4n+1 format (model requirement)
frame_num = ((frame_num - 1) // 4) * 4 + 1
print(f"Calculated frame number (pad mode): {frame_num} based on max audio duration {max_duration:.2f}s and frame rate {fps}fps")
return frame_num
else:
raise ValueError("No audio files, cannot determine frame number")
# 计算模型的参数量
def count_parameters(model):
total_params = sum(p.numel() for p in model.parameters())
total_params_in_millions = total_params / 1e6 # Convert to millions
return total_params_in_millions
# 构建空条件的audio_ref_features - 适配多人情况
def create_null_audio_ref_features(audio_ref_features):
null_features = {}
# 处理ref_face_list - 多人情况
if 'ref_face_list' in audio_ref_features and audio_ref_features['ref_face_list']:
null_ref_face_list = []
for ref_face in audio_ref_features['ref_face_list']:
if ref_face is not None:
null_ref_face_list.append(ref_face.clone().detach())
else:
null_ref_face_list.append(None)
null_features['ref_face_list'] = null_ref_face_list
else:
null_features['ref_face_list'] = []
# 处理audio_list - 多人情况
if 'audio_list' in audio_ref_features and audio_ref_features['audio_list']:
null_audio_list = []
for audio in audio_ref_features['audio_list']:
if audio is not None:
null_audio_list.append(torch.zeros_like(audio))
else:
null_audio_list.append(None)
null_features['audio_list'] = null_audio_list
else:
null_features['audio_list'] = []
return null_features
def process_audio_features(
audio_paths=None,
audio=None,
mode="pad",
F=None,
frame_num=None,
task_key=None,
fps=None,
wav2vec_model=None,
vocal_separator_model=None,
audio_output_dir=None,
device=None,
use_half=False,
half_dtype=None,
preprocess_audio=None,
resample_audio=None,
trim_to_4s=False, # Fast mode: trim audio to 4 seconds
):
"""
Process audio files and extract audio features.
Args:
audio_paths (list): List of audio file paths (new format, supports multiple audio files)
audio (str): Single audio file path (legacy format)
mode (str): Audio processing mode, "pad" or "concat"
F (int): Target frame number (already calculated outside)
frame_num (int): Frame number for cache file naming (legacy)
task_key (str): Task key for cache file naming
fps (int): Frames per second
wav2vec_model (str): Path to wav2vec model
vocal_separator_model (str): Path to vocal separator model
audio_output_dir (str): Directory for audio output
device: Device to use for processing
use_half (bool): Whether to use half precision
half_dtype: Half precision dtype (torch.float16 or torch.float32)
preprocess_audio: Function to preprocess audio
resample_audio: Function to resample audio
Returns:
list: List of audio feature tensors
"""
from .audio_utils import preprocess_audio as _preprocess_audio, resample_audio as _resample_audio
# Use provided functions or import from audio_utils
if preprocess_audio is None:
preprocess_audio = _preprocess_audio
if resample_audio is None:
resample_audio = _resample_audio
audio_feat_list = []
if audio_paths and len(audio_paths) > 0:
print(f"Processing {len(audio_paths)} audio files in {mode} mode: {audio_paths}")
cache_dir = os.path.join(audio_output_dir, "audio_preprocess")
os.makedirs(cache_dir, exist_ok=True)
if mode == "concat":
# Concat mode: record each audio's length, calculate total length,
# and pad each audio with zeros in non-speaker segments
audio_lengths = [] # Store actual length of each audio in frames
raw_audio_feat_list = [] # Store raw audio features before padding
# First pass: process all audios and record their actual lengths
for i, audio_path in enumerate(audio_paths):
if audio_path and os.path.exists(audio_path):
print(f"Processing audio {i} (first pass): {audio_path}")
target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_concat.wav")
if not os.path.exists(target_resampled_audio_path):
resample_audio(
audio_path,
target_resampled_audio_path,
)
with torch.no_grad():
# Process audio without padding to get actual length
audio_emb, audio_length = preprocess_audio(
wav_path=target_resampled_audio_path,
num_generated_frames_per_clip=-1, # -1 means no padding
fps=fps,
wav2vec_model=wav2vec_model,
vocal_separator_model=vocal_separator_model,
cache_dir=cache_dir,
device=device,
)
# If half precision is enabled, use float16; otherwise use bfloat16
audio_dtype = half_dtype if use_half else torch.bfloat16
audio_emb = audio_emb.to(device, dtype=audio_dtype)
# Get actual frame length (audio_length is in frames)
actual_frame_length = audio_emb.shape[0]
audio_lengths.append(actual_frame_length)
raw_audio_feat_list.append(audio_emb)
print(f"Audio {i} actual length: {actual_frame_length} frames, shape: {audio_emb.shape}")
else:
print(f"Warning: Audio {i} path is empty or file not found: {audio_path}")
audio_lengths.append(0)
raw_audio_feat_list.append(None)
# Calculate total length from actual processed frames
total_length = sum(audio_lengths)
print(f"Total audio length in concat mode (from processed frames): {total_length} frames")
# Fast mode: trim to 4 seconds if trim_to_4s is True
if trim_to_4s:
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
max_frames_4s = 97
if total_length > max_frames_4s:
print(f"Fast mode: Trimming audio from {total_length} frames to {max_frames_4s} frames (4 seconds)")
# Truncate each audio proportionally
scale_factor = max_frames_4s / total_length
cumulative_length = 0
for i, audio_len in enumerate(audio_lengths):
if audio_len > 0:
new_audio_len = int(audio_len * scale_factor)
# Ensure it fits within remaining space
remaining_space = max_frames_4s - cumulative_length
new_audio_len = min(new_audio_len, remaining_space)
audio_lengths[i] = new_audio_len
# Truncate the corresponding raw audio feature
if raw_audio_feat_list[i] is not None:
raw_audio_feat_list[i] = raw_audio_feat_list[i][:new_audio_len]
cumulative_length += new_audio_len
total_length = sum(audio_lengths)
print(f"After trimming: total_length = {total_length} frames")
# Ensure total length is in 4n+1 format (model requirement)
total_length = ((total_length - 1) // 4) * 4 + 1
print(f"Adjusted total length to 4n+1 format: {total_length} frames")
# Note: F was already calculated outside and passed as parameter
# We should not update F here because it has been used to create other tensors (noise, mask, etc.)
# If there's a mismatch, it means the calculation outside was inaccurate, but we'll use F as is
if total_length > F:
print(f"Warning: Actual processed frames ({total_length}) > pre-calculated F ({F}). Using F={F} to maintain consistency with other tensors.")
elif total_length < F:
print(f"Info: Actual processed frames ({total_length}) < pre-calculated F ({F}). Using F={F}.")
else:
print(f"Info: Actual processed frames ({total_length}) matches pre-calculated F={F}.")
# Second pass: create padded audio features for each audio
# Each audio is placed in its corresponding time segment, with zeros elsewhere
cumulative_length = 0
reference_feat_shape = None
# First, find a reference feature shape from valid audio
for raw_audio_feat in raw_audio_feat_list:
if raw_audio_feat is not None:
reference_feat_shape = raw_audio_feat.shape[1:] # Get shape without frame dimension
break
if reference_feat_shape is None:
raise ValueError("No valid audio files found in concat mode")
for i, (raw_audio_feat, audio_len) in enumerate(zip(raw_audio_feat_list, audio_lengths)):
if raw_audio_feat is not None and audio_len > 0:
# Create zero tensor with total length and same feature shape
padded_audio_feat = torch.zeros(
(F,) + reference_feat_shape,
dtype=raw_audio_feat.dtype,
device=raw_audio_feat.device
)
# Place audio data in its corresponding time segment
end_pos = min(cumulative_length + audio_len, F)
actual_audio_len = end_pos - cumulative_length
padded_audio_feat[cumulative_length:end_pos] = raw_audio_feat[:actual_audio_len]
audio_feat_list.append(padded_audio_feat)
print(f"Audio {i} padded: placed at frames [{cumulative_length}:{end_pos}], shape: {padded_audio_feat.shape}")
cumulative_length += audio_len
else:
# Create zero features for missing audio with total length
zero_audio_feat = torch.zeros(
(F,) + reference_feat_shape,
dtype=torch.bfloat16 if not use_half else half_dtype,
device=device
)
audio_feat_list.append(zero_audio_feat)
print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}")
else:
# Pad mode: keep existing logic, but apply trim_to_4s if needed
for i, audio_path in enumerate(audio_paths):
if audio_path and os.path.exists(audio_path):
print(f"Processing audio {i}: {audio_path}")
target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_{F}.wav")
if not os.path.exists(target_resampled_audio_path):
resample_audio(
audio_path,
target_resampled_audio_path,
)
with torch.no_grad():
print(f"wav2vec_model: {wav2vec_model}")
print(f"cache_dir:{cache_dir}")
# Fast mode: if trim_to_4s, limit to 4 seconds
target_frames = F
if trim_to_4s:
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
max_frames_4s = 97
target_frames = min(F, max_frames_4s)
if F > max_frames_4s:
print(f"Fast mode: Trimming audio {i} from {F} frames to {max_frames_4s} frames (4 seconds)")
# Use dynamically determined frame number
audio_emb, audio_length = preprocess_audio(
wav_path=target_resampled_audio_path,
num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
fps=fps,
wav2vec_model=wav2vec_model,
vocal_separator_model=vocal_separator_model,
cache_dir=cache_dir,
device=device,
)
# If half precision is enabled, use float16; otherwise use bfloat16
audio_dtype = half_dtype if use_half else torch.bfloat16
audio_emb = audio_emb.to(device, dtype=audio_dtype)
# Ensure we don't exceed F frames (for consistency with other tensors)
audio_feat = audio_emb[:F] # Use F to maintain consistency
audio_feat_list.append(audio_feat)
print(f"Audio {i} processed, shape: {audio_feat.shape}")
else:
print(f"Warning: Audio {i} path is empty or file not found: {audio_path}")
# Create zero features for missing audio
if len(audio_feat_list) > 0:
# Use first audio's shape to create zero features
zero_audio_feat = torch.zeros_like(audio_feat_list[0])
audio_feat_list.append(zero_audio_feat)
else:
print(f"Error: No valid audio files found, cannot create zero features")
else:
# Compatible with old format: use single audio parameter
if audio is not None:
print(f"Processing single audio (legacy format): {audio}")
cache_dir = os.path.join(audio_output_dir, "audio_preprocess")
os.makedirs(cache_dir, exist_ok=True)
target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio).split('.')[0]}-16k.wav")
if not os.path.exists(target_resampled_audio_path):
audio = resample_audio(
audio,
target_resampled_audio_path,
)
with torch.no_grad():
# Fast mode: if trim_to_4s, limit to 4 seconds
target_frames = F
if trim_to_4s:
# 4秒固定为97帧(4n+1格式:4秒*24fps=96帧,向上取整为97帧)
max_frames_4s = 97
target_frames = min(F, max_frames_4s)
if F > max_frames_4s:
print(f"Fast mode: Trimming single audio from {F} frames to {max_frames_4s} frames (4 seconds)")
# Use dynamically determined frame number
audio_emb, audio_length = preprocess_audio(
wav_path=audio,
num_generated_frames_per_clip=target_frames, # Use target frames (may be trimmed)
fps=fps,
wav2vec_model=wav2vec_model,
vocal_separator_model=vocal_separator_model,
cache_dir=cache_dir,
device=device,
)
# If half precision is enabled, use float16; otherwise use bfloat16
audio_dtype = half_dtype if use_half else torch.bfloat16
audio_emb = audio_emb.to(device, dtype=audio_dtype)
# Ensure we don't exceed F frames (for consistency with other tensors)
audio_feat = audio_emb[:F] # Use F to maintain consistency
audio_feat_list.append(audio_feat)
print(f"Single audio processed, shape: {audio_feat.shape}")
else:
print("No audio files provided")
return audio_feat_list
@torch.cuda.amp.autocast(dtype=torch.float32)
def optimized_scale(positive_flat, negative_flat):
# Calculate dot production
positive_norm = torch.norm(positive_flat, dim=-1, keepdim=True)
negative_norm = torch.norm(negative_flat, dim=-1, keepdim=True)
# Calculate cosine similarity
cosine_sim = torch.sum(positive_flat * negative_flat, dim=-1, keepdim=True) / (positive_norm * negative_norm + 1e-8)
# Calculate scale factor
scale = (positive_norm / (negative_norm + 1e-8)) * cosine_sim
return scale
def expand_face_mask_flexible(face_mask, width_scale_factor, height_scale_factor):
"""
将face_mask中值为1的区域按指定的宽度和高度倍数独立扩大
Args:
face_mask: tensor, shape: [H, W],原始的face mask
width_scale_factor: float, 宽度扩大倍数
height_scale_factor: float, 高度扩大倍数
Returns:
tensor: shape: [H, W],扩大后的face mask
"""
if width_scale_factor == 1.0 and height_scale_factor == 1.0:
return face_mask
# 找到mask中非零区域的边界框
mask_indices = torch.nonzero(face_mask > 0.5)
if mask_indices.numel() == 0:
return face_mask
# 计算当前mask的边界框
min_h, min_w = mask_indices.min(dim=0)[0]
max_h, max_w = mask_indices.max(dim=0)[0]
# 计算中心点
center_h = (min_h + max_h) / 2.0
center_w = (min_w + max_w) / 2.0
# 计算当前bbox的尺寸
current_h = max_h - min_h + 1
current_w = max_w - min_w + 1
# 计算扩大后的尺寸,宽度和高度独立缩放
new_h = int(current_h * height_scale_factor)
new_w = int(current_w * width_scale_factor)
# 计算新的边界框(居中扩大)
new_min_h = int(center_h - new_h / 2.0)
new_max_h = int(center_h + new_h / 2.0)
new_min_w = int(center_w - new_w / 2.0)
new_max_w = int(center_w + new_w / 2.0)
# 确保新边界框不超出原图像范围
H, W = face_mask.shape
new_min_h = max(0, new_min_h)
new_max_h = min(H - 1, new_max_h)
new_min_w = max(0, new_min_w)
new_max_w = min(W - 1, new_max_w)
# 创建新的mask
expanded_mask = torch.zeros_like(face_mask)
# 将原始mask区域调整到新的边界框
if new_max_h > new_min_h and new_max_w > new_min_w:
# 提取原始mask的内容
original_content = face_mask[min_h:max_h+1, min_w:max_w+1]
# 将原始内容缩放到新的尺寸
target_h = new_max_h - new_min_h + 1
target_w = new_max_w - new_min_w + 1
if target_h > 0 and target_w > 0:
scaled_content = torch.nn.functional.interpolate(
original_content.unsqueeze(0).unsqueeze(0),
size=(target_h, target_w),
mode='bilinear',
align_corners=False
).squeeze(0).squeeze(0)
# 将缩放后的内容放置到新位置
expanded_mask[new_min_h:new_max_h+1, new_min_w:new_max_w+1] = scaled_content
return expanded_mask
def gen_inference_masks(masks, img_shape, num_frames=None):
"""
为推理生成与训练时相同格式的mask
注意:推理时的mask是按整个图片标记的,不需要切割50%的逻辑
为了适配训练格式,需要添加batch维度和帧维度 [H, W] -> [1, F, H, W]
Args:
masks: list of tensors, 人脸检测模型生成的mask列表,每个mask都是[H, W]格式
img_shape: tuple, 图像形状 (H, W)
num_frames: int, 视频帧数
Returns:
dict: 包含face_mask_list的字典,human_mask_list设为None
"""
H, W = img_shape
F = num_frames if num_frames is not None else 1
num_faces = len(masks)
print(f"gen_inference_masks: 处理{num_faces}个人脸,图像尺寸{H}x{W},帧数{F}")
with torch.no_grad():
face_mask_list = []
# 为每个人脸生成多帧mask
for i, mask in enumerate(masks):
# 创建多帧mask:所有帧都使用face_mask
face_mask_multi = mask.unsqueeze(0).unsqueeze(0).repeat(1, 1, F, 1, 1) # [B, C, F, H, W]
face_mask_list.append(face_mask_multi)
# 构建concat mask - 将所有mask在宽度方向拼接
if num_faces > 1:
face_mask_concat = torch.cat(face_mask_list, dim=4) # [B, C, F, H, num_faces*W]
else:
face_mask_concat = face_mask_list[0]
return {
"face_mask_list": face_mask_list,
"human_mask_list": None, # 不再使用human mask
"face_mask_concat": face_mask_concat,
"num_faces": num_faces
}
def expand_bbox_and_crop_image(img, bbox, width_scale_factor, height_scale_factor):
"""
将bbox按scale_factor放大并从图像中安全切割对应区域
Args:
img: tensor, shape: [C, H, W], 输入图像 (值域为-1到1)
bbox: list or tuple, [x1, y1, x2, y2], bbox坐标
width_scale_factor: float, 宽度放大倍数
height_scale_factor: float, 高度放大倍数
Returns:
tuple: (cropped_image, new_bbox)
- cropped_image: tensor, shape: [C, new_h, new_w], 切割后的图像
- new_bbox: list, [new_x1, new_y1, new_x2, new_y2], 调整后的bbox坐标
"""
# 获取原始bbox坐标
x1, y1, x2, y2 = bbox
# 获取图像尺寸
_, img_h, img_w = img.shape
# 计算bbox的中心点和原始尺寸
center_x = (x1 + x2) / 2.0
center_y = (y1 + y2) / 2.0
original_w = x2 - x1
original_h = y2 - y1
# 计算放大后的尺寸
new_w = original_w * width_scale_factor
new_h = original_h * height_scale_factor
# 计算放大后的bbox坐标(以中心点为准)
new_x1 = center_x - new_w / 2.0
new_y1 = center_y - new_h / 2.0
new_x2 = center_x + new_w / 2.0
new_y2 = center_y + new_h / 2.0
# 确保bbox不超出图像边界,同时保持最小尺寸
new_x1 = max(0, new_x1)
new_y1 = max(0, new_y1)
new_x2 = min(img_w, new_x2)
new_y2 = min(img_h, new_y2)
# 确保切割后的尺寸至少为1像素
if new_x2 <= new_x1:
# 如果宽度为0或负数,调整为最小可用宽度
if center_x < img_w / 2:
new_x1 = max(0, int(center_x) - 1)
new_x2 = min(img_w, new_x1 + max(1, int(original_w)))
else:
new_x2 = min(img_w, int(center_x) + 1)
new_x1 = max(0, new_x2 - max(1, int(original_w)))
if new_y2 <= new_y1:
# 如果高度为0或负数,调整为最小可用高度
if center_y < img_h / 2:
new_y1 = max(0, int(center_y) - 1)
new_y2 = min(img_h, new_y1 + max(1, int(original_h)))
else:
new_y2 = min(img_h, int(center_y) + 1)
new_y1 = max(0, new_y2 - max(1, int(original_h)))
# 转换为整数坐标
new_x1, new_y1, new_x2, new_y2 = int(new_x1), int(new_y1), int(new_x2), int(new_y2)
# 最终检查,确保坐标有效
assert new_x2 > new_x1 and new_y2 > new_y1, f"Invalid bbox after adjustment: [{new_x1}, {new_y1}, {new_x2}, {new_y2}]"
# 从原图切割放大后的区域
cropped_image = img[:, new_y1:new_y2, new_x1:new_x2]
return cropped_image, [new_x1, new_y1, new_x2, new_y2]
def gen_smooth_transition_mask_for_dit(face_mask, lat_h, lat_w, F, device, mask_dtype, target_translate=(0, 0), target_scale=1.0):
"""
Generate smooth transition mask based on face_mask and latent shape for DIT mask
First frame is all white (all 1s), subsequent frames gradually transition from original position to target position and scale
Args:
face_mask: tensor, shape: [H, W]
lat_h: int, latent height
lat_w: int, latent width
F: int, number of frames in original video
device: torch.device, device to create tensors on
mask_dtype: torch.dtype, dtype for mask tensors
target_translate: tuple, (x, y) target translation amount
target_scale: float, target scale ratio
Returns:
tensor: shape: [4, F, lat_h, lat_w], mask for DIT
"""
# Resize face_mask to latent size
face_mask_resized = torch.nn.functional.interpolate(
face_mask.unsqueeze(0).unsqueeze(0), # [1, 1, H, W]
size=(lat_h, lat_w),
mode='bilinear',
align_corners=False
).squeeze(0).squeeze(0) # [lat_h, lat_w]
# Create mask, first frame all white (all 1s), remaining frames gradually transition
msk = torch.zeros(1, F, lat_h, lat_w, device=device, dtype=mask_dtype)
msk[:, 0:1] = 1.0 # First frame all white
if F > 1:
# Generate different transformation parameters for each frame to achieve smooth transition
for frame_idx in range(1, F):
# Calculate transition progress for current frame (0 to 1)
progress = (frame_idx - 1) / (F - 2) if F > 2 else 1.0
# Use linear transition for more uniform changes
# progress is already linear, use directly
# Translation and scale for current frame (only horizontal translation allowed)
current_translate = (
0, # Vertical direction always 0, no vertical movement allowed
int(target_translate[1] * progress) # Only use horizontal translation
)
current_scale = 1.0 + (target_scale - 1.0) * progress
# Generate mask for current frame
if current_scale != 1.0:
# Calculate scaled size
scaled_h = int(lat_h * current_scale)
scaled_w = int(lat_w * current_scale)
# Scale mask
scaled_mask = torch.nn.functional.interpolate(
face_mask_resized.unsqueeze(0).unsqueeze(0), # [1, 1, lat_h, lat_w]
size=(scaled_h, scaled_w),
mode='bilinear',
align_corners=False
).squeeze(0).squeeze(0) # [scaled_h, scaled_w]
# Create zero mask of target size
transformed_mask = torch.zeros(lat_h, lat_w, device=device, dtype=mask_dtype)
# Calculate placement position (centered)
start_h = max(0, (lat_h - scaled_h) // 2)
start_w = max(0, (lat_w - scaled_w) // 2)
end_h = min(lat_h, start_h + scaled_h)
end_w = min(lat_w, start_w + scaled_w)
# Calculate crop range in scaled_mask
src_start_h = max(0, (scaled_h - lat_h) // 2)
src_start_w = max(0, (scaled_w - lat_w) // 2)
src_end_h = src_start_h + (end_h - start_h)
src_end_w = src_start_w + (end_w - start_w)
# Place scaled mask to target position
transformed_mask[start_h:end_h, start_w:end_w] = scaled_mask[src_start_h:src_end_h, src_start_w:src_end_w]
else:
transformed_mask = face_mask_resized.clone().to(dtype=mask_dtype)
# Apply horizontal translation, stop when touching boundary
translate_w = current_translate[1] # Only take horizontal translation
if translate_w != 0:
# Find horizontal boundaries of mask
mask_indices = torch.nonzero(transformed_mask > 0.5)
if mask_indices.numel() > 0:
mask_min_w = mask_indices[:, 1].min().item()
mask_max_w = mask_indices[:, 1].max().item()
# Calculate actual available horizontal translation amount
if translate_w < 0:
# When moving left, check left boundary
max_translate_w = -mask_min_w
actual_translate_w = max(translate_w, max_translate_w)
else:
# When moving right, check right boundary
max_translate_w = lat_w - 1 - mask_max_w
actual_translate_w = min(translate_w, max_translate_w)
# If there is valid translation amount, execute translation
if actual_translate_w != 0:
# Use torch.roll for horizontal translation, but ensure not exceeding boundary
if abs(actual_translate_w) <= min(mask_min_w, lat_w - 1 - mask_max_w):
# Only use roll within safe range
transformed_mask = torch.roll(transformed_mask, shifts=actual_translate_w, dims=1)
else:
# Manually copy to avoid wrapping
new_mask = torch.zeros_like(transformed_mask, dtype=mask_dtype)
if actual_translate_w > 0:
# Move right
new_mask[:, actual_translate_w:] = transformed_mask[:, :-actual_translate_w]
else:
# Move left
new_mask[:, :actual_translate_w] = transformed_mask[:, -actual_translate_w:]
transformed_mask = new_mask
# Assign mask for current frame
msk[:, frame_idx:frame_idx+1] = transformed_mask.unsqueeze(0).unsqueeze(0)
# Reference encode_image_vae processing method, convert mask to format required by DIT
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] # shape: [4, F, lat_h, lat_w]
return msk