teszenofficial commited on
Commit
82bf5cc
·
verified ·
1 Parent(s): a55b1f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -191
app.py CHANGED
@@ -4,11 +4,12 @@ import torch
4
  import json
5
  import time
6
  import gc
 
7
  from fastapi import FastAPI, Request
8
- from fastapi.responses import HTMLResponse, StreamingResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel, Field
11
- from huggingface_hub import snapshot_download, hf_hub_download
12
  import uvicorn
13
  import math
14
  import torch.nn as nn
@@ -33,7 +34,7 @@ torch.set_grad_enabled(False)
33
  MODEL_REPO = "TeszenAI/MTP-3"
34
 
35
  # ======================
36
- # DEFINIR ARQUITECTURA DEL MODELO (MTP-1.1)
37
  # ======================
38
  class LayerNorm(nn.Module):
39
  def __init__(self, d_model: int, eps: float = 1e-5):
@@ -41,7 +42,6 @@ class LayerNorm(nn.Module):
41
  self.weight = nn.Parameter(torch.ones(d_model))
42
  self.bias = nn.Parameter(torch.zeros(d_model))
43
  self.eps = eps
44
-
45
  def forward(self, x):
46
  mean = x.mean(-1, keepdim=True)
47
  std = x.std(-1, keepdim=True)
@@ -60,7 +60,6 @@ class MultiHeadAttention(nn.Module):
60
  self.w_o = nn.Linear(d_model, d_model)
61
  self.dropout = nn.Dropout(dropout)
62
  self.scale = math.sqrt(self.d_k)
63
-
64
  def forward(self, x, mask=None):
65
  batch_size, seq_len, _ = x.shape
66
  Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
@@ -81,7 +80,6 @@ class FeedForward(nn.Module):
81
  self.linear1 = nn.Linear(d_model, d_ff)
82
  self.linear2 = nn.Linear(d_ff, d_model)
83
  self.dropout = nn.Dropout(dropout)
84
-
85
  def forward(self, x):
86
  return self.linear2(self.dropout(F.gelu(self.linear1(x))))
87
 
@@ -94,7 +92,6 @@ class TransformerBlock(nn.Module):
94
  self.norm2 = LayerNorm(d_model)
95
  self.dropout1 = nn.Dropout(dropout)
96
  self.dropout2 = nn.Dropout(dropout)
97
-
98
  def forward(self, x, mask=None):
99
  attn_output = self.attention(x, mask)
100
  x = x + self.dropout1(attn_output)
@@ -113,22 +110,19 @@ class PositionalEncoding(nn.Module):
113
  pe[:, 0::2] = torch.sin(position * div_term)
114
  pe[:, 1::2] = torch.cos(position * div_term)
115
  self.register_buffer('pe', pe.unsqueeze(0))
116
-
117
  def forward(self, x):
118
  return x + self.pe[:, :x.size(1), :]
119
 
120
  class MTPModel(nn.Module):
121
- def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4,
122
- n_layers: int = 4, d_ff: int = 512, dropout: float = 0.1, max_len: int = 256):
123
  super().__init__()
124
  self.vocab_size = vocab_size
125
  self.d_model = d_model
126
  self.max_len = max_len
127
  self.token_embedding = nn.Embedding(vocab_size, d_model)
128
  self.pos_encoding = PositionalEncoding(d_model, max_len)
129
- self.blocks = nn.ModuleList([
130
- TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
131
- ])
132
  self.norm = LayerNorm(d_model)
133
  self.lm_head = nn.Linear(d_model, vocab_size)
134
 
@@ -140,82 +134,34 @@ class MTPModel(nn.Module):
140
  for block in self.blocks:
141
  x = block(x, mask)
142
  x = self.norm(x)
143
- logits = self.lm_head(x)
144
- return logits
145
-
146
- def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
147
- """Método de generación compatible con la interfaz"""
148
- generated = input_ids
149
-
150
- for _ in range(max_new_tokens):
151
- with torch.no_grad():
152
- logits = self(generated)
153
- next_logits = logits[0, -1, :] / temperature
154
-
155
- if repetition_penalty != 1.0:
156
- for token_id in set(generated[0].tolist()):
157
- next_logits[token_id] /= repetition_penalty
158
-
159
- if top_k > 0:
160
- indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
161
- next_logits[indices_to_remove] = float('-inf')
162
-
163
- if top_p < 1.0:
164
- sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
165
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
166
- sorted_indices_to_remove = cumulative_probs > top_p
167
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
168
- sorted_indices_to_remove[..., 0] = 0
169
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
170
- next_logits[indices_to_remove] = float('-inf')
171
-
172
- probs = F.softmax(next_logits, dim=-1)
173
- next_token = torch.multinomial(probs, num_samples=1).item()
174
-
175
- if next_token == 3:
176
- break
177
-
178
- generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
179
-
180
- return generated
181
 
182
  # ======================
183
  # DESCARGA Y CARGA DEL MODELO CON REINTENTOS
184
  # ======================
185
  def download_with_retry(repo_id, local_dir, max_retries=3):
186
- """Descarga el modelo con reintentos para evitar timeouts"""
187
-
188
  for attempt in range(max_retries):
189
  try:
190
  print(f"📦 Intento {attempt + 1}/{max_retries} - Descargando modelo desde {repo_id}...")
191
-
192
- # Configurar timeout más largo para descargas
193
  repo_path = snapshot_download(
194
  repo_id=repo_id,
195
  repo_type="model",
196
  local_dir=local_dir,
197
  resume_download=True,
198
- local_files_only=False,
199
- ignore_patterns=["*.h5", "*.ot", "*.msgpack"] # Ignorar archivos grandes innecesarios
200
  )
201
-
202
  print(f"✅ Modelo descargado exitosamente en: {repo_path}")
203
  return repo_path
204
-
205
  except Exception as e:
206
  print(f"⚠️ Error en intento {attempt + 1}: {str(e)[:200]}")
207
  if attempt < max_retries - 1:
208
- wait_time = (attempt + 1) * 3
209
- print(f"🔄 Reintentando en {wait_time} segundos...")
210
- time.sleep(wait_time)
211
  else:
212
- print("❌ No se pudo descargar el modelo después de múltiples intentos")
213
  raise
 
214
 
215
- # Intentar descargar el modelo
216
  print(f"🚀 Iniciando carga del modelo desde {MODEL_REPO}...")
217
 
218
- # Verificar si ya existe en caché local
219
  if os.path.exists("mtp_repo") and os.path.exists("mtp_repo/mtp_model.pt"):
220
  print("📁 Modelo encontrado en caché local")
221
  repo_path = "mtp_repo"
@@ -223,10 +169,8 @@ else:
223
  try:
224
  repo_path = download_with_retry(MODEL_REPO, "mtp_repo", max_retries=3)
225
  except Exception as e:
226
- print(f"⚠️ Error crítico: {e}")
227
- print("🏗️ Usando configuración por defecto...")
228
  repo_path = "mtp_repo"
229
- os.makedirs(repo_path, exist_ok=True)
230
 
231
  # Cargar configuración
232
  config_path = os.path.join(repo_path, "config.json")
@@ -234,35 +178,29 @@ if os.path.exists(config_path):
234
  with open(config_path, "r") as f:
235
  config = json.load(f)
236
  else:
 
237
  config = {
238
- "vocab_size": 5000,
239
- "d_model": 128,
240
- "n_heads": 4,
241
- "n_layers": 4,
242
- "d_ff": 512,
243
  "dropout": 0.1,
244
- "max_len": 256
245
  }
246
 
247
  # Cargar tokenizador
248
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
249
  if os.path.exists(tokenizer_path):
250
- try:
251
- sp = spm.SentencePieceProcessor()
252
- sp.load(tokenizer_path)
253
- VOCAB_SIZE = sp.get_piece_size()
254
- print(f"✅ Tokenizador cargado: {VOCAB_SIZE} tokens")
255
- except Exception as e:
256
- print(f"⚠️ Error cargando tokenizador: {e}")
257
- VOCAB_SIZE = config.get("vocab_size", 5000)
258
- sp = None
259
  else:
260
- print("⚠️ No se encontró tokenizador, usando vocabulario por defecto")
261
- VOCAB_SIZE = config.get("vocab_size", 5000)
262
  sp = None
263
-
264
- # Actualizar vocab_size en config
265
- config["vocab_size"] = VOCAB_SIZE
266
 
267
  print(f"🧠 Inicializando modelo MTP...")
268
  print(f" → Vocabulario: {VOCAB_SIZE}")
@@ -279,39 +217,21 @@ if os.path.exists(model_path):
279
  try:
280
  state_dict = torch.load(model_path, map_location=DEVICE)
281
  model.load_state_dict(state_dict)
282
- print("✅ Pesos del modelo cargados")
283
  except Exception as e:
284
  print(f"⚠️ Error cargando pesos: {e}")
285
- print(" Usando pesos aleatorios")
286
  else:
287
- print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
288
 
289
  model.eval()
290
 
291
- # Cuantización para CPU
292
- if DEVICE == "cpu":
293
- print("⚡ Optimizando para CPU...")
294
- try:
295
- model = torch.quantization.quantize_dynamic(
296
- model,
297
- {nn.Linear},
298
- dtype=torch.qint8
299
- )
300
- print("✅ Cuantización aplicada")
301
- except Exception as e:
302
- print(f"⚠️ No se pudo aplicar cuantización: {e}")
303
-
304
  param_count = sum(p.numel() for p in model.parameters())
305
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
306
 
307
  # ======================
308
  # API CONFIG
309
  # ======================
310
- app = FastAPI(
311
- title="MTP-1.1 API",
312
- description="API para modelo de lenguaje MTP-1.1",
313
- version="1.1"
314
- )
315
 
316
  app.add_middleware(
317
  CORSMiddleware,
@@ -321,58 +241,103 @@ app.add_middleware(
321
  )
322
 
323
  class PromptRequest(BaseModel):
324
- text: str = Field(..., max_length=2000, description="Texto de entrada")
325
- max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
326
- temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
327
- top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
328
- top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
329
- repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- def build_prompt(user_input: str) -> str:
332
- """Construye el prompt en el formato del modelo"""
333
- return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
 
334
 
335
  # ======================
336
- # GESTIÓN DE CARGA
337
  # ======================
338
  ACTIVE_REQUESTS = 0
339
 
340
- class MTPTokenizer:
341
- """Wrapper para el tokenizador de SentencePiece"""
342
  def __init__(self, sp_model):
343
  self.sp = sp_model
344
-
345
  def encode(self, text):
346
  if self.sp is None:
347
- # Tokenizador simple para fallback
348
  return [ord(c) % 1000 for c in text[:200]]
349
  return self.sp.encode(text)
350
-
351
  def decode(self, tokens):
352
  if self.sp is None:
353
  return ''.join([chr(t % 128) if 32 <= t % 128 < 127 else ' ' for t in tokens])
354
  return self.sp.decode(tokens)
355
-
356
- def bos_id(self):
357
- if self.sp is None:
358
- return 2
359
- return self.sp.bos_id()
360
-
361
  def eos_id(self):
362
- if self.sp is None:
363
- return 3
364
- return self.sp.eos_id()
365
-
366
  def pad_id(self):
367
- if self.sp is None:
368
- return 0
369
- return self.sp.pad_id()
370
 
371
- tokenizer_wrapper = MTPTokenizer(sp)
372
 
373
  @app.post("/generate")
374
  async def generate(req: PromptRequest):
375
- """Endpoint principal de generación de texto"""
376
  global ACTIVE_REQUESTS
377
  ACTIVE_REQUESTS += 1
378
 
@@ -389,83 +354,51 @@ async def generate(req: PromptRequest):
389
  ACTIVE_REQUESTS -= 1
390
  return {"reply": "", "tokens_generated": 0}
391
 
392
- full_prompt = build_prompt(user_input)
393
- tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
394
- input_ids = torch.tensor([tokens], device=DEVICE)
395
-
396
  try:
397
- with torch.no_grad():
398
- output_ids = model.generate(
399
- input_ids,
400
- max_new_tokens=dyn_max_tokens,
401
- temperature=dyn_temperature,
402
- top_k=req.top_k,
403
- top_p=req.top_p,
404
- repetition_penalty=req.repetition_penalty
405
- )
406
-
407
- gen_tokens = output_ids[0, len(tokens):].tolist()
408
-
409
- safe_tokens = [
410
- t for t in gen_tokens
411
- if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
412
- ]
413
-
414
- response = tokenizer_wrapper.decode(safe_tokens).strip()
415
-
416
- # Limpiar la respuesta
417
- if "###" in response:
418
- response = response.split("###")[0].strip()
419
-
420
- # Si la respuesta está vacía, devolver mensaje por defecto
421
- if not response or len(response) < 2:
422
- response = "Entendido. ¿En qué más puedo ayudarte?"
423
-
424
  return {
425
  "reply": response,
426
- "tokens_generated": len(safe_tokens),
427
- "model": "MTP-1.1"
428
  }
429
-
430
  except Exception as e:
431
- print(f"❌ Error durante generación: {e}")
432
- return {
433
- "reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
434
- "error": str(e)
435
- }
436
-
437
  finally:
438
  ACTIVE_REQUESTS -= 1
439
  if DEVICE == "cuda":
440
  torch.cuda.empty_cache()
441
  gc.collect()
442
 
443
- # ======================
444
- # ENDPOINTS DE INFORMACIÓN
445
- # ======================
446
  @app.get("/health")
447
  def health_check():
448
  return {
449
  "status": "healthy",
450
- "model": "MTP-1.1",
451
  "device": DEVICE,
452
  "active_requests": ACTIVE_REQUESTS,
453
- "vocab_size": VOCAB_SIZE,
454
- "model_loaded": os.path.exists("mtp_repo/mtp_model.pt")
455
  }
456
 
457
  @app.get("/info")
458
  def model_info():
459
  return {
460
- "model_name": "MTP-1.1",
461
- "version": "1.1",
462
  "architecture": config,
463
  "parameters": sum(p.numel() for p in model.parameters()),
464
  "device": DEVICE
465
  }
466
 
467
  # ======================
468
- # INTERFAZ WEB (MODERNA DE MTP-3)
469
  # ======================
470
  @app.get("/", response_class=HTMLResponse)
471
  def chat_ui():
@@ -475,7 +408,7 @@ def chat_ui():
475
  <head>
476
  <meta charset="UTF-8">
477
  <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
478
- <title>MTP 1.1</title>
479
  <link rel="preconnect" href="https://fonts.googleapis.com">
480
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
481
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
@@ -795,7 +728,7 @@ async function sendMessage(textOverride = null) {
795
  const response = await fetch('/generate', {
796
  method: 'POST',
797
  headers: { 'Content-Type': 'application/json' },
798
- body: JSON.stringify({ text: text }),
799
  signal: abortController.signal
800
  });
801
  const data = await response.json();
@@ -883,7 +816,7 @@ window.onload = () => userInput.focus();
883
 
884
  if __name__ == "__main__":
885
  port = int(os.environ.get("PORT", 7860))
886
- print(f"\n🚀 Iniciando servidor MTP-1.1 en puerto {port}...")
887
  print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
888
  print(f"📡 API docs: http://0.0.0.0:{port}/docs")
889
 
 
4
  import json
5
  import time
6
  import gc
7
+ import re
8
  from fastapi import FastAPI, Request
9
+ from fastapi.responses import HTMLResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel, Field
12
+ from huggingface_hub import snapshot_download
13
  import uvicorn
14
  import math
15
  import torch.nn as nn
 
34
  MODEL_REPO = "TeszenAI/MTP-3"
35
 
36
  # ======================
37
+ # ARQUITECTURA DEL MODELO (MISMA QUE EN colab.py)
38
  # ======================
39
  class LayerNorm(nn.Module):
40
  def __init__(self, d_model: int, eps: float = 1e-5):
 
42
  self.weight = nn.Parameter(torch.ones(d_model))
43
  self.bias = nn.Parameter(torch.zeros(d_model))
44
  self.eps = eps
 
45
  def forward(self, x):
46
  mean = x.mean(-1, keepdim=True)
47
  std = x.std(-1, keepdim=True)
 
60
  self.w_o = nn.Linear(d_model, d_model)
61
  self.dropout = nn.Dropout(dropout)
62
  self.scale = math.sqrt(self.d_k)
 
63
  def forward(self, x, mask=None):
64
  batch_size, seq_len, _ = x.shape
65
  Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
 
80
  self.linear1 = nn.Linear(d_model, d_ff)
81
  self.linear2 = nn.Linear(d_ff, d_model)
82
  self.dropout = nn.Dropout(dropout)
 
83
  def forward(self, x):
84
  return self.linear2(self.dropout(F.gelu(self.linear1(x))))
85
 
 
92
  self.norm2 = LayerNorm(d_model)
93
  self.dropout1 = nn.Dropout(dropout)
94
  self.dropout2 = nn.Dropout(dropout)
 
95
  def forward(self, x, mask=None):
96
  attn_output = self.attention(x, mask)
97
  x = x + self.dropout1(attn_output)
 
110
  pe[:, 0::2] = torch.sin(position * div_term)
111
  pe[:, 1::2] = torch.cos(position * div_term)
112
  self.register_buffer('pe', pe.unsqueeze(0))
 
113
  def forward(self, x):
114
  return x + self.pe[:, :x.size(1), :]
115
 
116
  class MTPModel(nn.Module):
117
+ def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
118
+ n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
119
  super().__init__()
120
  self.vocab_size = vocab_size
121
  self.d_model = d_model
122
  self.max_len = max_len
123
  self.token_embedding = nn.Embedding(vocab_size, d_model)
124
  self.pos_encoding = PositionalEncoding(d_model, max_len)
125
+ self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
 
 
126
  self.norm = LayerNorm(d_model)
127
  self.lm_head = nn.Linear(d_model, vocab_size)
128
 
 
134
  for block in self.blocks:
135
  x = block(x, mask)
136
  x = self.norm(x)
137
+ return self.lm_head(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # ======================
140
  # DESCARGA Y CARGA DEL MODELO CON REINTENTOS
141
  # ======================
142
  def download_with_retry(repo_id, local_dir, max_retries=3):
 
 
143
  for attempt in range(max_retries):
144
  try:
145
  print(f"📦 Intento {attempt + 1}/{max_retries} - Descargando modelo desde {repo_id}...")
 
 
146
  repo_path = snapshot_download(
147
  repo_id=repo_id,
148
  repo_type="model",
149
  local_dir=local_dir,
150
  resume_download=True,
151
+ local_files_only=False
 
152
  )
 
153
  print(f"✅ Modelo descargado exitosamente en: {repo_path}")
154
  return repo_path
 
155
  except Exception as e:
156
  print(f"⚠️ Error en intento {attempt + 1}: {str(e)[:200]}")
157
  if attempt < max_retries - 1:
158
+ time.sleep(3)
 
 
159
  else:
 
160
  raise
161
+ return local_dir
162
 
 
163
  print(f"🚀 Iniciando carga del modelo desde {MODEL_REPO}...")
164
 
 
165
  if os.path.exists("mtp_repo") and os.path.exists("mtp_repo/mtp_model.pt"):
166
  print("📁 Modelo encontrado en caché local")
167
  repo_path = "mtp_repo"
 
169
  try:
170
  repo_path = download_with_retry(MODEL_REPO, "mtp_repo", max_retries=3)
171
  except Exception as e:
172
+ print(f"⚠️ Error: {e}")
 
173
  repo_path = "mtp_repo"
 
174
 
175
  # Cargar configuración
176
  config_path = os.path.join(repo_path, "config.json")
 
178
  with open(config_path, "r") as f:
179
  config = json.load(f)
180
  else:
181
+ # Configuración por defecto (MISMA que en colab.py)
182
  config = {
183
+ "vocab_size": 2000,
184
+ "d_model": 256,
185
+ "n_heads": 8,
186
+ "n_layers": 6,
187
+ "d_ff": 1024,
188
  "dropout": 0.1,
189
+ "max_len": 512
190
  }
191
 
192
  # Cargar tokenizador
193
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
194
  if os.path.exists(tokenizer_path):
195
+ sp = spm.SentencePieceProcessor()
196
+ sp.load(tokenizer_path)
197
+ VOCAB_SIZE = sp.get_piece_size()
198
+ config["vocab_size"] = VOCAB_SIZE
199
+ print(f"✅ Tokenizador cargado: {VOCAB_SIZE} tokens")
 
 
 
 
200
  else:
201
+ print(" No se encontró tokenizador")
 
202
  sp = None
203
+ VOCAB_SIZE = config.get("vocab_size", 2000)
 
 
204
 
205
  print(f"🧠 Inicializando modelo MTP...")
206
  print(f" → Vocabulario: {VOCAB_SIZE}")
 
217
  try:
218
  state_dict = torch.load(model_path, map_location=DEVICE)
219
  model.load_state_dict(state_dict)
220
+ print("✅ Pesos del modelo cargados correctamente")
221
  except Exception as e:
222
  print(f"⚠️ Error cargando pesos: {e}")
 
223
  else:
224
+ print("⚠️ No se encontró mtp_model.pt")
225
 
226
  model.eval()
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  param_count = sum(p.numel() for p in model.parameters())
229
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
230
 
231
  # ======================
232
  # API CONFIG
233
  # ======================
234
+ app = FastAPI(title="MTP API", description="API para modelo de lenguaje MTP", version="1.0")
 
 
 
 
235
 
236
  app.add_middleware(
237
  CORSMiddleware,
 
241
  )
242
 
243
  class PromptRequest(BaseModel):
244
+ text: str = Field(..., max_length=2000)
245
+ max_tokens: int = Field(default=150, ge=10, le=300)
246
+ temperature: float = Field(default=0.7, ge=0.1, le=2.0)
247
+ top_k: int = Field(default=50, ge=1, le=100)
248
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0)
249
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
250
+
251
+ # ======================
252
+ # FUNCIÓN DE GENERACIÓN (IGUAL QUE EN colab.py)
253
+ # ======================
254
+ def generate_response(model, tokenizer, prompt, max_length=150, temperature=0.7, top_k=50, top_p=0.9, device='cpu'):
255
+ model.eval()
256
+ formatted_prompt = f"### Instrucción:\n{prompt}\n\n### Respuesta:\n"
257
+ input_ids = tokenizer.encode(formatted_prompt)
258
+ generated = input_ids.copy()
259
+ eos_id = tokenizer.eos_id()
260
+
261
+ for _ in range(max_length):
262
+ input_tensor = torch.tensor([generated[-model.max_len:]], dtype=torch.long).to(device)
263
+ with torch.no_grad():
264
+ logits = model(input_tensor)
265
+ next_logits = logits[0, -1, :] / temperature
266
+
267
+ if top_k > 0:
268
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
269
+ next_logits[indices_to_remove] = float('-inf')
270
+
271
+ if top_p < 1.0:
272
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
273
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
274
+ sorted_indices_to_remove = cumulative_probs > top_p
275
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
276
+ sorted_indices_to_remove[..., 0] = 0
277
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
278
+ next_logits[indices_to_remove] = float('-inf')
279
+
280
+ probs = F.softmax(next_logits, dim=-1)
281
+ next_token = torch.multinomial(probs, 1).item()
282
+
283
+ if next_token == eos_id:
284
+ break
285
+
286
+ if len(generated) > 20:
287
+ last_tokens = generated[-10:]
288
+ if len(set(last_tokens)) == 1:
289
+ break
290
+
291
+ generated.append(next_token)
292
+
293
+ response = tokenizer.decode(generated)
294
+ if "### Respuesta:" in response:
295
+ response = response.split("### Respuesta:")[-1].strip()
296
+ elif "Respuesta:" in response:
297
+ response = response.split("Respuesta:")[-1].strip()
298
+ elif "[/INST]" in response:
299
+ response = response.split("[/INST]")[-1].strip()
300
+
301
+ # Limpiar caracteres basura
302
+ garbage_words = ['foompañances', 'ciudadores', 'mejtedon', 'calportedon', 'rápidodcor', 'baon', 'domol']
303
+ for word in garbage_words:
304
+ response = response.replace(word, '')
305
+
306
+ response = re.sub(r'[^\w\s\u00C0-\u00FF\u0100-\u017F.,!?¿¡()\-:;"]+', ' ', response)
307
+ response = re.sub(r'\s+', ' ', response).strip()
308
 
309
+ if len(response) < 2:
310
+ response = "Entendido. ¿Algo más en lo que pueda ayudarte?"
311
+
312
+ return response
313
 
314
  # ======================
315
+ # ENDPOINTS
316
  # ======================
317
  ACTIVE_REQUESTS = 0
318
 
319
+ class TokenizerWrapper:
 
320
  def __init__(self, sp_model):
321
  self.sp = sp_model
 
322
  def encode(self, text):
323
  if self.sp is None:
 
324
  return [ord(c) % 1000 for c in text[:200]]
325
  return self.sp.encode(text)
 
326
  def decode(self, tokens):
327
  if self.sp is None:
328
  return ''.join([chr(t % 128) if 32 <= t % 128 < 127 else ' ' for t in tokens])
329
  return self.sp.decode(tokens)
 
 
 
 
 
 
330
  def eos_id(self):
331
+ return self.sp.eos_id() if self.sp else 3
332
+ def bos_id(self):
333
+ return self.sp.bos_id() if self.sp else 2
 
334
  def pad_id(self):
335
+ return self.sp.pad_id() if self.sp else 0
 
 
336
 
337
+ tokenizer_wrapper = TokenizerWrapper(sp)
338
 
339
  @app.post("/generate")
340
  async def generate(req: PromptRequest):
 
341
  global ACTIVE_REQUESTS
342
  ACTIVE_REQUESTS += 1
343
 
 
354
  ACTIVE_REQUESTS -= 1
355
  return {"reply": "", "tokens_generated": 0}
356
 
 
 
 
 
357
  try:
358
+ response = generate_response(
359
+ model, tokenizer_wrapper, user_input,
360
+ max_length=dyn_max_tokens,
361
+ temperature=dyn_temperature,
362
+ top_k=req.top_k,
363
+ top_p=req.top_p,
364
+ device=DEVICE
365
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  return {
367
  "reply": response,
368
+ "tokens_generated": len(response.split()),
369
+ "model": "MTP"
370
  }
 
371
  except Exception as e:
372
+ print(f"❌ Error: {e}")
373
+ return {"reply": "Lo siento, ocurrió un error.", "error": str(e)}
 
 
 
 
374
  finally:
375
  ACTIVE_REQUESTS -= 1
376
  if DEVICE == "cuda":
377
  torch.cuda.empty_cache()
378
  gc.collect()
379
 
 
 
 
380
  @app.get("/health")
381
  def health_check():
382
  return {
383
  "status": "healthy",
384
+ "model": "MTP",
385
  "device": DEVICE,
386
  "active_requests": ACTIVE_REQUESTS,
387
+ "vocab_size": VOCAB_SIZE
 
388
  }
389
 
390
  @app.get("/info")
391
  def model_info():
392
  return {
393
+ "model_name": "MTP",
394
+ "version": "1.0",
395
  "architecture": config,
396
  "parameters": sum(p.numel() for p in model.parameters()),
397
  "device": DEVICE
398
  }
399
 
400
  # ======================
401
+ # INTERFAZ WEB COMPLETA (CON TODAS LAS FUNCIONES ORIGINALES)
402
  # ======================
403
  @app.get("/", response_class=HTMLResponse)
404
  def chat_ui():
 
408
  <head>
409
  <meta charset="UTF-8">
410
  <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
411
+ <title>MTP 3</title>
412
  <link rel="preconnect" href="https://fonts.googleapis.com">
413
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
414
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
 
728
  const response = await fetch('/generate', {
729
  method: 'POST',
730
  headers: { 'Content-Type': 'application/json' },
731
+ body: JSON.stringify({ text: text, max_tokens: 150, temperature: 0.7 }),
732
  signal: abortController.signal
733
  });
734
  const data = await response.json();
 
816
 
817
  if __name__ == "__main__":
818
  port = int(os.environ.get("PORT", 7860))
819
+ print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
820
  print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
821
  print(f"📡 API docs: http://0.0.0.0:{port}/docs")
822