teszenofficial commited on
Commit
919c8d2
·
verified ·
1 Parent(s): d07a348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -506
app.py CHANGED
@@ -32,137 +32,85 @@ if DEVICE == "cpu":
32
  torch.set_grad_enabled(False)
33
 
34
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
- MODEL_REPO = "TeszenAI/MTP-3.1.1" # <-- CAMBIA A TU REPO
36
 
37
  # ======================
38
  # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
  # ======================
40
 
41
  def truncate_greeting_response(text: str) -> str:
42
- """
43
- Si la respuesta comienza con un saludo (Hola, ¡Hola!, etc.),
44
- la trunca después del primer punto, signo de interrogación o exclamación.
45
- """
46
  if not text:
47
  return text
48
 
49
- # Buscar el primer punto, signo de interrogación o exclamación que cierra una frase
50
  end_pattern = r'([.!?])'
51
  end_match = re.search(end_pattern, text)
52
 
53
  if end_match:
54
- # Cortar justo después del signo de puntuación
55
  end_pos = end_match.end()
56
  truncated = text[:end_pos].strip()
57
  return truncated
58
 
59
- # Si no encuentra puntuación, devolver solo la primera línea o las primeras 50 caracteres
60
- first_line = text.split('\n')[0].strip()
61
- if len(first_line) > 5:
62
- return first_line
63
-
64
  return text
65
 
66
  def clean_response(text: str, user_input: str = "") -> str:
67
- """
68
- Limpia la respuesta eliminando repeticiones, frases sin sentido y
69
- asegurando que termine correctamente.
70
- """
71
  if not text:
72
  return ""
73
 
74
- # 1. Eliminar repeticiones excesivas de palabras o frases cortas
75
  words = text.split()
76
  cleaned_words = []
77
- last_phrase = ""
78
  repeat_count = 0
79
 
80
  for word in words:
81
- if word == last_phrase:
82
  repeat_count += 1
83
- if repeat_count > 2: # Si repite más de 2 veces seguidas
84
  continue
85
  else:
86
- last_phrase = word
87
  repeat_count = 0
88
  cleaned_words.append(word)
89
 
90
  text = " ".join(cleaned_words)
91
 
92
- # 2. Eliminar patrones sin sentido (repeticiones de letras, caracteres raros)
93
- text = re.sub(r'(.)\1{4,}', r'\1\1', text) # aaa... -> aa
94
- text = re.sub(r'[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ0-9\s.,;:!?¿¡()\-"]+', '', text)
95
-
96
- # 3. Detectar si es un saludo (el usuario dijo "hola" o similar)
97
- is_greeting = re.match(r'^(¡?Hola!?\s?)', text, re.IGNORECASE)
98
 
99
- # 4. Si es un saludo, truncar después del primer punto
100
- if is_greeting:
101
- text = truncate_greeting_response(text)
102
- else:
103
- # Para respuestas normales, cortar en patrones de finalización
104
- stop_patterns = [
105
- r'(\.\s*)$', # Punto final
106
- r'[.!?](\s+)?$', # Fin de oración
107
- r'(gracias|hasta luego|adiós|saludos|fin|fin del mensaje)$',
108
- r'(¿algo más\?|¿necesitas algo más\?|¿en qué más puedo ayudarte\?)'
109
- ]
110
-
111
- for pattern in stop_patterns:
112
- match = re.search(pattern, text, re.IGNORECASE)
113
- if match:
114
- end_pos = match.end()
115
- text = text[:end_pos]
116
- break
117
 
118
- # 5. Si la respuesta es muy corta o vacía, devolver mensaje por defecto
119
- if len(text.strip()) < 10:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
121
 
122
- # 6. Eliminar espacios múltiples y saltos de línea excesivos
123
  text = re.sub(r'\s+', ' ', text).strip()
124
 
125
  return text
126
 
127
- def should_stop_generation(generated_text: str, user_input: str = "", min_length: int = 30, max_length: int = 300) -> bool:
128
- """
129
- Determina si debemos detener la generación basado en el texto generado.
130
- """
131
- # Verificar si el usuario dijo hola (para detener generación temprano)
132
- if user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos"]:
133
- # Buscar el primer punto, signo de interrogación o exclamación
134
- match = re.search(r'[.!?]', generated_text)
135
- if match:
136
- return True
137
-
138
- # Si ya superamos la longitud máxima
139
- if len(generated_text) > max_length:
140
- return True
141
-
142
- # Si es muy corto y no hay puntuación final
143
- if len(generated_text) < min_length and not re.search(r'[.!?]$', generated_text):
144
- return False
145
-
146
- # Señales de que ya terminó la respuesta
147
- stop_signals = [
148
- r'(gracias por tu pregunta|espero haberte ayudado|¿necesitas algo más\?)',
149
- r'(hasta luego|adiós|quedo atento|saludos cordiales)',
150
- r'(fin del mensaje|fin de la conversación)'
151
- ]
152
-
153
- for signal in stop_signals:
154
- if re.search(signal, generated_text, re.IGNORECASE):
155
- return True
156
-
157
- # Si la última frase parece completa
158
- last_sentence = generated_text.split('.')[-1].strip()
159
- if len(last_sentence) > 5 and re.search(r'[.!?]$', last_sentence):
160
- # Y ya hemos generado suficiente contenido
161
- if len(generated_text) > min_length:
162
- return True
163
-
164
- return False
165
-
166
  # ======================
167
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
168
  # ======================
@@ -274,15 +222,9 @@ class MTPModel(nn.Module):
274
  logits = self.lm_head(x)
275
  return logits
276
 
277
- def generate(self, input_ids, user_input="", max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
278
- """Método de generación mejorado con detección inteligente de fin"""
279
  generated = input_ids
280
- generated_text = ""
281
- min_response_length = 30
282
- max_response_length = max_new_tokens * 2
283
-
284
- # Detectar si es un saludo
285
- is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos"]
286
 
287
  for step in range(max_new_tokens):
288
  with torch.no_grad():
@@ -309,23 +251,11 @@ class MTPModel(nn.Module):
309
  probs = F.softmax(next_logits, dim=-1)
310
  next_token = torch.multinomial(probs, num_samples=1).item()
311
 
312
- if next_token == 3: # EOS ID para SentencePiece
 
313
  break
314
 
315
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
316
-
317
- # Decodificar parcialmente para verificar si debemos parar
318
- if step > 5 and step % 5 == 0:
319
- # Decodificar los últimos tokens para verificar puntuación
320
- if is_greeting:
321
- # Para saludos, detener en cuanto haya un punto, signo de interrogación o exclamación
322
- # Necesitamos decodificar para ver el texto real
323
- partial_tokens = generated[0].tolist()
324
- if len(partial_tokens) > 10:
325
- # Decodificar parcialmente (esto es aproximado)
326
- if step > 10:
327
- # Buscar puntuación en el texto decodificado
328
- break
329
 
330
  return generated
331
 
@@ -357,6 +287,10 @@ else:
357
 
358
  # Cargar tokenizador
359
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
 
 
 
 
360
  sp = spm.SentencePieceProcessor()
361
  sp.load(tokenizer_path)
362
  VOCAB_SIZE = sp.get_piece_size()
@@ -377,25 +311,13 @@ model.to(DEVICE)
377
  model_path = os.path.join(repo_path, "mtp_model.pt")
378
  if os.path.exists(model_path):
379
  state_dict = torch.load(model_path, map_location=DEVICE)
380
- model.load_state_dict(state_dict)
381
  print("✅ Pesos del modelo cargados")
382
  else:
383
- print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
384
 
385
  model.eval()
386
 
387
- # Cuantización para CPU
388
- if DEVICE == "cpu":
389
- print("⚡ Aplicando cuantización dinámica para CPU...")
390
- try:
391
- model = torch.quantization.quantize_dynamic(
392
- model,
393
- {nn.Linear},
394
- dtype=torch.qint8
395
- )
396
- except Exception as e:
397
- print(f"⚠️ No se pudo aplicar cuantización: {e}")
398
-
399
  param_count = sum(p.numel() for p in model.parameters())
400
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
401
 
@@ -417,7 +339,7 @@ app.add_middleware(
417
 
418
  class PromptRequest(BaseModel):
419
  text: str = Field(..., max_length=2000, description="Texto de entrada")
420
- max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
421
  temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
422
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
423
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
@@ -460,37 +382,27 @@ async def generate(req: PromptRequest):
460
  global ACTIVE_REQUESTS
461
  ACTIVE_REQUESTS += 1
462
 
463
- dyn_max_tokens = req.max_tokens
464
- dyn_temperature = req.temperature
465
-
466
- if ACTIVE_REQUESTS > 2:
467
- print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
468
- dyn_max_tokens = min(dyn_max_tokens, 120)
469
- dyn_temperature = max(0.5, dyn_temperature * 0.9)
470
-
471
  user_input = req.text.strip()
472
  if not user_input:
473
  ACTIVE_REQUESTS -= 1
474
  return {"reply": "", "tokens_generated": 0}
 
 
 
 
 
 
475
 
476
  full_prompt = build_prompt(user_input)
477
- tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
478
  input_ids = torch.tensor([tokens], device=DEVICE)
479
-
480
- # Detectar si es un saludo para limitar la generación
481
- is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos"]
482
-
483
- # Si es saludo, limitar max_tokens a 30 (suficiente para una frase corta)
484
- if is_greeting:
485
- dyn_max_tokens = min(dyn_max_tokens, 30)
486
 
487
  try:
488
  with torch.no_grad():
489
  output_ids = model.generate(
490
  input_ids,
491
- user_input=user_input,
492
- max_new_tokens=dyn_max_tokens,
493
- temperature=dyn_temperature,
494
  top_k=req.top_k,
495
  top_p=req.top_p,
496
  repetition_penalty=req.repetition_penalty
@@ -498,29 +410,23 @@ async def generate(req: PromptRequest):
498
 
499
  gen_tokens = output_ids[0, len(tokens):].tolist()
500
 
501
- safe_tokens = [
502
- t for t in gen_tokens
503
- if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
504
- ]
505
-
506
- response = tokenizer_wrapper.decode(safe_tokens).strip()
507
 
508
- if "###" in response:
509
- response = response.split("###")[0].strip()
 
 
510
 
511
- # Aplicar limpieza inteligente a la respuesta (incluye truncamiento de saludos)
512
  response = clean_response(response, user_input)
513
 
514
- # Para saludos, asegurar que solo mostramos hasta el primer punto
515
- if is_greeting:
516
- # Buscar el primer punto, signo de interrogación o exclamación
517
- punct_match = re.search(r'[.!?]', response)
518
- if punct_match:
519
- response = response[:punct_match.end()].strip()
520
  else:
521
- # Si no hay puntuación, tomar solo las primeras 60 caracteres
522
- if len(response) > 60:
523
- response = response[:60] + "..."
524
 
525
  return {
526
  "reply": response,
@@ -530,8 +436,13 @@ async def generate(req: PromptRequest):
530
 
531
  except Exception as e:
532
  print(f"❌ Error durante generación: {e}")
 
 
 
 
 
533
  return {
534
- "reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
535
  "error": str(e)
536
  }
537
 
@@ -565,7 +476,7 @@ def model_info():
565
  }
566
 
567
  # ======================
568
- # INTERFAZ WEB (MODERNA CON LOGO INTEGRADO)
569
  # ======================
570
  @app.get("/", response_class=HTMLResponse)
571
  def chat_ui():
@@ -574,410 +485,197 @@ def chat_ui():
574
  <html lang="es">
575
  <head>
576
  <meta charset="UTF-8">
577
- <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
578
  <title>MTP - Asistente IA</title>
579
- <link rel="preconnect" href="https://fonts.googleapis.com">
580
- <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
581
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
582
  <style>
583
- :root {
584
- --bg-color: #131314;
585
- --surface-color: #1E1F20;
586
- --accent-color: #4a9eff;
587
- --text-primary: #e3e3e3;
588
- --text-secondary: #9aa0a6;
589
- --user-bubble: #282a2c;
590
- }
591
- * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
592
  body {
593
- margin: 0;
594
- background-color: var(--bg-color);
595
- font-family: 'Inter', sans-serif;
596
- color: var(--text-primary);
597
- height: 100dvh;
598
  display: flex;
599
  flex-direction: column;
600
- overflow: hidden;
601
- }
602
- header {
603
- padding: 12px 20px;
604
- display: flex;
605
- align-items: center;
606
- justify-content: space-between;
607
- background: rgba(19, 19, 20, 0.85);
608
- backdrop-filter: blur(12px);
609
- position: fixed;
610
- top: 0;
611
- width: 100%;
612
- z-index: 50;
613
- border-bottom: 1px solid rgba(255,255,255,0.05);
614
- }
615
- .brand-wrapper {
616
- display: flex;
617
- align-items: center;
618
- gap: 12px;
619
- cursor: pointer;
620
  }
621
- .brand-logo {
622
- width: 32px;
623
- height: 32px;
624
- border-radius: 50%;
625
- background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
626
- background-size: cover;
627
- background-position: center;
628
- background-repeat: no-repeat;
629
- border: 1px solid rgba(255,255,255,0.1);
630
  }
631
- .brand-text {
 
 
632
  font-weight: 500;
633
- font-size: 1.05rem;
634
- display: flex;
635
- align-items: center;
636
- gap: 8px;
637
- }
638
- .version-badge {
639
- font-size: 0.75rem;
640
- background: rgba(74, 158, 255, 0.15);
641
- color: #8ab4f8;
642
- padding: 2px 8px;
643
- border-radius: 12px;
644
- font-weight: 600;
645
  }
646
- .chat-scroll {
647
  flex: 1;
648
  overflow-y: auto;
649
- padding: 80px 20px 40px 20px;
650
  display: flex;
651
  flex-direction: column;
652
- gap: 30px;
653
- max-width: 850px;
654
- margin: 0 auto;
655
- width: 100%;
656
- scroll-behavior: smooth;
657
  }
658
- .msg-row {
659
  display: flex;
660
- gap: 16px;
661
- width: 100%;
662
- opacity: 0;
663
- transform: translateY(10px);
664
- animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
665
  }
666
- .msg-row.user { justify-content: flex-end; }
667
- .msg-row.bot { justify-content: flex-start; align-items: flex-start; }
668
- .msg-content {
669
- line-height: 1.6;
670
- font-size: 1rem;
671
- word-wrap: break-word;
672
- max-width: 85%;
673
  }
674
- .user .msg-content {
675
- background-color: var(--user-bubble);
676
- padding: 10px 18px;
677
  border-radius: 18px;
678
- border-top-right-radius: 4px;
679
- color: #fff;
680
  }
681
- .bot .msg-content-wrapper {
682
- display: flex;
683
- flex-direction: column;
684
- gap: 8px;
685
- width: 100%;
686
  }
687
- .bot .msg-text {
688
- padding-top: 6px;
689
- color: var(--text-primary);
 
690
  }
691
- .bot-avatar {
692
- width: 34px;
693
- height: 34px;
694
- min-width: 34px;
695
- border-radius: 50%;
696
- background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
697
- background-size: cover;
698
- background-position: center;
699
- background-repeat: no-repeat;
700
- box-shadow: 0 2px 6px rgba(0,0,0,0.2);
701
  }
702
- .bot-actions {
703
  display: flex;
704
- gap: 10px;
705
- opacity: 0;
706
- transition: opacity 0.3s;
707
- margin-top: 5px;
708
- }
709
- .action-btn {
710
- background: transparent;
711
- border: none;
712
- color: var(--text-secondary);
713
- cursor: pointer;
714
- padding: 4px;
715
- border-radius: 4px;
716
- display: flex;
717
- align-items: center;
718
- transition: color 0.2s, background 0.2s;
719
- }
720
- .action-btn:hover {
721
- color: var(--text-primary);
722
- background: rgba(255,255,255,0.08);
723
- }
724
- .action-btn svg { width: 16px; height: 16px; fill: currentColor; }
725
- .typing-cursor::after {
726
- content: '▊';
727
- display: inline-block;
728
- margin-left: 2px;
729
- animation: blink 1s infinite;
730
- }
731
- .footer-container {
732
- padding: 0 20px 20px 20px;
733
- background: linear-gradient(to top, var(--bg-color) 85%, transparent);
734
- position: relative;
735
- z-index: 60;
736
- }
737
- .input-box {
738
- max-width: 850px;
739
  margin: 0 auto;
740
- background: var(--surface-color);
741
- border-radius: 28px;
742
- padding: 8px 10px 8px 20px;
743
- display: flex;
744
- align-items: center;
745
- border: 1px solid rgba(255,255,255,0.1);
746
- transition: border-color 0.2s, box-shadow 0.2s;
747
- }
748
- .input-box:focus-within {
749
- border-color: rgba(74, 158, 255, 0.5);
750
- box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
751
  }
752
- #userInput {
753
  flex: 1;
754
- background: transparent;
 
755
  border: none;
 
756
  color: white;
757
- font-size: 1rem;
758
- font-family: inherit;
759
- padding: 10px 0;
760
  }
761
- #mainBtn {
762
- background: white;
763
- color: black;
 
 
 
764
  border: none;
765
- width: 36px;
766
- height: 36px;
767
- border-radius: 50%;
768
- display: flex;
769
- align-items: center;
770
- justify-content: center;
771
  cursor: pointer;
772
- margin-left: 8px;
773
- transition: transform 0.2s;
774
  }
775
- #mainBtn:hover { transform: scale(1.05); }
776
- .disclaimer {
777
- text-align: center;
778
- font-size: 0.75rem;
779
- color: #666;
780
- margin-top: 12px;
781
  }
782
- @keyframes slideUpFade {
783
- from { opacity: 0; transform: translateY(15px); }
784
- to { opacity: 1; transform: translateY(0); }
 
 
 
 
 
 
 
 
785
  }
786
- @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
787
- @keyframes pulseAvatar {
788
- 0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
789
- 70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
790
- 100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
791
  }
792
- .pulsing { animation: pulseAvatar 1.5s infinite; }
793
- ::-webkit-scrollbar { width: 8px; }
794
- ::-webkit-scrollbar-track { background: transparent; }
795
- ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
796
  </style>
797
  </head>
798
  <body>
799
- <header>
800
- <div class="brand-wrapper" onclick="location.reload()">
801
- <div class="brand-logo"></div>
802
- <div class="brand-text">
803
- MTP <span class="version-badge">v1</span>
804
- </div>
805
- </div>
806
- </header>
807
- <div id="chatScroll" class="chat-scroll">
808
- <div class="msg-row bot" style="animation-delay: 0.1s;">
809
- <div class="bot-avatar"></div>
810
- <div class="msg-content-wrapper">
811
- <div class="msg-text">
812
- ¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
813
- </div>
814
- </div>
815
- </div>
816
  </div>
817
- <div class="footer-container">
818
- <div class="input-box">
819
- <input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
820
- <button id="mainBtn" onclick="handleBtnClick()">➤</button>
821
  </div>
822
- <div class="disclaimer">
823
- MTP puede cometer errores. Considera verificar la información importante.
 
 
 
824
  </div>
825
  </div>
826
  <script>
827
- const chatScroll = document.getElementById('chatScroll');
828
- const userInput = document.getElementById('userInput');
829
- const mainBtn = document.getElementById('mainBtn');
830
- let isGenerating = false;
831
- let abortController = null;
832
- let typingTimeout = null;
833
- let lastUserPrompt = "";
834
-
835
- function scrollToBottom() {
836
- chatScroll.scrollTop = chatScroll.scrollHeight;
837
- }
838
 
839
- function setBtnState(state) {
840
- if (state === 'sending') {
841
- mainBtn.innerHTML = '';
842
- isGenerating = true;
843
- } else {
844
- mainBtn.innerHTML = '➤';
845
- isGenerating = false;
846
- abortController = null;
847
- }
848
  }
849
 
850
- function handleBtnClick() {
851
- if (isGenerating) {
852
- stopGeneration();
853
- } else {
854
- sendMessage();
855
- }
 
856
  }
857
 
858
- function stopGeneration() {
859
- if (abortController) abortController.abort();
860
- if (typingTimeout) clearTimeout(typingTimeout);
861
- const activeCursor = document.querySelector('.typing-cursor');
862
- if (activeCursor) activeCursor.classList.remove('typing-cursor');
863
- const activeAvatar = document.querySelector('.pulsing');
864
- if (activeAvatar) activeAvatar.classList.remove('pulsing');
865
- setBtnState('idle');
866
- userInput.focus();
867
  }
868
 
869
- async function sendMessage(textOverride = null) {
870
- const text = textOverride || userInput.value.trim();
871
- if (!text) return;
872
- lastUserPrompt = text;
873
- if (!textOverride) {
874
- userInput.value = '';
875
- addMessage(text, 'user');
876
- }
877
- setBtnState('sending');
878
- abortController = new AbortController();
879
- const botRow = document.createElement('div');
880
- botRow.className = 'msg-row bot';
881
- const avatar = document.createElement('div');
882
- avatar.className = 'bot-avatar pulsing';
883
- const wrapper = document.createElement('div');
884
- wrapper.className = 'msg-content-wrapper';
885
- const msgText = document.createElement('div');
886
- msgText.className = 'msg-text';
887
- wrapper.appendChild(msgText);
888
- botRow.appendChild(avatar);
889
- botRow.appendChild(wrapper);
890
- chatScroll.appendChild(botRow);
891
- scrollToBottom();
892
  try {
893
  const response = await fetch('/generate', {
894
  method: 'POST',
895
  headers: { 'Content-Type': 'application/json' },
896
- body: JSON.stringify({ text: text }),
897
- signal: abortController.signal
898
  });
899
  const data = await response.json();
900
- if (!isGenerating) return;
901
- avatar.classList.remove('pulsing');
902
- const reply = data.reply || "No entendí eso.";
903
- await typeWriter(msgText, reply);
904
- if (isGenerating) {
905
- addActions(wrapper, reply);
906
- setBtnState('idle');
907
- }
908
  } catch (error) {
909
- if (error.name === 'AbortError') {
910
- msgText.textContent += " [Detenido]";
911
- } else {
912
- avatar.classList.remove('pulsing');
913
- msgText.textContent = "Error de conexión.";
914
- msgText.style.color = "#ff8b8b";
915
- setBtnState('idle');
916
- }
917
  }
918
  }
919
 
920
- function addMessage(text, sender) {
921
- const row = document.createElement('div');
922
- row.className = `msg-row ${sender}`;
923
- const content = document.createElement('div');
924
- content.className = 'msg-content';
925
- content.textContent = text;
926
- row.appendChild(content);
927
- chatScroll.appendChild(row);
928
- scrollToBottom();
929
- }
930
-
931
- function typeWriter(element, text, speed = 12) {
932
- return new Promise(resolve => {
933
- let i = 0;
934
- element.classList.add('typing-cursor');
935
- function type() {
936
- if (!isGenerating) {
937
- element.classList.remove('typing-cursor');
938
- resolve();
939
- return;
940
- }
941
- if (i < text.length) {
942
- element.textContent += text.charAt(i);
943
- i++;
944
- scrollToBottom();
945
- typingTimeout = setTimeout(type, speed + Math.random() * 5);
946
- } else {
947
- element.classList.remove('typing-cursor');
948
- resolve();
949
- }
950
- }
951
- type();
952
- });
953
- }
954
-
955
- function addActions(wrapperElement, textToCopy) {
956
- const actionsDiv = document.createElement('div');
957
- actionsDiv.className = 'bot-actions';
958
- const copyBtn = document.createElement('button');
959
- copyBtn.className = 'action-btn';
960
- copyBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
961
- copyBtn.onclick = () => {
962
- navigator.clipboard.writeText(textToCopy);
963
- };
964
- const regenBtn = document.createElement('button');
965
- regenBtn.className = 'action-btn';
966
- regenBtn.innerHTML = `<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
967
- regenBtn.onclick = () => {
968
- sendMessage(lastUserPrompt);
969
- };
970
- actionsDiv.appendChild(copyBtn);
971
- actionsDiv.appendChild(regenBtn);
972
- wrapperElement.appendChild(actionsDiv);
973
- requestAnimationFrame(() => actionsDiv.style.opacity = "1");
974
- scrollToBottom();
975
- }
976
-
977
- userInput.addEventListener('keydown', (e) => {
978
- if (e.key === 'Enter') handleBtnClick();
979
  });
980
- window.onload = () => userInput.focus();
 
981
  </script>
982
  </body>
983
  </html>
 
32
  torch.set_grad_enabled(False)
33
 
34
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
+ MODEL_REPO = "TeszenAI/MTP-3.1.1"
36
 
37
  # ======================
38
  # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
  # ======================
40
 
41
  def truncate_greeting_response(text: str) -> str:
42
+ """Trunca la respuesta de saludo después del primer punto o signo de puntuación"""
 
 
 
43
  if not text:
44
  return text
45
 
46
+ # Buscar el primer punto, signo de interrogación o exclamación
47
  end_pattern = r'([.!?])'
48
  end_match = re.search(end_pattern, text)
49
 
50
  if end_match:
 
51
  end_pos = end_match.end()
52
  truncated = text[:end_pos].strip()
53
  return truncated
54
 
55
+ # Si no hay puntuación, devolver solo primeras 80 caracteres
56
+ if len(text) > 80:
57
+ return text[:80] + "..."
 
 
58
  return text
59
 
60
  def clean_response(text: str, user_input: str = "") -> str:
61
+ """Limpia la respuesta del modelo"""
 
 
 
62
  if not text:
63
  return ""
64
 
65
+ # Eliminar repeticiones excesivas
66
  words = text.split()
67
  cleaned_words = []
68
+ last_word = ""
69
  repeat_count = 0
70
 
71
  for word in words:
72
+ if word == last_word:
73
  repeat_count += 1
74
+ if repeat_count > 2:
75
  continue
76
  else:
77
+ last_word = word
78
  repeat_count = 0
79
  cleaned_words.append(word)
80
 
81
  text = " ".join(cleaned_words)
82
 
83
+ # Eliminar caracteres raros
84
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text)
 
 
 
 
85
 
86
+ # Detectar si es un saludo
87
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ if is_greeting and text:
90
+ # Para saludos, truncar después del primer signo de puntuación
91
+ punct_match = re.search(r'[.!?]', text)
92
+ if punct_match:
93
+ text = text[:punct_match.end()].strip()
94
+ else:
95
+ # Si no hay puntuación, tomar solo la primera oración o 60 caracteres
96
+ first_sentence = text.split('.')[0].split('?')[0].split('!')[0].strip()
97
+ if len(first_sentence) > 5:
98
+ text = first_sentence
99
+ elif len(text) > 60:
100
+ text = text[:60]
101
+
102
+ # Si la respuesta es muy corta o vacía
103
+ if len(text.strip()) < 5:
104
+ # Respuestas por defecto según el tipo de saludo
105
+ if is_greeting:
106
+ return "¡Hola! ¿En qué puedo ayudarte?"
107
  return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
108
 
109
+ # Eliminar espacios múltiples
110
  text = re.sub(r'\s+', ' ', text).strip()
111
 
112
  return text
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # ======================
115
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
116
  # ======================
 
222
  logits = self.lm_head(x)
223
  return logits
224
 
225
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
226
+ """Genera texto token por token"""
227
  generated = input_ids
 
 
 
 
 
 
228
 
229
  for step in range(max_new_tokens):
230
  with torch.no_grad():
 
251
  probs = F.softmax(next_logits, dim=-1)
252
  next_token = torch.multinomial(probs, num_samples=1).item()
253
 
254
+ # EOS ID común para SentencePiece
255
+ if next_token == 2 or next_token == 3:
256
  break
257
 
258
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  return generated
261
 
 
287
 
288
  # Cargar tokenizador
289
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
290
+ if not os.path.exists(tokenizer_path):
291
+ print(f"❌ Tokenizador no encontrado en {tokenizer_path}")
292
+ sys.exit(1)
293
+
294
  sp = spm.SentencePieceProcessor()
295
  sp.load(tokenizer_path)
296
  VOCAB_SIZE = sp.get_piece_size()
 
311
  model_path = os.path.join(repo_path, "mtp_model.pt")
312
  if os.path.exists(model_path):
313
  state_dict = torch.load(model_path, map_location=DEVICE)
314
+ model.load_state_dict(state_dict, strict=False)
315
  print("✅ Pesos del modelo cargados")
316
  else:
317
+ print(f"⚠️ No se encontró {model_path}, usando pesos aleatorios")
318
 
319
  model.eval()
320
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  param_count = sum(p.numel() for p in model.parameters())
322
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
323
 
 
339
 
340
  class PromptRequest(BaseModel):
341
  text: str = Field(..., max_length=2000, description="Texto de entrada")
342
+ max_tokens: int = Field(default=100, ge=10, le=200, description="Tokens máximos a generar")
343
  temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
344
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
345
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
 
382
  global ACTIVE_REQUESTS
383
  ACTIVE_REQUESTS += 1
384
 
 
 
 
 
 
 
 
 
385
  user_input = req.text.strip()
386
  if not user_input:
387
  ACTIVE_REQUESTS -= 1
388
  return {"reply": "", "tokens_generated": 0}
389
+
390
+ # Detectar si es un saludo
391
+ is_greeting = user_input.lower().strip() in ["hola", "hola!", "hola.", "buenas", "saludos", "hola?"]
392
+
393
+ # Si es saludo, usar menos tokens
394
+ max_tokens = 30 if is_greeting else req.max_tokens
395
 
396
  full_prompt = build_prompt(user_input)
397
+ tokens = tokenizer_wrapper.encode(full_prompt)
398
  input_ids = torch.tensor([tokens], device=DEVICE)
 
 
 
 
 
 
 
399
 
400
  try:
401
  with torch.no_grad():
402
  output_ids = model.generate(
403
  input_ids,
404
+ max_new_tokens=max_tokens,
405
+ temperature=req.temperature,
 
406
  top_k=req.top_k,
407
  top_p=req.top_p,
408
  repetition_penalty=req.repetition_penalty
 
410
 
411
  gen_tokens = output_ids[0, len(tokens):].tolist()
412
 
413
+ # Filtrar tokens inválidos
414
+ safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE]
 
 
 
 
415
 
416
+ if safe_tokens:
417
+ response = tokenizer_wrapper.decode(safe_tokens).strip()
418
+ else:
419
+ response = ""
420
 
421
+ # Limpiar respuesta
422
  response = clean_response(response, user_input)
423
 
424
+ # Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
425
+ if len(response) < 3:
426
+ if is_greeting:
427
+ response = "¡Hola! ¿En qué puedo ayudarte?"
 
 
428
  else:
429
+ response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?"
 
 
430
 
431
  return {
432
  "reply": response,
 
436
 
437
  except Exception as e:
438
  print(f"❌ Error durante generación: {e}")
439
+ # Respuesta de fallback
440
+ if is_greeting:
441
+ fallback = "¡Hola! ¿Cómo estás? Estoy aquí para ayudarte."
442
+ else:
443
+ fallback = "Lo siento, ocurrió un error al procesar tu solicitud."
444
  return {
445
+ "reply": fallback,
446
  "error": str(e)
447
  }
448
 
 
476
  }
477
 
478
  # ======================
479
+ # INTERFAZ WEB
480
  # ======================
481
  @app.get("/", response_class=HTMLResponse)
482
  def chat_ui():
 
485
  <html lang="es">
486
  <head>
487
  <meta charset="UTF-8">
488
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
489
  <title>MTP - Asistente IA</title>
 
 
 
490
  <style>
491
+ * { margin: 0; padding: 0; box-sizing: border-box; }
 
 
 
 
 
 
 
 
492
  body {
493
+ background: #131314;
494
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
495
+ height: 100vh;
 
 
496
  display: flex;
497
  flex-direction: column;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  }
499
+ .chat-header {
500
+ padding: 16px 20px;
501
+ background: #1E1F20;
502
+ border-bottom: 1px solid #2a2b2e;
 
 
 
 
 
503
  }
504
+ .chat-header h1 {
505
+ color: white;
506
+ font-size: 1.2rem;
507
  font-weight: 500;
 
 
 
 
 
 
 
 
 
 
 
 
508
  }
509
+ .chat-messages {
510
  flex: 1;
511
  overflow-y: auto;
512
+ padding: 20px;
513
  display: flex;
514
  flex-direction: column;
515
+ gap: 16px;
 
 
 
 
516
  }
517
+ .message {
518
  display: flex;
519
+ gap: 12px;
520
+ max-width: 80%;
 
 
 
521
  }
522
+ .message.user {
523
+ align-self: flex-end;
524
+ flex-direction: row-reverse;
 
 
 
 
525
  }
526
+ .message-content {
527
+ padding: 10px 16px;
 
528
  border-radius: 18px;
529
+ font-size: 0.95rem;
530
+ line-height: 1.4;
531
  }
532
+ .user .message-content {
533
+ background: #4a9eff;
534
+ color: white;
535
+ border-radius: 18px 4px 18px 18px;
 
536
  }
537
+ .bot .message-content {
538
+ background: #1E1F20;
539
+ color: #e3e3e3;
540
+ border-radius: 4px 18px 18px 18px;
541
  }
542
+ .chat-input-container {
543
+ padding: 16px 20px;
544
+ background: #1E1F20;
545
+ border-top: 1px solid #2a2b2e;
 
 
 
 
 
 
546
  }
547
+ .input-wrapper {
548
  display: flex;
549
+ gap: 12px;
550
+ max-width: 800px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  margin: 0 auto;
 
 
 
 
 
 
 
 
 
 
 
552
  }
553
+ #messageInput {
554
  flex: 1;
555
+ padding: 12px 16px;
556
+ background: #2a2b2e;
557
  border: none;
558
+ border-radius: 24px;
559
  color: white;
560
+ font-size: 0.95rem;
561
+ outline: none;
 
562
  }
563
+ #messageInput::placeholder {
564
+ color: #888;
565
+ }
566
+ #sendBtn {
567
+ padding: 12px 24px;
568
+ background: #4a9eff;
569
  border: none;
570
+ border-radius: 24px;
571
+ color: white;
572
+ font-weight: 500;
 
 
 
573
  cursor: pointer;
574
+ transition: opacity 0.2s;
 
575
  }
576
+ #sendBtn:hover { opacity: 0.9; }
577
+ #sendBtn:disabled {
578
+ opacity: 0.5;
579
+ cursor: not-allowed;
 
 
580
  }
581
+ .typing {
582
+ display: flex;
583
+ gap: 4px;
584
+ padding: 10px 16px;
585
+ }
586
+ .typing span {
587
+ width: 8px;
588
+ height: 8px;
589
+ background: #888;
590
+ border-radius: 50%;
591
+ animation: bounce 1.4s infinite ease-in-out;
592
  }
593
+ .typing span:nth-child(1) { animation-delay: -0.32s; }
594
+ .typing span:nth-child(2) { animation-delay: -0.16s; }
595
+ @keyframes bounce {
596
+ 0%, 80%, 100% { transform: scale(0); }
597
+ 40% { transform: scale(1); }
598
  }
 
 
 
 
599
  </style>
600
  </head>
601
  <body>
602
+ <div class="chat-header">
603
+ <h1>🤖 MTP - Asistente IA</h1>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  </div>
605
+ <div class="chat-messages" id="chatMessages">
606
+ <div class="message bot">
607
+ <div class="message-content">¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?</div>
 
608
  </div>
609
+ </div>
610
+ <div class="chat-input-container">
611
+ <div class="input-wrapper">
612
+ <input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
613
+ <button id="sendBtn">Enviar</button>
614
  </div>
615
  </div>
616
  <script>
617
+ const chatMessages = document.getElementById('chatMessages');
618
+ const messageInput = document.getElementById('messageInput');
619
+ const sendBtn = document.getElementById('sendBtn');
620
+ let isLoading = false;
 
 
 
 
 
 
 
621
 
622
+ function addMessage(text, isUser) {
623
+ const div = document.createElement('div');
624
+ div.className = `message ${isUser ? 'user' : 'bot'}`;
625
+ div.innerHTML = `<div class="message-content">${text}</div>`;
626
+ chatMessages.appendChild(div);
627
+ chatMessages.scrollTop = chatMessages.scrollHeight;
628
+ return div;
 
 
629
  }
630
 
631
+ function addTypingIndicator() {
632
+ const div = document.createElement('div');
633
+ div.className = 'message bot';
634
+ div.id = 'typingIndicator';
635
+ div.innerHTML = `<div class="typing"><span></span><span></span><span></span></div>`;
636
+ chatMessages.appendChild(div);
637
+ chatMessages.scrollTop = chatMessages.scrollHeight;
638
  }
639
 
640
+ function removeTypingIndicator() {
641
+ const indicator = document.getElementById('typingIndicator');
642
+ if (indicator) indicator.remove();
 
 
 
 
 
 
643
  }
644
 
645
+ async function sendMessage() {
646
+ const text = messageInput.value.trim();
647
+ if (!text || isLoading) return;
648
+
649
+ messageInput.value = '';
650
+ addMessage(text, true);
651
+ isLoading = true;
652
+ sendBtn.disabled = true;
653
+ addTypingIndicator();
654
+
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  try {
656
  const response = await fetch('/generate', {
657
  method: 'POST',
658
  headers: { 'Content-Type': 'application/json' },
659
+ body: JSON.stringify({ text: text })
 
660
  });
661
  const data = await response.json();
662
+ removeTypingIndicator();
663
+ addMessage(data.reply, false);
 
 
 
 
 
 
664
  } catch (error) {
665
+ removeTypingIndicator();
666
+ addMessage('Error de conexión. Intenta de nuevo.', false);
667
+ } finally {
668
+ isLoading = false;
669
+ sendBtn.disabled = false;
670
+ messageInput.focus();
 
 
671
  }
672
  }
673
 
674
+ messageInput.addEventListener('keypress', (e) => {
675
+ if (e.key === 'Enter') sendMessage();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  });
677
+ sendBtn.addEventListener('click', sendMessage);
678
+ messageInput.focus();
679
  </script>
680
  </body>
681
  </html>