Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import base64 | |
| import io | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| from hparams import create_hparams | |
| from model import Tacotron2 | |
| from layers import TacotronSTFT | |
| from train import load_model | |
| from text import text_to_sequence | |
| from utils import load_wav_to_torch | |
| import os | |
| import random | |
| import librosa | |
| import librosa.display | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device('cuda' if use_cuda else 'cpu') | |
| hparams = create_hparams() | |
| hparams.sampling_rate = 22050 | |
| stft = TacotronSTFT( | |
| hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, | |
| hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax).to(device) | |
| # Function to plot data | |
| def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'], | |
| xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'], | |
| ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None): | |
| fig, axes = plt.subplots(1, len(data), figsize=figsize) | |
| for i in range(len(data)): | |
| im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis') | |
| if titles: | |
| axes[i].set_title(titles[i]) | |
| if xlabel: | |
| axes[i].set_xlabel(xlabel[i]) | |
| if ylabel: | |
| axes[i].set_ylabel(ylabel[i]) | |
| # Add color bar | |
| cbar = fig.colorbar(im, ax=axes[i]) | |
| if colorbar_labels: | |
| cbar.set_label(colorbar_labels[i]) | |
| plt.tight_layout() | |
| img_buffer = io.BytesIO() | |
| plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| return img_base64 | |
| #Function to plot timedomain waveform | |
| def plot_waveforms(audio_data): | |
| # Load the audio from BytesIO | |
| buffer = BytesIO(audio_data) | |
| y, sr = librosa.load(buffer, sr=None) | |
| # Create waveform plot | |
| plt.figure(figsize=(10, 4)) | |
| librosa.display.waveshow(y, sr=sr) | |
| plt.xlabel("Time (s)") | |
| plt.ylabel("Amplitude") | |
| plt.title("Waveform") | |
| # Save the plot to a BytesIO object | |
| wave_buffer = BytesIO() | |
| plt.savefig(wave_buffer, format="png") | |
| wave_buffer.seek(0) | |
| plt.close() | |
| # Encode the plot as base64 | |
| wave_base64 = base64.b64encode(wave_buffer.read()).decode('utf-8') | |
| return wave_base64 | |
| # load speaker model | |
| def load_speaker_model(speaker_model_path): | |
| from speaker.model import SpeakerEncoder | |
| device = torch.device('cuda' if use_cuda else 'cpu') | |
| loss_device = torch.device("cpu") | |
| model = SpeakerEncoder(device, loss_device) | |
| speaker_dict = torch.load(speaker_model_path, map_location='cpu') | |
| model.load_state_dict(speaker_dict) | |
| # Freeze the weights of the speaker model | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| speaker_model = load_speaker_model('speaker/saved_models/saved_model_e273_LargeBatch.pt').to(device).eval().float() | |
| def extract_speech_embedding(audio_path: str): | |
| audio, sampling_rate = load_wav_to_torch(audio_path) | |
| if sampling_rate != stft.sampling_rate: | |
| raise ValueError("{} SR doesn't match target {} SR".format(sampling_rate, stft.sampling_rate)) | |
| audio_norm = audio / 32768.0 | |
| audio_norm = audio_norm.unsqueeze(0) | |
| audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False).to(device) | |
| melspec = stft.mel_spectrogram(audio_norm).transpose(1,2).float() | |
| if melspec.shape[1] <= 128: | |
| mel_slice = mel | |
| else: | |
| slice_start = random.randint(0,melspec.shape[1]-128) | |
| mel_slice = melspec[:,slice_start:slice_start+128] | |
| speaker_embedding = speaker_model(mel_slice) | |
| return speaker_embedding | |
| def synthesize_voice(text_input, checkpoint_path): | |
| # Load Tacotron2 model from checkpoint | |
| model = load_model(hparams) | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model = model.to(device).eval().float() | |
| # Nepali text | |
| speaker_audio_path='speaker_audio/ariana.wav' | |
| sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :] | |
| sequence = torch.autograd.Variable(torch.from_numpy(sequence)).to(device).long() | |
| speaker_embedding = extract_speech_embedding(speaker_audio_path) | |
| # Melspectrogram and Alignment graph | |
| mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, speaker_embedding) | |
| mel_output_data = mel_outputs.data.cpu().numpy()[0] | |
| mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0] | |
| alignments_data = alignments.data.cpu().numpy()[0].T | |
| return mel_output_data, mel_output_postnet_data, alignments_data |