Spaces:
Sleeping
Sleeping
| import os | |
| import warnings | |
| import traceback | |
| warnings.simplefilter("ignore") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import io | |
| import torch | |
| import numpy as np | |
| from audiocraft.models import musicgen | |
| from scipy.io.wavfile import write as wav_write | |
| try: | |
| from logger import logging | |
| except: | |
| import logging | |
| class GenerateAudio: | |
| def __init__(self, model="musicgen-stereo-small"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_name = self.get_model_name(model) | |
| self.model = self.get_model(self.model_name, self.device) | |
| self.generated_audio = None | |
| self.sampling_rate = None | |
| def get_model(model, device): | |
| try: | |
| model = musicgen.MusicGen.get_pretrained(model, device=device) | |
| logging.info(f"Loaded model: {model}") | |
| return model | |
| except Exception as e: | |
| logging.error( | |
| f"Failed to load model: {e}, Traceback: {traceback.format_exc()}" | |
| ) | |
| raise ValueError(f"Failed to load model: {e}") | |
| return | |
| def get_model_name(model_name): | |
| if model_name.startswith("facebook/"): | |
| return model_name | |
| return f"facebook/{model_name}" | |
| def duration_sanity_check(duration): | |
| if duration < 1: | |
| logging.warning( | |
| "Duration is less than 1 second. Setting duration to 1 second." | |
| ) | |
| return 1 | |
| elif duration > 30: | |
| logging.warning( | |
| "Duration is greater than 30 seconds. Setting duration to 30 seconds." | |
| ) | |
| return 30 | |
| return duration | |
| def prompts_sanity_check(prompts): | |
| if isinstance(prompts, str): | |
| prompts = [prompts] | |
| elif not isinstance(prompts, list): | |
| raise ValueError("Prompts should be a string or a list of strings.") | |
| else: | |
| for prompt in prompts: | |
| if not isinstance(prompt, str): | |
| raise ValueError("Prompts should be a string or a list of strings.") | |
| if len(prompts) > 8: # Too many prompts will cause OOM error | |
| raise ValueError("Maximum number of prompts allowed is 8.") | |
| return prompts | |
| def generate_audio(self, prompts, duration=10): | |
| duration = self.duration_sanity_check(duration) | |
| prompts = self.prompts_sanity_check(prompts) | |
| try: | |
| self.sampling_rate = self.model.sample_rate | |
| if duration <= 30: | |
| self.model.set_generation_params(duration=duration) | |
| result = self.model.generate(prompts, progress=False) | |
| elif duration > 30: | |
| self.model.set_generation_params(duration=30) | |
| result = self.model.generate(prompts, progress=False) | |
| self.model.set_generation_params(duration=duration) | |
| result = self.model.generate_with_chroma( | |
| prompts, | |
| result, | |
| melody_sample_rate=self.sampling_rate, | |
| progress=False, | |
| ) | |
| self.result = result.cpu().numpy().T | |
| self.result = self.result.transpose((2, 0, 1)) | |
| logging.info( | |
| f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz" | |
| ) | |
| return self.sampling_rate, self.result | |
| except Exception as e: | |
| logging.error( | |
| f"Failed to generate audio: {e}, Traceback: {traceback.format_exc()}" | |
| ) | |
| raise ValueError(f"Failed to generate audio: {e}") | |
| def save_audio(self, audio_dir="generated_audio"): | |
| if self.result is None: | |
| raise ValueError("Audio is not generated yet.") | |
| if self.sampling_rate is None: | |
| raise ValueError("Sampling rate is not available.") | |
| paths = [] | |
| os.makedirs(audio_dir, exist_ok=True) | |
| for i, audio in enumerate(self.result): | |
| path = os.path.join(audio_dir, f"audio_{i}.wav") | |
| wav_write(path, self.sampling_rate, audio) | |
| paths.append(path) | |
| return paths | |
| def get_audio_buffer(self): | |
| if self.result is None: | |
| raise ValueError("Audio is not generated yet.") | |
| if self.sampling_rate is None: | |
| raise ValueError("Sampling rate is not available.") | |
| buffers = [] | |
| for audio in self.result: | |
| buffer = io.BytesIO() | |
| wav_write(buffer, self.sampling_rate, audio) | |
| buffer.seek(0) | |
| buffers.append(buffer) | |
| return buffers | |
| if __name__ == "__main__": | |
| audio_gen = GenerateAudio() | |
| sample_rate, result = audio_gen.generate_audio( | |
| [ | |
| "A piano playing a jazz melody", | |
| "A guitar playing a rock riff", | |
| "A LoFi music for coding", | |
| ], | |
| duration=10, | |
| ) | |
| paths = audio_gen.save_audio() | |
| print(f"Saved audio to: {paths}") | |
| buffers = audio_gen.get_audio_buffer() | |
| print(f"Audio buffers: {buffers}") | |