|
|
import uvicorn |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
|
import scipy.io.wavfile as wavfile |
|
|
import torch |
|
|
import numpy as np |
|
|
import os |
|
|
import uuid |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🚀 Chargement du modèle MusicGen sur {device}...") |
|
|
|
|
|
|
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
|
|
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small").to(device) |
|
|
print("✅ Modèle chargé avec succès.") |
|
|
except Exception as e: |
|
|
print(f"❌ Erreur critique chargement modèle: {e}") |
|
|
model = None |
|
|
|
|
|
class MusicRequest(BaseModel): |
|
|
prompt: str |
|
|
duration: int = 10 |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_music(request: MusicRequest): |
|
|
print(f"🎵 Génération : {request.prompt} ({request.duration}s)") |
|
|
|
|
|
if not model: |
|
|
raise HTTPException(status_code=500, detail="Modèle non chargé") |
|
|
|
|
|
try: |
|
|
|
|
|
duration = min(request.duration, 30) |
|
|
max_new_tokens = int(duration * 50) |
|
|
|
|
|
inputs = processor( |
|
|
text=[request.prompt], |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
).to(device) |
|
|
|
|
|
|
|
|
audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
|
|
|
|
|
|
sampling_rate = model.config.audio_encoder.sampling_rate |
|
|
audio_data = audio_values[0, 0].cpu().numpy() |
|
|
|
|
|
filename = f"music_{uuid.uuid4()}.wav" |
|
|
wavfile.write(filename, sampling_rate, audio_data) |
|
|
|
|
|
return {"audio_url": f"/{filename}"} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Erreur génération : {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/") |
|
|
async def read_index(): |
|
|
return FileResponse('index.html') |
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory=".", html=True), name="static") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|