File size: 3,680 Bytes
0cc41af
 
421323e
0cc41af
a3e78b9
0cc41af
 
 
 
421323e
e599c74
 
 
 
 
 
 
a3e78b9
0cc41af
 
 
 
 
 
 
 
6f849d1
0cc41af
6f849d1
0cc41af
421323e
 
6f849d1
 
 
e599c74
 
 
 
421323e
 
 
6f849d1
421323e
6f849d1
421323e
 
 
 
 
 
 
 
 
 
6f849d1
421323e
 
6f849d1
421323e
 
 
e599c74
421323e
0cc41af
 
 
 
 
6f849d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc41af
6f849d1
 
 
 
0cc41af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import sys
import numpy as np
import torch
import gradio as gr
from vae_module import VAE, Encoder, Decoder, loss_function
from config import config
from slicer_module import get_slices
from diffusers import UNet2DConditionModel, DDPMScheduler
from mel_module import Mel
from generator_module import Generator
import shutil

slices_folder = 'slices'

if os.path.exists(slices_folder): # delete previous tracks
    shutil.rmtree(slices_folder)

vae = VAE()
vae.load_state_dict(torch.load('vae_model_state_dict.pth', map_location=torch.device('cpu')))
vae.to(config.device)
vae.eval()

model = UNet2DConditionModel.from_pretrained(config.hub_model_id, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(config.hub_model_id, subfolder="scheduler")

def generate_new_track(audio_paths, progress=gr.Progress(track_tqdm=True)):
    for i, audio_path in enumerate(audio_paths):
        print(audio_paths, audio_path)
        get_slices(audio_path)

    embedding = get_embedding()
    print("sample latent", embedding.shape)
    
    generator = Generator(config, model, noise_scheduler, vae, embedding, progress_callback=progress)
    generator.generate()

    return config.generated_track_path

def get_embedding(): # returns middle point of given audio files latent representations
    latents = []
    slices_dir = 'slices'

    for slice_file in os.listdir(slices_dir):
        if slice_file.endswith('.wav'): # make sure the file is audio
            mel = Mel(os.path.join(slices_dir, slice_file))
            spectrogram = mel.get_spectrogram()
            tensor = torch.tensor(spectrogram).float().unsqueeze(0).unsqueeze(0)
            mu, log_var = vae.encode(tensor)
            latent = torch.cat((mu, log_var), dim=1)
            min_val = latent.min()
            max_val = latent.max()
            normalized_tensor = 2 * ((latent - min_val) / (max_val - min_val)) - 1
            latent = normalized_tensor.unsqueeze(0)
            latents.append(latent)

    if not latents:
        return None

    latents_tensor = torch.cat(latents, dim=0)
    mean_latent = latents_tensor.mean(dim=0, keepdim=True)
    return mean_latent


interface = gr.Interface(
    fn=generate_new_track,
    inputs=gr.Files(file_count="multiple", label="Upload Your Audio Files"),
    outputs=gr.Audio(type="filepath", label="Generated Track"),
    title="AMUSE: Music Generation",
   description = (
    "<h3>Welcome to the AMUSE music generation app</h3>"
    "<p>Here's how it works:</p>"
    "<ol>"
    "<li><strong>Upload Your Audio Files:</strong> Provide audio files from which the taste will be extracted, "
    "and a new track will be generated accordingly. The audio files should be in .wav format!</li>"
    "<li><strong>Process:</strong> The app slices the audio, extracts features, and generates a new track using a VAE and a diffusion model.</li>"
    "<li><strong>Progress:</strong> The progress bar will show the generation process in real-time. Note that this takes a significant amount of time, "
    "so you may leave the site in the free version and come back later to see the result.</li>"
    "<li><strong>Download:</strong> Once the track is generated, you can download it directly.</li>"
    "</ol>"
    "<h4>Notes:</h4>"
    "<ul>"
    "<li>As mentioned earlier, it takes a significant amount of time to generate a new track in the free version of HF Spaces. "
    "So, submit your tracks and forget about it for a little while :) Then come back to see the new track.</li>"
    "<li>Ensure your audio files are clean and of good quality for the best results (sample rate: 44100 and .wav format).</li>"
    "</ul>"
)
)




interface.launch()