MarcosFRGames's picture
Update app.py
15d78d3 verified
from flask import Flask, request, jsonify, Response, send_file
import os
import json
import logging
import threading
import tempfile
import time
import gc
import torch
import numpy as np
from datetime import datetime
import requests
from concurrent.futures import ThreadPoolExecutor
import io
import soundfile as sf
# Configuración básica de logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Cargar configuración de modelos
with open('engines.json', 'r') as f:
TTS_MODELS = json.load(f)
# Constantes de configuración
MAX_AUDIO_LENGTH = 30 # segundos máximo
MAX_TEXT_LENGTH = 500 # caracteres máximo
class TTSManager:
def __init__(self, models_config):
self.models = {}
self.models_config = models_config
self.executor = ThreadPoolExecutor(max_workers=2)
self.generation_lock = threading.Lock()
self.session = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=2, pool_maxsize=2)
self.session.mount('http://', adapter)
self.session.mount('https://', adapter)
self.load_all_models()
def load_all_models(self):
"""Cargar todos los modelos TTS en RAM desde URLs"""
for model_config in self.models_config:
try:
model_id = model_config["id"]
model_url = model_config["url"]
model_type = model_config.get("type", "transformers")
logger.info(f"🚀 Cargando modelo TTS: {model_id}")
# Descargar modelo a archivo temporal
temp_path = self._download_model(model_url, model_id)
# Verificar tamaño del archivo
actual_size = os.path.getsize(temp_path)
actual_mb = actual_size / (1024*1024)
logger.info(f"📊 Tamaño descargado para {model_id}: {actual_mb:.2f} MB")
# Cargar modelo según su tipo
logger.info(f"🔄 Cargando {model_id} en RAM...")
if model_type == "transformers":
model_instance = self._load_transformers_model(temp_path, model_config)
elif model_type == "coqui":
model_instance = self._load_coqui_model(temp_path, model_config)
elif model_type == "speecht5":
model_instance = self._load_speecht5_model(temp_path, model_config)
else:
raise ValueError(f"Tipo de modelo no soportado: {model_type}")
# Limpiar archivo temporal
os.remove(temp_path)
logger.info(f"🗑️ Archivo temporal {temp_path} eliminado")
self.models[model_id] = {
"instance": model_instance,
"loaded": True,
"config": model_config,
"type": model_type,
"loaded_at": datetime.now().isoformat()
}
logger.info(f"✅ Modelo TTS {model_id} cargado exitosamente")
except Exception as e:
logger.error(f"❌ Error cargando modelo {model_config.get('id', 'unknown')}: {e}")
self.models[model_config["id"]] = {
"instance": None,
"loaded": False,
"config": model_config,
"error": str(e)
}
def _download_model(self, model_url, model_id):
"""Descargar modelo desde URL a archivo temporal"""
# Crear directorio temporal si no existe
temp_dir = "/tmp/tts_models"
os.makedirs(temp_dir, exist_ok=True)
# Nombre de archivo basado en ID del modelo
file_extension = self._get_file_extension(model_url)
temp_path = os.path.join(temp_dir, f"{model_id}{file_extension}")
# Si ya existe en cache temporal, usarlo
if os.path.exists(temp_path):
logger.info(f"📂 Usando modelo cacheado en temporal: {temp_path}")
return temp_path
logger.info(f"📥 Descargando modelo desde: {model_url}")
# Descargar con timeout largo para modelos grandes
response = self.session.get(model_url, stream=True, timeout=600)
response.raise_for_status()
# Escribir archivo en chunks
downloaded = 0
with open(temp_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=32768):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if downloaded % (100 * 1024 * 1024) == 0: # Cada 100MB
mb_downloaded = downloaded / (1024 * 1024)
logger.info(f"📥 Descargados {mb_downloaded:.1f} MB...")
logger.info(f"✅ Descarga completada: {temp_path}")
return temp_path
def _get_file_extension(self, url):
"""Obtener extensión de archivo desde URL"""
from urllib.parse import urlparse
path = urlparse(url).path
if '.' in path:
return '.' + path.split('.')[-1]
return '.bin' # Extensión por defecto
def _load_transformers_model(self, model_path, config):
"""Cargar modelo transformers desde archivo local"""
from transformers import AutoModelForTextToSpeech, AutoProcessor
logger.info(f"🤖 Cargando modelo transformers desde: {model_path}")
# Determinar dispositivo
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"💻 Usando dispositivo: {device}")
# Cargar modelo y processor
model = AutoModelForTextToSpeech.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
low_cpu_mem_usage=True
).to(device)
processor = AutoProcessor.from_pretrained(model_path)
# Configurar para evaluación
model.eval()
return {
"model": model,
"processor": processor,
"device": device,
"model_type": "transformers"
}
def _load_coqui_model(self, model_path, config):
"""Cargar modelo Coqui TTS desde archivo local"""
from TTS.api import TTS
logger.info(f"🤖 Cargando modelo Coqui TTS desde: {model_path}")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"💻 Usando dispositivo: {device}")
# Coqui TTS puede cargar modelos locales
tts_instance = TTS(model_path, gpu=(device == "cuda"))
return {
"tts": tts_instance,
"device": device,
"model_type": "coqui"
}
def _load_speecht5_model(self, model_path, config):
"""Cargar modelo SpeechT5 desde archivo local"""
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
logger.info(f"🤖 Cargando modelo SpeechT5 desde: {model_path}")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"💻 Usando dispositivo: {device}")
# Cargar componentes
processor = SpeechT5Processor.from_pretrained(model_path)
model = SpeechT5ForTextToSpeech.from_pretrained(model_path).to(device)
# Cargar vocoder si se especifica
vocoder = None
if "vocoder_url" in config:
vocoder_path = self._download_model(config["vocoder_url"], f"{config['id']}_vocoder")
vocoder = SpeechT5HifiGan.from_pretrained(vocoder_path).to(device)
os.remove(vocoder_path)
# Configurar para evaluación
model.eval()
if vocoder:
vocoder.eval()
return {
"processor": processor,
"model": model,
"vocoder": vocoder,
"device": device,
"model_type": "speecht5"
}
def get_model(self, model_id):
"""Obtener instancia de modelo por ID"""
return self.models.get(model_id)
def generate_speech(self, model_id, text, **kwargs):
"""Generar audio con modelo específico"""
if not self.generation_lock.acquire(blocking=False):
return {"error": "Servidor ocupado - Generación en progreso"}
try:
model_data = self.get_model(model_id)
if not model_data or not model_data["loaded"]:
error_msg = f"Modelo {model_id} no cargado"
if model_data and "error" in model_data:
error_msg += f": {model_data['error']}"
return {"error": error_msg}
# Validar longitud del texto
if len(text) > MAX_TEXT_LENGTH:
text = text[:MAX_TEXT_LENGTH]
logger.warning(f"Texto truncado a {MAX_TEXT_LENGTH} caracteres")
result = [None]
exception = [None]
def generate():
try:
model_type = model_data["type"]
if model_type == "transformers":
result[0] = self._generate_transformers_speech(model_data, text, kwargs)
elif model_type == "coqui":
result[0] = self._generate_coqui_speech(model_data, text, kwargs)
elif model_type == "speecht5":
result[0] = self._generate_speecht5_speech(model_data, text, kwargs)
else:
exception[0] = ValueError(f"Tipo de modelo no soportado: {model_type}")
except Exception as e:
exception[0] = e
# Ejecutar generación en thread separado
gen_thread = threading.Thread(target=generate, daemon=True)
gen_thread.start()
gen_thread.join(timeout=120) # Timeout de 2 minutos
if gen_thread.is_alive():
return {"error": "Timeout en generación (120 segundos)"}
if exception[0]:
raise exception[0]
return result[0]
finally:
self.generation_lock.release()
gc.collect()
def _generate_transformers_speech(self, model_data, text, params):
"""Generar audio con modelo transformers"""
import torch
model = model_data["instance"]["model"]
processor = model_data["instance"]["processor"]
device = model_data["instance"]["device"]
# Preparar inputs
inputs = processor(text=text, return_tensors="pt").to(device)
# Parámetros de generación
generate_kwargs = {}
if "speed" in params:
# Ajustar longitud basado en velocidad
pass # Los modelos transformers no siempre soportan ajuste de velocidad
# Generar audio
with torch.no_grad():
speech = model.generate(**inputs, **generate_kwargs)
audio_array = speech.cpu().numpy().squeeze()
sample_rate = getattr(model.config, "sample_rate", 16000)
# Aplicar ajuste de velocidad si se especifica
if "speed" in params and params["speed"] != 1.0:
audio_array = self._adjust_speed(audio_array, sample_rate, params["speed"])
return {
"audio": audio_array,
"sample_rate": sample_rate,
"duration": len(audio_array) / sample_rate
}
def _generate_coqui_speech(self, model_data, text, params):
"""Generar audio con Coqui TTS"""
tts = model_data["instance"]["tts"]
# Parámetros para Coqui
speaker = params.get("speaker")
language = params.get("language", "es")
speed = params.get("speed", 1.0)
# Generar audio
if hasattr(tts, 'tts_to_file'):
# Usar archivo temporal
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tts.tts_to_file(
text=text,
speaker=speaker,
language=language,
speed=speed,
file_path=tmp.name
)
# Leer archivo generado
audio_array, sample_rate = sf.read(tmp.name)
os.unlink(tmp.name)
else:
# Método antiguo
audio_array = tts.tts(
text=text,
speaker=speaker,
language=language,
speed=speed
)
sample_rate = 24000 # Default para XTTS
# Ajustar duración si es muy larga
max_samples = MAX_AUDIO_LENGTH * sample_rate
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
logger.warning(f"Audio truncado a {MAX_AUDIO_LENGTH} segundos")
return {
"audio": audio_array,
"sample_rate": sample_rate,
"duration": len(audio_array) / sample_rate
}
def _generate_speecht5_speech(self, model_data, text, params):
"""Generar audio con SpeechT5"""
import torch
processor = model_data["instance"]["processor"]
model = model_data["instance"]["model"]
vocoder = model_data["instance"]["vocoder"]
device = model_data["instance"]["device"]
# Preparar inputs
inputs = processor(text=text, return_tensors="pt").to(device)
# Obtener o generar speaker embeddings
speaker_embeddings = params.get("speaker_embeddings")
if speaker_embeddings is None:
# Embedding por defecto
speaker_embeddings = torch.randn((1, 512)).to(device)
elif isinstance(speaker_embeddings, list):
speaker_embeddings = torch.tensor(speaker_embeddings).to(device)
# Generar audio
with torch.no_grad():
speech = model.generate_speech(
inputs["input_ids"],
speaker_embeddings,
vocoder=vocoder
)
audio_array = speech.cpu().numpy().squeeze()
sample_rate = 16000 # SpeechT5 usa 16kHz
# Ajustar velocidad si se especifica
if "speed" in params and params["speed"] != 1.0:
audio_array = self._adjust_speed(audio_array, sample_rate, params["speed"])
# Ajustar duración
max_samples = MAX_AUDIO_LENGTH * sample_rate
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
return {
"audio": audio_array,
"sample_rate": sample_rate,
"duration": len(audio_array) / sample_rate
}
def _adjust_speed(self, audio_array, sample_rate, speed_factor):
"""Ajustar velocidad del audio"""
if speed_factor == 1.0:
return audio_array
try:
import librosa
# Ajustar velocidad manteniendo tono
audio_stretched = librosa.effects.time_stretch(
y=audio_array,
rate=speed_factor
)
return audio_stretched
except ImportError:
logger.warning("Librosa no instalado, omitiendo ajuste de velocidad")
return audio_array
def get_loaded_models(self):
"""Obtener lista de modelos cargados"""
loaded = []
for model_id, data in self.models.items():
if data["loaded"]:
loaded.append(model_id)
return loaded
def get_all_models_status(self):
"""Obtener estado de todos los modelos"""
status = {}
for model_id, data in self.models.items():
status[model_id] = {
"loaded": data["loaded"],
"type": data.get("type", "unknown"),
"config": data["config"]
}
if "error" in data:
status[model_id]["error"] = data["error"]
if "loaded_at" in data:
status[model_id]["loaded_at"] = data["loaded_at"]
return status
# Inicializar el gestor de TTS
tts_manager = TTSManager(TTS_MODELS)
def audio_to_wav_bytes(audio_array, sample_rate):
"""Convertir array de audio a bytes WAV"""
wav_buffer = io.BytesIO()
sf.write(wav_buffer, audio_array, sample_rate, format='WAV')
wav_buffer.seek(0)
return wav_buffer
@app.route('/')
def home():
loaded_models = tts_manager.get_loaded_models()
status_html = "<ul>"
for model_id, model_data in tts_manager.models.items():
status = "✅" if model_data["loaded"] else "❌"
model_type = model_data.get("type", "unknown")
status_html += f"<li>{model_id} ({model_type}): {status}</li>"
status_html += "</ul>"
return f'''
<!DOCTYPE html>
<html>
<head>
<title>TTS API - Text to Speech</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.config {{ background: #f0f0f0; padding: 15px; border-radius: 5px; margin-bottom: 20px; }}
.endpoint {{ background: #e8f4f8; padding: 10px; border-left: 4px solid #2196F3; margin: 10px 0; }}
</style>
</head>
<body>
<h1>🔊 TTS API - Text to Speech</h1>
<div class="config">
<h3>⚙️ Configuración</h3>
<p><strong>Max Text Length:</strong> {MAX_TEXT_LENGTH} caracteres</p>
<p><strong>Max Audio Length:</strong> {MAX_AUDIO_LENGTH} segundos</p>
<p><strong>Device:</strong> {"CUDA/GPU" if torch.cuda.is_available() else "CPU"}</p>
</div>
<h2>📦 Modelos TTS cargados:</h2>
{status_html}
<p>Total modelos: {len(loaded_models)}/{len(TTS_MODELS)}</p>
<h2>🔗 Endpoints disponibles:</h2>
<div class="endpoint">
<strong>GET /tts?text=&lt;texto&gt;[&params]</strong><br>
Genera audio desde texto. Parámetros opcionales:<br>
• model= (ID del modelo, default: primer modelo)<br>
• speed= (0.5-2.0, velocidad de habla)<br>
• language= (idioma, ej: es, en)<br>
• speaker= (voz específica)<br>
• download= (true/false, forzar descarga)
</div>
<div class="endpoint">
<strong>POST /v1/audio/speech</strong><br>
Compatible con OpenAI Audio API
</div>
<div class="endpoint">
<strong>POST /generate</strong><br>
Endpoint alternativo con JSON
</div>
<div class="endpoint">
<strong>GET /health</strong><br>
Estado del servicio
</div>
<div class="endpoint">
<strong>GET /models</strong><br>
Lista todos los modelos disponibles
</div>
</body>
</html>
'''
@app.route('/v1/audio/speech', methods=['POST'])
def openai_compatible_endpoint():
"""Endpoint compatible con OpenAI Audio API"""
try:
data = request.get_json()
text = data.get('input', '')
model_id = data.get('model', TTS_MODELS[0]["id"])
if not text:
return jsonify({"error": "El campo 'input' es requerido"}), 400
if len(text) > MAX_TEXT_LENGTH:
return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
# Extraer parámetros
params = {k: v for k, v in data.items() if k not in ['input', 'model']}
# Generar audio
result = tts_manager.generate_speech(model_id, text, **params)
if "error" in result:
return jsonify(result), 500
# Convertir a bytes WAV
wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
# Devolver como audio
return Response(
wav_buffer.read(),
mimetype='audio/wav',
headers={'Content-Disposition': f'attachment; filename="speech.wav"'}
)
except Exception as e:
logger.error(f"Error en OpenAI endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/tts', methods=['GET'])
def tts_get_endpoint():
"""Endpoint GET para generar audio desde texto"""
try:
# Obtener parámetros
text = request.args.get('text', '')
model_id = request.args.get('model', TTS_MODELS[0]["id"])
speed = float(request.args.get('speed', 1.0))
language = request.args.get('language', 'es')
speaker = request.args.get('speaker')
download = request.args.get('download', 'false').lower() == 'true'
# Validaciones
if not text:
return jsonify({"error": "El parámetro 'text' es requerido"}), 400
if len(text) > MAX_TEXT_LENGTH:
return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
if speed < 0.5 or speed > 2.0:
return jsonify({"error": "El parámetro 'speed' debe estar entre 0.5 y 2.0"}), 400
# Preparar parámetros
params = {
"speed": speed,
"language": language
}
if speaker:
params["speaker"] = speaker
# Generar audio
result = tts_manager.generate_speech(model_id, text, **params)
if "error" in result:
return jsonify(result), 500
# Convertir a bytes WAV
wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
# Configurar respuesta
filename = f"tts_{model_id}.wav"
if download:
return send_file(
wav_buffer,
mimetype='audio/wav',
as_attachment=True,
download_name=filename
)
else:
return Response(
wav_buffer.read(),
mimetype='audio/wav',
headers={'Content-Disposition': f'inline; filename="{filename}"'}
)
except ValueError as e:
return jsonify({"error": f"Parámetros inválidos: {str(e)}"}), 400
except Exception as e:
logger.error(f"Error en TTS GET: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/generate', methods=['POST'])
def generate_endpoint():
"""Endpoint alternativo para generación de audio"""
try:
data = request.get_json()
text = data.get('text', '')
model_id = data.get('model', TTS_MODELS[0]["id"])
if not text:
return jsonify({"error": "El campo 'text' es requerido"}), 400
if len(text) > MAX_TEXT_LENGTH:
return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
# Extraer parámetros
params = {k: v for k, v in data.items() if k not in ['text', 'model']}
# Generar audio
result = tts_manager.generate_speech(model_id, text, **params)
if "error" in result:
return jsonify(result), 500
# Convertir a bytes
wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
# Devolver como audio
return Response(
wav_buffer.read(),
mimetype='audio/wav',
headers={'Content-Disposition': f'inline; filename="generated.wav"'}
)
except Exception as e:
logger.error(f"Error en generate endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
loaded_models = tts_manager.get_loaded_models()
return jsonify({
"status": "healthy" if len(loaded_models) > 0 else "error",
"loaded_models": loaded_models,
"total_models": len(TTS_MODELS),
"device": "cuda" if torch.cuda.is_available() else "cpu",
"config": {
"max_text_length": MAX_TEXT_LENGTH,
"max_audio_length": MAX_AUDIO_LENGTH
}
})
@app.route('/models', methods=['GET'])
def list_models():
"""Endpoint para listar todos los modelos y su estado"""
return jsonify({
"available_models": TTS_MODELS,
"status": tts_manager.get_all_models_status(),
"config": {
"max_text_length": MAX_TEXT_LENGTH,
"max_audio_length": MAX_AUDIO_LENGTH
}
})
@app.route('/models/<model_id>', methods=['GET'])
def get_model_status(model_id):
"""Endpoint para obtener el estado de un modelo específico"""
model_data = tts_manager.get_model(model_id)
if not model_data:
return jsonify({"error": f"Modelo '{model_id}' no encontrado"}), 404
return jsonify({
"model": model_id,
"loaded": model_data["loaded"],
"type": model_data.get("type", "unknown"),
"config": model_data["config"],
"error": model_data.get("error"),
"loaded_at": model_data.get("loaded_at")
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=False)