Update audioldm_train/modules/latent_diffusion/ddpm.py
Browse files
audioldm_train/modules/latent_diffusion/ddpm.py
CHANGED
|
@@ -1335,7 +1335,7 @@ class LatentDiffusion(DDPM):
|
|
| 1335 |
waveform = self.first_stage_model.vocoder(mel)
|
| 1336 |
waveform = waveform.cpu().detach().numpy()
|
| 1337 |
if save:
|
| 1338 |
-
self.save_waveform(waveform, savepath
|
| 1339 |
return waveform
|
| 1340 |
|
| 1341 |
def encode_first_stage(self, x):
|
|
@@ -1818,44 +1818,31 @@ class LatentDiffusion(DDPM):
|
|
| 1818 |
**kwargs,
|
| 1819 |
)
|
| 1820 |
|
| 1821 |
-
def save_waveform(self, waveform, savepath, name="
|
| 1822 |
-
|
| 1823 |
-
|
| 1824 |
-
|
| 1825 |
-
|
| 1826 |
-
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
|
| 1830 |
-
|
| 1831 |
-
|
| 1832 |
-
|
| 1833 |
-
|
| 1834 |
-
|
| 1835 |
-
|
| 1836 |
-
|
| 1837 |
-
|
| 1838 |
-
|
| 1839 |
-
|
| 1840 |
-
|
| 1841 |
-
|
|
|
|
|
|
|
|
|
|
| 1842 |
|
| 1843 |
-
if (not ".wav" in name[i])
|
| 1844 |
-
else os.path.basename(name[i]).split(".")[0]
|
| 1845 |
-
),
|
| 1846 |
-
)
|
| 1847 |
-
else:
|
| 1848 |
-
# import pdb
|
| 1849 |
-
# pdb.set_trace()
|
| 1850 |
-
raise NotImplementedError
|
| 1851 |
-
todo_waveform = waveform[i, 0]
|
| 1852 |
-
todo_waveform = (
|
| 1853 |
-
todo_waveform / np.max(np.abs(todo_waveform))
|
| 1854 |
-
) * 0.8 # Normalize the energy of the generation output
|
| 1855 |
-
try:
|
| 1856 |
-
sf.write(path, todo_waveform, samplerate=self.sampling_rate)
|
| 1857 |
-
except:
|
| 1858 |
-
print('waveform name ERROR!!!!!!!!!!!!')
|
| 1859 |
|
| 1860 |
@torch.no_grad()
|
| 1861 |
def sample_log(
|
|
@@ -2054,7 +2041,7 @@ class LatentDiffusion(DDPM):
|
|
| 2054 |
print("Choose the following indexes:", best_index)
|
| 2055 |
except Exception as e:
|
| 2056 |
print("Warning: while calculating CLAP score (not fatal), ", e)
|
| 2057 |
-
self.save_waveform(waveform,
|
| 2058 |
return waveform_save_path
|
| 2059 |
|
| 2060 |
|
|
|
|
| 1335 |
waveform = self.first_stage_model.vocoder(mel)
|
| 1336 |
waveform = waveform.cpu().detach().numpy()
|
| 1337 |
if save:
|
| 1338 |
+
self.save_waveform(waveform, savepath="./")
|
| 1339 |
return waveform
|
| 1340 |
|
| 1341 |
def encode_first_stage(self, x):
|
|
|
|
| 1818 |
**kwargs,
|
| 1819 |
)
|
| 1820 |
|
| 1821 |
+
def save_waveform(self, waveform, savepath="./", name="awesome.wav", n_gen=1):
|
| 1822 |
+
print(f'debug_name : {name}')
|
| 1823 |
+
|
| 1824 |
+
# If `name` is a list, join the elements into a string or select the first element
|
| 1825 |
+
if isinstance(name, list):
|
| 1826 |
+
name = "_".join(name) # Joins the list elements with an underscore
|
| 1827 |
+
name += ".wav" # Ensures the file has a `.wav` extension
|
| 1828 |
+
elif not isinstance(name, str):
|
| 1829 |
+
raise TypeError("Name must be a string or list")
|
| 1830 |
+
|
| 1831 |
+
# Normalize the energy of the waveform
|
| 1832 |
+
todo_waveform = waveform[0, 0] # Assuming you are only saving the first waveform
|
| 1833 |
+
todo_waveform = (todo_waveform / np.max(np.abs(todo_waveform))) * 0.8
|
| 1834 |
+
|
| 1835 |
+
# Define the path where to save the file
|
| 1836 |
+
path = os.path.join(savepath, name)
|
| 1837 |
+
|
| 1838 |
+
try:
|
| 1839 |
+
# Save the waveform to the specified path
|
| 1840 |
+
sf.write(path, todo_waveform, samplerate=self.sampling_rate)
|
| 1841 |
+
print(f'Waveform saved at -> {path}')
|
| 1842 |
+
except Exception as e:
|
| 1843 |
+
print(f'Error saving waveform: {e}')
|
| 1844 |
+
|
| 1845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1846 |
|
| 1847 |
@torch.no_grad()
|
| 1848 |
def sample_log(
|
|
|
|
| 2041 |
print("Choose the following indexes:", best_index)
|
| 2042 |
except Exception as e:
|
| 2043 |
print("Warning: while calculating CLAP score (not fatal), ", e)
|
| 2044 |
+
self.save_waveform(waveform, savepath="./")
|
| 2045 |
return waveform_save_path
|
| 2046 |
|
| 2047 |
|