host / logic.py
lord-reso's picture
Update logic.py
07c1f47 verified
import matplotlib.pyplot as plt
import numpy as np
import torch
import base64
import io
from io import BytesIO
import matplotlib.pyplot as plt
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT
from train import load_model
from text import text_to_sequence
from utils import load_wav_to_torch
import os
import random
import librosa
import librosa.display
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
hparams = create_hparams()
hparams.sampling_rate = 22050
stft = TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels,
hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax).to(device)
# Function to plot data
def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'],
xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'],
ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis')
if titles:
axes[i].set_title(titles[i])
if xlabel:
axes[i].set_xlabel(xlabel[i])
if ylabel:
axes[i].set_ylabel(ylabel[i])
# Add color bar
cbar = fig.colorbar(im, ax=axes[i])
if colorbar_labels:
cbar.set_label(colorbar_labels[i])
plt.tight_layout()
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
return img_base64
#Function to plot timedomain waveform
def plot_waveforms(audio_data):
# Load the audio from BytesIO
buffer = BytesIO(audio_data)
y, sr = librosa.load(buffer, sr=None)
# Create waveform plot
plt.figure(figsize=(10, 4))
librosa.display.waveshow(y, sr=sr)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.title("Waveform")
# Save the plot to a BytesIO object
wave_buffer = BytesIO()
plt.savefig(wave_buffer, format="png")
wave_buffer.seek(0)
plt.close()
# Encode the plot as base64
wave_base64 = base64.b64encode(wave_buffer.read()).decode('utf-8')
return wave_base64
# load speaker model
def load_speaker_model(speaker_model_path):
from speaker.model import SpeakerEncoder
device = torch.device('cuda' if use_cuda else 'cpu')
loss_device = torch.device("cpu")
model = SpeakerEncoder(device, loss_device)
speaker_dict = torch.load(speaker_model_path, map_location='cpu')
model.load_state_dict(speaker_dict)
# Freeze the weights of the speaker model
for param in model.parameters():
param.requires_grad = False
return model
speaker_model = load_speaker_model('speaker/saved_models/saved_model_e273_LargeBatch.pt').to(device).eval().float()
def extract_speech_embedding(audio_path: str):
audio, sampling_rate = load_wav_to_torch(audio_path)
if sampling_rate != stft.sampling_rate:
raise ValueError("{} SR doesn't match target {} SR".format(sampling_rate, stft.sampling_rate))
audio_norm = audio / 32768.0
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False).to(device)
melspec = stft.mel_spectrogram(audio_norm).transpose(1,2).float()
if melspec.shape[1] <= 128:
mel_slice = mel
else:
slice_start = random.randint(0,melspec.shape[1]-128)
mel_slice = melspec[:,slice_start:slice_start+128]
speaker_embedding = speaker_model(mel_slice)
return speaker_embedding
def synthesize_voice(text_input, checkpoint_path):
# Load Tacotron2 model from checkpoint
model = load_model(hparams)
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device).eval().float()
# Nepali text
speaker_audio_path='speaker_audio/ariana.wav'
sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).to(device).long()
speaker_embedding = extract_speech_embedding(speaker_audio_path)
# Melspectrogram and Alignment graph
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, speaker_embedding)
mel_output_data = mel_outputs.data.cpu().numpy()[0]
mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0]
alignments_data = alignments.data.cpu().numpy()[0].T
return mel_output_data, mel_output_postnet_data, alignments_data