music-gen / app.py
ken123777's picture
Update app.py
7fc5540 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForTextToWaveform
import torch
import librosa
import numpy as np
import scipy.io.wavfile
# 모델 설정
MODEL_ID = "facebook/musicgen-melody"
print(f"Loading Model: {MODEL_ID}...")
# 4.40.2 버전의 강력함을 믿고 AutoProcessor 사용
processor = AutoProcessor.from_pretrained(MODEL_ID)
# 모델 로드 (AutoModel이 MusicgenMelodyForConditionalGeneration을 찾아줌)
model = AutoModelForTextToWaveform.from_pretrained(MODEL_ID)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on {device}")
def generate(text, audio_path, duration, guidance_scale, top_k):
# 오디오 로드 (Librosa 사용)
audio = None
sampling_rate = 32000 # MusicGen 기본 SR
if audio_path:
try:
# 로드
y, sr = librosa.load(audio_path, sr=sampling_rate, mono=True)
audio = y
except Exception as e:
print(f"Audio Load Failed: {e}")
pass
# 통합 전처리 (Processor에게 위임)
if audio is not None:
inputs = processor(
text=[text],
audio=[audio],
sampling_rate=sampling_rate,
padding=True,
return_tensors="pt",
).to(device)
else:
inputs = processor(
text=[text],
padding=True,
return_tensors="pt",
).to(device)
max_new_tokens = int(duration * 50)
# 생성
audio_values = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
guidance_scale=guidance_scale,
do_sample=True,
top_k=top_k,
)
# 저장
sampling_rate = model.config.audio_encoder.sampling_rate
audio_data = audio_values[0, 0].cpu().numpy()
# 노이즈 방지
audio_data = np.clip(audio_data, -1.0, 1.0)
audio_data = (audio_data * 32767).astype(np.int16)
output_path = "output.wav"
scipy.io.wavfile.write(output_path, rate=sampling_rate, data=audio_data)
return output_path
# UI 구성
with gr.Blocks(title="나만의 MusicGen 서버") as demo:
gr.Markdown("# 🎵 나만의 AI 작곡가 (MusicGen - Melody Mode)")
with gr.Row():
with gr.Column():
txt_input = gr.Textbox(label="음악 설명", placeholder="Jazz, Lo-fi...")
audio_input = gr.Audio(label="멜로디 입력 (Humming)", type="filepath")
num_duration = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="길이 (초)")
with gr.Accordion("고급 설정", open=False):
slider_guidance = gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale")
slider_topk = gr.Slider(minimum=10, maximum=500, value=250, label="Top-k")
btn = gr.Button("🎵 음악 생성하기", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="생성 결과")
btn.click(generate, inputs=[txt_input, audio_input, num_duration, slider_guidance, slider_topk], outputs=audio_output, api_name="predict")
demo.queue().launch()