mg_tts / app.py
h-rand's picture
Update app.py
1c46a93 verified
from fastapi import FastAPI, Response, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import VitsModel, AutoTokenizer
import torch
import scipy.io.wavfile
import io
import numpy as np
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- CONFIGURATION ---
# Le projet original utilise "facebook/mms-tts-mlg"
MODEL_ID = "facebook/mms-tts-mlg"
model = None
tokenizer = None
print("⏳ Démarrage du serveur Malagasy TTS...")
def load_model():
global model, tokenizer
try:
if model is not None: return True
print(f"📥 Chargement du modèle {MODEL_ID}...")
# On utilise le CPU pour le plan gratuit (suffisant pour ce modèle léger)
model = VitsModel.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("✅ Modèle Malgache chargé !")
return True
except Exception as e:
print(f"❌ Erreur critique chargement : {e}")
return False
# Chargement au démarrage
load_model()
@app.post("/tts")
async def generate_speech(request: Request, data: dict):
# 🛡️ SÉCURITÉ (Décommente ces 3 lignes si tu veux bloquer les accès hors Cloudflare)
# client_token = request.headers.get("x-dynamic-token")
# if not client_token:
# raise HTTPException(status_code=403, detail="Accès refusé")
# Rechargement si nécessaire (Cold start)
if model is None:
if not load_model():
raise HTTPException(status_code=500, detail="Modèle indisponible")
text = data.get("text", "")
if not text:
raise HTTPException(status_code=400, detail="Texte vide")
print(f"🇲🇬 Génération pour : {text[:30]}...")
try:
# 1. Tokenization
inputs = tokenizer(text, return_tensors="pt")
# 2. Inférence (Sans gradient pour économiser la mémoire)
with torch.no_grad():
output = model(**inputs).waveform
# 3. Conversion Audio & COMPRESSION
# Le modèle sort du float32 (très lourd)
audio_array = output.float().numpy().squeeze()
sample_rate = model.config.sampling_rate
# --- 🚀 OPTIMISATION : Division de la taille par 2 ---
# Normalisation (Met la voix au volume maximum sans grésiller)
max_amp = np.max(np.abs(audio_array))
if max_amp > 0:
audio_array = audio_array / max_amp
# Conversion de Float32 (32-bits) vers Int16 (16-bits)
audio_int16 = (audio_array * 32767.0).astype(np.int16)
# -----------------------------------------------------
# 4. Écriture WAV en mémoire
buffer = io.BytesIO()
scipy.io.wavfile.write(buffer, rate=sample_rate, data=audio_int16)
buffer.seek(0)
return Response(content=buffer.read(), media_type="audio/wav")
except Exception as e:
print(f"❌ Erreur génération : {e}")
return Response(content=str(e), status_code=500)
@app.get("/")
def home():
status = "Ready ✅" if model else "Error ❌"
return {"status": status, "lang": "mlg (Malagasy)", "optimized": "Int16 Compression Active"}