music-api / app.py
Simonc-44's picture
Update app.py
d6e665d verified
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy.io.wavfile as wavfile
import torch
import numpy as np
import os
import uuid
# --- CONFIGURATION ---
app = FastAPI()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Chargement du modèle MusicGen sur {device}...")
# Chargement du modèle au démarrage
try:
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(device)
print("✅ Modèle chargé avec succès.")
except Exception as e:
print(f"❌ Erreur critique chargement modèle: {e}")
model = None
class MusicRequest(BaseModel):
prompt: str
duration: int = 10
@app.post("/generate")
async def generate_music(request: MusicRequest):
print(f"🎵 Génération : {request.prompt} ({request.duration}s)")
if not model:
raise HTTPException(status_code=500, detail="Modèle non chargé")
try:
# Configuration
duration = min(request.duration, 30) # Max 30s
max_new_tokens = int(duration * 50)
inputs = processor(
text=[request.prompt],
padding=True,
return_tensors="pt",
).to(device)
# Génération
audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens)
# Sauvegarde
sampling_rate = model.config.audio_encoder.sampling_rate
audio_data = audio_values[0, 0].cpu().numpy()
filename = f"music_{uuid.uuid4()}.wav"
wavfile.write(filename, sampling_rate, audio_data)
return {"audio_url": f"/{filename}"}
except Exception as e:
print(f"❌ Erreur génération : {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def read_index():
return FileResponse('index.html')
# Servir les fichiers statiques (y compris les fichiers audio générés)
app.mount("/", StaticFiles(directory=".", html=True), name="static")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)