teszenofficial commited on
Commit
5fc9932
·
verified ·
1 Parent(s): 7af1d0b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +696 -0
app.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import json
5
+ import time
6
+ import gc
7
+ import re
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import HTMLResponse, StreamingResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ from huggingface_hub import snapshot_download
13
+ import uvicorn
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import sentencepiece as spm
18
+
19
+ # ======================
20
+ # CONFIGURACIÓN DE DISPOSITIVO
21
+ # ======================
22
+ if torch.cuda.is_available():
23
+ DEVICE = "cuda"
24
+ print("✅ GPU NVIDIA detectada. Usando CUDA.")
25
+ else:
26
+ DEVICE = "cpu"
27
+ print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
28
+
29
+ if DEVICE == "cpu":
30
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
31
+
32
+ torch.set_grad_enabled(False)
33
+
34
+ # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
+ MODEL_REPO = "TeszenAI/MTP-3.2"
36
+
37
+ # ======================
38
+ # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
+ # ======================
40
+
41
+ def truncate_greeting_response(text: str) -> str:
42
+ """
43
+ Para respuestas de saludo, trunca SOLO en el primer PUNTO (.)
44
+ No usa signos de exclamación o interrogación.
45
+ """
46
+ if not text:
47
+ return text
48
+
49
+ # Buscar el primer PUNTO (.)
50
+ end_match = re.search(r'\.', text)
51
+
52
+ if end_match:
53
+ # Cortar justo después del punto
54
+ end_pos = end_match.end()
55
+ truncated = text[:end_pos].strip()
56
+ return truncated
57
+
58
+ # Si no hay punto, devolver solo primeras 80 caracteres
59
+ if len(text) > 80:
60
+ return text[:80] + "..."
61
+ return text
62
+
63
+ def clean_response(text: str, user_input: str = "") -> str:
64
+ """Limpia la respuesta del modelo"""
65
+ if not text:
66
+ return ""
67
+
68
+ # Eliminar repeticiones excesivas
69
+ words = text.split()
70
+ cleaned_words = []
71
+ last_word = ""
72
+ repeat_count = 0
73
+
74
+ for word in words:
75
+ if word == last_word:
76
+ repeat_count += 1
77
+ if repeat_count > 2:
78
+ continue
79
+ else:
80
+ last_word = word
81
+ repeat_count = 0
82
+ cleaned_words.append(word)
83
+
84
+ text = " ".join(cleaned_words)
85
+
86
+ # Eliminar caracteres raros
87
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text)
88
+
89
+ # Detectar si es un saludo
90
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
91
+
92
+ if is_greeting and text:
93
+ # Para saludos, truncar SOLO en el primer PUNTO (.)
94
+ punct_match = re.search(r'\.', text)
95
+ if punct_match:
96
+ text = text[:punct_match.end()].strip()
97
+ else:
98
+ # Si no hay punto, tomar solo la primera oración o 60 caracteres
99
+ first_sentence = text.split('.')[0].strip()
100
+ if len(first_sentence) > 5:
101
+ text = first_sentence
102
+ elif len(text) > 60:
103
+ text = text[:60]
104
+
105
+ # Si la respuesta es muy corta o vacía
106
+ if len(text.strip()) < 5:
107
+ if is_greeting:
108
+ return "¡Hola! ¿En qué puedo ayudarte?"
109
+ return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
110
+
111
+ # Eliminar espacios múltiples
112
+ text = re.sub(r'\s+', ' ', text).strip()
113
+
114
+ return text
115
+
116
+ # ======================
117
+ # DEFINIR ARQUITECTURA DEL MODELO (MTP)
118
+ # ======================
119
+ class LayerNorm(nn.Module):
120
+ def __init__(self, d_model: int, eps: float = 1e-5):
121
+ super().__init__()
122
+ self.weight = nn.Parameter(torch.ones(d_model))
123
+ self.bias = nn.Parameter(torch.zeros(d_model))
124
+ self.eps = eps
125
+
126
+ def forward(self, x):
127
+ mean = x.mean(-1, keepdim=True)
128
+ std = x.std(-1, keepdim=True)
129
+ return self.weight * (x - mean) / (std + self.eps) + self.bias
130
+
131
+ class MultiHeadAttention(nn.Module):
132
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
133
+ super().__init__()
134
+ assert d_model % n_heads == 0
135
+ self.d_model = d_model
136
+ self.n_heads = n_heads
137
+ self.d_k = d_model // n_heads
138
+ self.w_q = nn.Linear(d_model, d_model)
139
+ self.w_k = nn.Linear(d_model, d_model)
140
+ self.w_v = nn.Linear(d_model, d_model)
141
+ self.w_o = nn.Linear(d_model, d_model)
142
+ self.dropout = nn.Dropout(dropout)
143
+ self.scale = math.sqrt(self.d_k)
144
+
145
+ def forward(self, x, mask=None):
146
+ batch_size, seq_len, _ = x.shape
147
+ Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
148
+ K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
149
+ V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
150
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
151
+ if mask is not None:
152
+ scores = scores.masked_fill(mask == 0, float('-inf'))
153
+ attn_weights = F.softmax(scores, dim=-1)
154
+ attn_weights = self.dropout(attn_weights)
155
+ attn_output = torch.matmul(attn_weights, V)
156
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
157
+ return self.w_o(attn_output)
158
+
159
+ class FeedForward(nn.Module):
160
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
161
+ super().__init__()
162
+ self.linear1 = nn.Linear(d_model, d_ff)
163
+ self.linear2 = nn.Linear(d_ff, d_model)
164
+ self.dropout = nn.Dropout(dropout)
165
+
166
+ def forward(self, x):
167
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
168
+
169
+ class TransformerBlock(nn.Module):
170
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
171
+ super().__init__()
172
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
173
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
174
+ self.norm1 = LayerNorm(d_model)
175
+ self.norm2 = LayerNorm(d_model)
176
+ self.dropout1 = nn.Dropout(dropout)
177
+ self.dropout2 = nn.Dropout(dropout)
178
+
179
+ def forward(self, x, mask=None):
180
+ attn_output = self.attention(x, mask)
181
+ x = x + self.dropout1(attn_output)
182
+ x = self.norm1(x)
183
+ ff_output = self.feed_forward(x)
184
+ x = x + self.dropout2(ff_output)
185
+ x = self.norm2(x)
186
+ return x
187
+
188
+ class PositionalEncoding(nn.Module):
189
+ def __init__(self, d_model: int, max_len: int = 5000):
190
+ super().__init__()
191
+ pe = torch.zeros(max_len, d_model)
192
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
193
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
194
+ pe[:, 0::2] = torch.sin(position * div_term)
195
+ pe[:, 1::2] = torch.cos(position * div_term)
196
+ self.register_buffer('pe', pe.unsqueeze(0))
197
+
198
+ def forward(self, x):
199
+ return x + self.pe[:, :x.size(1), :]
200
+
201
+ class MTPModel(nn.Module):
202
+ def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
203
+ n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
204
+ super().__init__()
205
+ self.vocab_size = vocab_size
206
+ self.d_model = d_model
207
+ self.max_len = max_len
208
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
209
+ self.pos_encoding = PositionalEncoding(d_model, max_len)
210
+ self.blocks = nn.ModuleList([
211
+ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
212
+ ])
213
+ self.norm = LayerNorm(d_model)
214
+ self.lm_head = nn.Linear(d_model, vocab_size)
215
+
216
+ def forward(self, x, mask=None):
217
+ if mask is None:
218
+ mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
219
+ x = self.token_embedding(x) * math.sqrt(self.d_model)
220
+ x = self.pos_encoding(x)
221
+ for block in self.blocks:
222
+ x = block(x, mask)
223
+ x = self.norm(x)
224
+ logits = self.lm_head(x)
225
+ return logits
226
+
227
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
228
+ """Genera texto token por token"""
229
+ generated = input_ids
230
+
231
+ for step in range(max_new_tokens):
232
+ with torch.no_grad():
233
+ logits = self(generated)
234
+ next_logits = logits[0, -1, :] / temperature
235
+
236
+ if repetition_penalty != 1.0:
237
+ for token_id in set(generated[0].tolist()):
238
+ next_logits[token_id] /= repetition_penalty
239
+
240
+ if top_k > 0:
241
+ indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
242
+ next_logits[indices_to_remove] = float('-inf')
243
+
244
+ if top_p < 1.0:
245
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
246
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
247
+ sorted_indices_to_remove = cumulative_probs > top_p
248
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
249
+ sorted_indices_to_remove[..., 0] = 0
250
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
251
+ next_logits[indices_to_remove] = float('-inf')
252
+
253
+ probs = F.softmax(next_logits, dim=-1)
254
+ next_token = torch.multinomial(probs, num_samples=1).item()
255
+
256
+ # EOS ID común para SentencePiece
257
+ if next_token == 2 or next_token == 3:
258
+ break
259
+
260
+ generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
261
+
262
+ return generated
263
+
264
+ # ======================
265
+ # DESCARGA Y CARGA DEL MODELO
266
+ # ======================
267
+ print(f"📦 Descargando modelo desde {MODEL_REPO}...")
268
+ repo_path = snapshot_download(
269
+ repo_id=MODEL_REPO,
270
+ repo_type="model",
271
+ local_dir="mtp_repo"
272
+ )
273
+
274
+ # Cargar configuraci��n
275
+ config_path = os.path.join(repo_path, "config.json")
276
+ if os.path.exists(config_path):
277
+ with open(config_path, "r") as f:
278
+ config = json.load(f)
279
+ else:
280
+ config = {
281
+ "vocab_size": 5000,
282
+ "d_model": 256,
283
+ "n_heads": 8,
284
+ "n_layers": 6,
285
+ "d_ff": 1024,
286
+ "dropout": 0.1,
287
+ "max_len": 512
288
+ }
289
+
290
+ # Cargar tokenizador
291
+ tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
292
+ if not os.path.exists(tokenizer_path):
293
+ print(f"❌ Tokenizador no encontrado en {tokenizer_path}")
294
+ sys.exit(1)
295
+
296
+ sp = spm.SentencePieceProcessor()
297
+ sp.load(tokenizer_path)
298
+ VOCAB_SIZE = sp.get_piece_size()
299
+
300
+ # Actualizar vocab_size en config
301
+ config["vocab_size"] = VOCAB_SIZE
302
+
303
+ print(f"🧠 Inicializando modelo MTP...")
304
+ print(f" → Vocabulario: {VOCAB_SIZE}")
305
+ print(f" → Dimensión: {config['d_model']}")
306
+ print(f" → Capas: {config['n_layers']}")
307
+ print(f" → Heads: {config['n_heads']}")
308
+
309
+ model = MTPModel(**config)
310
+ model.to(DEVICE)
311
+
312
+ # Cargar pesos del modelo
313
+ model_path = os.path.join(repo_path, "mtp_model.pt")
314
+ if os.path.exists(model_path):
315
+ state_dict = torch.load(model_path, map_location=DEVICE)
316
+ model.load_state_dict(state_dict, strict=False)
317
+ print("✅ Pesos del modelo cargados")
318
+ else:
319
+ print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios")
320
+
321
+ model.eval()
322
+
323
+ param_count = sum(p.numel() for p in model.parameters())
324
+ print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
325
+
326
+ # ======================
327
+ # API CONFIG
328
+ # ======================
329
+ app = FastAPI(
330
+ title="MTP API",
331
+ description="API para modelo de lenguaje MTP",
332
+ version="1.0"
333
+ )
334
+
335
+ app.add_middleware(
336
+ CORSMiddleware,
337
+ allow_origins=["*"],
338
+ allow_methods=["*"],
339
+ allow_headers=["*"],
340
+ )
341
+
342
+ class PromptRequest(BaseModel):
343
+ text: str = Field(..., max_length=2000, description="Texto de entrada")
344
+ max_tokens: int = Field(default=150, ge=10, le=250, description="Tokens máximos a generar")
345
+ temperature: float = Field(default=0.3, ge=0.1, le=2.0, description="Temperatura de muestreo")
346
+ top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
347
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
348
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
349
+
350
+ def build_prompt(user_input: str) -> str:
351
+ """Construye el prompt en el formato del modelo"""
352
+ return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
353
+
354
+ # ======================
355
+ # GESTIÓN DE CARGA
356
+ # ======================
357
+ ACTIVE_REQUESTS = 0
358
+
359
+ class MTPTokenizer:
360
+ """Wrapper para el tokenizador de SentencePiece"""
361
+ def __init__(self, sp_model):
362
+ self.sp = sp_model
363
+
364
+ def encode(self, text):
365
+ return self.sp.encode(text)
366
+
367
+ def decode(self, tokens):
368
+ return self.sp.decode(tokens)
369
+
370
+ def bos_id(self):
371
+ return self.sp.bos_id()
372
+
373
+ def eos_id(self):
374
+ return self.sp.eos_id()
375
+
376
+ def pad_id(self):
377
+ return self.sp.pad_id()
378
+
379
+ tokenizer_wrapper = MTPTokenizer(sp)
380
+
381
+ @app.post("/generate")
382
+ async def generate(req: PromptRequest):
383
+ """Endpoint principal de generación de texto"""
384
+ global ACTIVE_REQUESTS
385
+ ACTIVE_REQUESTS += 1
386
+
387
+ user_input = req.text.strip()
388
+ if not user_input:
389
+ ACTIVE_REQUESTS -= 1
390
+ return {"reply": "", "tokens_generated": 0}
391
+
392
+ # Detectar si es un saludo
393
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
394
+
395
+ # Si es saludo, usar menos tokens
396
+ max_tokens = 30 if is_greeting else req.max_tokens
397
+
398
+ full_prompt = build_prompt(user_input)
399
+ tokens = tokenizer_wrapper.encode(full_prompt)
400
+ input_ids = torch.tensor([tokens], device=DEVICE)
401
+
402
+ try:
403
+ with torch.no_grad():
404
+ output_ids = model.generate(
405
+ input_ids,
406
+ max_new_tokens=max_tokens,
407
+ temperature=req.temperature,
408
+ top_k=req.top_k,
409
+ top_p=req.top_p,
410
+ repetition_penalty=req.repetition_penalty
411
+ )
412
+
413
+ gen_tokens = output_ids[0, len(tokens):].tolist()
414
+
415
+ # Filtrar tokens inválidos
416
+ safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE]
417
+
418
+ if safe_tokens:
419
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
420
+ else:
421
+ response = ""
422
+
423
+ # Limpiar respuesta
424
+ response = clean_response(response, user_input)
425
+
426
+ # Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
427
+ if len(response) < 3:
428
+ if is_greeting:
429
+ response = "¡Hola! ¿En qué puedo ayudarte?"
430
+ else:
431
+ response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?"
432
+
433
+ return {
434
+ "reply": response,
435
+ "tokens_generated": len(safe_tokens),
436
+ "model": "MTP"
437
+ }
438
+
439
+ except Exception as e:
440
+ print(f"❌ Error durante generación: {e}")
441
+ if is_greeting:
442
+ fallback = "¡Hola! ¿En qué puedo ayudarte?"
443
+ else:
444
+ fallback = "Lo siento, ocurrió un error al procesar tu solicitud."
445
+ return {
446
+ "reply": fallback,
447
+ "error": str(e)
448
+ }
449
+
450
+ finally:
451
+ ACTIVE_REQUESTS -= 1
452
+ if DEVICE == "cuda":
453
+ torch.cuda.empty_cache()
454
+ gc.collect()
455
+
456
+ # ======================
457
+ # ENDPOINTS DE INFORMACIÓN
458
+ # ======================
459
+ @app.get("/health")
460
+ def health_check():
461
+ return {
462
+ "status": "healthy",
463
+ "model": "MTP",
464
+ "device": DEVICE,
465
+ "active_requests": ACTIVE_REQUESTS,
466
+ "vocab_size": VOCAB_SIZE
467
+ }
468
+
469
+ @app.get("/info")
470
+ def model_info():
471
+ return {
472
+ "model_name": "MTP",
473
+ "version": "1.0",
474
+ "architecture": config,
475
+ "parameters": sum(p.numel() for p in model.parameters()),
476
+ "device": DEVICE
477
+ }
478
+
479
+ # ======================
480
+ # INTERFAZ WEB
481
+ # ======================
482
+ @app.get("/", response_class=HTMLResponse)
483
+ def chat_ui():
484
+ return """
485
+ <!DOCTYPE html>
486
+ <html lang="es">
487
+ <head>
488
+ <meta charset="UTF-8">
489
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
490
+ <title>MTP - Asistente IA</title>
491
+ <style>
492
+ * { margin: 0; padding: 0; box-sizing: border-box; }
493
+ body {
494
+ background: #131314;
495
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
496
+ height: 100vh;
497
+ display: flex;
498
+ flex-direction: column;
499
+ }
500
+ .chat-header {
501
+ padding: 16px 20px;
502
+ background: #1E1F20;
503
+ border-bottom: 1px solid #2a2b2e;
504
+ }
505
+ .chat-header h1 {
506
+ color: white;
507
+ font-size: 1.2rem;
508
+ font-weight: 500;
509
+ }
510
+ .chat-messages {
511
+ flex: 1;
512
+ overflow-y: auto;
513
+ padding: 20px;
514
+ display: flex;
515
+ flex-direction: column;
516
+ gap: 16px;
517
+ }
518
+ .message {
519
+ display: flex;
520
+ gap: 12px;
521
+ max-width: 80%;
522
+ }
523
+ .message.user {
524
+ align-self: flex-end;
525
+ flex-direction: row-reverse;
526
+ }
527
+ .message-content {
528
+ padding: 10px 16px;
529
+ border-radius: 18px;
530
+ font-size: 0.95rem;
531
+ line-height: 1.4;
532
+ }
533
+ .user .message-content {
534
+ background: #4a9eff;
535
+ color: white;
536
+ border-radius: 18px 4px 18px 18px;
537
+ }
538
+ .bot .message-content {
539
+ background: #1E1F20;
540
+ color: #e3e3e3;
541
+ border-radius: 4px 18px 18px 18px;
542
+ }
543
+ .chat-input-container {
544
+ padding: 16px 20px;
545
+ background: #1E1F20;
546
+ border-top: 1px solid #2a2b2e;
547
+ }
548
+ .input-wrapper {
549
+ display: flex;
550
+ gap: 12px;
551
+ max-width: 800px;
552
+ margin: 0 auto;
553
+ }
554
+ #messageInput {
555
+ flex: 1;
556
+ padding: 12px 16px;
557
+ background: #2a2b2e;
558
+ border: none;
559
+ border-radius: 24px;
560
+ color: white;
561
+ font-size: 0.95rem;
562
+ outline: none;
563
+ }
564
+ #messageInput::placeholder {
565
+ color: #888;
566
+ }
567
+ #sendBtn {
568
+ padding: 12px 24px;
569
+ background: #4a9eff;
570
+ border: none;
571
+ border-radius: 24px;
572
+ color: white;
573
+ font-weight: 500;
574
+ cursor: pointer;
575
+ transition: opacity 0.2s;
576
+ }
577
+ #sendBtn:hover { opacity: 0.9; }
578
+ #sendBtn:disabled {
579
+ opacity: 0.5;
580
+ cursor: not-allowed;
581
+ }
582
+ .typing {
583
+ display: flex;
584
+ gap: 4px;
585
+ padding: 10px 16px;
586
+ }
587
+ .typing span {
588
+ width: 8px;
589
+ height: 8px;
590
+ background: #888;
591
+ border-radius: 50%;
592
+ animation: bounce 1.4s infinite ease-in-out;
593
+ }
594
+ .typing span:nth-child(1) { animation-delay: -0.32s; }
595
+ .typing span:nth-child(2) { animation-delay: -0.16s; }
596
+ @keyframes bounce {
597
+ 0%, 80%, 100% { transform: scale(0); }
598
+ 40% { transform: scale(1); }
599
+ }
600
+ </style>
601
+ </head>
602
+ <body>
603
+ <div class="chat-header">
604
+ <h1>🤖 MTP - Asistente IA</h1>
605
+ </div>
606
+ <div class="chat-messages" id="chatMessages">
607
+ <div class="message bot">
608
+ <div class="message-content">¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?</div>
609
+ </div>
610
+ </div>
611
+ <div class="chat-input-container">
612
+ <div class="input-wrapper">
613
+ <input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
614
+ <button id="sendBtn">Enviar</button>
615
+ </div>
616
+ </div>
617
+ <script>
618
+ const chatMessages = document.getElementById('chatMessages');
619
+ const messageInput = document.getElementById('messageInput');
620
+ const sendBtn = document.getElementById('sendBtn');
621
+ let isLoading = false;
622
+
623
+ function addMessage(text, isUser) {
624
+ const div = document.createElement('div');
625
+ div.className = `message ${isUser ? 'user' : 'bot'}`;
626
+ div.innerHTML = `<div class="message-content">${text}</div>`;
627
+ chatMessages.appendChild(div);
628
+ chatMessages.scrollTop = chatMessages.scrollHeight;
629
+ return div;
630
+ }
631
+
632
+ function addTypingIndicator() {
633
+ const div = document.createElement('div');
634
+ div.className = 'message bot';
635
+ div.id = 'typingIndicator';
636
+ div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
637
+ chatMessages.appendChild(div);
638
+ chatMessages.scrollTop = chatMessages.scrollHeight;
639
+ }
640
+
641
+ function removeTypingIndicator() {
642
+ const indicator = document.getElementById('typingIndicator');
643
+ if (indicator) indicator.remove();
644
+ }
645
+
646
+ async function sendMessage() {
647
+ const text = messageInput.value.trim();
648
+ if (!text || isLoading) return;
649
+
650
+ messageInput.value = '';
651
+ addMessage(text, true);
652
+ isLoading = true;
653
+ sendBtn.disabled = true;
654
+ addTypingIndicator();
655
+
656
+ try {
657
+ const response = await fetch('/generate', {
658
+ method: 'POST',
659
+ headers: { 'Content-Type': 'application/json' },
660
+ body: JSON.stringify({ text: text })
661
+ });
662
+ const data = await response.json();
663
+ removeTypingIndicator();
664
+ addMessage(data.reply, false);
665
+ } catch (error) {
666
+ removeTypingIndicator();
667
+ addMessage('Error de conexión. Intenta de nuevo.', false);
668
+ } finally {
669
+ isLoading = false;
670
+ sendBtn.disabled = false;
671
+ messageInput.focus();
672
+ }
673
+ }
674
+
675
+ messageInput.addEventListener('keypress', (e) => {
676
+ if (e.key === 'Enter') sendMessage();
677
+ });
678
+ sendBtn.addEventListener('click', sendMessage);
679
+ messageInput.focus();
680
+ </script>
681
+ </body>
682
+ </html>
683
+ """
684
+
685
+ if __name__ == "__main__":
686
+ port = int(os.environ.get("PORT", 7860))
687
+ print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
688
+ print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
689
+ print(f"📡 API docs: http://0.0.0.0:{port}/docs")
690
+
691
+ uvicorn.run(
692
+ app,
693
+ host="0.0.0.0",
694
+ port=port,
695
+ log_level="info"
696
+ )