m_gen1 / main.py
JAKTEch's picture
Upload 3 files
4af6544 verified
import uvicorn
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from transformers import MusicgenForConditionalGeneration, AutoProcessor
import torch
import scipy.io.wavfile
import io
import numpy as np
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("Initializing MusicGen...")
model = None
processor = None
def load_model():
global model, processor
if model is None:
try:
print("Loading Model Weights (CPU Mode)...")
# We use the 'small' model first to GUARANTEE it works.
# You can change "facebook/musicgen-small" to "facebook/musicgen-large" later if this succeeds.
repo_id = "facebook/musicgen-small"
processor = AutoProcessor.from_pretrained(repo_id)
model = MusicgenForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.float32)
model.to("cpu")
print("MusicGen Loaded.")
except Exception as e:
print(f"Load Error: {e}")
model = "ERROR"
load_model()
@app.get("/")
async def home():
return "sSs MusicGen (Transformers) Online."
@app.get("/generate")
async def generate_music(prompt: str, duration: int = 10):
print(f"Generating: {prompt} ({duration}s)")
if duration > 30: duration = 30
if duration < 2: duration = 2
try:
if model == "ERROR" or model is None:
load_model()
if model == "ERROR": return Response(content="Model Load Failed", status_code=500)
# Calculate tokens (roughly 50 tokens per second of audio)
max_tokens = int(duration * 50)
# Process Text
inputs = processor(
text=[prompt],
padding=True,
return_tensors="pt",
)
# Generate Audio
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
# Convert to WAV
# MusicGen output is 32kHz
sampling_rate = model.config.audio_encoder.sampling_rate
# Move to CPU numpy
audio_data = audio_values[0, 0].cpu().numpy()
# Write WAV
buffer = io.BytesIO()
scipy.io.wavfile.write(buffer, sampling_rate, audio_data)
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
import traceback
traceback.print_exc()
return Response(content=f"Gen Error: {str(e)}", status_code=500)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)