| import torch |
| import numpy as np |
| import cv2 |
| import os |
| import librosa |
| import math |
|
|
| def calculate_frame_num_from_audio(audio_paths, fps=24, mode="pad"): |
| """ |
| Calculate corresponding frame number based on audio file length |
| |
| Args: |
| audio_paths (list): List of audio file paths |
| fps (int): Video frame rate, default 24fps |
| mode (str): Audio processing mode, "pad" or "concat". |
| In "pad" mode, returns max duration. |
| In "concat" mode, returns sum of all durations. |
| |
| Returns: |
| int: Calculated frame number, returns default 81 if audio file does not exist |
| """ |
| if not audio_paths: |
| raise ValueError("No audio files, cannot determine frame number") |
| |
| if mode == "concat": |
| |
| total_duration = 0 |
| for audio_path in audio_paths: |
| if audio_path and os.path.exists(audio_path): |
| try: |
| |
| duration = librosa.get_duration(filename=audio_path) |
| total_duration += duration |
| print(f"audio file {audio_path} duration: {duration:.2f} seconds") |
| except Exception as e: |
| raise ValueError(f"Failed to read audio file {audio_path}: {e}") |
| |
| if total_duration > 0: |
| |
| frame_num = int(math.ceil(total_duration * fps)) |
| |
| frame_num = ((frame_num - 1) // 4) * 4 + 1 |
| print(f"Calculated frame number (concat mode): {frame_num} based on total audio duration {total_duration:.2f}s and frame rate {fps}fps") |
| return frame_num |
| else: |
| raise ValueError("No audio files, cannot determine frame number") |
| else: |
| |
| max_duration = 0 |
| for audio_path in audio_paths: |
| if audio_path and os.path.exists(audio_path): |
| try: |
| |
| duration = librosa.get_duration(filename=audio_path) |
| max_duration = max(max_duration, duration) |
| print(f"audio file {audio_path} duration: {duration:.2f} seconds") |
| except Exception as e: |
| raise ValueError(f"Failed to read audio file {audio_path}: {e}") |
| |
| if max_duration > 0: |
| |
| frame_num = int(math.ceil(max_duration * fps)) |
| |
| frame_num = ((frame_num - 1) // 4) * 4 + 1 |
| print(f"Calculated frame number (pad mode): {frame_num} based on max audio duration {max_duration:.2f}s and frame rate {fps}fps") |
| return frame_num |
| else: |
| raise ValueError("No audio files, cannot determine frame number") |
|
|
| |
| def count_parameters(model): |
| total_params = sum(p.numel() for p in model.parameters()) |
| total_params_in_millions = total_params / 1e6 |
| return total_params_in_millions |
|
|
| |
| def create_null_audio_ref_features(audio_ref_features): |
| null_features = {} |
| |
| if 'ref_face_list' in audio_ref_features and audio_ref_features['ref_face_list']: |
| null_ref_face_list = [] |
| for ref_face in audio_ref_features['ref_face_list']: |
| if ref_face is not None: |
| null_ref_face_list.append(ref_face.clone().detach()) |
| else: |
| null_ref_face_list.append(None) |
| null_features['ref_face_list'] = null_ref_face_list |
| else: |
| null_features['ref_face_list'] = [] |
| |
| |
| if 'audio_list' in audio_ref_features and audio_ref_features['audio_list']: |
| null_audio_list = [] |
| for audio in audio_ref_features['audio_list']: |
| if audio is not None: |
| null_audio_list.append(torch.zeros_like(audio)) |
| else: |
| null_audio_list.append(None) |
| null_features['audio_list'] = null_audio_list |
| else: |
| null_features['audio_list'] = [] |
| |
| return null_features |
|
|
| def process_audio_features( |
| audio_paths=None, |
| audio=None, |
| mode="pad", |
| F=None, |
| frame_num=None, |
| task_key=None, |
| fps=None, |
| wav2vec_model=None, |
| vocal_separator_model=None, |
| audio_output_dir=None, |
| device=None, |
| use_half=False, |
| half_dtype=None, |
| preprocess_audio=None, |
| resample_audio=None, |
| trim_to_4s=False, |
| ): |
| """ |
| Process audio files and extract audio features. |
| |
| Args: |
| audio_paths (list): List of audio file paths (new format, supports multiple audio files) |
| audio (str): Single audio file path (legacy format) |
| mode (str): Audio processing mode, "pad" or "concat" |
| F (int): Target frame number (already calculated outside) |
| frame_num (int): Frame number for cache file naming (legacy) |
| task_key (str): Task key for cache file naming |
| fps (int): Frames per second |
| wav2vec_model (str): Path to wav2vec model |
| vocal_separator_model (str): Path to vocal separator model |
| audio_output_dir (str): Directory for audio output |
| device: Device to use for processing |
| use_half (bool): Whether to use half precision |
| half_dtype: Half precision dtype (torch.float16 or torch.float32) |
| preprocess_audio: Function to preprocess audio |
| resample_audio: Function to resample audio |
| |
| Returns: |
| list: List of audio feature tensors |
| """ |
| from .audio_utils import preprocess_audio as _preprocess_audio, resample_audio as _resample_audio |
| |
| |
| if preprocess_audio is None: |
| preprocess_audio = _preprocess_audio |
| if resample_audio is None: |
| resample_audio = _resample_audio |
| |
| audio_feat_list = [] |
| |
| if audio_paths and len(audio_paths) > 0: |
| print(f"Processing {len(audio_paths)} audio files in {mode} mode: {audio_paths}") |
| cache_dir = os.path.join(audio_output_dir, "audio_preprocess") |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| if mode == "concat": |
| |
| |
| audio_lengths = [] |
| raw_audio_feat_list = [] |
| |
| |
| for i, audio_path in enumerate(audio_paths): |
| if audio_path and os.path.exists(audio_path): |
| print(f"Processing audio {i} (first pass): {audio_path}") |
| target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_concat.wav") |
| if not os.path.exists(target_resampled_audio_path): |
| resample_audio( |
| audio_path, |
| target_resampled_audio_path, |
| ) |
| with torch.no_grad(): |
| |
| audio_emb, audio_length = preprocess_audio( |
| wav_path=target_resampled_audio_path, |
| num_generated_frames_per_clip=-1, |
| fps=fps, |
| wav2vec_model=wav2vec_model, |
| vocal_separator_model=vocal_separator_model, |
| cache_dir=cache_dir, |
| device=device, |
| ) |
| |
| audio_dtype = half_dtype if use_half else torch.bfloat16 |
| audio_emb = audio_emb.to(device, dtype=audio_dtype) |
| |
| |
| actual_frame_length = audio_emb.shape[0] |
| audio_lengths.append(actual_frame_length) |
| raw_audio_feat_list.append(audio_emb) |
| print(f"Audio {i} actual length: {actual_frame_length} frames, shape: {audio_emb.shape}") |
| else: |
| print(f"Warning: Audio {i} path is empty or file not found: {audio_path}") |
| audio_lengths.append(0) |
| raw_audio_feat_list.append(None) |
| |
| |
| total_length = sum(audio_lengths) |
| print(f"Total audio length in concat mode (from processed frames): {total_length} frames") |
| |
| |
| if trim_to_4s: |
| |
| max_frames_4s = 97 |
| if total_length > max_frames_4s: |
| print(f"Fast mode: Trimming audio from {total_length} frames to {max_frames_4s} frames (4 seconds)") |
| |
| scale_factor = max_frames_4s / total_length |
| cumulative_length = 0 |
| for i, audio_len in enumerate(audio_lengths): |
| if audio_len > 0: |
| new_audio_len = int(audio_len * scale_factor) |
| |
| remaining_space = max_frames_4s - cumulative_length |
| new_audio_len = min(new_audio_len, remaining_space) |
| audio_lengths[i] = new_audio_len |
| |
| if raw_audio_feat_list[i] is not None: |
| raw_audio_feat_list[i] = raw_audio_feat_list[i][:new_audio_len] |
| cumulative_length += new_audio_len |
| total_length = sum(audio_lengths) |
| print(f"After trimming: total_length = {total_length} frames") |
| |
| |
| total_length = ((total_length - 1) // 4) * 4 + 1 |
| print(f"Adjusted total length to 4n+1 format: {total_length} frames") |
| |
| |
| |
| |
| if total_length > F: |
| print(f"Warning: Actual processed frames ({total_length}) > pre-calculated F ({F}). Using F={F} to maintain consistency with other tensors.") |
| elif total_length < F: |
| print(f"Info: Actual processed frames ({total_length}) < pre-calculated F ({F}). Using F={F}.") |
| else: |
| print(f"Info: Actual processed frames ({total_length}) matches pre-calculated F={F}.") |
| |
| |
| |
| cumulative_length = 0 |
| reference_feat_shape = None |
| |
| |
| for raw_audio_feat in raw_audio_feat_list: |
| if raw_audio_feat is not None: |
| reference_feat_shape = raw_audio_feat.shape[1:] |
| break |
| |
| if reference_feat_shape is None: |
| raise ValueError("No valid audio files found in concat mode") |
| |
| for i, (raw_audio_feat, audio_len) in enumerate(zip(raw_audio_feat_list, audio_lengths)): |
| if raw_audio_feat is not None and audio_len > 0: |
| |
| padded_audio_feat = torch.zeros( |
| (F,) + reference_feat_shape, |
| dtype=raw_audio_feat.dtype, |
| device=raw_audio_feat.device |
| ) |
| |
| |
| end_pos = min(cumulative_length + audio_len, F) |
| actual_audio_len = end_pos - cumulative_length |
| padded_audio_feat[cumulative_length:end_pos] = raw_audio_feat[:actual_audio_len] |
| |
| audio_feat_list.append(padded_audio_feat) |
| print(f"Audio {i} padded: placed at frames [{cumulative_length}:{end_pos}], shape: {padded_audio_feat.shape}") |
| cumulative_length += audio_len |
| else: |
| |
| zero_audio_feat = torch.zeros( |
| (F,) + reference_feat_shape, |
| dtype=torch.bfloat16 if not use_half else half_dtype, |
| device=device |
| ) |
| audio_feat_list.append(zero_audio_feat) |
| print(f"Audio {i} is missing, created zero features with shape: {zero_audio_feat.shape}") |
| else: |
| |
| for i, audio_path in enumerate(audio_paths): |
| if audio_path and os.path.exists(audio_path): |
| print(f"Processing audio {i}: {audio_path}") |
| target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio_path).split('.')[0]}-{task_key}_16k_{F}.wav") |
| if not os.path.exists(target_resampled_audio_path): |
| resample_audio( |
| audio_path, |
| target_resampled_audio_path, |
| ) |
| with torch.no_grad(): |
| print(f"wav2vec_model: {wav2vec_model}") |
| print(f"cache_dir:{cache_dir}") |
| |
| target_frames = F |
| if trim_to_4s: |
| |
| max_frames_4s = 97 |
| target_frames = min(F, max_frames_4s) |
| if F > max_frames_4s: |
| print(f"Fast mode: Trimming audio {i} from {F} frames to {max_frames_4s} frames (4 seconds)") |
| |
| audio_emb, audio_length = preprocess_audio( |
| wav_path=target_resampled_audio_path, |
| num_generated_frames_per_clip=target_frames, |
| fps=fps, |
| wav2vec_model=wav2vec_model, |
| vocal_separator_model=vocal_separator_model, |
| cache_dir=cache_dir, |
| device=device, |
| ) |
| |
| audio_dtype = half_dtype if use_half else torch.bfloat16 |
| audio_emb = audio_emb.to(device, dtype=audio_dtype) |
| |
| |
| audio_feat = audio_emb[:F] |
| audio_feat_list.append(audio_feat) |
| print(f"Audio {i} processed, shape: {audio_feat.shape}") |
| else: |
| print(f"Warning: Audio {i} path is empty or file not found: {audio_path}") |
| |
| if len(audio_feat_list) > 0: |
| |
| zero_audio_feat = torch.zeros_like(audio_feat_list[0]) |
| audio_feat_list.append(zero_audio_feat) |
| else: |
| print(f"Error: No valid audio files found, cannot create zero features") |
| else: |
| |
| if audio is not None: |
| print(f"Processing single audio (legacy format): {audio}") |
| cache_dir = os.path.join(audio_output_dir, "audio_preprocess") |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| target_resampled_audio_path = os.path.join(cache_dir, f"{os.path.basename(audio).split('.')[0]}-16k.wav") |
| if not os.path.exists(target_resampled_audio_path): |
| audio = resample_audio( |
| audio, |
| target_resampled_audio_path, |
| ) |
| with torch.no_grad(): |
| |
| target_frames = F |
| if trim_to_4s: |
| |
| max_frames_4s = 97 |
| target_frames = min(F, max_frames_4s) |
| if F > max_frames_4s: |
| print(f"Fast mode: Trimming single audio from {F} frames to {max_frames_4s} frames (4 seconds)") |
| |
| audio_emb, audio_length = preprocess_audio( |
| wav_path=audio, |
| num_generated_frames_per_clip=target_frames, |
| fps=fps, |
| wav2vec_model=wav2vec_model, |
| vocal_separator_model=vocal_separator_model, |
| cache_dir=cache_dir, |
| device=device, |
| ) |
| |
| audio_dtype = half_dtype if use_half else torch.bfloat16 |
| audio_emb = audio_emb.to(device, dtype=audio_dtype) |
| |
| |
| audio_feat = audio_emb[:F] |
| audio_feat_list.append(audio_feat) |
| print(f"Single audio processed, shape: {audio_feat.shape}") |
| else: |
| print("No audio files provided") |
| |
| return audio_feat_list |
|
|
| @torch.cuda.amp.autocast(dtype=torch.float32) |
| def optimized_scale(positive_flat, negative_flat): |
| |
| positive_norm = torch.norm(positive_flat, dim=-1, keepdim=True) |
| negative_norm = torch.norm(negative_flat, dim=-1, keepdim=True) |
| |
| |
| cosine_sim = torch.sum(positive_flat * negative_flat, dim=-1, keepdim=True) / (positive_norm * negative_norm + 1e-8) |
| |
| |
| scale = (positive_norm / (negative_norm + 1e-8)) * cosine_sim |
| |
| return scale |
|
|
|
|
| def expand_face_mask_flexible(face_mask, width_scale_factor, height_scale_factor): |
| """ |
| 将face_mask中值为1的区域按指定的宽度和高度倍数独立扩大 |
| |
| Args: |
| face_mask: tensor, shape: [H, W],原始的face mask |
| width_scale_factor: float, 宽度扩大倍数 |
| height_scale_factor: float, 高度扩大倍数 |
| |
| Returns: |
| tensor: shape: [H, W],扩大后的face mask |
| """ |
| if width_scale_factor == 1.0 and height_scale_factor == 1.0: |
| return face_mask |
| |
| |
| mask_indices = torch.nonzero(face_mask > 0.5) |
| if mask_indices.numel() == 0: |
| return face_mask |
| |
| |
| min_h, min_w = mask_indices.min(dim=0)[0] |
| max_h, max_w = mask_indices.max(dim=0)[0] |
| |
| |
| center_h = (min_h + max_h) / 2.0 |
| center_w = (min_w + max_w) / 2.0 |
| |
| |
| current_h = max_h - min_h + 1 |
| current_w = max_w - min_w + 1 |
| |
| |
| new_h = int(current_h * height_scale_factor) |
| new_w = int(current_w * width_scale_factor) |
| |
| |
| new_min_h = int(center_h - new_h / 2.0) |
| new_max_h = int(center_h + new_h / 2.0) |
| new_min_w = int(center_w - new_w / 2.0) |
| new_max_w = int(center_w + new_w / 2.0) |
| |
| |
| H, W = face_mask.shape |
| new_min_h = max(0, new_min_h) |
| new_max_h = min(H - 1, new_max_h) |
| new_min_w = max(0, new_min_w) |
| new_max_w = min(W - 1, new_max_w) |
| |
| |
| expanded_mask = torch.zeros_like(face_mask) |
| |
| |
| if new_max_h > new_min_h and new_max_w > new_min_w: |
| |
| original_content = face_mask[min_h:max_h+1, min_w:max_w+1] |
| |
| |
| target_h = new_max_h - new_min_h + 1 |
| target_w = new_max_w - new_min_w + 1 |
| |
| if target_h > 0 and target_w > 0: |
| scaled_content = torch.nn.functional.interpolate( |
| original_content.unsqueeze(0).unsqueeze(0), |
| size=(target_h, target_w), |
| mode='bilinear', |
| align_corners=False |
| ).squeeze(0).squeeze(0) |
| |
| |
| expanded_mask[new_min_h:new_max_h+1, new_min_w:new_max_w+1] = scaled_content |
| |
| return expanded_mask |
|
|
|
|
| def gen_inference_masks(masks, img_shape, num_frames=None): |
| """ |
| 为推理生成与训练时相同格式的mask |
| 注意:推理时的mask是按整个图片标记的,不需要切割50%的逻辑 |
| 为了适配训练格式,需要添加batch维度和帧维度 [H, W] -> [1, F, H, W] |
| |
| Args: |
| masks: list of tensors, 人脸检测模型生成的mask列表,每个mask都是[H, W]格式 |
| img_shape: tuple, 图像形状 (H, W) |
| num_frames: int, 视频帧数 |
| |
| Returns: |
| dict: 包含face_mask_list的字典,human_mask_list设为None |
| """ |
| H, W = img_shape |
| F = num_frames if num_frames is not None else 1 |
| num_faces = len(masks) |
| |
| print(f"gen_inference_masks: 处理{num_faces}个人脸,图像尺寸{H}x{W},帧数{F}") |
| |
| with torch.no_grad(): |
| face_mask_list = [] |
| |
| |
| for i, mask in enumerate(masks): |
| |
| face_mask_multi = mask.unsqueeze(0).unsqueeze(0).repeat(1, 1, F, 1, 1) |
| face_mask_list.append(face_mask_multi) |
| |
| |
| if num_faces > 1: |
| face_mask_concat = torch.cat(face_mask_list, dim=4) |
| else: |
| face_mask_concat = face_mask_list[0] |
| |
| return { |
| "face_mask_list": face_mask_list, |
| "human_mask_list": None, |
| "face_mask_concat": face_mask_concat, |
| "num_faces": num_faces |
| } |
|
|
|
|
| def expand_bbox_and_crop_image(img, bbox, width_scale_factor, height_scale_factor): |
| """ |
| 将bbox按scale_factor放大并从图像中安全切割对应区域 |
| |
| Args: |
| img: tensor, shape: [C, H, W], 输入图像 (值域为-1到1) |
| bbox: list or tuple, [x1, y1, x2, y2], bbox坐标 |
| width_scale_factor: float, 宽度放大倍数 |
| height_scale_factor: float, 高度放大倍数 |
| |
| Returns: |
| tuple: (cropped_image, new_bbox) |
| - cropped_image: tensor, shape: [C, new_h, new_w], 切割后的图像 |
| - new_bbox: list, [new_x1, new_y1, new_x2, new_y2], 调整后的bbox坐标 |
| """ |
| |
| x1, y1, x2, y2 = bbox |
| |
| |
| _, img_h, img_w = img.shape |
| |
| |
| center_x = (x1 + x2) / 2.0 |
| center_y = (y1 + y2) / 2.0 |
| original_w = x2 - x1 |
| original_h = y2 - y1 |
| |
| |
| new_w = original_w * width_scale_factor |
| new_h = original_h * height_scale_factor |
| |
| |
| new_x1 = center_x - new_w / 2.0 |
| new_y1 = center_y - new_h / 2.0 |
| new_x2 = center_x + new_w / 2.0 |
| new_y2 = center_y + new_h / 2.0 |
| |
| |
| new_x1 = max(0, new_x1) |
| new_y1 = max(0, new_y1) |
| new_x2 = min(img_w, new_x2) |
| new_y2 = min(img_h, new_y2) |
| |
| |
| if new_x2 <= new_x1: |
| |
| if center_x < img_w / 2: |
| new_x1 = max(0, int(center_x) - 1) |
| new_x2 = min(img_w, new_x1 + max(1, int(original_w))) |
| else: |
| new_x2 = min(img_w, int(center_x) + 1) |
| new_x1 = max(0, new_x2 - max(1, int(original_w))) |
| |
| if new_y2 <= new_y1: |
| |
| if center_y < img_h / 2: |
| new_y1 = max(0, int(center_y) - 1) |
| new_y2 = min(img_h, new_y1 + max(1, int(original_h))) |
| else: |
| new_y2 = min(img_h, int(center_y) + 1) |
| new_y1 = max(0, new_y2 - max(1, int(original_h))) |
| |
| |
| new_x1, new_y1, new_x2, new_y2 = int(new_x1), int(new_y1), int(new_x2), int(new_y2) |
| |
| |
| assert new_x2 > new_x1 and new_y2 > new_y1, f"Invalid bbox after adjustment: [{new_x1}, {new_y1}, {new_x2}, {new_y2}]" |
| |
| |
| cropped_image = img[:, new_y1:new_y2, new_x1:new_x2] |
| |
| return cropped_image, [new_x1, new_y1, new_x2, new_y2] |
|
|
|
|
| def gen_smooth_transition_mask_for_dit(face_mask, lat_h, lat_w, F, device, mask_dtype, target_translate=(0, 0), target_scale=1.0): |
| """ |
| Generate smooth transition mask based on face_mask and latent shape for DIT mask |
| First frame is all white (all 1s), subsequent frames gradually transition from original position to target position and scale |
| |
| Args: |
| face_mask: tensor, shape: [H, W] |
| lat_h: int, latent height |
| lat_w: int, latent width |
| F: int, number of frames in original video |
| device: torch.device, device to create tensors on |
| mask_dtype: torch.dtype, dtype for mask tensors |
| target_translate: tuple, (x, y) target translation amount |
| target_scale: float, target scale ratio |
| |
| Returns: |
| tensor: shape: [4, F, lat_h, lat_w], mask for DIT |
| """ |
| |
| |
| face_mask_resized = torch.nn.functional.interpolate( |
| face_mask.unsqueeze(0).unsqueeze(0), |
| size=(lat_h, lat_w), |
| mode='bilinear', |
| align_corners=False |
| ).squeeze(0).squeeze(0) |
| |
| |
| msk = torch.zeros(1, F, lat_h, lat_w, device=device, dtype=mask_dtype) |
| msk[:, 0:1] = 1.0 |
| |
| if F > 1: |
| |
| for frame_idx in range(1, F): |
| |
| progress = (frame_idx - 1) / (F - 2) if F > 2 else 1.0 |
| |
| |
| |
| |
| |
| current_translate = ( |
| 0, |
| int(target_translate[1] * progress) |
| ) |
| current_scale = 1.0 + (target_scale - 1.0) * progress |
| |
| |
| if current_scale != 1.0: |
| |
| scaled_h = int(lat_h * current_scale) |
| scaled_w = int(lat_w * current_scale) |
| |
| |
| scaled_mask = torch.nn.functional.interpolate( |
| face_mask_resized.unsqueeze(0).unsqueeze(0), |
| size=(scaled_h, scaled_w), |
| mode='bilinear', |
| align_corners=False |
| ).squeeze(0).squeeze(0) |
| |
| |
| transformed_mask = torch.zeros(lat_h, lat_w, device=device, dtype=mask_dtype) |
| |
| |
| start_h = max(0, (lat_h - scaled_h) // 2) |
| start_w = max(0, (lat_w - scaled_w) // 2) |
| end_h = min(lat_h, start_h + scaled_h) |
| end_w = min(lat_w, start_w + scaled_w) |
| |
| |
| src_start_h = max(0, (scaled_h - lat_h) // 2) |
| src_start_w = max(0, (scaled_w - lat_w) // 2) |
| src_end_h = src_start_h + (end_h - start_h) |
| src_end_w = src_start_w + (end_w - start_w) |
| |
| |
| transformed_mask[start_h:end_h, start_w:end_w] = scaled_mask[src_start_h:src_end_h, src_start_w:src_end_w] |
| else: |
| transformed_mask = face_mask_resized.clone().to(dtype=mask_dtype) |
| |
| |
| translate_w = current_translate[1] |
| if translate_w != 0: |
| |
| mask_indices = torch.nonzero(transformed_mask > 0.5) |
| if mask_indices.numel() > 0: |
| mask_min_w = mask_indices[:, 1].min().item() |
| mask_max_w = mask_indices[:, 1].max().item() |
| |
| |
| if translate_w < 0: |
| |
| max_translate_w = -mask_min_w |
| actual_translate_w = max(translate_w, max_translate_w) |
| else: |
| |
| max_translate_w = lat_w - 1 - mask_max_w |
| actual_translate_w = min(translate_w, max_translate_w) |
| |
| |
| if actual_translate_w != 0: |
| |
| if abs(actual_translate_w) <= min(mask_min_w, lat_w - 1 - mask_max_w): |
| |
| transformed_mask = torch.roll(transformed_mask, shifts=actual_translate_w, dims=1) |
| else: |
| |
| new_mask = torch.zeros_like(transformed_mask, dtype=mask_dtype) |
| if actual_translate_w > 0: |
| |
| new_mask[:, actual_translate_w:] = transformed_mask[:, :-actual_translate_w] |
| else: |
| |
| new_mask[:, :actual_translate_w] = transformed_mask[:, -actual_translate_w:] |
| transformed_mask = new_mask |
| |
| |
| msk[:, frame_idx:frame_idx+1] = transformed_mask.unsqueeze(0).unsqueeze(0) |
| |
| |
| msk = torch.concat([ |
| torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] |
| ], dim=1) |
| |
| msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) |
| msk = msk.transpose(1, 2)[0] |
| return msk |