| import gradio as gr |
| from transformers import AutoProcessor, AutoModelForTextToWaveform |
| import torch |
| import librosa |
| import numpy as np |
| import scipy.io.wavfile |
|
|
| |
| MODEL_ID = "facebook/musicgen-melody" |
| print(f"Loading Model: {MODEL_ID}...") |
|
|
| |
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
| |
| model = AutoModelForTextToWaveform.from_pretrained(MODEL_ID) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(device) |
| print(f"Model loaded on {device}") |
|
|
| def generate(text, audio_path, duration, guidance_scale, top_k): |
| |
| audio = None |
| sampling_rate = 32000 |
| |
| if audio_path: |
| try: |
| |
| y, sr = librosa.load(audio_path, sr=sampling_rate, mono=True) |
| audio = y |
| except Exception as e: |
| print(f"Audio Load Failed: {e}") |
| pass |
| |
| |
| if audio is not None: |
| inputs = processor( |
| text=[text], |
| audio=[audio], |
| sampling_rate=sampling_rate, |
| padding=True, |
| return_tensors="pt", |
| ).to(device) |
| else: |
| inputs = processor( |
| text=[text], |
| padding=True, |
| return_tensors="pt", |
| ).to(device) |
|
|
| max_new_tokens = int(duration * 50) |
| |
| |
| audio_values = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| guidance_scale=guidance_scale, |
| do_sample=True, |
| top_k=top_k, |
| ) |
| |
| |
| sampling_rate = model.config.audio_encoder.sampling_rate |
| audio_data = audio_values[0, 0].cpu().numpy() |
| |
| |
| audio_data = np.clip(audio_data, -1.0, 1.0) |
| audio_data = (audio_data * 32767).astype(np.int16) |
| |
| output_path = "output.wav" |
| scipy.io.wavfile.write(output_path, rate=sampling_rate, data=audio_data) |
| |
| return output_path |
|
|
| |
| with gr.Blocks(title="나만의 MusicGen 서버") as demo: |
| gr.Markdown("# 🎵 나만의 AI 작곡가 (MusicGen - Melody Mode)") |
| |
| with gr.Row(): |
| with gr.Column(): |
| txt_input = gr.Textbox(label="음악 설명", placeholder="Jazz, Lo-fi...") |
| audio_input = gr.Audio(label="멜로디 입력 (Humming)", type="filepath") |
| num_duration = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="길이 (초)") |
| |
| with gr.Accordion("고급 설정", open=False): |
| slider_guidance = gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale") |
| slider_topk = gr.Slider(minimum=10, maximum=500, value=250, label="Top-k") |
| |
| btn = gr.Button("🎵 음악 생성하기", variant="primary") |
| with gr.Column(): |
| audio_output = gr.Audio(label="생성 결과") |
| |
| btn.click(generate, inputs=[txt_input, audio_input, num_duration, slider_guidance, slider_topk], outputs=audio_output, api_name="predict") |
|
|
| demo.queue().launch() |