File size: 4,857 Bytes
5bcf511
 
 
 
 
5bbe04d
5bcf511
 
 
993218e
5bcf511
 
993218e
5bcf511
a2616e4
5bbe04d
5bcf511
 
79bd67b
 
 
d678d4b
 
 
 
 
 
df3bce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bbe04d
 
 
 
df3bce7
5bbe04d
 
df3bce7
5bbe04d
 
 
 
 
 
 
 
df3bce7
 
5bbe04d
 
 
 
df3bce7
becd625
 
 
 
27292be
becd625
 
 
 
 
 
 
 
 
 
 
1b2b4ed
07c1f47
becd625
 
 
 
 
 
 
 
 
2a96f72
becd625
 
 
 
 
 
 
 
 
df3bce7
d678d4b
df3bce7
79bd67b
 
 
 
df3bce7
 
0fc06bc
df3bce7
79bd67b
becd625
 
df3bce7
becd625
df3bce7
 
 
 
8894b5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import matplotlib.pyplot as plt
import numpy as np
import torch
import base64
import io
from io import BytesIO
import matplotlib.pyplot as plt
from hparams import create_hparams
from model import Tacotron2
from layers import TacotronSTFT
from train import load_model
from text import text_to_sequence
from utils import load_wav_to_torch
import os
import random
import librosa
import librosa.display

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

hparams = create_hparams()
hparams.sampling_rate = 22050
stft = TacotronSTFT(
    hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, 
    hparams.sampling_rate, hparams.mel_fmin, hparams.mel_fmax).to(device)

# Function to plot data
def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'],
              xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'],
              ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None):
    fig, axes = plt.subplots(1, len(data), figsize=figsize)
    for i in range(len(data)):
        im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis')

        if titles:
            axes[i].set_title(titles[i])
        if xlabel:
            axes[i].set_xlabel(xlabel[i])
        if ylabel:
            axes[i].set_ylabel(ylabel[i])

        # Add color bar
        cbar = fig.colorbar(im, ax=axes[i])
        if colorbar_labels:
            cbar.set_label(colorbar_labels[i])

    plt.tight_layout()
    img_buffer = io.BytesIO()
    plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
    plt.close()

    img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')

    return img_base64

#Function to plot timedomain waveform
def plot_waveforms(audio_data):
    # Load the audio from BytesIO
    buffer = BytesIO(audio_data)
    y, sr = librosa.load(buffer, sr=None)

    # Create waveform plot
    plt.figure(figsize=(10, 4))
    librosa.display.waveshow(y, sr=sr)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.title("Waveform")

    # Save the plot to a BytesIO object
    wave_buffer = BytesIO()
    plt.savefig(wave_buffer, format="png")
    wave_buffer.seek(0)
    plt.close()

    # Encode the plot as base64
    wave_base64 = base64.b64encode(wave_buffer.read()).decode('utf-8')

    return wave_base64


# load speaker model
def load_speaker_model(speaker_model_path):
    from speaker.model import SpeakerEncoder
    device = torch.device('cuda' if use_cuda else 'cpu')
    loss_device = torch.device("cpu")

    model = SpeakerEncoder(device, loss_device)
    speaker_dict = torch.load(speaker_model_path, map_location='cpu')
    model.load_state_dict(speaker_dict)

    # Freeze the weights of the speaker model
    for param in model.parameters():
        param.requires_grad = False

    return model
    
speaker_model = load_speaker_model('speaker/saved_models/saved_model_e273_LargeBatch.pt').to(device).eval().float()

def extract_speech_embedding(audio_path: str):
    audio, sampling_rate = load_wav_to_torch(audio_path)
    if sampling_rate != stft.sampling_rate:
        raise ValueError("{} SR doesn't match target {} SR".format(sampling_rate, stft.sampling_rate))
    
    audio_norm = audio / 32768.0
    audio_norm = audio_norm.unsqueeze(0)
    audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False).to(device)
    melspec = stft.mel_spectrogram(audio_norm).transpose(1,2).float()

    if melspec.shape[1] <= 128:
        mel_slice = mel
    else:
        slice_start = random.randint(0,melspec.shape[1]-128)
        mel_slice = melspec[:,slice_start:slice_start+128]
    speaker_embedding = speaker_model(mel_slice)
    return speaker_embedding
    
def synthesize_voice(text_input, checkpoint_path):
    # Load Tacotron2 model from checkpoint
    model = load_model(hparams)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device).eval().float()


    # Nepali text
    speaker_audio_path='speaker_audio/ariana.wav'
    sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :]
    sequence = torch.autograd.Variable(torch.from_numpy(sequence)).to(device).long()
    speaker_embedding = extract_speech_embedding(speaker_audio_path)
    
    # Melspectrogram and Alignment graph
    mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, speaker_embedding)
    mel_output_data = mel_outputs.data.cpu().numpy()[0]
    mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0]
    alignments_data = alignments.data.cpu().numpy()[0].T

    return mel_output_data, mel_output_postnet_data, alignments_data