MarcosFRGames commited on
Commit
15d78d3
·
verified ·
1 Parent(s): da5bb45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +506 -216
app.py CHANGED
@@ -1,25 +1,34 @@
1
- from flask import Flask, request, jsonify, Response
2
  import os
 
3
  import logging
4
  import threading
5
- import time
6
- from llama_cpp import Llama
7
- import requests
8
  import tempfile
9
- import json
10
  import gc
 
 
 
 
11
  from concurrent.futures import ThreadPoolExecutor
 
 
12
 
13
- app = Flask(__name__)
14
  logging.basicConfig(level=logging.INFO)
 
15
 
16
- MAX_CONTEXT_TOKENS = 1024 * 4
17
- MAX_GENERATION_TOKENS = 1024 * 4
18
 
 
19
  with open('engines.json', 'r') as f:
20
- MODELS = json.load(f)
 
 
 
 
21
 
22
- class LLMManager:
23
  def __init__(self, models_config):
24
  self.models = {}
25
  self.models_config = models_config
@@ -32,107 +41,226 @@ class LLMManager:
32
  self.load_all_models()
33
 
34
  def load_all_models(self):
35
- """Cargar todos los modelos en RAM"""
36
  for model_config in self.models_config:
37
  try:
38
- model_name = model_config["name"]
39
- logging.info(f"🚀 Cargando modelo: {model_name}")
 
40
 
41
- temp_path = self._download_model(model_config["url"])
42
 
 
 
 
 
43
  actual_size = os.path.getsize(temp_path)
44
- actual_gb = actual_size / (1024*1024*1024)
45
- logging.info(f"📊 Tamaño descargado para {model_name}: {actual_gb:.2f} GB")
46
-
47
- n_batch = model_config.get("n_batch", 96)
48
-
49
- logging.info(f"🔄 Cargando {model_name} en RAM…")
50
- llm_instance = Llama(
51
- model_path=temp_path,
52
- n_ctx=MAX_CONTEXT_TOKENS,
53
- n_batch=n_batch,
54
- n_threads=2,
55
- n_threads_batch=2,
56
- use_mlock=True,
57
- mmap=True,
58
- low_vram=False,
59
- vocab_only=False,
60
- verbose=False,
61
- logits_all=False,
62
- mul_mat_q=True
63
- )
64
 
 
 
 
 
 
 
 
 
 
 
65
  os.remove(temp_path)
66
-
67
- self.models[model_name] = {
68
- "instance": llm_instance,
 
69
  "loaded": True,
70
- "config": model_config
 
 
71
  }
72
- logging.info(f"✅ Modelo {model_name} cargado")
73
-
74
  except Exception as e:
75
- logging.error(f"❌ Error cargando modelo {model_config['name']}: {e}")
76
- self.models[model_config["name"]] = {
77
  "instance": None,
78
  "loaded": False,
79
  "config": model_config,
80
  "error": str(e)
81
  }
82
 
83
- def _download_model(self, model_url):
84
- """Descargar modelo"""
85
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".gguf")
86
- temp_path = temp_file.name
87
- temp_file.close()
88
-
89
- logging.info("📥 Descargando modelo…")
 
 
 
 
 
 
 
90
 
91
- response = self.session.get(model_url, stream=True, timeout=300)
 
 
 
92
  response.raise_for_status()
93
-
 
94
  downloaded = 0
95
  with open(temp_path, 'wb') as f:
96
  for chunk in response.iter_content(chunk_size=32768):
97
  if chunk:
98
  f.write(chunk)
99
  downloaded += len(chunk)
 
 
 
100
 
 
101
  return temp_path
102
 
103
- def get_model(self, model_name):
104
- """Obtener instancia de modelo por nombre"""
105
- return self.models.get(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def chat_completion(self, model_name, messages, **kwargs):
108
- """Generar respuesta con modelo específico"""
109
  if not self.generation_lock.acquire(blocking=False):
110
  return {"error": "Servidor ocupado - Generación en progreso"}
111
 
112
  try:
113
- model_data = self.get_model(model_name)
114
 
115
  if not model_data or not model_data["loaded"]:
116
- error_msg = f"Modelo {model_name} no cargado"
117
  if model_data and "error" in model_data:
118
  error_msg += f": {model_data['error']}"
119
  return {"error": error_msg}
120
 
 
 
 
 
 
121
  result = [None]
122
  exception = [None]
123
 
124
  def generate():
125
  try:
126
- result[0] = model_data["instance"].create_chat_completion(
127
- messages=messages,
128
- **kwargs
129
- )
 
 
 
 
 
 
 
130
  except Exception as e:
131
  exception[0] = e
132
 
 
133
  gen_thread = threading.Thread(target=generate, daemon=True)
134
  gen_thread.start()
135
- gen_thread.join(timeout=120)
136
 
137
  if gen_thread.is_alive():
138
  return {"error": "Timeout en generación (120 segundos)"}
@@ -140,51 +268,205 @@ class LLMManager:
140
  if exception[0]:
141
  raise exception[0]
142
 
143
- result[0]["provider"] = "telechars-ai"
144
- result[0]["model"] = model_name
145
  return result[0]
146
 
147
  finally:
148
  self.generation_lock.release()
149
  gc.collect()
150
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def get_loaded_models(self):
152
  """Obtener lista de modelos cargados"""
153
  loaded = []
154
- for name, data in self.models.items():
155
  if data["loaded"]:
156
- loaded.append(name)
157
  return loaded
158
 
159
  def get_all_models_status(self):
160
  """Obtener estado de todos los modelos"""
161
  status = {}
162
- for name, data in self.models.items():
163
- status[name] = {
164
  "loaded": data["loaded"],
165
- "url": data["config"]["url"]
 
166
  }
167
  if "error" in data:
168
- status[name]["error"] = data["error"]
 
 
169
  return status
170
 
171
- # Inicializar el gestor con todos los modelos
172
- llm_manager = LLMManager(MODELS)
 
 
 
 
 
 
 
173
 
174
  @app.route('/')
175
  def home():
176
- loaded_models = llm_manager.get_loaded_models()
177
  status_html = "<ul>"
178
- for model_name, model_data in llm_manager.models.items():
179
  status = "✅" if model_data["loaded"] else "❌"
180
- status_html += f"<li>{model_name}: {status}</li>"
 
181
  status_html += "</ul>"
182
 
183
  return f'''
184
  <!DOCTYPE html>
185
  <html>
186
  <head>
187
- <title>TeleChars AI API</title>
188
  <style>
189
  body {{ font-family: Arial, sans-serif; margin: 40px; }}
190
  .config {{ background: #f0f0f0; padding: 15px; border-radius: 5px; margin-bottom: 20px; }}
@@ -192,33 +474,38 @@ def home():
192
  </style>
193
  </head>
194
  <body>
195
- <h1>TeleChars AI API</h1>
196
 
197
  <div class="config">
198
  <h3>⚙️ Configuración</h3>
199
- <p><strong>Max Context Tokens:</strong> {MAX_CONTEXT_TOKENS}</p>
200
- <p><strong>Max Generation Tokens:</strong> {MAX_GENERATION_TOKENS}</p>
 
201
  </div>
202
 
203
- <h2>📦 Modelos cargados:</h2>
204
  {status_html}
205
- <p>Total modelos: {len(loaded_models)}/{len(MODELS)}</p>
206
 
207
  <h2>🔗 Endpoints disponibles:</h2>
208
  <div class="endpoint">
209
- <strong>GET /generate/&lt;mensaje&gt;[?params]</strong><br>
210
- Devuelve solo el texto generado. Parámetros opcionales:<br>
211
- system= (instrucciones del sistema)<br>
212
- temperature= (0.0-2.0)<br>
213
- top_p= (0.0-1.0)<br>
214
- top_k= (0-100)<br>
215
- model= (nombre del modelo)<br>
216
- • max_tokens= (máximo tokens a generar, default: {MAX_GENERATION_TOKENS})
 
 
 
 
217
  </div>
218
 
219
  <div class="endpoint">
220
- <strong>POST /v1/chat/completions</strong><br>
221
- Compatible con OpenAI API
222
  </div>
223
 
224
  <div class="endpoint">
@@ -234,150 +521,155 @@ def home():
234
  </html>
235
  '''
236
 
237
- @app.route('/v1/chat/completions', methods=['POST'])
238
- def chat_completions():
 
239
  try:
240
  data = request.get_json()
241
- messages = data.get('messages', [])
242
- model_name = data.get('model', MODELS[0]["name"])
243
 
244
- if model_name not in llm_manager.models:
245
- return jsonify({"error": f"Modelo '{model_name}' no encontrado. Modelos disponibles: {list(llm_manager.models.keys())}"}), 400
246
 
247
- kwargs = {}
248
- for key in data.keys():
249
- if key not in ['messages', 'model']:
250
- kwargs[key] = data[key]
251
 
252
- # Aplicar límite de tokens si no se especifica
253
- if 'max_tokens' not in kwargs:
254
- kwargs['max_tokens'] = MAX_GENERATION_TOKENS
255
- else:
256
- # Validar que max_tokens no exceda el máximo permitido
257
- if kwargs['max_tokens'] > MAX_GENERATION_TOKENS:
258
- kwargs['max_tokens'] = MAX_GENERATION_TOKENS
 
259
 
260
- result = llm_manager.chat_completion(model_name, messages, **kwargs)
261
-
262
  if "error" in result:
263
  return jsonify(result), 500
264
-
265
- return jsonify(result), 200
 
 
 
 
 
 
 
 
266
 
267
  except Exception as e:
 
268
  return jsonify({"error": str(e)}), 500
269
 
270
- @app.route('/generate/<path:user_message>', methods=['GET'])
271
- def generate_endpoint(user_message):
272
- """Endpoint GET para generar respuestas - Devuelve solo texto"""
273
  try:
274
- # Obtener parámetros GET con valores por defecto
275
- system_instruction = request.args.get('system', '')
276
- temperature = float(request.args.get('temperature', 0.7))
277
- top_p = float(request.args.get('top_p', 0.95))
278
- top_k = int(request.args.get('top_k', 0))
279
- model_name = request.args.get('model', MODELS[0]["name"])
280
- max_tokens = int(request.args.get('max_tokens', MAX_GENERATION_TOKENS))
281
-
282
- # Validar rangos
283
- if not 0 <= temperature <= 2:
284
- return Response(
285
- f"Error: El parámetro 'temperature' debe estar entre 0 y 2",
286
- status=400,
287
- mimetype='text/plain'
288
- )
289
 
290
- if not 0 <= top_p <= 1:
291
- return Response(
292
- f"Error: El parámetro 'top_p' debe estar entre 0 y 1",
293
- status=400,
294
- mimetype='text/plain'
295
- )
296
-
297
- if not 0 <= top_k <= 100:
298
- return Response(
299
- f"Error: El parámetro 'top_k' debe estar entre 0 y 100",
300
- status=400,
301
- mimetype='text/plain'
302
- )
303
 
304
- # Limitar max_tokens a la configuración máxima
305
- if max_tokens > MAX_GENERATION_TOKENS:
306
- max_tokens = MAX_GENERATION_TOKENS
307
 
308
- # Validar que el modelo existe
309
- if model_name not in llm_manager.models:
310
- return Response(
311
- f"Error: Modelo '{model_name}' no encontrado. Modelos disponibles: {', '.join(llm_manager.models.keys())}",
312
- status=400,
313
- mimetype='text/plain'
314
- )
315
 
316
- # Crear mensajes
317
- messages = [
318
- {"role": "system", "content": system_instruction},
319
- {"role": "user", "content": user_message}
320
- ]
321
-
322
- # Configurar parámetros
323
- kwargs = {
324
- "temperature": temperature,
325
- "top_p": top_p,
326
- "max_tokens": max_tokens,
327
- "stream": False
328
  }
329
-
330
- if top_k:
331
- try:
332
- kwargs["top_k"] = int(top_k)
333
- except ValueError:
334
- return Response("Error: top_k debe ser número entero", status=400)
335
-
336
- # Generar respuesta
337
- result = llm_manager.chat_completion(model_name, messages, **kwargs)
338
 
339
  if "error" in result:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  return Response(
341
- f"Error: {result['error']}",
342
- status=500,
343
- mimetype='text/plain'
344
  )
 
 
 
 
 
 
345
 
346
- response_text = result.get("choices", [{}])[0].get("message", {}).get("content", "")
 
 
 
 
347
 
348
- if not response_text:
349
- response_text = "No se generó respuesta"
350
 
351
- # Devolver solo el texto plano
352
- return Response(
353
- response_text,
354
- status=200,
355
- mimetype='text/plain'
356
- )
357
-
358
- except ValueError as e:
 
 
 
 
 
 
 
 
 
 
 
359
  return Response(
360
- f"Error: Parámetros inválidos - {str(e)}. Asegúrate de que temperature, top_p y max_tokens sean números válidos.",
361
- status=400,
362
- mimetype='text/plain'
363
  )
 
364
  except Exception as e:
365
- return Response(
366
- f"Error: {str(e)}",
367
- status=500,
368
- mimetype='text/plain'
369
- )
370
 
371
  @app.route('/health', methods=['GET'])
372
  def health():
373
- loaded_models = llm_manager.get_loaded_models()
374
  return jsonify({
375
  "status": "healthy" if len(loaded_models) > 0 else "error",
376
  "loaded_models": loaded_models,
377
- "total_models": len(MODELS),
 
378
  "config": {
379
- "max_context_tokens": MAX_CONTEXT_TOKENS,
380
- "max_generation_tokens": MAX_GENERATION_TOKENS
381
  }
382
  })
383
 
@@ -385,30 +677,28 @@ def health():
385
  def list_models():
386
  """Endpoint para listar todos los modelos y su estado"""
387
  return jsonify({
388
- "available_models": MODELS,
389
- "status": llm_manager.get_all_models_status(),
390
  "config": {
391
- "max_context_tokens": MAX_CONTEXT_TOKENS,
392
- "max_generation_tokens": MAX_GENERATION_TOKENS
393
  }
394
  })
395
 
396
- @app.route('/models/<model_name>', methods=['GET'])
397
- def get_model_status(model_name):
398
  """Endpoint para obtener el estado de un modelo específico"""
399
- model_data = llm_manager.get_model(model_name)
400
  if not model_data:
401
- return jsonify({"error": f"Modelo '{model_name}' no encontrado"}), 404
402
 
403
  return jsonify({
404
- "model": model_name,
405
  "loaded": model_data["loaded"],
406
- "url": model_data["config"]["url"],
 
407
  "error": model_data.get("error"),
408
- "config": {
409
- "max_context_tokens": MAX_CONTEXT_TOKENS,
410
- "max_generation_tokens": MAX_GENERATION_TOKENS
411
- }
412
  })
413
 
414
  if __name__ == '__main__':
 
1
+ from flask import Flask, request, jsonify, Response, send_file
2
  import os
3
+ import json
4
  import logging
5
  import threading
 
 
 
6
  import tempfile
7
+ import time
8
  import gc
9
+ import torch
10
+ import numpy as np
11
+ from datetime import datetime
12
+ import requests
13
  from concurrent.futures import ThreadPoolExecutor
14
+ import io
15
+ import soundfile as sf
16
 
17
+ # Configuración básica de logging
18
  logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
+ app = Flask(__name__)
 
22
 
23
+ # Cargar configuración de modelos
24
  with open('engines.json', 'r') as f:
25
+ TTS_MODELS = json.load(f)
26
+
27
+ # Constantes de configuración
28
+ MAX_AUDIO_LENGTH = 30 # segundos máximo
29
+ MAX_TEXT_LENGTH = 500 # caracteres máximo
30
 
31
+ class TTSManager:
32
  def __init__(self, models_config):
33
  self.models = {}
34
  self.models_config = models_config
 
41
  self.load_all_models()
42
 
43
  def load_all_models(self):
44
+ """Cargar todos los modelos TTS en RAM desde URLs"""
45
  for model_config in self.models_config:
46
  try:
47
+ model_id = model_config["id"]
48
+ model_url = model_config["url"]
49
+ model_type = model_config.get("type", "transformers")
50
 
51
+ logger.info(f"🚀 Cargando modelo TTS: {model_id}")
52
 
53
+ # Descargar modelo a archivo temporal
54
+ temp_path = self._download_model(model_url, model_id)
55
+
56
+ # Verificar tamaño del archivo
57
  actual_size = os.path.getsize(temp_path)
58
+ actual_mb = actual_size / (1024*1024)
59
+ logger.info(f"📊 Tamaño descargado para {model_id}: {actual_mb:.2f} MB")
60
+
61
+ # Cargar modelo según su tipo
62
+ logger.info(f"🔄 Cargando {model_id} en RAM...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ if model_type == "transformers":
65
+ model_instance = self._load_transformers_model(temp_path, model_config)
66
+ elif model_type == "coqui":
67
+ model_instance = self._load_coqui_model(temp_path, model_config)
68
+ elif model_type == "speecht5":
69
+ model_instance = self._load_speecht5_model(temp_path, model_config)
70
+ else:
71
+ raise ValueError(f"Tipo de modelo no soportado: {model_type}")
72
+
73
+ # Limpiar archivo temporal
74
  os.remove(temp_path)
75
+ logger.info(f"🗑️ Archivo temporal {temp_path} eliminado")
76
+
77
+ self.models[model_id] = {
78
+ "instance": model_instance,
79
  "loaded": True,
80
+ "config": model_config,
81
+ "type": model_type,
82
+ "loaded_at": datetime.now().isoformat()
83
  }
84
+ logger.info(f"✅ Modelo TTS {model_id} cargado exitosamente")
85
+
86
  except Exception as e:
87
+ logger.error(f"❌ Error cargando modelo {model_config.get('id', 'unknown')}: {e}")
88
+ self.models[model_config["id"]] = {
89
  "instance": None,
90
  "loaded": False,
91
  "config": model_config,
92
  "error": str(e)
93
  }
94
 
95
+ def _download_model(self, model_url, model_id):
96
+ """Descargar modelo desde URL a archivo temporal"""
97
+ # Crear directorio temporal si no existe
98
+ temp_dir = "/tmp/tts_models"
99
+ os.makedirs(temp_dir, exist_ok=True)
100
+
101
+ # Nombre de archivo basado en ID del modelo
102
+ file_extension = self._get_file_extension(model_url)
103
+ temp_path = os.path.join(temp_dir, f"{model_id}{file_extension}")
104
+
105
+ # Si ya existe en cache temporal, usarlo
106
+ if os.path.exists(temp_path):
107
+ logger.info(f"📂 Usando modelo cacheado en temporal: {temp_path}")
108
+ return temp_path
109
 
110
+ logger.info(f"📥 Descargando modelo desde: {model_url}")
111
+
112
+ # Descargar con timeout largo para modelos grandes
113
+ response = self.session.get(model_url, stream=True, timeout=600)
114
  response.raise_for_status()
115
+
116
+ # Escribir archivo en chunks
117
  downloaded = 0
118
  with open(temp_path, 'wb') as f:
119
  for chunk in response.iter_content(chunk_size=32768):
120
  if chunk:
121
  f.write(chunk)
122
  downloaded += len(chunk)
123
+ if downloaded % (100 * 1024 * 1024) == 0: # Cada 100MB
124
+ mb_downloaded = downloaded / (1024 * 1024)
125
+ logger.info(f"📥 Descargados {mb_downloaded:.1f} MB...")
126
 
127
+ logger.info(f"✅ Descarga completada: {temp_path}")
128
  return temp_path
129
 
130
+ def _get_file_extension(self, url):
131
+ """Obtener extensión de archivo desde URL"""
132
+ from urllib.parse import urlparse
133
+ path = urlparse(url).path
134
+ if '.' in path:
135
+ return '.' + path.split('.')[-1]
136
+ return '.bin' # Extensión por defecto
137
+
138
+ def _load_transformers_model(self, model_path, config):
139
+ """Cargar modelo transformers desde archivo local"""
140
+ from transformers import AutoModelForTextToSpeech, AutoProcessor
141
+
142
+ logger.info(f"🤖 Cargando modelo transformers desde: {model_path}")
143
+
144
+ # Determinar dispositivo
145
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
146
+ logger.info(f"💻 Usando dispositivo: {device}")
147
+
148
+ # Cargar modelo y processor
149
+ model = AutoModelForTextToSpeech.from_pretrained(
150
+ model_path,
151
+ torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
152
+ low_cpu_mem_usage=True
153
+ ).to(device)
154
+
155
+ processor = AutoProcessor.from_pretrained(model_path)
156
+
157
+ # Configurar para evaluación
158
+ model.eval()
159
+
160
+ return {
161
+ "model": model,
162
+ "processor": processor,
163
+ "device": device,
164
+ "model_type": "transformers"
165
+ }
166
+
167
+ def _load_coqui_model(self, model_path, config):
168
+ """Cargar modelo Coqui TTS desde archivo local"""
169
+ from TTS.api import TTS
170
+
171
+ logger.info(f"🤖 Cargando modelo Coqui TTS desde: {model_path}")
172
+
173
+ device = "cuda" if torch.cuda.is_available() else "cpu"
174
+ logger.info(f"💻 Usando dispositivo: {device}")
175
+
176
+ # Coqui TTS puede cargar modelos locales
177
+ tts_instance = TTS(model_path, gpu=(device == "cuda"))
178
+
179
+ return {
180
+ "tts": tts_instance,
181
+ "device": device,
182
+ "model_type": "coqui"
183
+ }
184
+
185
+ def _load_speecht5_model(self, model_path, config):
186
+ """Cargar modelo SpeechT5 desde archivo local"""
187
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
188
+
189
+ logger.info(f"🤖 Cargando modelo SpeechT5 desde: {model_path}")
190
+
191
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
192
+ logger.info(f"💻 Usando dispositivo: {device}")
193
+
194
+ # Cargar componentes
195
+ processor = SpeechT5Processor.from_pretrained(model_path)
196
+ model = SpeechT5ForTextToSpeech.from_pretrained(model_path).to(device)
197
+
198
+ # Cargar vocoder si se especifica
199
+ vocoder = None
200
+ if "vocoder_url" in config:
201
+ vocoder_path = self._download_model(config["vocoder_url"], f"{config['id']}_vocoder")
202
+ vocoder = SpeechT5HifiGan.from_pretrained(vocoder_path).to(device)
203
+ os.remove(vocoder_path)
204
+
205
+ # Configurar para evaluación
206
+ model.eval()
207
+ if vocoder:
208
+ vocoder.eval()
209
+
210
+ return {
211
+ "processor": processor,
212
+ "model": model,
213
+ "vocoder": vocoder,
214
+ "device": device,
215
+ "model_type": "speecht5"
216
+ }
217
+
218
+ def get_model(self, model_id):
219
+ """Obtener instancia de modelo por ID"""
220
+ return self.models.get(model_id)
221
 
222
+ def generate_speech(self, model_id, text, **kwargs):
223
+ """Generar audio con modelo específico"""
224
  if not self.generation_lock.acquire(blocking=False):
225
  return {"error": "Servidor ocupado - Generación en progreso"}
226
 
227
  try:
228
+ model_data = self.get_model(model_id)
229
 
230
  if not model_data or not model_data["loaded"]:
231
+ error_msg = f"Modelo {model_id} no cargado"
232
  if model_data and "error" in model_data:
233
  error_msg += f": {model_data['error']}"
234
  return {"error": error_msg}
235
 
236
+ # Validar longitud del texto
237
+ if len(text) > MAX_TEXT_LENGTH:
238
+ text = text[:MAX_TEXT_LENGTH]
239
+ logger.warning(f"Texto truncado a {MAX_TEXT_LENGTH} caracteres")
240
+
241
  result = [None]
242
  exception = [None]
243
 
244
  def generate():
245
  try:
246
+ model_type = model_data["type"]
247
+
248
+ if model_type == "transformers":
249
+ result[0] = self._generate_transformers_speech(model_data, text, kwargs)
250
+ elif model_type == "coqui":
251
+ result[0] = self._generate_coqui_speech(model_data, text, kwargs)
252
+ elif model_type == "speecht5":
253
+ result[0] = self._generate_speecht5_speech(model_data, text, kwargs)
254
+ else:
255
+ exception[0] = ValueError(f"Tipo de modelo no soportado: {model_type}")
256
+
257
  except Exception as e:
258
  exception[0] = e
259
 
260
+ # Ejecutar generación en thread separado
261
  gen_thread = threading.Thread(target=generate, daemon=True)
262
  gen_thread.start()
263
+ gen_thread.join(timeout=120) # Timeout de 2 minutos
264
 
265
  if gen_thread.is_alive():
266
  return {"error": "Timeout en generación (120 segundos)"}
 
268
  if exception[0]:
269
  raise exception[0]
270
 
 
 
271
  return result[0]
272
 
273
  finally:
274
  self.generation_lock.release()
275
  gc.collect()
276
+
277
+ def _generate_transformers_speech(self, model_data, text, params):
278
+ """Generar audio con modelo transformers"""
279
+ import torch
280
+
281
+ model = model_data["instance"]["model"]
282
+ processor = model_data["instance"]["processor"]
283
+ device = model_data["instance"]["device"]
284
+
285
+ # Preparar inputs
286
+ inputs = processor(text=text, return_tensors="pt").to(device)
287
+
288
+ # Parámetros de generación
289
+ generate_kwargs = {}
290
+ if "speed" in params:
291
+ # Ajustar longitud basado en velocidad
292
+ pass # Los modelos transformers no siempre soportan ajuste de velocidad
293
+
294
+ # Generar audio
295
+ with torch.no_grad():
296
+ speech = model.generate(**inputs, **generate_kwargs)
297
+
298
+ audio_array = speech.cpu().numpy().squeeze()
299
+ sample_rate = getattr(model.config, "sample_rate", 16000)
300
+
301
+ # Aplicar ajuste de velocidad si se especifica
302
+ if "speed" in params and params["speed"] != 1.0:
303
+ audio_array = self._adjust_speed(audio_array, sample_rate, params["speed"])
304
+
305
+ return {
306
+ "audio": audio_array,
307
+ "sample_rate": sample_rate,
308
+ "duration": len(audio_array) / sample_rate
309
+ }
310
+
311
+ def _generate_coqui_speech(self, model_data, text, params):
312
+ """Generar audio con Coqui TTS"""
313
+ tts = model_data["instance"]["tts"]
314
+
315
+ # Parámetros para Coqui
316
+ speaker = params.get("speaker")
317
+ language = params.get("language", "es")
318
+ speed = params.get("speed", 1.0)
319
+
320
+ # Generar audio
321
+ if hasattr(tts, 'tts_to_file'):
322
+ # Usar archivo temporal
323
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
324
+ tts.tts_to_file(
325
+ text=text,
326
+ speaker=speaker,
327
+ language=language,
328
+ speed=speed,
329
+ file_path=tmp.name
330
+ )
331
+
332
+ # Leer archivo generado
333
+ audio_array, sample_rate = sf.read(tmp.name)
334
+ os.unlink(tmp.name)
335
+ else:
336
+ # Método antiguo
337
+ audio_array = tts.tts(
338
+ text=text,
339
+ speaker=speaker,
340
+ language=language,
341
+ speed=speed
342
+ )
343
+ sample_rate = 24000 # Default para XTTS
344
+
345
+ # Ajustar duración si es muy larga
346
+ max_samples = MAX_AUDIO_LENGTH * sample_rate
347
+ if len(audio_array) > max_samples:
348
+ audio_array = audio_array[:max_samples]
349
+ logger.warning(f"Audio truncado a {MAX_AUDIO_LENGTH} segundos")
350
+
351
+ return {
352
+ "audio": audio_array,
353
+ "sample_rate": sample_rate,
354
+ "duration": len(audio_array) / sample_rate
355
+ }
356
+
357
+ def _generate_speecht5_speech(self, model_data, text, params):
358
+ """Generar audio con SpeechT5"""
359
+ import torch
360
+
361
+ processor = model_data["instance"]["processor"]
362
+ model = model_data["instance"]["model"]
363
+ vocoder = model_data["instance"]["vocoder"]
364
+ device = model_data["instance"]["device"]
365
+
366
+ # Preparar inputs
367
+ inputs = processor(text=text, return_tensors="pt").to(device)
368
+
369
+ # Obtener o generar speaker embeddings
370
+ speaker_embeddings = params.get("speaker_embeddings")
371
+ if speaker_embeddings is None:
372
+ # Embedding por defecto
373
+ speaker_embeddings = torch.randn((1, 512)).to(device)
374
+ elif isinstance(speaker_embeddings, list):
375
+ speaker_embeddings = torch.tensor(speaker_embeddings).to(device)
376
+
377
+ # Generar audio
378
+ with torch.no_grad():
379
+ speech = model.generate_speech(
380
+ inputs["input_ids"],
381
+ speaker_embeddings,
382
+ vocoder=vocoder
383
+ )
384
+
385
+ audio_array = speech.cpu().numpy().squeeze()
386
+ sample_rate = 16000 # SpeechT5 usa 16kHz
387
+
388
+ # Ajustar velocidad si se especifica
389
+ if "speed" in params and params["speed"] != 1.0:
390
+ audio_array = self._adjust_speed(audio_array, sample_rate, params["speed"])
391
+
392
+ # Ajustar duración
393
+ max_samples = MAX_AUDIO_LENGTH * sample_rate
394
+ if len(audio_array) > max_samples:
395
+ audio_array = audio_array[:max_samples]
396
+
397
+ return {
398
+ "audio": audio_array,
399
+ "sample_rate": sample_rate,
400
+ "duration": len(audio_array) / sample_rate
401
+ }
402
+
403
+ def _adjust_speed(self, audio_array, sample_rate, speed_factor):
404
+ """Ajustar velocidad del audio"""
405
+ if speed_factor == 1.0:
406
+ return audio_array
407
+
408
+ try:
409
+ import librosa
410
+
411
+ # Ajustar velocidad manteniendo tono
412
+ audio_stretched = librosa.effects.time_stretch(
413
+ y=audio_array,
414
+ rate=speed_factor
415
+ )
416
+
417
+ return audio_stretched
418
+ except ImportError:
419
+ logger.warning("Librosa no instalado, omitiendo ajuste de velocidad")
420
+ return audio_array
421
+
422
  def get_loaded_models(self):
423
  """Obtener lista de modelos cargados"""
424
  loaded = []
425
+ for model_id, data in self.models.items():
426
  if data["loaded"]:
427
+ loaded.append(model_id)
428
  return loaded
429
 
430
  def get_all_models_status(self):
431
  """Obtener estado de todos los modelos"""
432
  status = {}
433
+ for model_id, data in self.models.items():
434
+ status[model_id] = {
435
  "loaded": data["loaded"],
436
+ "type": data.get("type", "unknown"),
437
+ "config": data["config"]
438
  }
439
  if "error" in data:
440
+ status[model_id]["error"] = data["error"]
441
+ if "loaded_at" in data:
442
+ status[model_id]["loaded_at"] = data["loaded_at"]
443
  return status
444
 
445
+ # Inicializar el gestor de TTS
446
+ tts_manager = TTSManager(TTS_MODELS)
447
+
448
+ def audio_to_wav_bytes(audio_array, sample_rate):
449
+ """Convertir array de audio a bytes WAV"""
450
+ wav_buffer = io.BytesIO()
451
+ sf.write(wav_buffer, audio_array, sample_rate, format='WAV')
452
+ wav_buffer.seek(0)
453
+ return wav_buffer
454
 
455
  @app.route('/')
456
  def home():
457
+ loaded_models = tts_manager.get_loaded_models()
458
  status_html = "<ul>"
459
+ for model_id, model_data in tts_manager.models.items():
460
  status = "✅" if model_data["loaded"] else "❌"
461
+ model_type = model_data.get("type", "unknown")
462
+ status_html += f"<li>{model_id} ({model_type}): {status}</li>"
463
  status_html += "</ul>"
464
 
465
  return f'''
466
  <!DOCTYPE html>
467
  <html>
468
  <head>
469
+ <title>TTS API - Text to Speech</title>
470
  <style>
471
  body {{ font-family: Arial, sans-serif; margin: 40px; }}
472
  .config {{ background: #f0f0f0; padding: 15px; border-radius: 5px; margin-bottom: 20px; }}
 
474
  </style>
475
  </head>
476
  <body>
477
+ <h1>🔊 TTS API - Text to Speech</h1>
478
 
479
  <div class="config">
480
  <h3>⚙️ Configuración</h3>
481
+ <p><strong>Max Text Length:</strong> {MAX_TEXT_LENGTH} caracteres</p>
482
+ <p><strong>Max Audio Length:</strong> {MAX_AUDIO_LENGTH} segundos</p>
483
+ <p><strong>Device:</strong> {"CUDA/GPU" if torch.cuda.is_available() else "CPU"}</p>
484
  </div>
485
 
486
+ <h2>📦 Modelos TTS cargados:</h2>
487
  {status_html}
488
+ <p>Total modelos: {len(loaded_models)}/{len(TTS_MODELS)}</p>
489
 
490
  <h2>🔗 Endpoints disponibles:</h2>
491
  <div class="endpoint">
492
+ <strong>GET /tts?text=&lt;texto&gt;[&params]</strong><br>
493
+ Genera audio desde texto. Parámetros opcionales:<br>
494
+ model= (ID del modelo, default: primer modelo)<br>
495
+ speed= (0.5-2.0, velocidad de habla)<br>
496
+ language= (idioma, ej: es, en)<br>
497
+ speaker= (voz específica)<br>
498
+ download= (true/false, forzar descarga)
499
+ </div>
500
+
501
+ <div class="endpoint">
502
+ <strong>POST /v1/audio/speech</strong><br>
503
+ Compatible con OpenAI Audio API
504
  </div>
505
 
506
  <div class="endpoint">
507
+ <strong>POST /generate</strong><br>
508
+ Endpoint alternativo con JSON
509
  </div>
510
 
511
  <div class="endpoint">
 
521
  </html>
522
  '''
523
 
524
+ @app.route('/v1/audio/speech', methods=['POST'])
525
+ def openai_compatible_endpoint():
526
+ """Endpoint compatible con OpenAI Audio API"""
527
  try:
528
  data = request.get_json()
 
 
529
 
530
+ text = data.get('input', '')
531
+ model_id = data.get('model', TTS_MODELS[0]["id"])
532
 
533
+ if not text:
534
+ return jsonify({"error": "El campo 'input' es requerido"}), 400
 
 
535
 
536
+ if len(text) > MAX_TEXT_LENGTH:
537
+ return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
538
+
539
+ # Extraer parámetros
540
+ params = {k: v for k, v in data.items() if k not in ['input', 'model']}
541
+
542
+ # Generar audio
543
+ result = tts_manager.generate_speech(model_id, text, **params)
544
 
 
 
545
  if "error" in result:
546
  return jsonify(result), 500
547
+
548
+ # Convertir a bytes WAV
549
+ wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
550
+
551
+ # Devolver como audio
552
+ return Response(
553
+ wav_buffer.read(),
554
+ mimetype='audio/wav',
555
+ headers={'Content-Disposition': f'attachment; filename="speech.wav"'}
556
+ )
557
 
558
  except Exception as e:
559
+ logger.error(f"Error en OpenAI endpoint: {str(e)}")
560
  return jsonify({"error": str(e)}), 500
561
 
562
+ @app.route('/tts', methods=['GET'])
563
+ def tts_get_endpoint():
564
+ """Endpoint GET para generar audio desde texto"""
565
  try:
566
+ # Obtener parámetros
567
+ text = request.args.get('text', '')
568
+ model_id = request.args.get('model', TTS_MODELS[0]["id"])
569
+ speed = float(request.args.get('speed', 1.0))
570
+ language = request.args.get('language', 'es')
571
+ speaker = request.args.get('speaker')
572
+ download = request.args.get('download', 'false').lower() == 'true'
 
 
 
 
 
 
 
 
573
 
574
+ # Validaciones
575
+ if not text:
576
+ return jsonify({"error": "El parámetro 'text' es requerido"}), 400
 
 
 
 
 
 
 
 
 
 
577
 
578
+ if len(text) > MAX_TEXT_LENGTH:
579
+ return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
 
580
 
581
+ if speed < 0.5 or speed > 2.0:
582
+ return jsonify({"error": "El parámetro 'speed' debe estar entre 0.5 y 2.0"}), 400
 
 
 
 
 
583
 
584
+ # Preparar parámetros
585
+ params = {
586
+ "speed": speed,
587
+ "language": language
 
 
 
 
 
 
 
 
588
  }
589
+ if speaker:
590
+ params["speaker"] = speaker
591
+
592
+ # Generar audio
593
+ result = tts_manager.generate_speech(model_id, text, **params)
 
 
 
 
594
 
595
  if "error" in result:
596
+ return jsonify(result), 500
597
+
598
+ # Convertir a bytes WAV
599
+ wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
600
+
601
+ # Configurar respuesta
602
+ filename = f"tts_{model_id}.wav"
603
+
604
+ if download:
605
+ return send_file(
606
+ wav_buffer,
607
+ mimetype='audio/wav',
608
+ as_attachment=True,
609
+ download_name=filename
610
+ )
611
+ else:
612
  return Response(
613
+ wav_buffer.read(),
614
+ mimetype='audio/wav',
615
+ headers={'Content-Disposition': f'inline; filename="{filename}"'}
616
  )
617
+
618
+ except ValueError as e:
619
+ return jsonify({"error": f"Parámetros inválidos: {str(e)}"}), 400
620
+ except Exception as e:
621
+ logger.error(f"Error en TTS GET: {str(e)}")
622
+ return jsonify({"error": str(e)}), 500
623
 
624
+ @app.route('/generate', methods=['POST'])
625
+ def generate_endpoint():
626
+ """Endpoint alternativo para generación de audio"""
627
+ try:
628
+ data = request.get_json()
629
 
630
+ text = data.get('text', '')
631
+ model_id = data.get('model', TTS_MODELS[0]["id"])
632
 
633
+ if not text:
634
+ return jsonify({"error": "El campo 'text' es requerido"}), 400
635
+
636
+ if len(text) > MAX_TEXT_LENGTH:
637
+ return jsonify({"error": f"Texto demasiado largo (máximo {MAX_TEXT_LENGTH} caracteres)"}), 400
638
+
639
+ # Extraer parámetros
640
+ params = {k: v for k, v in data.items() if k not in ['text', 'model']}
641
+
642
+ # Generar audio
643
+ result = tts_manager.generate_speech(model_id, text, **params)
644
+
645
+ if "error" in result:
646
+ return jsonify(result), 500
647
+
648
+ # Convertir a bytes
649
+ wav_buffer = audio_to_wav_bytes(result["audio"], result["sample_rate"])
650
+
651
+ # Devolver como audio
652
  return Response(
653
+ wav_buffer.read(),
654
+ mimetype='audio/wav',
655
+ headers={'Content-Disposition': f'inline; filename="generated.wav"'}
656
  )
657
+
658
  except Exception as e:
659
+ logger.error(f"Error en generate endpoint: {str(e)}")
660
+ return jsonify({"error": str(e)}), 500
 
 
 
661
 
662
  @app.route('/health', methods=['GET'])
663
  def health():
664
+ loaded_models = tts_manager.get_loaded_models()
665
  return jsonify({
666
  "status": "healthy" if len(loaded_models) > 0 else "error",
667
  "loaded_models": loaded_models,
668
+ "total_models": len(TTS_MODELS),
669
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
670
  "config": {
671
+ "max_text_length": MAX_TEXT_LENGTH,
672
+ "max_audio_length": MAX_AUDIO_LENGTH
673
  }
674
  })
675
 
 
677
  def list_models():
678
  """Endpoint para listar todos los modelos y su estado"""
679
  return jsonify({
680
+ "available_models": TTS_MODELS,
681
+ "status": tts_manager.get_all_models_status(),
682
  "config": {
683
+ "max_text_length": MAX_TEXT_LENGTH,
684
+ "max_audio_length": MAX_AUDIO_LENGTH
685
  }
686
  })
687
 
688
+ @app.route('/models/<model_id>', methods=['GET'])
689
+ def get_model_status(model_id):
690
  """Endpoint para obtener el estado de un modelo específico"""
691
+ model_data = tts_manager.get_model(model_id)
692
  if not model_data:
693
+ return jsonify({"error": f"Modelo '{model_id}' no encontrado"}), 404
694
 
695
  return jsonify({
696
+ "model": model_id,
697
  "loaded": model_data["loaded"],
698
+ "type": model_data.get("type", "unknown"),
699
+ "config": model_data["config"],
700
  "error": model_data.get("error"),
701
+ "loaded_at": model_data.get("loaded_at")
 
 
 
702
  })
703
 
704
  if __name__ == '__main__':