alppo commited on
Commit
421323e
·
1 Parent(s): 0cc41af

fix on slice overwrite on slicer module

Browse files
__pycache__/mel_module.cpython-312.pyc ADDED
Binary file (6.77 kB). View file
 
__pycache__/slicer_module.cpython-312.pyc CHANGED
Binary files a/__pycache__/slicer_module.cpython-312.pyc and b/__pycache__/slicer_module.cpython-312.pyc differ
 
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import os
2
  import sys
 
3
  import torch
4
  import gradio as gr
5
  from vae_module import VAE, Encoder, Decoder, loss_function
6
  from config import config
7
  from slicer_module import get_slices
8
  from diffusers import UNet2DConditionModel, DDPMScheduler
9
-
10
 
11
  vae = VAE()
12
  vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu')))
@@ -16,14 +17,45 @@ vae.eval()
16
  model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet")
17
  noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler")
18
 
 
 
19
  def generate_new_track(audio_paths):
20
 
21
  for i, audio_path in enumerate(audio_paths):
 
22
  get_slices(audio_path)
23
-
24
- return
 
 
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Define the Gradio interface
28
  interface = gr.Interface(
29
  fn=generate_new_track,
 
1
  import os
2
  import sys
3
+ import numpy as np
4
  import torch
5
  import gradio as gr
6
  from vae_module import VAE, Encoder, Decoder, loss_function
7
  from config import config
8
  from slicer_module import get_slices
9
  from diffusers import UNet2DConditionModel, DDPMScheduler
10
+ from mel_module import Mel
11
 
12
  vae = VAE()
13
  vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu')))
 
17
  model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet")
18
  noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler")
19
 
20
+
21
+
22
  def generate_new_track(audio_paths):
23
 
24
  for i, audio_path in enumerate(audio_paths):
25
+ print(audio_paths,audio_path)
26
  get_slices(audio_path)
27
+
28
+ embedding = get_embedding()
29
+ print(embedding)
30
+
31
 
32
 
33
+ def get_embedding(): # returns middle point of given audio files latent representations
34
+ latents = []
35
+ slices_dir = 'slices'
36
+
37
+ for slice_file in os.listdir(slices_dir):
38
+ if slice_file.endswith('.wav'): # make sure the file is audio
39
+ mel = Mel(os.path.join(slices_dir, slice_file))
40
+ spectrogram = mel.get_spectrogram()
41
+ tensor = torch.tensor(spectrogram).float().unsqueeze(0).unsqueeze(0)
42
+ mu, log_var = vae.encode(tensor)
43
+ latent = torch.cat((mu, log_var), dim=1)
44
+ min_val = latent.min()
45
+ max_val = latent.max()
46
+ normalized_tensor = 2 * ((latent - min_val) / (max_val - min_val)) - 1
47
+ latent = normalized_tensor.unsqueeze(0)
48
+ print(latent.shape)
49
+ latents.append(latent)
50
+
51
+ if not latents:
52
+ return None
53
+
54
+ latents_tensor = torch.cat(latents, dim=0)
55
+ mean_latent = latents_tensor.mean(dim=0, keepdim=True)
56
+ return mean_latent
57
+
58
+
59
  # Define the Gradio interface
60
  interface = gr.Interface(
61
  fn=generate_new_track,
mel_module.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from config import config
3
+ import numpy as np
4
+ import librosa
5
+ from PIL import Image
6
+
7
+ import warnings
8
+ warnings.filterwarnings("ignore", category=UserWarning, module='librosa')
9
+
10
+ class Mel:
11
+ def __init__(
12
+ self,
13
+ file_path: str = None,
14
+ spectrogram: Optional[np.ndarray] = None,
15
+ image: Image.Image = None,
16
+ x_res: int = config.image_size,
17
+ y_res: int = config.image_size,
18
+ sample_rate: int = config.sample_rate,
19
+ n_fft: int = 2048,
20
+ hop_length: int = 882,
21
+ top_db: int = 80,
22
+ n_iter: int = 32,
23
+ ):
24
+ self.hop_length = hop_length
25
+ self.sr = sample_rate
26
+ self.n_fft = n_fft
27
+ self.top_db = top_db
28
+ self.n_iter = n_iter
29
+ self.x_res = x_res
30
+ self.y_res = y_res
31
+ self.n_mels = self.y_res
32
+ self.slice_size = self.x_res * self.hop_length - 1
33
+ self.file_path = file_path
34
+ self.spectrogram = spectrogram
35
+ self.image = image
36
+
37
+ if file_path is not None and not isinstance(file_path, str):
38
+ raise ValueError("file_path must be a string")
39
+ if spectrogram is not None and not isinstance(spectrogram, np.ndarray):
40
+ raise ValueError("spectrogram must be an ndarray")
41
+ if image is not None and not isinstance(image, Image.Image):
42
+ raise ValueError("image must be a PIL Image")
43
+
44
+ if file_path is not None:
45
+ self.load_file()
46
+ elif image is not None:
47
+ self.load_spectrogram()
48
+ elif spectrogram is not None:
49
+ self.load_image()
50
+ else:
51
+ print("Both file path and image are None!")
52
+
53
+ def load_file(self):
54
+ try:
55
+ # Load audio
56
+ if ".wav" in self.file_path:
57
+ audio, _ = librosa.load(self.file_path, mono=True, sr=self.sr)
58
+ # Pad audio if necessary
59
+ if len(audio) < self.x_res * self.hop_length:
60
+ audio = np.concatenate([audio, np.zeros((self.x_res * self.hop_length - len(audio),))])
61
+ # Compute mel spectrogram
62
+ S = librosa.feature.melspectrogram(
63
+ y=audio, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels, fmax=self.sr//2
64
+ )
65
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
66
+ log_S = log_S[:self.y_res, :self.x_res] # Ensure the spectrogram is of the desired size
67
+ self.spectrogram = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
68
+ self.image = Image.fromarray(self.spectrogram)
69
+ except Exception as e:
70
+ print(f"Error loading {self.file_path}: {e}")
71
+
72
+ def load_spectrogram(self):
73
+ self.spectrogram = np.array(self.image)
74
+
75
+ def load_image(self):
76
+ self.spectrogram = self.spectrogram.astype("uint8")
77
+ self.image = Image.fromarray(self.spectrogram)
78
+
79
+ def get_spectrogram(self):
80
+ return self.spectrogram
81
+
82
+ def get_image(self):
83
+ return self.image
84
+
85
+ def get_audio(self):
86
+ log_S = self.spectrogram.astype("float") * self.top_db / 255 - self.top_db
87
+ S = librosa.db_to_power(log_S)
88
+ audio = librosa.feature.inverse.mel_to_audio(
89
+ S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
90
+ )
91
+ display(Audio(audio, rate=self.sr))
92
+
93
+ def plot_spectrogram(self):
94
+ plt.figure(figsize=(10, 4))
95
+ plt.imshow(self.spectrogram, aspect='auto', origin='lower', cmap='viridis')
96
+ plt.colorbar(label='Magnitude')
97
+ plt.title('Mel Spectrogram')
98
+ plt.xlabel('Time (frames)')
99
+ plt.ylabel('Frequency (Mel bins)')
100
+ plt.tight_layout()
101
+ plt.show()
requirements.txt CHANGED
@@ -1,4 +1,8 @@
 
1
  diffusers
2
  torch
3
  librosa
4
- soundfile
 
 
 
 
1
+ accelerate
2
  diffusers
3
  torch
4
  librosa
5
+ soundfile
6
+ Optional
7
+ pillow
8
+ numpy
slicer_module.py CHANGED
@@ -6,24 +6,28 @@ def get_slices(file_path, sample_rate=44100, slice_duration=10, output_dir='slic
6
  os.makedirs(output_dir, exist_ok=True)
7
 
8
  audio, sr = librosa.load(file_path, sr=sample_rate)
9
-
10
  slice_samples = slice_duration * sample_rate
11
 
 
 
 
 
 
 
 
12
  num_slices = len(audio) // slice_samples
13
  for i in range(num_slices):
14
  start_sample = i * slice_samples
15
  end_sample = start_sample + slice_samples
16
  audio_slice = audio[start_sample:end_sample]
17
 
18
- # save it into /slices
19
- output_file = os.path.join(output_dir, f'slice_{i:04d}.wav')
20
  sf.write(output_file, audio_slice, sample_rate)
21
 
22
- # handle last slice
23
  if len(audio) % slice_samples != 0:
24
  start_sample = num_slices * slice_samples
25
  audio_slice = audio[start_sample:]
26
- output_file = os.path.join(output_dir, f'slice_{num_slices:04d}.wav')
27
  sf.write(output_file, audio_slice, sample_rate)
28
 
29
  if __name__ == "__main__":
 
6
  os.makedirs(output_dir, exist_ok=True)
7
 
8
  audio, sr = librosa.load(file_path, sr=sample_rate)
 
9
  slice_samples = slice_duration * sample_rate
10
 
11
+ existing_slices = [f for f in os.listdir(output_dir) if f.endswith('.wav')]
12
+ if existing_slices:
13
+ max_index = max(int(f.split('_')[1].split('.')[0]) for f in existing_slices)
14
+ start_index = max_index + 1
15
+ else:
16
+ start_index = 0
17
+
18
  num_slices = len(audio) // slice_samples
19
  for i in range(num_slices):
20
  start_sample = i * slice_samples
21
  end_sample = start_sample + slice_samples
22
  audio_slice = audio[start_sample:end_sample]
23
 
24
+ output_file = os.path.join(output_dir, f'slice_{start_index + i:04d}.wav')
 
25
  sf.write(output_file, audio_slice, sample_rate)
26
 
 
27
  if len(audio) % slice_samples != 0:
28
  start_sample = num_slices * slice_samples
29
  audio_slice = audio[start_sample:]
30
+ output_file = os.path.join(output_dir, f'slice_{start_index + num_slices:04d}.wav')
31
  sf.write(output_file, audio_slice, sample_rate)
32
 
33
  if __name__ == "__main__":