MusicLSTMDemo / MQGAN /feature_extractors.py
ZDisket
add1
827b824
import os
import json
import torch
# Relative import for TorchSTFT, assuming stft.py is in the same package
from .stft import TorchSTFT
MAX_WAV_VALUE = 32768.0
class ISTFTNetFE(torch.nn.Module):
def __init__(self, gen, stft):
super(ISTFTNetFE, self).__init__()
self.gen = gen
self.stft = stft
def infer(self, x):
wav_predictions = self(x)
audio = wav_predictions.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio
def infer_cpuistft(self, x):
self.stft = self.stft.cpu()
self.stft.window = self.stft.window.cpu()
with torch.no_grad():
spec, phase = self.gen(x)
spec, phase = spec.cpu(), phase.cpu()
y_g_hat = self.stft.inverse(spec, phase)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio
def forward(self, x):
with torch.no_grad():
spec, phase = self.gen(x)
y_g_hat = self.stft.inverse(spec, phase)
return y_g_hat
def export_ts(self, out_foldername, sampling_rate, ex_devices = ["cuda", "cpu"]):
dummy_mel = torch.randn((1, 88, 600)) # create dummy mel input
for dev in ex_devices:
self.stft = self.stft.to(dev)
self.stft.window = self.stft.window.to(dev)
self.gen = self.gen.to(dev)
raw_fn = f"istft_{dev}.pt"
full_fn = os.path.join(out_foldername, raw_fn)
os.makedirs(out_foldername, exist_ok = True)
traced_istft = torch.jit.trace_module(self.gen, {"forward": (dummy_mel.to(dev))}, check_trace=False, strict=True)
torch.jit.save(traced_istft, full_fn)
# export istft and other info
configs_dict = {
"gen_istft_n_fft" : self.stft.filter_length,
"gen_istft_hop_size" : self.stft.hop_length,
"gen_istft_n_fft" : self.stft.win_length, # Note: Original code had n_fft twice, assuming win_length for the second
"sampling_rate" : sampling_rate
}
out_jsonfn = os.path.join(out_foldername, "config.json")
with open(out_jsonfn, "w") as outfile:
json.dump(configs_dict, outfile)
def load_ts(self, in_foldername, in_dev = "cuda"):
self.gen = torch.jit.load(
os.path.join(in_foldername, f"istft_{in_dev}.pt")
)
json_fn = os.path.join(in_foldername, "config.json")
with open(json_fn) as json_file:
json_dat = json.load(json_file)
# Ensure TorchSTFT is defined or imported correctly to be used here
self.stft = TorchSTFT(filter_length=json_dat["gen_istft_n_fft"],
hop_length=json_dat["gen_istft_hop_size"],
win_length= json_dat["gen_istft_n_fft"]).to(in_dev) # Assuming win_length for the second n_fft
self.stft.window = self.stft.window.to(in_dev)
self.sampling_rate = json_dat["sampling_rate"]