| import uvicorn
|
| from fastapi import FastAPI, Response
|
| from fastapi.middleware.cors import CORSMiddleware
|
| from transformers import MusicgenForConditionalGeneration, AutoProcessor
|
| import torch
|
| import scipy.io.wavfile
|
| import io
|
| import numpy as np
|
|
|
| app = FastAPI()
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
| print("Initializing MusicGen...")
|
|
|
| model = None
|
| processor = None
|
|
|
| def load_model():
|
| global model, processor
|
| if model is None:
|
| try:
|
| print("Loading Model Weights (CPU Mode)...")
|
|
|
|
|
| repo_id = "facebook/musicgen-small"
|
|
|
| processor = AutoProcessor.from_pretrained(repo_id)
|
| model = MusicgenForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.float32)
|
| model.to("cpu")
|
| print("MusicGen Loaded.")
|
| except Exception as e:
|
| print(f"Load Error: {e}")
|
| model = "ERROR"
|
|
|
| load_model()
|
|
|
| @app.get("/")
|
| async def home():
|
| return "sSs MusicGen (Transformers) Online."
|
|
|
| @app.get("/generate")
|
| async def generate_music(prompt: str, duration: int = 10):
|
| print(f"Generating: {prompt} ({duration}s)")
|
|
|
| if duration > 30: duration = 30
|
| if duration < 2: duration = 2
|
|
|
| try:
|
| if model == "ERROR" or model is None:
|
| load_model()
|
| if model == "ERROR": return Response(content="Model Load Failed", status_code=500)
|
|
|
|
|
| max_tokens = int(duration * 50)
|
|
|
|
|
| inputs = processor(
|
| text=[prompt],
|
| padding=True,
|
| return_tensors="pt",
|
| )
|
|
|
|
|
| audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
|
|
|
|
|
|
| sampling_rate = model.config.audio_encoder.sampling_rate
|
|
|
|
|
| audio_data = audio_values[0, 0].cpu().numpy()
|
|
|
|
|
| buffer = io.BytesIO()
|
| scipy.io.wavfile.write(buffer, sampling_rate, audio_data)
|
| buffer.seek(0)
|
|
|
| return Response(content=buffer.read(), media_type="audio/wav")
|
|
|
| except Exception as e:
|
| import traceback
|
| traceback.print_exc()
|
| return Response(content=f"Gen Error: {str(e)}", status_code=500)
|
|
|
| if __name__ == "__main__":
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |