Spaces:
Sleeping
Sleeping
Update hifigan/inference_e2e.py
Browse files- hifigan/inference_e2e.py +5 -26
hifigan/inference_e2e.py
CHANGED
|
@@ -1,26 +1,3 @@
|
|
| 1 |
-
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import json
|
| 6 |
-
import torch
|
| 7 |
-
from scipy.io.wavfile import write
|
| 8 |
-
from hifigan.env import AttrDict
|
| 9 |
-
from hifigan.models import Generator
|
| 10 |
-
from io import BytesIO
|
| 11 |
-
|
| 12 |
-
h = None
|
| 13 |
-
device = None
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def load_checkpoint(filepath, device):
|
| 17 |
-
assert os.path.isfile(filepath)
|
| 18 |
-
print("Loading '{}'".format(filepath))
|
| 19 |
-
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 20 |
-
print("Complete.")
|
| 21 |
-
return checkpoint_dict
|
| 22 |
-
|
| 23 |
-
|
| 24 |
def hifi_gan_inference(input_mel, checkpoint_file):
|
| 25 |
print('Initializing Inference Process..')
|
| 26 |
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
|
|
@@ -57,17 +34,19 @@ def hifi_gan_inference(input_mel, checkpoint_file):
|
|
| 57 |
|
| 58 |
x = torch.FloatTensor(x).to(device)
|
| 59 |
y_g_hat = generator(x)
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Set MAX_WAV_VALUE if not present
|
| 63 |
if 'MAX_WAV_VALUE' not in h:
|
| 64 |
h.MAX_WAV_VALUE = 32768.0 # Adjust this value based on your requirements
|
| 65 |
|
| 66 |
audio = audio * h.MAX_WAV_VALUE
|
| 67 |
-
audio = audio.
|
| 68 |
|
| 69 |
# Save audio to BytesIO
|
| 70 |
output_buffer = BytesIO()
|
| 71 |
write(output_buffer, h.sampling_rate, audio)
|
| 72 |
|
| 73 |
-
return output_buffer.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def hifi_gan_inference(input_mel, checkpoint_file):
|
| 2 |
print('Initializing Inference Process..')
|
| 3 |
config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
|
|
|
|
| 34 |
|
| 35 |
x = torch.FloatTensor(x).to(device)
|
| 36 |
y_g_hat = generator(x)
|
| 37 |
+
|
| 38 |
+
# Detach tensor before converting to numpy
|
| 39 |
+
audio = y_g_hat.squeeze().detach().numpy()
|
| 40 |
|
| 41 |
# Set MAX_WAV_VALUE if not present
|
| 42 |
if 'MAX_WAV_VALUE' not in h:
|
| 43 |
h.MAX_WAV_VALUE = 32768.0 # Adjust this value based on your requirements
|
| 44 |
|
| 45 |
audio = audio * h.MAX_WAV_VALUE
|
| 46 |
+
audio = audio.astype('int16')
|
| 47 |
|
| 48 |
# Save audio to BytesIO
|
| 49 |
output_buffer = BytesIO()
|
| 50 |
write(output_buffer, h.sampling_rate, audio)
|
| 51 |
|
| 52 |
+
return output_buffer.getvalue()
|