teszenofficial commited on
Commit
f5aa463
·
verified ·
1 Parent(s): aab0558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +657 -172
app.py CHANGED
@@ -1,316 +1,801 @@
1
  import os
 
2
  import torch
3
- from fastapi import FastAPI
4
- from fastapi.responses import HTMLResponse
 
 
 
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from pydantic import BaseModel
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import uvicorn
9
- import re
 
 
 
10
 
11
- # ==================== CONFIGURACIÓN ====================
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
- print(f"📱 Dispositivo: {DEVICE}")
 
 
 
 
 
 
14
 
15
- # Usar un modelo pequeño pero FUNCIONAL de HuggingFace
16
- # Opciones: "microsoft/DialoGPT-small" (mejor para conversación)
17
- # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" (más potente pero más lento)
18
- MODEL_NAME = "microsoft/DialoGPT-small" # ~60MB, rápido y funcional
19
 
20
- print(f"📦 Cargando modelo {MODEL_NAME}...")
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
23
- model.eval()
24
- print(f"✅ Modelo cargado: {sum(p.numel() for p in model.parameters()):,} parámetros")
25
 
26
- # ==================== API ====================
27
- app = FastAPI()
28
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
29
 
30
- class PromptRequest(BaseModel):
31
- text: str
 
32
 
33
- def clean_response(text: str) -> str:
34
  """Limpia la respuesta del modelo"""
35
  if not text:
36
  return ""
37
 
38
- # Eliminar caracteres especiales
39
- text = re.sub(r'<\|.*?\|>', '', text)
40
- text = re.sub(r'\[.*?\]', '', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  text = re.sub(r'\s+', ' ', text).strip()
42
 
43
- # Limitar longitud
44
- if len(text) > 400:
45
- text = text[:400]
46
- last_dot = text.rfind('.')
47
- if last_dot > 200:
48
- text = text[:last_dot + 1]
 
 
 
 
 
49
 
50
- return text if text else "Lo siento, no pude generar una respuesta."
 
 
 
51
 
52
- @app.post("/generate")
53
- async def generate(req: PromptRequest):
54
- user_input = req.text.strip()
55
- if not user_input:
56
- return {"reply": "Escribe un mensaje"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Formatear entrada para el modelo
59
- formatted_input = f"User: {user_input}\nBot:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Tokenizar
62
- inputs = tokenizer.encode(formatted_input, return_tensors="pt").to(DEVICE)
63
 
64
- # Generar
65
- with torch.no_grad():
66
- outputs = model.generate(
67
- inputs,
68
- max_new_tokens=100,
69
- temperature=0.7,
70
- top_k=50,
71
- top_p=0.9,
72
- do_sample=True,
73
- pad_token_id=tokenizer.eos_token_id
74
- )
75
 
76
- # Decodificar
77
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
 
79
- # Extraer solo la respuesta del bot
80
- if "Bot:" in response:
81
- response = response.split("Bot:")[-1].strip()
82
- elif "User:" in response:
83
- parts = response.split("User:")
84
- response = parts[-1].strip() if len(parts) > 1 else response
85
 
86
- response = clean_response(response)
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- print(f"📝 Usuario: {user_input[:50]}")
89
- print(f"🤖 Respuesta: {response[:100]}")
 
 
 
 
 
 
 
 
 
 
90
 
91
- return {"reply": response}
 
 
 
 
 
 
 
 
 
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @app.get("/health")
94
- def health():
95
- return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
97
  @app.get("/", response_class=HTMLResponse)
98
  def chat_ui():
99
  return """
100
  <!DOCTYPE html>
101
- <html>
102
  <head>
103
  <meta charset="UTF-8">
104
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
105
- <title>MTP - Asistente IA</title>
106
  <style>
107
  * { margin: 0; padding: 0; box-sizing: border-box; }
108
  body {
109
- background: #0a0a0f;
110
  font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
111
  height: 100vh;
112
  display: flex;
113
  flex-direction: column;
114
  }
115
- .header {
116
  padding: 16px 20px;
117
- background: #1a1a2e;
118
- border-bottom: 1px solid #2a2a4a;
119
- text-align: center;
120
  }
121
- .header h1 { color: white; font-size: 1.3rem; }
122
- .header p { color: #888; font-size: 0.75rem; margin-top: 4px; }
123
- .chat {
 
 
 
 
 
 
 
 
124
  flex: 1;
125
  overflow-y: auto;
126
- padding: 16px;
127
  display: flex;
128
  flex-direction: column;
129
- gap: 10px;
130
  }
131
  .message {
132
  display: flex;
133
- gap: 8px;
134
- max-width: 85%;
135
- animation: fadeIn 0.2s ease;
136
  }
137
  @keyframes fadeIn {
138
  from { opacity: 0; transform: translateY(10px); }
139
  to { opacity: 1; transform: translateY(0); }
140
  }
141
- .message.user { align-self: flex-end; flex-direction: row-reverse; }
 
 
 
142
  .message-content {
143
- padding: 8px 14px;
144
- border-radius: 16px;
145
- font-size: 0.9rem;
146
  line-height: 1.4;
147
  word-wrap: break-word;
 
148
  }
149
  .user .message-content {
150
- background: #667eea;
151
  color: white;
152
- border-radius: 16px 4px 16px 16px;
153
  }
154
  .bot .message-content {
155
- background: #1e1e2e;
156
- color: #e0e0e0;
157
- border-radius: 4px 16px 16px 16px;
158
- border: 1px solid #2a2a4a;
159
  }
160
- .input-area {
161
- padding: 12px 16px;
162
- background: #0f0f15;
163
- border-top: 1px solid #1a1a2e;
 
164
  }
165
  .input-wrapper {
166
  display: flex;
167
- gap: 10px;
168
  max-width: 800px;
169
  margin: 0 auto;
170
  }
171
- #input {
172
  flex: 1;
173
- padding: 10px 14px;
174
- background: #1a1a2e;
175
- border: 1px solid #2a2a4a;
176
- border-radius: 22px;
177
  color: white;
178
- font-size: 0.9rem;
179
  outline: none;
 
180
  }
181
- #input:focus { border-color: #667eea; }
182
- #send {
183
- padding: 10px 20px;
184
- background: #667eea;
 
 
 
 
 
 
185
  border: none;
186
- border-radius: 22px;
187
  color: white;
188
- font-weight: 600;
189
  cursor: pointer;
 
 
 
 
 
 
 
 
 
 
190
  }
191
- #send:hover { opacity: 0.9; }
192
- #send:disabled { opacity: 0.5; cursor: not-allowed; }
193
  .typing {
194
  display: flex;
195
  gap: 4px;
196
- padding: 8px 14px;
197
  }
198
  .typing span {
199
- width: 6px;
200
- height: 6px;
201
  background: #888;
202
  border-radius: 50%;
203
- animation: bounce 1.4s infinite;
204
  }
 
205
  .typing span:nth-child(2) { animation-delay: -0.16s; }
206
- .typing span:nth-child(3) { animation-delay: -0.32s; }
207
  @keyframes bounce {
208
  0%, 80%, 100% { transform: scale(0); }
209
  40% { transform: scale(1); }
210
  }
211
- .dot {
212
- display: inline-block;
213
- width: 8px;
214
- height: 8px;
215
- background: #4ade80;
216
- border-radius: 50%;
217
- margin-right: 6px;
218
- animation: pulse 2s infinite;
 
 
 
 
 
 
 
 
 
 
 
 
219
  }
220
- @keyframes pulse {
221
- 0%, 100% { opacity: 1; }
222
- 50% { opacity: 0.5; }
 
 
 
 
 
 
 
 
 
 
 
223
  }
224
  </style>
225
  </head>
226
  <body>
227
- <div class="header">
228
- <h1><span class="dot"></span> MTP Assistant</h1>
229
- <p>DialoGPT - Modelo conversacional real</p>
 
 
 
 
 
 
 
 
 
 
230
  </div>
231
- <div class="chat" id="chat">
232
  <div class="message bot">
233
- <div class="message-content">¡Hola! Soy MTP, tu asistente. ¿En qué puedo ayudarte hoy?</div>
234
  </div>
235
  </div>
236
- <div class="input-area">
237
  <div class="input-wrapper">
238
- <input type="text" id="input" placeholder="Escribe tu mensaje..." autocomplete="off">
239
- <button id="send">Enviar</button>
240
  </div>
241
  </div>
 
242
  <script>
243
- const chat = document.getElementById('chat');
244
- const input = document.getElementById('input');
245
- const sendBtn = document.getElementById('send');
246
- let loading = false;
247
 
248
  function addMessage(text, isUser) {
249
  const div = document.createElement('div');
250
  div.className = `message ${isUser ? 'user' : 'bot'}`;
251
  div.innerHTML = `<div class="message-content">${escapeHtml(text)}</div>`;
252
- chat.appendChild(div);
253
- chat.scrollTop = chat.scrollHeight;
 
254
  }
255
 
256
  function escapeHtml(text) {
257
- return text.replace(/</g, '&lt;').replace(/>/g, '&gt;');
 
 
258
  }
259
 
260
- function addTyping() {
261
  const div = document.createElement('div');
262
  div.className = 'message bot';
263
- div.id = 'typing';
264
  div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
265
- chat.appendChild(div);
266
- chat.scrollTop = chat.scrollHeight;
267
  }
268
 
269
- function removeTyping() {
270
- const t = document.getElementById('typing');
271
- if (t) t.remove();
272
  }
273
 
274
- async function send() {
275
- const text = input.value.trim();
276
- if (!text || loading) return;
277
 
278
- input.value = '';
279
- addMessage(text, true);
280
- loading = true;
281
  sendBtn.disabled = true;
282
- addTyping();
283
 
284
  try {
285
- const res = await fetch('/generate', {
286
  method: 'POST',
287
  headers: { 'Content-Type': 'application/json' },
288
- body: JSON.stringify({ text: text })
289
  });
290
- const data = await res.json();
291
- removeTyping();
292
- addMessage(data.reply || "No pude generar respuesta.", false);
293
- } catch (err) {
294
- removeTyping();
295
- addMessage("Error de conexión. Intenta de nuevo.", false);
296
  } finally {
297
- loading = false;
298
  sendBtn.disabled = false;
299
- input.focus();
300
  }
301
  }
302
 
303
- input.addEventListener('keypress', (e) => {
304
- if (e.key === 'Enter') send();
305
  });
306
- sendBtn.addEventListener('click', send);
307
- input.focus();
 
 
 
308
  </script>
309
  </body>
310
  </html>
311
  """
312
 
 
 
 
313
  if __name__ == "__main__":
314
  port = int(os.environ.get("PORT", 7860))
315
- print(f"\n🚀 Servidor: http://0.0.0.0:{port}")
316
- uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
  import torch
4
+ import json
5
+ import time
6
+ import gc
7
+ import re
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel, Field
12
+ from huggingface_hub import snapshot_download
13
  import uvicorn
14
+ import math
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import sentencepiece as spm
18
 
19
+ # ======================
20
+ # CONFIGURACIÓN DE DISPOSITIVO
21
+ # ======================
22
+ if torch.cuda.is_available():
23
+ DEVICE = "cuda"
24
+ print("✅ GPU NVIDIA detectada. Usando CUDA.")
25
+ else:
26
+ DEVICE = "cpu"
27
+ print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
28
 
29
+ if DEVICE == "cpu":
30
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
 
 
31
 
32
+ torch.set_grad_enabled(False)
 
 
 
 
33
 
34
+ # CONFIGURACIÓN DEL MODELO - ACTUALIZADO A VERSIÓN 3.3.1
35
+ MODEL_REPO = "TeszenAI/MTP-3.3.1"
 
36
 
37
+ # ======================
38
+ # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
+ # ======================
40
 
41
+ def clean_response(text: str, user_input: str = "") -> str:
42
  """Limpia la respuesta del modelo"""
43
  if not text:
44
  return ""
45
 
46
+ # Eliminar repeticiones excesivas de palabras
47
+ words = text.split()
48
+ cleaned_words = []
49
+ last_word = ""
50
+ repeat_count = 0
51
+
52
+ for word in words:
53
+ if word.lower() == last_word.lower():
54
+ repeat_count += 1
55
+ if repeat_count > 2:
56
+ continue
57
+ else:
58
+ last_word = word
59
+ repeat_count = 0
60
+ cleaned_words.append(word)
61
+
62
+ text = " ".join(cleaned_words)
63
+
64
+ # Eliminar caracteres repetidos excesivamente
65
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text)
66
+
67
+ # Detectar si es un saludo (más completo)
68
+ greetings = [
69
+ "hola", "hola!", "hola.", "buenas", "saludos", "hola?",
70
+ "buenos días", "buenas tardes", "buenas noches", "hey",
71
+ "hola!", "que tal", "cómo estás", "como estas"
72
+ ]
73
+ is_greeting = user_input.lower().strip() in greetings
74
+
75
+ if is_greeting and text:
76
+ # Para saludos, tomar solo la primera oración
77
+ first_sentence = text.split('.')[0].strip()
78
+ if len(first_sentence) > 5 and len(first_sentence) < 100:
79
+ text = first_sentence
80
+ elif len(text) > 80:
81
+ text = text[:80]
82
+
83
+ # Asegurar que termine con punto si es un saludo
84
+ if text and text[-1] not in '.!?':
85
+ text += '.'
86
+
87
+ # Si la respuesta es muy corta o vacía
88
+ if len(text.strip()) < 5:
89
+ if is_greeting:
90
+ return "¡Hola! ¿En qué puedo ayudarte?"
91
+ return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
92
+
93
+ # Eliminar espacios múltiples y limpiar
94
  text = re.sub(r'\s+', ' ', text).strip()
95
 
96
+ return text
97
+
98
+ # ======================
99
+ # DEFINIR ARQUITECTURA DEL MODELO (MTP V3.3.1)
100
+ # ======================
101
+ class LayerNorm(nn.Module):
102
+ def __init__(self, d_model: int, eps: float = 1e-5):
103
+ super().__init__()
104
+ self.weight = nn.Parameter(torch.ones(d_model))
105
+ self.bias = nn.Parameter(torch.zeros(d_model))
106
+ self.eps = eps
107
 
108
+ def forward(self, x):
109
+ mean = x.mean(-1, keepdim=True)
110
+ std = x.std(-1, keepdim=True)
111
+ return self.weight * (x - mean) / (std + self.eps) + self.bias
112
 
113
+ class MultiHeadAttention(nn.Module):
114
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
115
+ super().__init__()
116
+ assert d_model % n_heads == 0
117
+ self.d_model = d_model
118
+ self.n_heads = n_heads
119
+ self.d_k = d_model // n_heads
120
+ self.w_q = nn.Linear(d_model, d_model)
121
+ self.w_k = nn.Linear(d_model, d_model)
122
+ self.w_v = nn.Linear(d_model, d_model)
123
+ self.w_o = nn.Linear(d_model, d_model)
124
+ self.dropout = nn.Dropout(dropout)
125
+ self.scale = math.sqrt(self.d_k)
126
+
127
+ def forward(self, x, mask=None):
128
+ batch_size, seq_len, _ = x.shape
129
+ Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
130
+ K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
131
+ V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
132
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
133
+ if mask is not None:
134
+ scores = scores.masked_fill(mask == 0, float('-inf'))
135
+ attn_weights = F.softmax(scores, dim=-1)
136
+ attn_weights = self.dropout(attn_weights)
137
+ attn_output = torch.matmul(attn_weights, V)
138
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
139
+ return self.w_o(attn_output)
140
+
141
+ class FeedForward(nn.Module):
142
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
143
+ super().__init__()
144
+ self.linear1 = nn.Linear(d_model, d_ff)
145
+ self.linear2 = nn.Linear(d_ff, d_model)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, x):
149
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
150
+
151
+ class TransformerBlock(nn.Module):
152
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
153
+ super().__init__()
154
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
155
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
156
+ self.norm1 = LayerNorm(d_model)
157
+ self.norm2 = LayerNorm(d_model)
158
+ self.dropout1 = nn.Dropout(dropout)
159
+ self.dropout2 = nn.Dropout(dropout)
160
+
161
+ def forward(self, x, mask=None):
162
+ attn_output = self.attention(x, mask)
163
+ x = x + self.dropout1(attn_output)
164
+ x = self.norm1(x)
165
+ ff_output = self.feed_forward(x)
166
+ x = x + self.dropout2(ff_output)
167
+ x = self.norm2(x)
168
+ return x
169
+
170
+ class PositionalEncoding(nn.Module):
171
+ def __init__(self, d_model: int, max_len: int = 5000):
172
+ super().__init__()
173
+ pe = torch.zeros(max_len, d_model)
174
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
175
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
176
+ pe[:, 0::2] = torch.sin(position * div_term)
177
+ pe[:, 1::2] = torch.cos(position * div_term)
178
+ self.register_buffer('pe', pe.unsqueeze(0))
179
+
180
+ def forward(self, x):
181
+ return x + self.pe[:, :x.size(1), :]
182
+
183
+ class MTPModel(nn.Module):
184
+ def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
185
+ n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
186
+ super().__init__()
187
+ self.vocab_size = vocab_size
188
+ self.d_model = d_model
189
+ self.max_len = max_len
190
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
191
+ self.pos_encoding = PositionalEncoding(d_model, max_len)
192
+ self.blocks = nn.ModuleList([
193
+ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
194
+ ])
195
+ self.norm = LayerNorm(d_model)
196
+ self.lm_head = nn.Linear(d_model, vocab_size)
197
+
198
+ def forward(self, x, mask=None):
199
+ if mask is None:
200
+ mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
201
+ x = self.token_embedding(x) * math.sqrt(self.d_model)
202
+ x = self.pos_encoding(x)
203
+ for block in self.blocks:
204
+ x = block(x, mask)
205
+ x = self.norm(x)
206
+ logits = self.lm_head(x)
207
+ return logits
208
 
209
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1):
210
+ """Genera texto token por token"""
211
+ generated = input_ids
212
+ eos_id = 3 # EOS token id en SentencePiece
213
+
214
+ for step in range(max_new_tokens):
215
+ with torch.no_grad():
216
+ logits = self(generated)
217
+ next_logits = logits[0, -1, :] / temperature
218
+
219
+ if repetition_penalty != 1.0:
220
+ for token_id in set(generated[0].tolist()):
221
+ next_logits[token_id] /= repetition_penalty
222
+
223
+ if top_k > 0:
224
+ indices_to_remove = next_logits < torch.topk(next_logits, min(top_k, next_logits.size(-1)))[0][..., -1, None]
225
+ next_logits[indices_to_remove] = float('-inf')
226
+
227
+ if top_p < 1.0:
228
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
229
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
230
+ sorted_indices_to_remove = cumulative_probs > top_p
231
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
232
+ sorted_indices_to_remove[..., 0] = 0
233
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
234
+ next_logits[indices_to_remove] = float('-inf')
235
+
236
+ probs = F.softmax(next_logits, dim=-1)
237
+ next_token = torch.multinomial(probs, num_samples=1).item()
238
+
239
+ if next_token == eos_id:
240
+ break
241
+
242
+ generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
243
+
244
+ return generated
245
+
246
+ # ======================
247
+ # DESCARGA Y CARGA DEL MODELO
248
+ # ======================
249
+ print(f"📦 Descargando modelo desde {MODEL_REPO}...")
250
+ repo_path = snapshot_download(
251
+ repo_id=MODEL_REPO,
252
+ repo_type="model",
253
+ local_dir="mtp_repo"
254
+ )
255
+
256
+ # Cargar configuración
257
+ config_path = os.path.join(repo_path, "config.json")
258
+ if os.path.exists(config_path):
259
+ with open(config_path, "r") as f:
260
+ config = json.load(f)
261
+ print(f"✅ Configuración cargada: {config}")
262
+ else:
263
+ # Configuración por defecto para MTP V3.3.1
264
+ config = {
265
+ "vocab_size": 4000,
266
+ "d_model": 256,
267
+ "n_heads": 8,
268
+ "n_layers": 6,
269
+ "d_ff": 1024,
270
+ "dropout": 0.1,
271
+ "max_len": 512
272
+ }
273
+ print(f"⚠️ Usando configuración por defecto: {config}")
274
+
275
+ # Cargar tokenizador
276
+ tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
277
+ if not os.path.exists(tokenizer_path):
278
+ print(f"❌ Tokenizador no encontrado en {tokenizer_path}")
279
+ sys.exit(1)
280
+
281
+ sp = spm.SentencePieceProcessor()
282
+ sp.load(tokenizer_path)
283
+ VOCAB_SIZE = sp.get_piece_size()
284
+ print(f"✅ Tokenizador cargado. Vocabulario: {VOCAB_SIZE}")
285
+
286
+ # Actualizar vocab_size en config
287
+ config["vocab_size"] = VOCAB_SIZE
288
+
289
+ print(f"\n🧠 Inicializando modelo MTP V3.3.1...")
290
+ print(f" → Vocabulario: {VOCAB_SIZE}")
291
+ print(f" → Dimensión: {config['d_model']}")
292
+ print(f" → Capas: {config['n_layers']}")
293
+ print(f" → Heads: {config['n_heads']}")
294
+ print(f" → FFN dimensión: {config['d_ff']}")
295
+ print(f" → Max length: {config['max_len']}")
296
+
297
+ model = MTPModel(**config)
298
+ model.to(DEVICE)
299
+
300
+ # Cargar pesos del modelo
301
+ model_path = os.path.join(repo_path, "mtp_model.pt")
302
+ if os.path.exists(model_path):
303
+ try:
304
+ state_dict = torch.load(model_path, map_location=DEVICE)
305
+ model.load_state_dict(state_dict, strict=False)
306
+ print("✅ Pesos del modelo cargados exitosamente")
307
+ except Exception as e:
308
+ print(f"⚠️ Error cargando pesos: {e}")
309
+ print(" Continuando con pesos aleatorios...")
310
+ else:
311
+ print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios")
312
+
313
+ model.eval()
314
+ param_count = sum(p.numel() for p in model.parameters())
315
+ print(f"✅ Modelo listo: {param_count:,} parámetros ({param_count/1e6:.2f}M)")
316
+
317
+ # ======================
318
+ # API CONFIG
319
+ # ======================
320
+ app = FastAPI(
321
+ title="MTP API V3.3.1",
322
+ description="API para modelo de lenguaje MTP - Asistente IA entrenado desde cero",
323
+ version="3.3.1"
324
+ )
325
+
326
+ app.add_middleware(
327
+ CORSMiddleware,
328
+ allow_origins=["*"],
329
+ allow_methods=["*"],
330
+ allow_headers=["*"],
331
+ )
332
+
333
+ class PromptRequest(BaseModel):
334
+ text: str = Field(..., max_length=2000, description="Texto de entrada")
335
+ max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
336
+ temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
337
+ top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
338
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
339
+ repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
340
+
341
+ def build_prompt(user_input: str) -> str:
342
+ """Construye el prompt en el formato del modelo (Alpaca style)"""
343
+ return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
344
+
345
+ # ======================
346
+ # GESTIÓN DE CARGA
347
+ # ======================
348
+ ACTIVE_REQUESTS = 0
349
+
350
+ class MTPTokenizer:
351
+ """Wrapper para el tokenizador de SentencePiece"""
352
+ def __init__(self, sp_model):
353
+ self.sp = sp_model
354
 
355
+ def encode(self, text):
356
+ return self.sp.encode(text)
357
 
358
+ def decode(self, tokens):
359
+ return self.sp.decode(tokens)
 
 
 
 
 
 
 
 
 
360
 
361
+ def bos_id(self):
362
+ return self.sp.bos_id()
363
 
364
+ def eos_id(self):
365
+ return self.sp.eos_id()
 
 
 
 
366
 
367
+ def pad_id(self):
368
+ return self.sp.pad_id()
369
+
370
+ tokenizer_wrapper = MTPTokenizer(sp)
371
+
372
+ # ======================
373
+ # ENDPOINT PRINCIPAL
374
+ # ======================
375
+ @app.post("/generate")
376
+ async def generate(req: PromptRequest):
377
+ """Endpoint principal de generación de texto"""
378
+ global ACTIVE_REQUESTS
379
+ ACTIVE_REQUESTS += 1
380
 
381
+ user_input = req.text.strip()
382
+ if not user_input:
383
+ ACTIVE_REQUESTS -= 1
384
+ return {"reply": "", "tokens_generated": 0}
385
+
386
+ # Detectar si es un saludo
387
+ greetings = [
388
+ "hola", "hola!", "hola.", "buenas", "saludos", "hola?",
389
+ "buenos días", "buenas tardes", "buenas noches", "hey",
390
+ "que tal", "cómo estás", "como estas"
391
+ ]
392
+ is_greeting = user_input.lower().strip() in greetings
393
 
394
+ # Si es saludo, usar menos tokens y temperatura más alta para respuestas creativas
395
+ if is_greeting:
396
+ max_tokens = 30
397
+ temperature = 0.8
398
+ else:
399
+ max_tokens = req.max_tokens
400
+ temperature = req.temperature
401
+
402
+ full_prompt = build_prompt(user_input)
403
+ tokens = tokenizer_wrapper.encode(full_prompt)
404
+ input_ids = torch.tensor([tokens], device=DEVICE)
405
 
406
+ try:
407
+ start_time = time.time()
408
+
409
+ with torch.no_grad():
410
+ output_ids = model.generate(
411
+ input_ids,
412
+ max_new_tokens=max_tokens,
413
+ temperature=temperature,
414
+ top_k=req.top_k,
415
+ top_p=req.top_p,
416
+ repetition_penalty=req.repetition_penalty
417
+ )
418
+
419
+ inference_time = time.time() - start_time
420
+
421
+ # Extraer solo los tokens generados (no el prompt)
422
+ gen_tokens = output_ids[0, len(tokens):].tolist()
423
+
424
+ # Filtrar tokens inválidos
425
+ safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE and t != 0] # 0 es pad
426
+
427
+ if safe_tokens:
428
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
429
+ else:
430
+ response = ""
431
+
432
+ # Limpiar respuesta
433
+ response = clean_response(response, user_input)
434
+
435
+ # Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
436
+ if len(response) < 3:
437
+ if is_greeting:
438
+ response = "¡Hola! ¿En qué puedo ayudarte?"
439
+ else:
440
+ response = "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
441
+
442
+ return {
443
+ "reply": response,
444
+ "tokens_generated": len(safe_tokens),
445
+ "inference_time": round(inference_time, 3),
446
+ "model": "MTP-3.3.1",
447
+ "input_tokens": len(tokens)
448
+ }
449
+
450
+ except Exception as e:
451
+ print(f"❌ Error durante generación: {e}")
452
+ import traceback
453
+ traceback.print_exc()
454
+ if is_greeting:
455
+ fallback = "¡Hola! ¿En qué puedo ayudarte?"
456
+ else:
457
+ fallback = "Lo siento, ocurrió un error al procesar tu solicitud. Intenta de nuevo."
458
+ return {
459
+ "reply": fallback,
460
+ "error": str(e),
461
+ "model": "MTP-3.3.1"
462
+ }
463
+
464
+ finally:
465
+ ACTIVE_REQUESTS -= 1
466
+ if DEVICE == "cuda":
467
+ torch.cuda.empty_cache()
468
+ gc.collect()
469
+
470
+ # ======================
471
+ # ENDPOINTS DE INFORMACIÓN
472
+ # ======================
473
  @app.get("/health")
474
+ def health_check():
475
+ return {
476
+ "status": "healthy",
477
+ "model": "MTP-3.3.1",
478
+ "device": DEVICE,
479
+ "active_requests": ACTIVE_REQUESTS,
480
+ "vocab_size": VOCAB_SIZE,
481
+ "total_params": param_count
482
+ }
483
+
484
+ @app.get("/info")
485
+ def model_info():
486
+ return {
487
+ "model_name": "MTP",
488
+ "version": "3.3.1",
489
+ "architecture": config,
490
+ "parameters": param_count,
491
+ "parameters_millions": round(param_count / 1e6, 2),
492
+ "device": DEVICE,
493
+ "tokenizer_vocab": VOCAB_SIZE,
494
+ "repo": MODEL_REPO
495
+ }
496
 
497
+ # ======================
498
+ # INTERFAZ WEB MEJORADA
499
+ # ======================
500
  @app.get("/", response_class=HTMLResponse)
501
  def chat_ui():
502
  return """
503
  <!DOCTYPE html>
504
+ <html lang="es">
505
  <head>
506
  <meta charset="UTF-8">
507
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
508
+ <title>MTP V3.3.1 - Asistente IA</title>
509
  <style>
510
  * { margin: 0; padding: 0; box-sizing: border-box; }
511
  body {
512
+ background: linear-gradient(135deg, #0a0a0a 0%, #1a1a2e 100%);
513
  font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
514
  height: 100vh;
515
  display: flex;
516
  flex-direction: column;
517
  }
518
+ .chat-header {
519
  padding: 16px 20px;
520
+ background: rgba(0, 0, 0, 0.7);
521
+ backdrop-filter: blur(10px);
522
+ border-bottom: 1px solid rgba(255,255,255,0.1);
523
  }
524
+ .chat-header h1 {
525
+ color: white;
526
+ font-size: 1.3rem;
527
+ font-weight: 600;
528
+ }
529
+ .chat-header p {
530
+ color: #888;
531
+ font-size: 0.8rem;
532
+ margin-top: 4px;
533
+ }
534
+ .chat-messages {
535
  flex: 1;
536
  overflow-y: auto;
537
+ padding: 20px;
538
  display: flex;
539
  flex-direction: column;
540
+ gap: 16px;
541
  }
542
  .message {
543
  display: flex;
544
+ gap: 12px;
545
+ max-width: 80%;
546
+ animation: fadeIn 0.3s ease;
547
  }
548
  @keyframes fadeIn {
549
  from { opacity: 0; transform: translateY(10px); }
550
  to { opacity: 1; transform: translateY(0); }
551
  }
552
+ .message.user {
553
+ align-self: flex-end;
554
+ flex-direction: row-reverse;
555
+ }
556
  .message-content {
557
+ padding: 12px 18px;
558
+ border-radius: 20px;
559
+ font-size: 0.95rem;
560
  line-height: 1.4;
561
  word-wrap: break-word;
562
+ max-width: 100%;
563
  }
564
  .user .message-content {
565
+ background: linear-gradient(135deg, #667eea, #764ba2);
566
  color: white;
567
+ border-radius: 20px 4px 20px 20px;
568
  }
569
  .bot .message-content {
570
+ background: rgba(30, 30, 40, 0.9);
571
+ color: #e3e3e3;
572
+ border-radius: 4px 20px 20px 20px;
573
+ border: 1px solid rgba(255,255,255,0.05);
574
  }
575
+ .chat-input-container {
576
+ padding: 16px 20px;
577
+ background: rgba(0, 0, 0, 0.7);
578
+ backdrop-filter: blur(10px);
579
+ border-top: 1px solid rgba(255,255,255,0.1);
580
  }
581
  .input-wrapper {
582
  display: flex;
583
+ gap: 12px;
584
  max-width: 800px;
585
  margin: 0 auto;
586
  }
587
+ #messageInput {
588
  flex: 1;
589
+ padding: 12px 16px;
590
+ background: rgba(255,255,255,0.1);
591
+ border: 1px solid rgba(255,255,255,0.2);
592
+ border-radius: 24px;
593
  color: white;
594
+ font-size: 0.95rem;
595
  outline: none;
596
+ transition: all 0.2s;
597
  }
598
+ #messageInput:focus {
599
+ border-color: #667eea;
600
+ background: rgba(255,255,255,0.15);
601
+ }
602
+ #messageInput::placeholder {
603
+ color: #888;
604
+ }
605
+ #sendBtn {
606
+ padding: 12px 24px;
607
+ background: linear-gradient(135deg, #667eea, #764ba2);
608
  border: none;
609
+ border-radius: 24px;
610
  color: white;
611
+ font-weight: 500;
612
  cursor: pointer;
613
+ transition: all 0.2s;
614
+ }
615
+ #sendBtn:hover {
616
+ transform: scale(1.02);
617
+ opacity: 0.9;
618
+ }
619
+ #sendBtn:disabled {
620
+ opacity: 0.5;
621
+ transform: none;
622
+ cursor: not-allowed;
623
  }
 
 
624
  .typing {
625
  display: flex;
626
  gap: 4px;
627
+ padding: 12px 18px;
628
  }
629
  .typing span {
630
+ width: 8px;
631
+ height: 8px;
632
  background: #888;
633
  border-radius: 50%;
634
+ animation: bounce 1.4s infinite ease-in-out;
635
  }
636
+ .typing span:nth-child(1) { animation-delay: -0.32s; }
637
  .typing span:nth-child(2) { animation-delay: -0.16s; }
 
638
  @keyframes bounce {
639
  0%, 80%, 100% { transform: scale(0); }
640
  40% { transform: scale(1); }
641
  }
642
+ .suggestions {
643
+ display: flex;
644
+ gap: 10px;
645
+ padding: 12px 20px;
646
+ overflow-x: auto;
647
+ background: rgba(0,0,0,0.3);
648
+ }
649
+ .suggestion {
650
+ padding: 6px 14px;
651
+ background: rgba(255,255,255,0.1);
652
+ border-radius: 20px;
653
+ color: #aaa;
654
+ font-size: 0.8rem;
655
+ cursor: pointer;
656
+ transition: all 0.2s;
657
+ white-space: nowrap;
658
+ }
659
+ .suggestion:hover {
660
+ background: linear-gradient(135deg, #667eea, #764ba2);
661
+ color: white;
662
  }
663
+ .version-badge {
664
+ position: fixed;
665
+ bottom: 10px;
666
+ right: 10px;
667
+ background: rgba(0,0,0,0.5);
668
+ padding: 4px 10px;
669
+ border-radius: 20px;
670
+ font-size: 0.7rem;
671
+ color: #888;
672
+ font-family: monospace;
673
+ }
674
+ @media (max-width: 600px) {
675
+ .message { max-width: 95%; }
676
+ .suggestions { display: none; }
677
  }
678
  </style>
679
  </head>
680
  <body>
681
+ <div class="chat-header">
682
+ <h1>🤖 MTP V3.3.1 - Mi Transformer Pequeño</h1>
683
+ <p>Asistente IA entrenado desde cero con arquitectura Transformer | 15M parámetros</p>
684
+ </div>
685
+ <div class="suggestions">
686
+ <div class="suggestion">Hola</div>
687
+ <div class="suggestion">¿Quién eres?</div>
688
+ <div class="suggestion">¿Qué puedes hacer?</div>
689
+ <div class="suggestion">Explícame la IA</div>
690
+ <div class="suggestion">Háblame de BTS</div>
691
+ <div class="suggestion">¿Qué es un agujero negro?</div>
692
+ <div class="suggestion">Dime un chiste</div>
693
+ <div class="suggestion">Adiós</div>
694
  </div>
695
+ <div class="chat-messages" id="chatMessages">
696
  <div class="message bot">
697
+ <div class="message-content">¡Hola! Soy MTP versión 3.3.1, tu asistente de IA entrenado desde cero. Puedo hablar de ciencia, K-Pop (BTS, BLACKPINK), tecnología, filosofía y mucho más. ¿En qué puedo ayudarte hoy?</div>
698
  </div>
699
  </div>
700
+ <div class="chat-input-container">
701
  <div class="input-wrapper">
702
+ <input type="text" id="messageInput" placeholder="Escribe tu mensaje aquí..." autocomplete="off">
703
+ <button id="sendBtn">Enviar</button>
704
  </div>
705
  </div>
706
+ <div class="version-badge">MTP-3.3.1 | Transformer</div>
707
  <script>
708
+ const chatMessages = document.getElementById('chatMessages');
709
+ const messageInput = document.getElementById('messageInput');
710
+ const sendBtn = document.getElementById('sendBtn');
711
+ let isLoading = false;
712
 
713
  function addMessage(text, isUser) {
714
  const div = document.createElement('div');
715
  div.className = `message ${isUser ? 'user' : 'bot'}`;
716
  div.innerHTML = `<div class="message-content">${escapeHtml(text)}</div>`;
717
+ chatMessages.appendChild(div);
718
+ chatMessages.scrollTop = chatMessages.scrollHeight;
719
+ return div;
720
  }
721
 
722
  function escapeHtml(text) {
723
+ const div = document.createElement('div');
724
+ div.textContent = text;
725
+ return div.innerHTML;
726
  }
727
 
728
+ function addTypingIndicator() {
729
  const div = document.createElement('div');
730
  div.className = 'message bot';
731
+ div.id = 'typingIndicator';
732
  div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
733
+ chatMessages.appendChild(div);
734
+ chatMessages.scrollTop = chatMessages.scrollHeight;
735
  }
736
 
737
+ function removeTypingIndicator() {
738
+ const indicator = document.getElementById('typingIndicator');
739
+ if (indicator) indicator.remove();
740
  }
741
 
742
+ async function sendMessage(text = null) {
743
+ const messageText = text || messageInput.value.trim();
744
+ if (!messageText || isLoading) return;
745
 
746
+ if (!text) messageInput.value = '';
747
+ addMessage(messageText, true);
748
+ isLoading = true;
749
  sendBtn.disabled = true;
750
+ addTypingIndicator();
751
 
752
  try {
753
+ const response = await fetch('/generate', {
754
  method: 'POST',
755
  headers: { 'Content-Type': 'application/json' },
756
+ body: JSON.stringify({ text: messageText })
757
  });
758
+ const data = await response.json();
759
+ removeTypingIndicator();
760
+ addMessage(data.reply, false);
761
+ } catch (error) {
762
+ removeTypingIndicator();
763
+ addMessage('⚠️ Error de conexión. Por favor, intenta de nuevo.', false);
764
  } finally {
765
+ isLoading = false;
766
  sendBtn.disabled = false;
767
+ messageInput.focus();
768
  }
769
  }
770
 
771
+ messageInput.addEventListener('keypress', (e) => {
772
+ if (e.key === 'Enter') sendMessage();
773
  });
774
+ sendBtn.addEventListener('click', () => sendMessage());
775
+ document.querySelectorAll('.suggestion').forEach(el => {
776
+ el.addEventListener('click', () => sendMessage(el.textContent));
777
+ });
778
+ messageInput.focus();
779
  </script>
780
  </body>
781
  </html>
782
  """
783
 
784
+ # ======================
785
+ # MAIN
786
+ # ======================
787
  if __name__ == "__main__":
788
  port = int(os.environ.get("PORT", 7860))
789
+ print("\n" + "=" * 60)
790
+ print(f"🚀 Iniciando servidor MTP V3.3.1 en puerto {port}...")
791
+ print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
792
+ print(f"📡 API docs: http://0.0.0.0:{port}/docs")
793
+ print(f"❤️ Health check: http://0.0.0.0:{port}/health")
794
+ print("=" * 60)
795
+
796
+ uvicorn.run(
797
+ app,
798
+ host="0.0.0.0",
799
+ port=port,
800
+ log_level="info"
801
+ )