teszenofficial commited on
Commit
facf38f
·
verified ·
1 Parent(s): c4b6ca5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +931 -0
app.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import json
5
+ import time
6
+ import gc
7
+ import re
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import HTMLResponse, StreamingResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ from huggingface_hub import snapshot_download
13
+ import uvicorn
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import sentencepiece as spm
18
+
19
+ # ======================
20
+ # CONFIGURACIÓN DE DISPOSITIVO
21
+ # ======================
22
+ if torch.cuda.is_available():
23
+ DEVICE = "cuda"
24
+ print("✅ GPU NVIDIA detectada. Usando CUDA.")
25
+ else:
26
+ DEVICE = "cpu"
27
+ print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
28
+
29
+ if DEVICE == "cpu":
30
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
31
+
32
+ torch.set_grad_enabled(False)
33
+
34
+ # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
+ MODEL_REPO = "TeszenAI/MTP-3.1.1" # <-- CAMBIA A TU REPO
36
+
37
+ # ======================
38
+ # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
+ # ======================
40
+
41
+ def clean_response(text: str) -> str:
42
+ """
43
+ Limpia la respuesta eliminando repeticiones, frases sin sentido y
44
+ asegurando que termine correctamente.
45
+ """
46
+ if not text:
47
+ return ""
48
+
49
+ # 1. Eliminar repeticiones excesivas de palabras o frases cortas
50
+ words = text.split()
51
+ cleaned_words = []
52
+ last_phrase = ""
53
+ repeat_count = 0
54
+
55
+ for word in words:
56
+ if word == last_phrase:
57
+ repeat_count += 1
58
+ if repeat_count > 2: # Si repite más de 2 veces seguidas
59
+ continue
60
+ else:
61
+ last_phrase = word
62
+ repeat_count = 0
63
+ cleaned_words.append(word)
64
+
65
+ text = " ".join(cleaned_words)
66
+
67
+ # 2. Eliminar patrones sin sentido (repeticiones de letras, caracteres raros)
68
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text) # aaa... -> aa
69
+ text = re.sub(r'[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ0-9\s.,;:!?¿¡()\-"]+', '', text)
70
+
71
+ # 3. Cortar en la primera frase que parezca final coherente
72
+ stop_patterns = [
73
+ r'(\.\s*)$', # Punto final
74
+ r'[.!?](\s+)?$', # Fin de oración
75
+ r'(gracias|hasta luego|adiós|saludos|fin|fin del mensaje)$',
76
+ r'(¿algo más\?|¿necesitas algo más\?|¿en qué más puedo ayudarte\?)'
77
+ ]
78
+
79
+ for pattern in stop_patterns:
80
+ match = re.search(pattern, text, re.IGNORECASE)
81
+ if match:
82
+ # Cortar justo después del patrón de finalización
83
+ end_pos = match.end()
84
+ text = text[:end_pos]
85
+ break
86
+
87
+ # 4. Si la respuesta es muy corta o vacía, devolver mensaje por defecto
88
+ if len(text.strip()) < 10:
89
+ return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
90
+
91
+ # 5. Eliminar espacios múltiples y saltos de línea excesivos
92
+ text = re.sub(r'\s+', ' ', text).strip()
93
+
94
+ return text
95
+
96
+ def should_stop_generation(generated_text: str, min_length: int = 30, max_length: int = 300) -> bool:
97
+ """
98
+ Determina si debemos detener la generación basado en el texto generado.
99
+ """
100
+ # Si ya superamos la longitud máxima
101
+ if len(generated_text) > max_length:
102
+ return True
103
+
104
+ # Si es muy corto y no hay puntuación final
105
+ if len(generated_text) < min_length and not re.search(r'[.!?]$', generated_text):
106
+ return False
107
+
108
+ # Señales de que ya terminó la respuesta
109
+ stop_signals = [
110
+ r'(gracias por tu pregunta|espero haberte ayudado|¿necesitas algo más\?)',
111
+ r'(hasta luego|adiós|quedo atento|saludos cordiales)',
112
+ r'(fin del mensaje|fin de la conversación)'
113
+ ]
114
+
115
+ for signal in stop_signals:
116
+ if re.search(signal, generated_text, re.IGNORECASE):
117
+ return True
118
+
119
+ # Si la última frase parece completa
120
+ last_sentence = generated_text.split('.')[-1].strip()
121
+ if len(last_sentence) > 5 and re.search(r'[.!?]$', last_sentence):
122
+ # Y ya hemos generado suficiente contenido
123
+ if len(generated_text) > min_length:
124
+ return True
125
+
126
+ return False
127
+
128
+ # ======================
129
+ # DEFINIR ARQUITECTURA DEL MODELO (MTP)
130
+ # ======================
131
+ class LayerNorm(nn.Module):
132
+ def __init__(self, d_model: int, eps: float = 1e-5):
133
+ super().__init__()
134
+ self.weight = nn.Parameter(torch.ones(d_model))
135
+ self.bias = nn.Parameter(torch.zeros(d_model))
136
+ self.eps = eps
137
+
138
+ def forward(self, x):
139
+ mean = x.mean(-1, keepdim=True)
140
+ std = x.std(-1, keepdim=True)
141
+ return self.weight * (x - mean) / (std + self.eps) + self.bias
142
+
143
+ class MultiHeadAttention(nn.Module):
144
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
145
+ super().__init__()
146
+ assert d_model % n_heads == 0
147
+ self.d_model = d_model
148
+ self.n_heads = n_heads
149
+ self.d_k = d_model // n_heads
150
+ self.w_q = nn.Linear(d_model, d_model)
151
+ self.w_k = nn.Linear(d_model, d_model)
152
+ self.w_v = nn.Linear(d_model, d_model)
153
+ self.w_o = nn.Linear(d_model, d_model)
154
+ self.dropout = nn.Dropout(dropout)
155
+ self.scale = math.sqrt(self.d_k)
156
+
157
+ def forward(self, x, mask=None):
158
+ batch_size, seq_len, _ = x.shape
159
+ Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
160
+ K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
161
+ V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
162
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
163
+ if mask is not None:
164
+ scores = scores.masked_fill(mask == 0, float('-inf'))
165
+ attn_weights = F.softmax(scores, dim=-1)
166
+ attn_weights = self.dropout(attn_weights)
167
+ attn_output = torch.matmul(attn_weights, V)
168
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
169
+ return self.w_o(attn_output)
170
+
171
+ class FeedForward(nn.Module):
172
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
173
+ super().__init__()
174
+ self.linear1 = nn.Linear(d_model, d_ff)
175
+ self.linear2 = nn.Linear(d_ff, d_model)
176
+ self.dropout = nn.Dropout(dropout)
177
+
178
+ def forward(self, x):
179
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
180
+
181
+ class TransformerBlock(nn.Module):
182
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
183
+ super().__init__()
184
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
185
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
186
+ self.norm1 = LayerNorm(d_model)
187
+ self.norm2 = LayerNorm(d_model)
188
+ self.dropout1 = nn.Dropout(dropout)
189
+ self.dropout2 = nn.Dropout(dropout)
190
+
191
+ def forward(self, x, mask=None):
192
+ attn_output = self.attention(x, mask)
193
+ x = x + self.dropout1(attn_output)
194
+ x = self.norm1(x)
195
+ ff_output = self.feed_forward(x)
196
+ x = x + self.dropout2(ff_output)
197
+ x = self.norm2(x)
198
+ return x
199
+
200
+ class PositionalEncoding(nn.Module):
201
+ def __init__(self, d_model: int, max_len: int = 5000):
202
+ super().__init__()
203
+ pe = torch.zeros(max_len, d_model)
204
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
205
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
206
+ pe[:, 0::2] = torch.sin(position * div_term)
207
+ pe[:, 1::2] = torch.cos(position * div_term)
208
+ self.register_buffer('pe', pe.unsqueeze(0))
209
+
210
+ def forward(self, x):
211
+ return x + self.pe[:, :x.size(1), :]
212
+
213
+ class MTPModel(nn.Module):
214
+ def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
215
+ n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
216
+ super().__init__()
217
+ self.vocab_size = vocab_size
218
+ self.d_model = d_model
219
+ self.max_len = max_len
220
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
221
+ self.pos_encoding = PositionalEncoding(d_model, max_len)
222
+ self.blocks = nn.ModuleList([
223
+ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
224
+ ])
225
+ self.norm = LayerNorm(d_model)
226
+ self.lm_head = nn.Linear(d_model, vocab_size)
227
+
228
+ def forward(self, x, mask=None):
229
+ if mask is None:
230
+ mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
231
+ x = self.token_embedding(x) * math.sqrt(self.d_model)
232
+ x = self.pos_encoding(x)
233
+ for block in self.blocks:
234
+ x = block(x, mask)
235
+ x = self.norm(x)
236
+ logits = self.lm_head(x)
237
+ return logits
238
+
239
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
240
+ """Método de generación mejorado con detección inteligente de fin"""
241
+ generated = input_ids
242
+ generated_text = ""
243
+ min_response_length = 30
244
+ max_response_length = max_new_tokens * 2
245
+
246
+ for step in range(max_new_tokens):
247
+ with torch.no_grad():
248
+ logits = self(generated)
249
+ next_logits = logits[0, -1, :] / temperature
250
+
251
+ if repetition_penalty != 1.0:
252
+ for token_id in set(generated[0].tolist()):
253
+ next_logits[token_id] /= repetition_penalty
254
+
255
+ if top_k > 0:
256
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
257
+ next_logits[indices_to_remove] = float('-inf')
258
+
259
+ if top_p < 1.0:
260
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
261
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
262
+ sorted_indices_to_remove = cumulative_probs > top_p
263
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
264
+ sorted_indices_to_remove[..., 0] = 0
265
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
266
+ next_logits[indices_to_remove] = float('-inf')
267
+
268
+ probs = F.softmax(next_logits, dim=-1)
269
+ next_token = torch.multinomial(probs, num_samples=1).item()
270
+
271
+ if next_token == 3: # EOS ID para SentencePiece
272
+ break
273
+
274
+ generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
275
+
276
+ # Decodificar parcialmente para verificar si debemos parar (solo cada 10 pasos para eficiencia)
277
+ if step > 10 and step % 10 == 0:
278
+ # Intentar decodificar tokens generados (esto es aproximado, el tokenizador real está fuera)
279
+ if len(generated[0]) > 10:
280
+ if should_stop_generation(str(generated[0].tolist()), min_response_length, max_response_length):
281
+ break
282
+
283
+ return generated
284
+
285
+ # ======================
286
+ # DESCARGA Y CARGA DEL MODELO
287
+ # ======================
288
+ print(f"📦 Descargando modelo desde {MODEL_REPO}...")
289
+ repo_path = snapshot_download(
290
+ repo_id=MODEL_REPO,
291
+ repo_type="model",
292
+ local_dir="mtp_repo"
293
+ )
294
+
295
+ # Cargar configuración
296
+ config_path = os.path.join(repo_path, "config.json")
297
+ if os.path.exists(config_path):
298
+ with open(config_path, "r") as f:
299
+ config = json.load(f)
300
+ else:
301
+ config = {
302
+ "vocab_size": 5000,
303
+ "d_model": 256,
304
+ "n_heads": 8,
305
+ "n_layers": 6,
306
+ "d_ff": 1024,
307
+ "dropout": 0.1,
308
+ "max_len": 512
309
+ }
310
+
311
+ # Cargar tokenizador
312
+ tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
313
+ sp = spm.SentencePieceProcessor()
314
+ sp.load(tokenizer_path)
315
+ VOCAB_SIZE = sp.get_piece_size()
316
+
317
+ # Actualizar vocab_size en config
318
+ config["vocab_size"] = VOCAB_SIZE
319
+
320
+ print(f"🧠 Inicializando modelo MTP...")
321
+ print(f" → Vocabulario: {VOCAB_SIZE}")
322
+ print(f" → Dimensión: {config['d_model']}")
323
+ print(f" → Capas: {config['n_layers']}")
324
+ print(f" → Heads: {config['n_heads']}")
325
+
326
+ model = MTPModel(**config)
327
+ model.to(DEVICE)
328
+
329
+ # Cargar pesos del modelo
330
+ model_path = os.path.join(repo_path, "mtp_model.pt")
331
+ if os.path.exists(model_path):
332
+ state_dict = torch.load(model_path, map_location=DEVICE)
333
+ model.load_state_dict(state_dict)
334
+ print("✅ Pesos del modelo cargados")
335
+ else:
336
+ print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
337
+
338
+ model.eval()
339
+
340
+ # Cuantización para CPU
341
+ if DEVICE == "cpu":
342
+ print("⚡ Aplicando cuantización dinámica para CPU...")
343
+ try:
344
+ model = torch.quantization.quantize_dynamic(
345
+ model,
346
+ {nn.Linear},
347
+ dtype=torch.qint8
348
+ )
349
+ except Exception as e:
350
+ print(f"⚠️ No se pudo aplicar cuantización: {e}")
351
+
352
+ param_count = sum(p.numel() for p in model.parameters())
353
+ print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
354
+
355
+ # ======================
356
+ # API CONFIG
357
+ # ======================
358
+ app = FastAPI(
359
+ title="MTP API",
360
+ description="API para modelo de lenguaje MTP",
361
+ version="1.0"
362
+ )
363
+
364
+ app.add_middleware(
365
+ CORSMiddleware,
366
+ allow_origins=["*"],
367
+ allow_methods=["*"],
368
+ allow_headers=["*"],
369
+ )
370
+
371
+ class PromptRequest(BaseModel):
372
+ text: str = Field(..., max_length=2000, description="Texto de entrada")
373
+ max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
374
+ temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
375
+ top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
376
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
377
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
378
+
379
+ def build_prompt(user_input: str) -> str:
380
+ """Construye el prompt en el formato del modelo"""
381
+ return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
382
+
383
+ # ======================
384
+ # GESTIÓN DE CARGA
385
+ # ======================
386
+ ACTIVE_REQUESTS = 0
387
+
388
+ class MTPTokenizer:
389
+ """Wrapper para el tokenizador de SentencePiece"""
390
+ def __init__(self, sp_model):
391
+ self.sp = sp_model
392
+
393
+ def encode(self, text):
394
+ return self.sp.encode(text)
395
+
396
+ def decode(self, tokens):
397
+ return self.sp.decode(tokens)
398
+
399
+ def bos_id(self):
400
+ return self.sp.bos_id()
401
+
402
+ def eos_id(self):
403
+ return self.sp.eos_id()
404
+
405
+ def pad_id(self):
406
+ return self.sp.pad_id()
407
+
408
+ tokenizer_wrapper = MTPTokenizer(sp)
409
+
410
+ @app.post("/generate")
411
+ async def generate(req: PromptRequest):
412
+ """Endpoint principal de generación de texto"""
413
+ global ACTIVE_REQUESTS
414
+ ACTIVE_REQUESTS += 1
415
+
416
+ dyn_max_tokens = req.max_tokens
417
+ dyn_temperature = req.temperature
418
+
419
+ if ACTIVE_REQUESTS > 2:
420
+ print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
421
+ dyn_max_tokens = min(dyn_max_tokens, 120)
422
+ dyn_temperature = max(0.5, dyn_temperature * 0.9)
423
+
424
+ user_input = req.text.strip()
425
+ if not user_input:
426
+ ACTIVE_REQUESTS -= 1
427
+ return {"reply": "", "tokens_generated": 0}
428
+
429
+ full_prompt = build_prompt(user_input)
430
+ tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
431
+ input_ids = torch.tensor([tokens], device=DEVICE)
432
+
433
+ try:
434
+ with torch.no_grad():
435
+ output_ids = model.generate(
436
+ input_ids,
437
+ max_new_tokens=dyn_max_tokens,
438
+ temperature=dyn_temperature,
439
+ top_k=req.top_k,
440
+ top_p=req.top_p,
441
+ repetition_penalty=req.repetition_penalty
442
+ )
443
+
444
+ gen_tokens = output_ids[0, len(tokens):].tolist()
445
+
446
+ safe_tokens = [
447
+ t for t in gen_tokens
448
+ if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
449
+ ]
450
+
451
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
452
+
453
+ if "###" in response:
454
+ response = response.split("###")[0].strip()
455
+
456
+ # Aplicar limpieza inteligente a la respuesta
457
+ response = clean_response(response)
458
+
459
+ return {
460
+ "reply": response,
461
+ "tokens_generated": len(safe_tokens),
462
+ "model": "MTP"
463
+ }
464
+
465
+ except Exception as e:
466
+ print(f"❌ Error durante generación: {e}")
467
+ return {
468
+ "reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
469
+ "error": str(e)
470
+ }
471
+
472
+ finally:
473
+ ACTIVE_REQUESTS -= 1
474
+ if DEVICE == "cuda":
475
+ torch.cuda.empty_cache()
476
+ gc.collect()
477
+
478
+ # ======================
479
+ # ENDPOINTS DE INFORMACIÓN
480
+ # ======================
481
+ @app.get("/health")
482
+ def health_check():
483
+ return {
484
+ "status": "healthy",
485
+ "model": "MTP",
486
+ "device": DEVICE,
487
+ "active_requests": ACTIVE_REQUESTS,
488
+ "vocab_size": VOCAB_SIZE
489
+ }
490
+
491
+ @app.get("/info")
492
+ def model_info():
493
+ return {
494
+ "model_name": "MTP",
495
+ "version": "1.0",
496
+ "architecture": config,
497
+ "parameters": sum(p.numel() for p in model.parameters()),
498
+ "device": DEVICE
499
+ }
500
+
501
+ # ======================
502
+ # INTERFAZ WEB (MODERNA CON LOGO INTEGRADO)
503
+ # ======================
504
+ @app.get("/", response_class=HTMLResponse)
505
+ def chat_ui():
506
+ return """
507
+ <!DOCTYPE html>
508
+ <html lang="es">
509
+ <head>
510
+ <meta charset="UTF-8">
511
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
512
+ <title>MTP - Asistente IA</title>
513
+ <link rel="preconnect" href="https://fonts.googleapis.com">
514
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
515
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
516
+ <style>
517
+ :root {
518
+ --bg-color: #131314;
519
+ --surface-color: #1E1F20;
520
+ --accent-color: #4a9eff;
521
+ --text-primary: #e3e3e3;
522
+ --text-secondary: #9aa0a6;
523
+ --user-bubble: #282a2c;
524
+ }
525
+ * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
526
+ body {
527
+ margin: 0;
528
+ background-color: var(--bg-color);
529
+ font-family: 'Inter', sans-serif;
530
+ color: var(--text-primary);
531
+ height: 100dvh;
532
+ display: flex;
533
+ flex-direction: column;
534
+ overflow: hidden;
535
+ }
536
+ header {
537
+ padding: 12px 20px;
538
+ display: flex;
539
+ align-items: center;
540
+ justify-content: space-between;
541
+ background: rgba(19, 19, 20, 0.85);
542
+ backdrop-filter: blur(12px);
543
+ position: fixed;
544
+ top: 0;
545
+ width: 100%;
546
+ z-index: 50;
547
+ border-bottom: 1px solid rgba(255,255,255,0.05);
548
+ }
549
+ .brand-wrapper {
550
+ display: flex;
551
+ align-items: center;
552
+ gap: 12px;
553
+ cursor: pointer;
554
+ }
555
+ .brand-logo {
556
+ width: 32px;
557
+ height: 32px;
558
+ border-radius: 50%;
559
+ background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
560
+ background-size: cover;
561
+ background-position: center;
562
+ background-repeat: no-repeat;
563
+ border: 1px solid rgba(255,255,255,0.1);
564
+ }
565
+ .brand-text {
566
+ font-weight: 500;
567
+ font-size: 1.05rem;
568
+ display: flex;
569
+ align-items: center;
570
+ gap: 8px;
571
+ }
572
+ .version-badge {
573
+ font-size: 0.75rem;
574
+ background: rgba(74, 158, 255, 0.15);
575
+ color: #8ab4f8;
576
+ padding: 2px 8px;
577
+ border-radius: 12px;
578
+ font-weight: 600;
579
+ }
580
+ .chat-scroll {
581
+ flex: 1;
582
+ overflow-y: auto;
583
+ padding: 80px 20px 40px 20px;
584
+ display: flex;
585
+ flex-direction: column;
586
+ gap: 30px;
587
+ max-width: 850px;
588
+ margin: 0 auto;
589
+ width: 100%;
590
+ scroll-behavior: smooth;
591
+ }
592
+ .msg-row {
593
+ display: flex;
594
+ gap: 16px;
595
+ width: 100%;
596
+ opacity: 0;
597
+ transform: translateY(10px);
598
+ animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
599
+ }
600
+ .msg-row.user { justify-content: flex-end; }
601
+ .msg-row.bot { justify-content: flex-start; align-items: flex-start; }
602
+ .msg-content {
603
+ line-height: 1.6;
604
+ font-size: 1rem;
605
+ word-wrap: break-word;
606
+ max-width: 85%;
607
+ }
608
+ .user .msg-content {
609
+ background-color: var(--user-bubble);
610
+ padding: 10px 18px;
611
+ border-radius: 18px;
612
+ border-top-right-radius: 4px;
613
+ color: #fff;
614
+ }
615
+ .bot .msg-content-wrapper {
616
+ display: flex;
617
+ flex-direction: column;
618
+ gap: 8px;
619
+ width: 100%;
620
+ }
621
+ .bot .msg-text {
622
+ padding-top: 6px;
623
+ color: var(--text-primary);
624
+ }
625
+ .bot-avatar {
626
+ width: 34px;
627
+ height: 34px;
628
+ min-width: 34px;
629
+ border-radius: 50%;
630
+ background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
631
+ background-size: cover;
632
+ background-position: center;
633
+ background-repeat: no-repeat;
634
+ box-shadow: 0 2px 6px rgba(0,0,0,0.2);
635
+ }
636
+ .bot-actions {
637
+ display: flex;
638
+ gap: 10px;
639
+ opacity: 0;
640
+ transition: opacity 0.3s;
641
+ margin-top: 5px;
642
+ }
643
+ .action-btn {
644
+ background: transparent;
645
+ border: none;
646
+ color: var(--text-secondary);
647
+ cursor: pointer;
648
+ padding: 4px;
649
+ border-radius: 4px;
650
+ display: flex;
651
+ align-items: center;
652
+ transition: color 0.2s, background 0.2s;
653
+ }
654
+ .action-btn:hover {
655
+ color: var(--text-primary);
656
+ background: rgba(255,255,255,0.08);
657
+ }
658
+ .action-btn svg { width: 16px; height: 16px; fill: currentColor; }
659
+ .typing-cursor::after {
660
+ content: '▊';
661
+ display: inline-block;
662
+ margin-left: 2px;
663
+ animation: blink 1s infinite;
664
+ }
665
+ .footer-container {
666
+ padding: 0 20px 20px 20px;
667
+ background: linear-gradient(to top, var(--bg-color) 85%, transparent);
668
+ position: relative;
669
+ z-index: 60;
670
+ }
671
+ .input-box {
672
+ max-width: 850px;
673
+ margin: 0 auto;
674
+ background: var(--surface-color);
675
+ border-radius: 28px;
676
+ padding: 8px 10px 8px 20px;
677
+ display: flex;
678
+ align-items: center;
679
+ border: 1px solid rgba(255,255,255,0.1);
680
+ transition: border-color 0.2s, box-shadow 0.2s;
681
+ }
682
+ .input-box:focus-within {
683
+ border-color: rgba(74, 158, 255, 0.5);
684
+ box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
685
+ }
686
+ #userInput {
687
+ flex: 1;
688
+ background: transparent;
689
+ border: none;
690
+ color: white;
691
+ font-size: 1rem;
692
+ font-family: inherit;
693
+ padding: 10px 0;
694
+ }
695
+ #mainBtn {
696
+ background: white;
697
+ color: black;
698
+ border: none;
699
+ width: 36px;
700
+ height: 36px;
701
+ border-radius: 50%;
702
+ display: flex;
703
+ align-items: center;
704
+ justify-content: center;
705
+ cursor: pointer;
706
+ margin-left: 8px;
707
+ transition: transform 0.2s;
708
+ }
709
+ #mainBtn:hover { transform: scale(1.05); }
710
+ .disclaimer {
711
+ text-align: center;
712
+ font-size: 0.75rem;
713
+ color: #666;
714
+ margin-top: 12px;
715
+ }
716
+ @keyframes slideUpFade {
717
+ from { opacity: 0; transform: translateY(15px); }
718
+ to { opacity: 1; transform: translateY(0); }
719
+ }
720
+ @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
721
+ @keyframes pulseAvatar {
722
+ 0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
723
+ 70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
724
+ 100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
725
+ }
726
+ .pulsing { animation: pulseAvatar 1.5s infinite; }
727
+ ::-webkit-scrollbar { width: 8px; }
728
+ ::-webkit-scrollbar-track { background: transparent; }
729
+ ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
730
+ </style>
731
+ </head>
732
+ <body>
733
+ <header>
734
+ <div class="brand-wrapper" onclick="location.reload()">
735
+ <div class="brand-logo"></div>
736
+ <div class="brand-text">
737
+ MTP <span class="version-badge">v1</span>
738
+ </div>
739
+ </div>
740
+ </header>
741
+ <div id="chatScroll" class="chat-scroll">
742
+ <div class="msg-row bot" style="animation-delay: 0.1s;">
743
+ <div class="bot-avatar"></div>
744
+ <div class="msg-content-wrapper">
745
+ <div class="msg-text">
746
+ ¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
747
+ </div>
748
+ </div>
749
+ </div>
750
+ </div>
751
+ <div class="footer-container">
752
+ <div class="input-box">
753
+ <input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
754
+ <button id="mainBtn" onclick="handleBtnClick()">➤</button>
755
+ </div>
756
+ <div class="disclaimer">
757
+ MTP puede cometer errores. Considera verificar la información importante.
758
+ </div>
759
+ </div>
760
+ <script>
761
+ const chatScroll = document.getElementById('chatScroll');
762
+ const userInput = document.getElementById('userInput');
763
+ const mainBtn = document.getElementById('mainBtn');
764
+ let isGenerating = false;
765
+ let abortController = null;
766
+ let typingTimeout = null;
767
+ let lastUserPrompt = "";
768
+
769
+ function scrollToBottom() {
770
+ chatScroll.scrollTop = chatScroll.scrollHeight;
771
+ }
772
+
773
+ function setBtnState(state) {
774
+ if (state === 'sending') {
775
+ mainBtn.innerHTML = '⏹';
776
+ isGenerating = true;
777
+ } else {
778
+ mainBtn.innerHTML = '➤';
779
+ isGenerating = false;
780
+ abortController = null;
781
+ }
782
+ }
783
+
784
+ function handleBtnClick() {
785
+ if (isGenerating) {
786
+ stopGeneration();
787
+ } else {
788
+ sendMessage();
789
+ }
790
+ }
791
+
792
+ function stopGeneration() {
793
+ if (abortController) abortController.abort();
794
+ if (typingTimeout) clearTimeout(typingTimeout);
795
+ const activeCursor = document.querySelector('.typing-cursor');
796
+ if (activeCursor) activeCursor.classList.remove('typing-cursor');
797
+ const activeAvatar = document.querySelector('.pulsing');
798
+ if (activeAvatar) activeAvatar.classList.remove('pulsing');
799
+ setBtnState('idle');
800
+ userInput.focus();
801
+ }
802
+
803
+ async function sendMessage(textOverride = null) {
804
+ const text = textOverride || userInput.value.trim();
805
+ if (!text) return;
806
+ lastUserPrompt = text;
807
+ if (!textOverride) {
808
+ userInput.value = '';
809
+ addMessage(text, 'user');
810
+ }
811
+ setBtnState('sending');
812
+ abortController = new AbortController();
813
+ const botRow = document.createElement('div');
814
+ botRow.className = 'msg-row bot';
815
+ const avatar = document.createElement('div');
816
+ avatar.className = 'bot-avatar pulsing';
817
+ const wrapper = document.createElement('div');
818
+ wrapper.className = 'msg-content-wrapper';
819
+ const msgText = document.createElement('div');
820
+ msgText.className = 'msg-text';
821
+ wrapper.appendChild(msgText);
822
+ botRow.appendChild(avatar);
823
+ botRow.appendChild(wrapper);
824
+ chatScroll.appendChild(botRow);
825
+ scrollToBottom();
826
+ try {
827
+ const response = await fetch('/generate', {
828
+ method: 'POST',
829
+ headers: { 'Content-Type': 'application/json' },
830
+ body: JSON.stringify({ text: text }),
831
+ signal: abortController.signal
832
+ });
833
+ const data = await response.json();
834
+ if (!isGenerating) return;
835
+ avatar.classList.remove('pulsing');
836
+ const reply = data.reply || "No entendí eso.";
837
+ await typeWriter(msgText, reply);
838
+ if (isGenerating) {
839
+ addActions(wrapper, reply);
840
+ setBtnState('idle');
841
+ }
842
+ } catch (error) {
843
+ if (error.name === 'AbortError') {
844
+ msgText.textContent += " [Detenido]";
845
+ } else {
846
+ avatar.classList.remove('pulsing');
847
+ msgText.textContent = "Error de conexión.";
848
+ msgText.style.color = "#ff8b8b";
849
+ setBtnState('idle');
850
+ }
851
+ }
852
+ }
853
+
854
+ function addMessage(text, sender) {
855
+ const row = document.createElement('div');
856
+ row.className = `msg-row ${sender}`;
857
+ const content = document.createElement('div');
858
+ content.className = 'msg-content';
859
+ content.textContent = text;
860
+ row.appendChild(content);
861
+ chatScroll.appendChild(row);
862
+ scrollToBottom();
863
+ }
864
+
865
+ function typeWriter(element, text, speed = 12) {
866
+ return new Promise(resolve => {
867
+ let i = 0;
868
+ element.classList.add('typing-cursor');
869
+ function type() {
870
+ if (!isGenerating) {
871
+ element.classList.remove('typing-cursor');
872
+ resolve();
873
+ return;
874
+ }
875
+ if (i < text.length) {
876
+ element.textContent += text.charAt(i);
877
+ i++;
878
+ scrollToBottom();
879
+ typingTimeout = setTimeout(type, speed + Math.random() * 5);
880
+ } else {
881
+ element.classList.remove('typing-cursor');
882
+ resolve();
883
+ }
884
+ }
885
+ type();
886
+ });
887
+ }
888
+
889
+ function addActions(wrapperElement, textToCopy) {
890
+ const actionsDiv = document.createElement('div');
891
+ actionsDiv.className = 'bot-actions';
892
+ const copyBtn = document.createElement('button');
893
+ copyBtn.className = 'action-btn';
894
+ copyBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
895
+ copyBtn.onclick = () => {
896
+ navigator.clipboard.writeText(textToCopy);
897
+ };
898
+ const regenBtn = document.createElement('button');
899
+ regenBtn.className = 'action-btn';
900
+ regenBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
901
+ regenBtn.onclick = () => {
902
+ sendMessage(lastUserPrompt);
903
+ };
904
+ actionsDiv.appendChild(copyBtn);
905
+ actionsDiv.appendChild(regenBtn);
906
+ wrapperElement.appendChild(actionsDiv);
907
+ requestAnimationFrame(() => actionsDiv.style.opacity = "1");
908
+ scrollToBottom();
909
+ }
910
+
911
+ userInput.addEventListener('keydown', (e) => {
912
+ if (e.key === 'Enter') handleBtnClick();
913
+ });
914
+ window.onload = () => userInput.focus();
915
+ </script>
916
+ </body>
917
+ </html>
918
+ """
919
+
920
+ if __name__ == "__main__":
921
+ port = int(os.environ.get("PORT", 7860))
922
+ print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
923
+ print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
924
+ print(f"📡 API docs: http://0.0.0.0:{port}/docs")
925
+
926
+ uvicorn.run(
927
+ app,
928
+ host="0.0.0.0",
929
+ port=port,
930
+ log_level="info"
931
+ )