host / logic.py
lord-reso's picture
Upload 18 files
5bcf511 verified
raw
history blame
3.47 kB
import matplotlib.pyplot as plt
import numpy as np
import torch
import base64
import io
import matplotlib.pyplot as plt
from hparams import create_hparams
from model import Tacotron2
from train import load_model
from text import text_to_sequence
import os
import subprocess
import librosa.display
# Function to plot data
def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'],
xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'],
ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None):
fig, axes = plt.subplots(1, len(data), figsize=figsize)
for i in range(len(data)):
im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis')
if titles:
axes[i].set_title(titles[i])
if xlabel:
axes[i].set_xlabel(xlabel[i])
if ylabel:
axes[i].set_ylabel(ylabel[i])
# Add color bar
cbar = fig.colorbar(im, ax=axes[i])
if colorbar_labels:
cbar.set_label(colorbar_labels[i])
plt.tight_layout()
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
return img_base64
#Function to plot timedomain waveform
def plot_waveforms(audio_file, sr=22050):
# Load audio waveform
y, sr = librosa.load(audio_file, sr=sr)
# Create time vector
time = librosa.times_like(y, sr=sr)
# Plot the waveform
plt.figure(figsize=(16, 4))
librosa.display.waveshow(y, sr=sr)
plt.title('Time vs Amplitude')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.tight_layout()
# plt.savefig('static/waveform.png')
img_buffer = io.BytesIO()
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
plt.close()
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
return img_base64
def synthesize_voice(text_input, checkpoint_path):
# Load Tacotron2 model
hparams = create_hparams()
hparams.sampling_rate = 22050
# Load model from checkpoint
model = load_model(hparams)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
model = model.cuda().eval().half()
# Nepali text
sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
# Melspectrogram and Alignment graph
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
mel_output_data = mel_outputs.data.cpu().numpy()[0]
mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0]
alignments_data = alignments.data.cpu().numpy()[0].T
np.save('mel_files/mel1'+'.npy', mel_output_data)
input_mels_dir = 'mel_files/'
output_dir = 'audio_output/'
run_hifigan_inference(input_mels_dir, output_dir)
return mel_output_data, mel_output_postnet_data, alignments_data
def run_hifigan_inference(input_mels_dir, output_dir):
script_path = os.path.join(os.path.dirname("hifigan/"), "inference_e2e.py") # Assuming both scripts are in the same directory
subprocess.run(["python", script_path, "--checkpoint_file", "generator_v1", "--input_mels_dir", input_mels_dir, "--output_dir", output_dir])