teszenofficial commited on
Commit
8cd8259
·
verified ·
1 Parent(s): 470d2fa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -19
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import json
5
  import time
6
  import gc
 
7
  from fastapi import FastAPI, Request
8
  from fastapi.responses import HTMLResponse, StreamingResponse
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -33,6 +34,97 @@ torch.set_grad_enabled(False)
33
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
34
  MODEL_REPO = "TeszenAI/MTP-3" # <-- CAMBIA A TU REPO
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # ======================
37
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
38
  # ======================
@@ -144,11 +236,14 @@ class MTPModel(nn.Module):
144
  logits = self.lm_head(x)
145
  return logits
146
 
147
- def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
148
- """Método de generación compatible con la interfaz"""
149
  generated = input_ids
 
 
 
150
 
151
- for _ in range(max_new_tokens):
152
  with torch.no_grad():
153
  logits = self(generated)
154
  next_logits = logits[0, -1, :] / temperature
@@ -177,6 +272,13 @@ class MTPModel(nn.Module):
177
  break
178
 
179
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
 
 
 
 
 
 
 
180
 
181
  return generated
182
 
@@ -350,6 +452,9 @@ async def generate(req: PromptRequest):
350
 
351
  if "###" in response:
352
  response = response.split("###")[0].strip()
 
 
 
353
 
354
  return {
355
  "reply": response,
@@ -394,7 +499,7 @@ def model_info():
394
  }
395
 
396
  # ======================
397
- # INTERFAZ WEB (MODERNA)
398
  # ======================
399
  @app.get("/", response_class=HTMLResponse)
400
  def chat_ui():
@@ -451,12 +556,11 @@ header {
451
  width: 32px;
452
  height: 32px;
453
  border-radius: 50%;
454
- background: linear-gradient(135deg, #4a9eff, #7c3aed);
455
- display: flex;
456
- align-items: center;
457
- justify-content: center;
458
- font-weight: bold;
459
- font-size: 14px;
460
  }
461
  .brand-text {
462
  font-weight: 500;
@@ -523,12 +627,10 @@ header {
523
  height: 34px;
524
  min-width: 34px;
525
  border-radius: 50%;
526
- background: linear-gradient(135deg, #4a9eff, #7c3aed);
527
- display: flex;
528
- align-items: center;
529
- justify-content: center;
530
- font-weight: bold;
531
- font-size: 14px;
532
  box-shadow: 0 2px 6px rgba(0,0,0,0.2);
533
  }
534
  .bot-actions {
@@ -630,7 +732,7 @@ header {
630
  <body>
631
  <header>
632
  <div class="brand-wrapper" onclick="location.reload()">
633
- <div class="brand-logo">MTP</div>
634
  <div class="brand-text">
635
  MTP <span class="version-badge">v1</span>
636
  </div>
@@ -638,7 +740,7 @@ header {
638
  </header>
639
  <div id="chatScroll" class="chat-scroll">
640
  <div class="msg-row bot" style="animation-delay: 0.1s;">
641
- <div class="bot-avatar">M</div>
642
  <div class="msg-content-wrapper">
643
  <div class="msg-text">
644
  ¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
@@ -712,7 +814,6 @@ async function sendMessage(textOverride = null) {
712
  botRow.className = 'msg-row bot';
713
  const avatar = document.createElement('div');
714
  avatar.className = 'bot-avatar pulsing';
715
- avatar.textContent = 'M';
716
  const wrapper = document.createElement('div');
717
  wrapper.className = 'msg-content-wrapper';
718
  const msgText = document.createElement('div');
 
4
  import json
5
  import time
6
  import gc
7
+ import re
8
  from fastapi import FastAPI, Request
9
  from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
 
34
  # CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
35
  MODEL_REPO = "TeszenAI/MTP-3" # <-- CAMBIA A TU REPO
36
 
37
+ # ======================
38
+ # FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
39
+ # ======================
40
+
41
+ def clean_response(text: str) -> str:
42
+ """
43
+ Limpia la respuesta eliminando repeticiones, frases sin sentido y
44
+ asegurando que termine correctamente.
45
+ """
46
+ if not text:
47
+ return ""
48
+
49
+ # 1. Eliminar repeticiones excesivas de palabras o frases cortas
50
+ words = text.split()
51
+ cleaned_words = []
52
+ last_phrase = ""
53
+ repeat_count = 0
54
+
55
+ for word in words:
56
+ if word == last_phrase:
57
+ repeat_count += 1
58
+ if repeat_count > 2: # Si repite más de 2 veces seguidas
59
+ continue
60
+ else:
61
+ last_phrase = word
62
+ repeat_count = 0
63
+ cleaned_words.append(word)
64
+
65
+ text = " ".join(cleaned_words)
66
+
67
+ # 2. Eliminar patrones sin sentido (repeticiones de letras, caracteres raros)
68
+ text = re.sub(r'(.)\1{4,}', r'\1\1', text) # aaa... -> aa
69
+ text = re.sub(r'[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ0-9\s.,;:!?¿¡()\-"]+', '', text)
70
+
71
+ # 3. Cortar en la primera frase que parezca final coherente
72
+ stop_patterns = [
73
+ r'(\.\s*)$', # Punto final
74
+ r'[.!?](\s+)?$', # Fin de oración
75
+ r'(gracias|hasta luego|adiós|saludos|fin|fin del mensaje)$',
76
+ r'(¿algo más\?|¿necesitas algo más\?|¿en qué más puedo ayudarte\?)'
77
+ ]
78
+
79
+ for pattern in stop_patterns:
80
+ match = re.search(pattern, text, re.IGNORECASE)
81
+ if match:
82
+ # Cortar justo después del patrón de finalización
83
+ end_pos = match.end()
84
+ text = text[:end_pos]
85
+ break
86
+
87
+ # 4. Si la respuesta es muy corta o vacía, devolver mensaje por defecto
88
+ if len(text.strip()) < 10:
89
+ return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
90
+
91
+ # 5. Eliminar espacios múltiples y saltos de línea excesivos
92
+ text = re.sub(r'\s+', ' ', text).strip()
93
+
94
+ return text
95
+
96
+ def should_stop_generation(generated_text: str, min_length: int = 30, max_length: int = 300) -> bool:
97
+ """
98
+ Determina si debemos detener la generación basado en el texto generado.
99
+ """
100
+ # Si ya superamos la longitud máxima
101
+ if len(generated_text) > max_length:
102
+ return True
103
+
104
+ # Si es muy corto y no hay puntuación final
105
+ if len(generated_text) < min_length and not re.search(r'[.!?]$', generated_text):
106
+ return False
107
+
108
+ # Señales de que ya terminó la respuesta
109
+ stop_signals = [
110
+ r'(gracias por tu pregunta|espero haberte ayudado|¿necesitas algo más\?)',
111
+ r'(hasta luego|adiós|quedo atento|saludos cordiales)',
112
+ r'(fin del mensaje|fin de la conversación)'
113
+ ]
114
+
115
+ for signal in stop_signals:
116
+ if re.search(signal, generated_text, re.IGNORECASE):
117
+ return True
118
+
119
+ # Si la última frase parece completa
120
+ last_sentence = generated_text.split('.')[-1].strip()
121
+ if len(last_sentence) > 5 and re.search(r'[.!?]$', last_sentence):
122
+ # Y ya hemos generado suficiente contenido
123
+ if len(generated_text) > min_length:
124
+ return True
125
+
126
+ return False
127
+
128
  # ======================
129
  # DEFINIR ARQUITECTURA DEL MODELO (MTP)
130
  # ======================
 
236
  logits = self.lm_head(x)
237
  return logits
238
 
239
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
240
+ """Método de generación mejorado con detección inteligente de fin"""
241
  generated = input_ids
242
+ generated_text = ""
243
+ min_response_length = 30
244
+ max_response_length = max_new_tokens * 2
245
 
246
+ for step in range(max_new_tokens):
247
  with torch.no_grad():
248
  logits = self(generated)
249
  next_logits = logits[0, -1, :] / temperature
 
272
  break
273
 
274
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
275
+
276
+ # Decodificar parcialmente para verificar si debemos parar (solo cada 10 pasos para eficiencia)
277
+ if step > 10 and step % 10 == 0:
278
+ # Intentar decodificar tokens generados (esto es aproximado, el tokenizador real está fuera)
279
+ if len(generated[0]) > 10:
280
+ if should_stop_generation(str(generated[0].tolist()), min_response_length, max_response_length):
281
+ break
282
 
283
  return generated
284
 
 
452
 
453
  if "###" in response:
454
  response = response.split("###")[0].strip()
455
+
456
+ # Aplicar limpieza inteligente a la respuesta
457
+ response = clean_response(response)
458
 
459
  return {
460
  "reply": response,
 
499
  }
500
 
501
  # ======================
502
+ # INTERFAZ WEB (MODERNA CON LOGO INTEGRADO)
503
  # ======================
504
  @app.get("/", response_class=HTMLResponse)
505
  def chat_ui():
 
556
  width: 32px;
557
  height: 32px;
558
  border-radius: 50%;
559
+ background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
560
+ background-size: cover;
561
+ background-position: center;
562
+ background-repeat: no-repeat;
563
+ border: 1px solid rgba(255,255,255,0.1);
 
564
  }
565
  .brand-text {
566
  font-weight: 500;
 
627
  height: 34px;
628
  min-width: 34px;
629
  border-radius: 50%;
630
+ background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
631
+ background-size: cover;
632
+ background-position: center;
633
+ background-repeat: no-repeat;
 
 
634
  box-shadow: 0 2px 6px rgba(0,0,0,0.2);
635
  }
636
  .bot-actions {
 
732
  <body>
733
  <header>
734
  <div class="brand-wrapper" onclick="location.reload()">
735
+ <div class="brand-logo"></div>
736
  <div class="brand-text">
737
  MTP <span class="version-badge">v1</span>
738
  </div>
 
740
  </header>
741
  <div id="chatScroll" class="chat-scroll">
742
  <div class="msg-row bot" style="animation-delay: 0.1s;">
743
+ <div class="bot-avatar"></div>
744
  <div class="msg-content-wrapper">
745
  <div class="msg-text">
746
  ¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
 
814
  botRow.className = 'msg-row bot';
815
  const avatar = document.createElement('div');
816
  avatar.className = 'bot-avatar pulsing';
 
817
  const wrapper = document.createElement('div');
818
  wrapper.className = 'msg-content-wrapper';
819
  const msgText = document.createElement('div');