File size: 1,252 Bytes
fefd7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

def normalize(audio):
    mean = audio.mean(dim=(-2, -1), keepdim=True)
    std = audio.std(dim=(-2, -1), keepdim=True)
    audio = (audio - mean) / (std + 1e-5) # Add epsilon for stability
    return audio

def calculate_padding_mask(pad_frames, total_frames, sr, output_steps, process_seconds, device, B):
    # How many 2 seconds chunks does this audio have?
    # Find it and then multiply by the output_steps.
    total_frames = int((total_frames / sr) / process_seconds)
    total_output_steps = output_steps * total_frames
    mask = torch.zeros((B, total_output_steps), dtype = torch.bool, device = device)

    # Check the number of padding tokens that we have in the audio.
    output_sr = int(output_steps / process_seconds)
    pad_seconds = pad_frames / sr
    pad_steps = int(pad_seconds * output_sr)
    # Create the mask

    mask[..., total_output_steps - pad_steps:] = True 
    return mask, total_output_steps - pad_steps


def get_timestamps(sample_rate, B, input_audio_len, x):
    audio_len = input_audio_len
    sec = audio_len / sample_rate
    x_len = x.shape[1]
    step = sec / x_len * 1000  # sec -> ms
    ts = torch.tensor([step * i for i in range(x_len)]).unsqueeze(0)
    ts = ts.repeat(B, 1)
    return ts