jerryhai
Track binary files with Git LFS
90f7c1e
raw
history blame
1.65 kB
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
import torch
from torch.nn import functional as F
def repeat_expand_2d(content, target_len):
# align content with mel
src_len = content.shape[-1]
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
temp = torch.arange(src_len+1) * target_len / src_len
current_pos = 0
for i in range(target_len):
if i < temp[current_pos+1]:
target[:, i] = content[:, current_pos]
else:
current_pos += 1
target[:, i] = content[:, current_pos]
return target
def save_plot(tensor, savepath):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
def save_audio(file_path, sampling_rate, audio):
audio = np.clip(audio.detach().cpu().squeeze().numpy(), -0.999, 0.999)
wavfile.write(file_path, sampling_rate, (audio * 32767).astype("int16"))
def minmax_norm_diff(tensor: torch.Tensor, vmax: float = 2.5, vmin: float = -12) -> torch.Tensor:
tensor = torch.clip(tensor, vmin, vmax)
tensor = 2 * (tensor - vmin) / (vmax - vmin) - 1
return tensor
def reverse_minmax_norm_diff(tensor: torch.Tensor, vmax: float = 2.5, vmin: float = -12) -> torch.Tensor:
tensor = torch.clip(tensor, -1.0, 1.0)
tensor = (tensor + 1) / 2
tensor = tensor * (vmax - vmin) + vmin
return tensor