lord-reso commited on
Commit
cae7b2b
·
verified ·
1 Parent(s): dc5c54f

Update hifigan/inference_e2e.py

Browse files
Files changed (1) hide show
  1. hifigan/inference_e2e.py +5 -26
hifigan/inference_e2e.py CHANGED
@@ -1,26 +1,3 @@
1
- from __future__ import absolute_import, division, print_function, unicode_literals
2
-
3
- import os
4
- import numpy as np
5
- import json
6
- import torch
7
- from scipy.io.wavfile import write
8
- from hifigan.env import AttrDict
9
- from hifigan.models import Generator
10
- from io import BytesIO
11
-
12
- h = None
13
- device = None
14
-
15
-
16
- def load_checkpoint(filepath, device):
17
- assert os.path.isfile(filepath)
18
- print("Loading '{}'".format(filepath))
19
- checkpoint_dict = torch.load(filepath, map_location=device)
20
- print("Complete.")
21
- return checkpoint_dict
22
-
23
-
24
  def hifi_gan_inference(input_mel, checkpoint_file):
25
  print('Initializing Inference Process..')
26
  config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
@@ -57,17 +34,19 @@ def hifi_gan_inference(input_mel, checkpoint_file):
57
 
58
  x = torch.FloatTensor(x).to(device)
59
  y_g_hat = generator(x)
60
- audio = y_g_hat.squeeze()
 
 
61
 
62
  # Set MAX_WAV_VALUE if not present
63
  if 'MAX_WAV_VALUE' not in h:
64
  h.MAX_WAV_VALUE = 32768.0 # Adjust this value based on your requirements
65
 
66
  audio = audio * h.MAX_WAV_VALUE
67
- audio = audio.cpu().numpy().astype('int16')
68
 
69
  # Save audio to BytesIO
70
  output_buffer = BytesIO()
71
  write(output_buffer, h.sampling_rate, audio)
72
 
73
- return output_buffer.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def hifi_gan_inference(input_mel, checkpoint_file):
2
  print('Initializing Inference Process..')
3
  config_file = os.path.join(os.path.split(checkpoint_file)[0], 'config.json')
 
34
 
35
  x = torch.FloatTensor(x).to(device)
36
  y_g_hat = generator(x)
37
+
38
+ # Detach tensor before converting to numpy
39
+ audio = y_g_hat.squeeze().detach().numpy()
40
 
41
  # Set MAX_WAV_VALUE if not present
42
  if 'MAX_WAV_VALUE' not in h:
43
  h.MAX_WAV_VALUE = 32768.0 # Adjust this value based on your requirements
44
 
45
  audio = audio * h.MAX_WAV_VALUE
46
+ audio = audio.astype('int16')
47
 
48
  # Save audio to BytesIO
49
  output_buffer = BytesIO()
50
  write(output_buffer, h.sampling_rate, audio)
51
 
52
+ return output_buffer.getvalue()