mms_fr / app.py
h-rand's picture
Create app.py
11e203c verified
from fastapi import FastAPI, Response, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
import io
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- CONFIGURATION ---
MODEL_ID = "facebook/mms-tts-fra"
model = None
tokenizer = None
print("⏳ Démarrage du serveur MMS Français...")
def load_model():
global model, tokenizer
try:
if model is not None: return True
print(f"📥 Chargement du modèle {MODEL_ID}...")
# CPU est suffisant pour MMS (très léger)
model = VitsModel.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("✅ Modèle MMS Français chargé !")
return True
except Exception as e:
print(f"❌ Erreur critique chargement : {e}")
return False
load_model()
@app.post("/tts")
async def generate_speech(data: dict):
if model is None:
if not load_model():
raise HTTPException(status_code=500, detail="Modèle indisponible")
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="Texte vide")
try:
# 1. Tokenization
inputs = tokenizer(text, return_tensors="pt")
# 2. Inférence (Sans gradient = moins de RAM)
with torch.no_grad():
output = model(**inputs).waveform
# 3. Conversion Audio
audio_array = output.float().numpy().squeeze()
sample_rate = model.config.sampling_rate
# 4. Écriture WAV
buffer = io.BytesIO()
scipy.io.wavfile.write(buffer, rate=sample_rate, data=audio_array)
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
print(f"❌ Erreur génération : {e}")
return Response(content=str(e), status_code=500)
@app.get("/")
def home():
return {"status": "MMS French Ready 🇫🇷"}