Spaces:
Sleeping
Sleeping
File size: 4,857 Bytes
5bcf511 5bbe04d 5bcf511 993218e 5bcf511 993218e 5bcf511 a2616e4 5bbe04d 5bcf511 79bd67b d678d4b df3bce7 5bbe04d df3bce7 5bbe04d df3bce7 5bbe04d df3bce7 5bbe04d df3bce7 becd625 27292be becd625 1b2b4ed 07c1f47 becd625 2a96f72 becd625 df3bce7 d678d4b df3bce7 79bd67b df3bce7 0fc06bc df3bce7 79bd67b becd625 df3bce7 becd625 df3bce7 8894b5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import matplotlib.pyplot as plt
import numpy as np
import torch
import base64
import io
from io import BytesIO
import matplotlib.pyplot as plt
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT
from train import load_model
from text import text_to_sequence
from utils import load_wav_to_torch
import os
import random
import librosa
import librosa.display
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
hparams = create_hparams()
hparams.sampling_rate = 22050
stft = TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels,
hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax).to(device)
# Function to plot data
def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'],
xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'],
ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis')
if titles:
axes[i].set_title(titles[i])
if xlabel:
axes[i].set_xlabel(xlabel[i])
if ylabel:
axes[i].set_ylabel(ylabel[i])
# Add color bar
cbar = fig.colorbar(im, ax=axes[i])
if colorbar_labels:
cbar.set_label(colorbar_labels[i])
plt.tight_layout()
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
return img_base64
#Function to plot timedomain waveform
def plot_waveforms(audio_data):
# Load the audio from BytesIO
buffer = BytesIO(audio_data)
y, sr = librosa.load(buffer, sr=None)
# Create waveform plot
plt.figure(figsize=(10, 4))
librosa.display.waveshow(y, sr=sr)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.title("Waveform")
# Save the plot to a BytesIO object
wave_buffer = BytesIO()
plt.savefig(wave_buffer, format="png")
wave_buffer.seek(0)
plt.close()
# Encode the plot as base64
wave_base64 = base64.b64encode(wave_buffer.read()).decode('utf-8')
return wave_base64
# load speaker model
def load_speaker_model(speaker_model_path):
from speaker.model import SpeakerEncoder
device = torch.device('cuda' if use_cuda else 'cpu')
loss_device = torch.device("cpu")
model = SpeakerEncoder(device, loss_device)
speaker_dict = torch.load(speaker_model_path, map_location='cpu')
model.load_state_dict(speaker_dict)
# Freeze the weights of the speaker model
for param in model.parameters():
param.requires_grad = False
return model
speaker_model = load_speaker_model('speaker/saved_models/saved_model_e273_LargeBatch.pt').to(device).eval().float()
def extract_speech_embedding(audio_path: str):
audio, sampling_rate = load_wav_to_torch(audio_path)
if sampling_rate != stft.sampling_rate:
raise ValueError("{} SR doesn't match target {} SR".format(sampling_rate, stft.sampling_rate))
audio_norm = audio / 32768.0
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False).to(device)
melspec = stft.mel_spectrogram(audio_norm).transpose(1,2).float()
if melspec.shape[1] <= 128:
mel_slice = mel
else:
slice_start = random.randint(0,melspec.shape[1]-128)
mel_slice = melspec[:,slice_start:slice_start+128]
speaker_embedding = speaker_model(mel_slice)
return speaker_embedding
def synthesize_voice(text_input, checkpoint_path):
# Load Tacotron2 model from checkpoint
model = load_model(hparams)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device).eval().float()
# Nepali text
speaker_audio_path='speaker_audio/ariana.wav'
sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).to(device).long()
speaker_embedding = extract_speech_embedding(speaker_audio_path)
# Melspectrogram and Alignment graph
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, speaker_embedding)
mel_output_data = mel_outputs.data.cpu().numpy()[0]
mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0]
alignments_data = alignments.data.cpu().numpy()[0].T
return mel_output_data, mel_output_postnet_data, alignments_data |