teszenofficial commited on
Commit
44c921d
·
verified ·
1 Parent(s): 7d7b9e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -23
app.py CHANGED
@@ -1,9 +1,9 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- MTP 1.0 API - CORREGIDO PARA CPU
4
- - Misma arquitectura (512 dims, 16 heads, 8 layers)
5
- - Tokenizador BPE estilo GPT
6
- - Optimizado para velocidad en CPU
7
  """
8
 
9
  import os
@@ -25,11 +25,10 @@ import torch.nn.functional as F
25
  import sentencepiece as spm
26
 
27
  # ======================
28
- # OPTIMIZACIONES PARA CPU
29
  # ======================
30
  if torch.cuda.is_available():
31
  DEVICE = "cuda"
32
- torch.backends.cudnn.benchmark = True
33
  print("✅ GPU detectada")
34
  else:
35
  DEVICE = "cpu"
@@ -41,7 +40,7 @@ else:
41
  MODEL_REPO = "TeszenAI/MTP-1.0"
42
 
43
  # ======================
44
- # ARQUITECTURA MTP 1.0 (CORREGIDA)
45
  # ======================
46
  class RMSNorm(nn.Module):
47
  __slots__ = ('weight', 'eps')
@@ -97,7 +96,6 @@ class RotaryMultiHeadAttention(nn.Module):
97
  Q = self.w_q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
98
  K = self.w_k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
99
  V = self.w_v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
100
- # Rotación
101
  Q_rot = Q * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(Q) * sin.unsqueeze(0).unsqueeze(0)
102
  K_rot = K * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(K) * sin.unsqueeze(0).unsqueeze(0)
103
  scores = torch.matmul(Q_rot, K_rot.transpose(-2, -1)) / self.scale
@@ -146,7 +144,8 @@ class MTP1Model(nn.Module):
146
  return self.lm_head(self.norm(x))
147
 
148
  @torch.no_grad()
149
- def generate(self, input_ids, max_new=120, temperature=0.45, top_k=30, top_p=0.88, repetition_penalty=1.2):
 
150
  generated = input_ids
151
  eos_id = 3
152
  last_tokens = []
@@ -182,11 +181,13 @@ class MTP1Model(nn.Module):
182
  return generated
183
 
184
  # ======================
185
- # LIMPIEZA
186
  # ======================
187
  def clean_response(response: str) -> str:
188
  if not response:
189
  return ""
 
 
190
  words = response.split()
191
  cleaned = []
192
  last = ""
@@ -196,14 +197,29 @@ def clean_response(response: str) -> str:
196
  last = w
197
  response = " ".join(cleaned)
198
  response = re.sub(r'\s+', ' ', response).strip()
 
 
199
  if response and response[0].islower():
200
  response = response[0].upper() + response[1:]
 
 
201
  if response and response[-1] not in '.!?':
202
  response += '.'
 
 
203
  if len(response) > 400:
 
204
  period = response[:400].rfind('.')
205
  if period > 50:
206
  response = response[:period+1]
 
 
 
 
 
 
 
 
207
  return response
208
 
209
  # ======================
@@ -222,7 +238,7 @@ sp.load(tokenizer_path)
222
  config["vocab_size"] = sp.get_piece_size()
223
 
224
  print(f"🧠 Inicializando MTP 1.0...")
225
- print(f" → Vocabulario BPE: {config['vocab_size']} tokens")
226
  print(f" → Dimensiones: {config.get('d_model', 512)}")
227
  print(f" → Capas: {config.get('n_layers', 8)}")
228
 
@@ -263,14 +279,21 @@ async def generate(req: PromptRequest):
263
  ACTIVE_REQUESTS -= 1
264
  return {"reply": ""}
265
 
266
- tokens = sp.encode(build_prompt(user_input))[:350]
 
 
 
 
 
 
 
267
  input_ids = torch.tensor([tokens], device=DEVICE)
268
 
269
  try:
270
  start = time.time()
271
  output_ids = model.generate(
272
  input_ids,
273
- max_new=100,
274
  temperature=0.45,
275
  top_k=30,
276
  top_p=0.88,
@@ -290,13 +313,17 @@ async def generate(req: PromptRequest):
290
 
291
  response = clean_response(response)
292
 
293
- if len(response) < 3:
294
- response = "Lo siento, no pude generar una respuesta clara."
 
 
 
295
 
296
  return {
297
  "reply": response,
298
  "time": round(elapsed, 2),
299
  "tokens": len(safe_tokens),
 
300
  "model": "MTP-1.0"
301
  }
302
 
@@ -322,10 +349,10 @@ def info():
322
  "parameters": param_count,
323
  "parameters_millions": round(param_count / 1e6, 2),
324
  "device": DEVICE,
 
325
  "d_model": config.get('d_model', 512),
326
  "n_layers": config.get('n_layers', 8),
327
- "n_heads": config.get('n_heads', 16),
328
- "vocab_size": config.get('vocab_size')
329
  }
330
 
331
  # ======================
@@ -371,6 +398,7 @@ def chat_ui():
371
  font-size: 0.85rem;
372
  line-height: 1.4;
373
  animation: fadeIn 0.2s ease;
 
374
  }
375
  @keyframes fadeIn {
376
  from { opacity: 0; transform: translateY(5px); }
@@ -457,26 +485,29 @@ def chat_ui():
457
  <body>
458
  <div class="header">
459
  <h1>🤖 MTP 1.0 - Asistente IA</h1>
460
- <p>✨ 512 dims | 16 heads | 8 layers | Respuestas inteligentes</p>
461
  </div>
462
  <div class="messages" id="messages">
463
- <div class="message bot">✨ Hola, soy MTP 1.0. ¿En qué puedo ayudarte?</div>
464
  </div>
465
  <div class="input-area">
466
  <input type="text" id="input" placeholder="Escribe tu pregunta..." autocomplete="off">
467
  <button id="send">Enviar</button>
468
  </div>
469
- <div class="badge">⚡ MTP 1.0 | 🌡️ 0.45</div>
470
  <script>
471
  const messages = document.getElementById('messages');
472
  const input = document.getElementById('input');
473
  const sendBtn = document.getElementById('send');
474
  let loading = false;
475
 
476
- function addMessage(text, isUser, time = null) {
477
  const div = document.createElement('div');
478
  div.className = `message ${isUser ? 'user' : 'bot'}`;
479
- div.innerHTML = `<div>${escapeHtml(text)}</div>${time ? `<div style="font-size:0.6rem;color:#666;margin-top:4px;">⚡ ${time}s</div>` : ''}`;
 
 
 
480
  messages.appendChild(div);
481
  messages.scrollTop = messages.scrollHeight;
482
  }
@@ -519,7 +550,7 @@ def chat_ui():
519
  });
520
  const data = await response.json();
521
  hideTyping();
522
- addMessage(data.reply, false, data.time);
523
  } catch (error) {
524
  hideTyping();
525
  addMessage('⚠️ Error de conexión. Intenta de nuevo.', false);
@@ -544,6 +575,7 @@ if __name__ == "__main__":
544
  print(f"🚀 MTP 1.0 en http://0.0.0.0:{port}")
545
  print(f"📊 Parámetros: {param_count:,} ({param_count/1e6:.2f}M)")
546
  print(f"🌡️ Temperatura: 0.45 | 🔁 Repetition penalty: 1.2")
 
547
  print(f"💻 Dispositivo: {DEVICE.upper()}")
548
  print("=" * 60)
549
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ MTP 1.0 API - OPTIMIZADO (400 caracteres máx)
4
+ - Respuestas concisas pero inteligentes
5
+ - Límite de 400 caracteres por respuesta
6
+ - Rápido en CPU
7
  """
8
 
9
  import os
 
25
  import sentencepiece as spm
26
 
27
  # ======================
28
+ # OPTIMIZACIONES
29
  # ======================
30
  if torch.cuda.is_available():
31
  DEVICE = "cuda"
 
32
  print("✅ GPU detectada")
33
  else:
34
  DEVICE = "cpu"
 
40
  MODEL_REPO = "TeszenAI/MTP-1.0"
41
 
42
  # ======================
43
+ # ARQUITECTURA MTP 1.0
44
  # ======================
45
  class RMSNorm(nn.Module):
46
  __slots__ = ('weight', 'eps')
 
96
  Q = self.w_q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
97
  K = self.w_k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
98
  V = self.w_v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
 
99
  Q_rot = Q * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(Q) * sin.unsqueeze(0).unsqueeze(0)
100
  K_rot = K * cos.unsqueeze(0).unsqueeze(0) + self._rotate_half(K) * sin.unsqueeze(0).unsqueeze(0)
101
  scores = torch.matmul(Q_rot, K_rot.transpose(-2, -1)) / self.scale
 
144
  return self.lm_head(self.norm(x))
145
 
146
  @torch.no_grad()
147
+ def generate(self, input_ids, max_new=80, temperature=0.45, top_k=30, top_p=0.88, repetition_penalty=1.2):
148
+ """Generación optimizada para respuestas cortas"""
149
  generated = input_ids
150
  eos_id = 3
151
  last_tokens = []
 
181
  return generated
182
 
183
  # ======================
184
+ # LIMPIEZA INTELIGENTE (400 CARACTERES)
185
  # ======================
186
  def clean_response(response: str) -> str:
187
  if not response:
188
  return ""
189
+
190
+ # Eliminar repeticiones
191
  words = response.split()
192
  cleaned = []
193
  last = ""
 
197
  last = w
198
  response = " ".join(cleaned)
199
  response = re.sub(r'\s+', ' ', response).strip()
200
+
201
+ # Capitalizar
202
  if response and response[0].islower():
203
  response = response[0].upper() + response[1:]
204
+
205
+ # Asegurar punto final
206
  if response and response[-1] not in '.!?':
207
  response += '.'
208
+
209
+ # LIMITAR A 400 CARACTERES (inteligentemente)
210
  if len(response) > 400:
211
+ # Buscar el último punto antes de 400
212
  period = response[:400].rfind('.')
213
  if period > 50:
214
  response = response[:period+1]
215
+ else:
216
+ # Buscar última coma o espacio
217
+ space = response[:400].rfind(' ')
218
+ if space > 50:
219
+ response = response[:space] + '...'
220
+ else:
221
+ response = response[:400] + '...'
222
+
223
  return response
224
 
225
  # ======================
 
238
  config["vocab_size"] = sp.get_piece_size()
239
 
240
  print(f"🧠 Inicializando MTP 1.0...")
241
+ print(f" → Vocabulario: {config['vocab_size']} tokens")
242
  print(f" → Dimensiones: {config.get('d_model', 512)}")
243
  print(f" → Capas: {config.get('n_layers', 8)}")
244
 
 
279
  ACTIVE_REQUESTS -= 1
280
  return {"reply": ""}
281
 
282
+ # Detectar saludo para respuesta más corta
283
+ greetings = ["hola", "buenos dias", "buenas tardes", "buenas noches", "hey", "que tal"]
284
+ is_greeting = user_input.lower().strip() in greetings
285
+
286
+ # Ajustar longitud según tipo
287
+ max_new = 50 if is_greeting else 80
288
+
289
+ tokens = sp.encode(build_prompt(user_input))[:300]
290
  input_ids = torch.tensor([tokens], device=DEVICE)
291
 
292
  try:
293
  start = time.time()
294
  output_ids = model.generate(
295
  input_ids,
296
+ max_new=max_new,
297
  temperature=0.45,
298
  top_k=30,
299
  top_p=0.88,
 
313
 
314
  response = clean_response(response)
315
 
316
+ if len(response) < 5:
317
+ if is_greeting:
318
+ response = "¡Hola! ¿En qué puedo ayudarte?"
319
+ else:
320
+ response = "Lo siento, no pude generar una respuesta clara."
321
 
322
  return {
323
  "reply": response,
324
  "time": round(elapsed, 2),
325
  "tokens": len(safe_tokens),
326
+ "characters": len(response),
327
  "model": "MTP-1.0"
328
  }
329
 
 
349
  "parameters": param_count,
350
  "parameters_millions": round(param_count / 1e6, 2),
351
  "device": DEVICE,
352
+ "max_response_chars": 400,
353
  "d_model": config.get('d_model', 512),
354
  "n_layers": config.get('n_layers', 8),
355
+ "n_heads": config.get('n_heads', 16)
 
356
  }
357
 
358
  # ======================
 
398
  font-size: 0.85rem;
399
  line-height: 1.4;
400
  animation: fadeIn 0.2s ease;
401
+ word-wrap: break-word;
402
  }
403
  @keyframes fadeIn {
404
  from { opacity: 0; transform: translateY(5px); }
 
485
  <body>
486
  <div class="header">
487
  <h1>🤖 MTP 1.0 - Asistente IA</h1>
488
+ <p>✨ Respuestas concisas | Máximo 400 caracteres | Rápido</p>
489
  </div>
490
  <div class="messages" id="messages">
491
+ <div class="message bot">✨ Hola, soy MTP 1.0. Respuestas cortas pero inteligentes (máx 400 caracteres). ¿En qué puedo ayudarte?</div>
492
  </div>
493
  <div class="input-area">
494
  <input type="text" id="input" placeholder="Escribe tu pregunta..." autocomplete="off">
495
  <button id="send">Enviar</button>
496
  </div>
497
+ <div class="badge">⚡ MTP 1.0 | 🌡️ 0.45 | 📏 400 chars máx</div>
498
  <script>
499
  const messages = document.getElementById('messages');
500
  const input = document.getElementById('input');
501
  const sendBtn = document.getElementById('send');
502
  let loading = false;
503
 
504
+ function addMessage(text, isUser, time = null, chars = null) {
505
  const div = document.createElement('div');
506
  div.className = `message ${isUser ? 'user' : 'bot'}`;
507
+ let info = '';
508
+ if (time) info += `⚡ ${time}s`;
509
+ if (chars) info += `${info ? ' | ' : ''}📝 ${chars} chars`;
510
+ div.innerHTML = `<div>${escapeHtml(text)}</div>${info ? `<div style="font-size:0.6rem;color:#666;margin-top:4px;">${info}</div>` : ''}`;
511
  messages.appendChild(div);
512
  messages.scrollTop = messages.scrollHeight;
513
  }
 
550
  });
551
  const data = await response.json();
552
  hideTyping();
553
+ addMessage(data.reply, false, data.time, data.characters);
554
  } catch (error) {
555
  hideTyping();
556
  addMessage('⚠️ Error de conexión. Intenta de nuevo.', false);
 
575
  print(f"🚀 MTP 1.0 en http://0.0.0.0:{port}")
576
  print(f"📊 Parámetros: {param_count:,} ({param_count/1e6:.2f}M)")
577
  print(f"🌡️ Temperatura: 0.45 | 🔁 Repetition penalty: 1.2")
578
+ print(f"📏 Máximo de caracteres por respuesta: 400")
579
  print(f"💻 Dispositivo: {DEVICE.upper()}")
580
  print("=" * 60)
581