| |
| import math |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
|
|
|
| def get_sample_indices(original_fps, |
| total_frames, |
| target_fps, |
| num_sample, |
| fixed_start=None): |
| required_duration = num_sample / target_fps |
| required_origin_frames = int(np.ceil(required_duration * original_fps)) |
| if required_duration > total_frames / original_fps: |
| raise ValueError("required_duration must be less than video length") |
|
|
| if not fixed_start is None and fixed_start >= 0: |
| start_frame = fixed_start |
| else: |
| max_start = total_frames - required_origin_frames |
| if max_start < 0: |
| raise ValueError("video length is too short") |
| start_frame = np.random.randint(0, max_start + 1) |
| start_time = start_frame / original_fps |
|
|
| end_time = start_time + required_duration |
| time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) |
|
|
| frame_indices = np.round(np.array(time_points) * original_fps).astype(int) |
| frame_indices = np.clip(frame_indices, 0, total_frames - 1) |
| return frame_indices |
|
|
|
|
| def linear_interpolation(features, input_fps, output_fps, output_len=None): |
| """ |
| features: shape=[1, T, 512] |
| input_fps: fps for audio, f_a |
| output_fps: fps for video, f_m |
| output_len: video length |
| """ |
| features = features.transpose(1, 2) |
| seq_len = features.shape[2] / float(input_fps) |
| if output_len is None: |
| output_len = int(seq_len * output_fps) |
| output_features = F.interpolate( |
| features, size=output_len, align_corners=True, |
| mode='linear') |
| return output_features.transpose(1, 2) |
|
|
|
|
| class AudioEncoder(): |
|
|
| def __init__(self, device='cpu', model_id="facebook/wav2vec2-base-960h"): |
| |
| self.processor = Wav2Vec2Processor.from_pretrained(model_id) |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_id) |
|
|
| self.model = self.model.to(device) |
|
|
| self.video_rate = 30 |
|
|
| def get_audio_embed_bucket(self, |
| audio_embed, |
| stride=2, |
| batch_frames=12, |
| m=2): |
| num_layers, audio_frame_num, audio_dim = audio_embed.shape |
|
|
| if num_layers > 1: |
| return_all_layers = True |
| else: |
| return_all_layers = False |
|
|
| min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 |
|
|
| bucket_num = min_batch_num * batch_frames |
| batch_idx = [stride * i for i in range(bucket_num)] |
| batch_audio_eb = [] |
| for bi in batch_idx: |
| if bi < audio_frame_num: |
| audio_sample_stride = 2 |
| chosen_idx = list( |
| range(bi - m * audio_sample_stride, |
| bi + (m + 1) * audio_sample_stride, |
| audio_sample_stride)) |
| chosen_idx = [0 if c < 0 else c for c in chosen_idx] |
| chosen_idx = [ |
| audio_frame_num - 1 if c >= audio_frame_num else c |
| for c in chosen_idx |
| ] |
|
|
| if return_all_layers: |
| frame_audio_embed = audio_embed[:, chosen_idx].flatten( |
| start_dim=-2, end_dim=-1) |
| else: |
| frame_audio_embed = audio_embed[0][chosen_idx].flatten() |
| else: |
| frame_audio_embed = \ |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
| batch_audio_eb.append(frame_audio_embed) |
| batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], |
| dim=0) |
|
|
| return batch_audio_eb, min_batch_num |
|
|
| def get_audio_embed_bucket_fps(self, |
| audio_embed, |
| fps=16, |
| batch_frames=81, |
| m=0): |
| num_layers, audio_frame_num, audio_dim = audio_embed.shape |
|
|
| if num_layers > 1: |
| return_all_layers = True |
| else: |
| return_all_layers = False |
|
|
| scale = self.video_rate / fps |
|
|
| min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 |
|
|
| bucket_num = min_batch_num * batch_frames |
| padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * |
| self.video_rate) - audio_frame_num |
| batch_idx = get_sample_indices( |
| original_fps=self.video_rate, |
| total_frames=audio_frame_num + padd_audio_num, |
| target_fps=fps, |
| num_sample=bucket_num, |
| fixed_start=0) |
| batch_audio_eb = [] |
| audio_sample_stride = int(self.video_rate / fps) |
| for bi in batch_idx: |
| if bi < audio_frame_num: |
|
|
| chosen_idx = list( |
| range(bi - m * audio_sample_stride, |
| bi + (m + 1) * audio_sample_stride, |
| audio_sample_stride)) |
| chosen_idx = [0 if c < 0 else c for c in chosen_idx] |
| chosen_idx = [ |
| audio_frame_num - 1 if c >= audio_frame_num else c |
| for c in chosen_idx |
| ] |
|
|
| if return_all_layers: |
| frame_audio_embed = audio_embed[:, chosen_idx].flatten( |
| start_dim=-2, end_dim=-1) |
| else: |
| frame_audio_embed = audio_embed[0][chosen_idx].flatten() |
| else: |
| frame_audio_embed = \ |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) |
| batch_audio_eb.append(frame_audio_embed) |
| batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], |
| dim=0) |
|
|
| return batch_audio_eb, min_batch_num |
|
|