Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import gc
|
|
|
|
| 7 |
from fastapi import FastAPI, Request
|
| 8 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -33,6 +34,97 @@ torch.set_grad_enabled(False)
|
|
| 33 |
# CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
|
| 34 |
MODEL_REPO = "TeszenAI/MTP-3" # <-- CAMBIA A TU REPO
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# ======================
|
| 37 |
# DEFINIR ARQUITECTURA DEL MODELO (MTP)
|
| 38 |
# ======================
|
|
@@ -144,11 +236,14 @@ class MTPModel(nn.Module):
|
|
| 144 |
logits = self.lm_head(x)
|
| 145 |
return logits
|
| 146 |
|
| 147 |
-
def generate(self, input_ids, max_new_tokens=
|
| 148 |
-
"""Método de generación
|
| 149 |
generated = input_ids
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
for
|
| 152 |
with torch.no_grad():
|
| 153 |
logits = self(generated)
|
| 154 |
next_logits = logits[0, -1, :] / temperature
|
|
@@ -177,6 +272,13 @@ class MTPModel(nn.Module):
|
|
| 177 |
break
|
| 178 |
|
| 179 |
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
return generated
|
| 182 |
|
|
@@ -350,6 +452,9 @@ async def generate(req: PromptRequest):
|
|
| 350 |
|
| 351 |
if "###" in response:
|
| 352 |
response = response.split("###")[0].strip()
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
return {
|
| 355 |
"reply": response,
|
|
@@ -394,7 +499,7 @@ def model_info():
|
|
| 394 |
}
|
| 395 |
|
| 396 |
# ======================
|
| 397 |
-
# INTERFAZ WEB (MODERNA)
|
| 398 |
# ======================
|
| 399 |
@app.get("/", response_class=HTMLResponse)
|
| 400 |
def chat_ui():
|
|
@@ -451,12 +556,11 @@ header {
|
|
| 451 |
width: 32px;
|
| 452 |
height: 32px;
|
| 453 |
border-radius: 50%;
|
| 454 |
-
background:
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
font-size: 14px;
|
| 460 |
}
|
| 461 |
.brand-text {
|
| 462 |
font-weight: 500;
|
|
@@ -523,12 +627,10 @@ header {
|
|
| 523 |
height: 34px;
|
| 524 |
min-width: 34px;
|
| 525 |
border-radius: 50%;
|
| 526 |
-
background:
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
font-weight: bold;
|
| 531 |
-
font-size: 14px;
|
| 532 |
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
|
| 533 |
}
|
| 534 |
.bot-actions {
|
|
@@ -630,7 +732,7 @@ header {
|
|
| 630 |
<body>
|
| 631 |
<header>
|
| 632 |
<div class="brand-wrapper" onclick="location.reload()">
|
| 633 |
-
<div class="brand-logo">
|
| 634 |
<div class="brand-text">
|
| 635 |
MTP <span class="version-badge">v1</span>
|
| 636 |
</div>
|
|
@@ -638,7 +740,7 @@ header {
|
|
| 638 |
</header>
|
| 639 |
<div id="chatScroll" class="chat-scroll">
|
| 640 |
<div class="msg-row bot" style="animation-delay: 0.1s;">
|
| 641 |
-
<div class="bot-avatar">
|
| 642 |
<div class="msg-content-wrapper">
|
| 643 |
<div class="msg-text">
|
| 644 |
¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
|
|
@@ -712,7 +814,6 @@ async function sendMessage(textOverride = null) {
|
|
| 712 |
botRow.className = 'msg-row bot';
|
| 713 |
const avatar = document.createElement('div');
|
| 714 |
avatar.className = 'bot-avatar pulsing';
|
| 715 |
-
avatar.textContent = 'M';
|
| 716 |
const wrapper = document.createElement('div');
|
| 717 |
wrapper.className = 'msg-content-wrapper';
|
| 718 |
const msgText = document.createElement('div');
|
|
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import gc
|
| 7 |
+
import re
|
| 8 |
from fastapi import FastAPI, Request
|
| 9 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 34 |
# CAMBIA ESTO POR EL NOMBRE DE TU REPO EN HUGGING FACE
|
| 35 |
MODEL_REPO = "TeszenAI/MTP-3" # <-- CAMBIA A TU REPO
|
| 36 |
|
| 37 |
+
# ======================
|
| 38 |
+
# FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
|
| 39 |
+
# ======================
|
| 40 |
+
|
| 41 |
+
def clean_response(text: str) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Limpia la respuesta eliminando repeticiones, frases sin sentido y
|
| 44 |
+
asegurando que termine correctamente.
|
| 45 |
+
"""
|
| 46 |
+
if not text:
|
| 47 |
+
return ""
|
| 48 |
+
|
| 49 |
+
# 1. Eliminar repeticiones excesivas de palabras o frases cortas
|
| 50 |
+
words = text.split()
|
| 51 |
+
cleaned_words = []
|
| 52 |
+
last_phrase = ""
|
| 53 |
+
repeat_count = 0
|
| 54 |
+
|
| 55 |
+
for word in words:
|
| 56 |
+
if word == last_phrase:
|
| 57 |
+
repeat_count += 1
|
| 58 |
+
if repeat_count > 2: # Si repite más de 2 veces seguidas
|
| 59 |
+
continue
|
| 60 |
+
else:
|
| 61 |
+
last_phrase = word
|
| 62 |
+
repeat_count = 0
|
| 63 |
+
cleaned_words.append(word)
|
| 64 |
+
|
| 65 |
+
text = " ".join(cleaned_words)
|
| 66 |
+
|
| 67 |
+
# 2. Eliminar patrones sin sentido (repeticiones de letras, caracteres raros)
|
| 68 |
+
text = re.sub(r'(.)\1{4,}', r'\1\1', text) # aaa... -> aa
|
| 69 |
+
text = re.sub(r'[^a-zA-ZáéíóúñüÁÉÍÓÚÑÜ0-9\s.,;:!?¿¡()\-"]+', '', text)
|
| 70 |
+
|
| 71 |
+
# 3. Cortar en la primera frase que parezca final coherente
|
| 72 |
+
stop_patterns = [
|
| 73 |
+
r'(\.\s*)$', # Punto final
|
| 74 |
+
r'[.!?](\s+)?$', # Fin de oración
|
| 75 |
+
r'(gracias|hasta luego|adiós|saludos|fin|fin del mensaje)$',
|
| 76 |
+
r'(¿algo más\?|¿necesitas algo más\?|¿en qué más puedo ayudarte\?)'
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
for pattern in stop_patterns:
|
| 80 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
| 81 |
+
if match:
|
| 82 |
+
# Cortar justo después del patrón de finalización
|
| 83 |
+
end_pos = match.end()
|
| 84 |
+
text = text[:end_pos]
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
# 4. Si la respuesta es muy corta o vacía, devolver mensaje por defecto
|
| 88 |
+
if len(text.strip()) < 10:
|
| 89 |
+
return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
|
| 90 |
+
|
| 91 |
+
# 5. Eliminar espacios múltiples y saltos de línea excesivos
|
| 92 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 93 |
+
|
| 94 |
+
return text
|
| 95 |
+
|
| 96 |
+
def should_stop_generation(generated_text: str, min_length: int = 30, max_length: int = 300) -> bool:
|
| 97 |
+
"""
|
| 98 |
+
Determina si debemos detener la generación basado en el texto generado.
|
| 99 |
+
"""
|
| 100 |
+
# Si ya superamos la longitud máxima
|
| 101 |
+
if len(generated_text) > max_length:
|
| 102 |
+
return True
|
| 103 |
+
|
| 104 |
+
# Si es muy corto y no hay puntuación final
|
| 105 |
+
if len(generated_text) < min_length and not re.search(r'[.!?]$', generated_text):
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
# Señales de que ya terminó la respuesta
|
| 109 |
+
stop_signals = [
|
| 110 |
+
r'(gracias por tu pregunta|espero haberte ayudado|¿necesitas algo más\?)',
|
| 111 |
+
r'(hasta luego|adiós|quedo atento|saludos cordiales)',
|
| 112 |
+
r'(fin del mensaje|fin de la conversación)'
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
for signal in stop_signals:
|
| 116 |
+
if re.search(signal, generated_text, re.IGNORECASE):
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
# Si la última frase parece completa
|
| 120 |
+
last_sentence = generated_text.split('.')[-1].strip()
|
| 121 |
+
if len(last_sentence) > 5 and re.search(r'[.!?]$', last_sentence):
|
| 122 |
+
# Y ya hemos generado suficiente contenido
|
| 123 |
+
if len(generated_text) > min_length:
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
# ======================
|
| 129 |
# DEFINIR ARQUITECTURA DEL MODELO (MTP)
|
| 130 |
# ======================
|
|
|
|
| 236 |
logits = self.lm_head(x)
|
| 237 |
return logits
|
| 238 |
|
| 239 |
+
def generate(self, input_ids, max_new_tokens=150, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
|
| 240 |
+
"""Método de generación mejorado con detección inteligente de fin"""
|
| 241 |
generated = input_ids
|
| 242 |
+
generated_text = ""
|
| 243 |
+
min_response_length = 30
|
| 244 |
+
max_response_length = max_new_tokens * 2
|
| 245 |
|
| 246 |
+
for step in range(max_new_tokens):
|
| 247 |
with torch.no_grad():
|
| 248 |
logits = self(generated)
|
| 249 |
next_logits = logits[0, -1, :] / temperature
|
|
|
|
| 272 |
break
|
| 273 |
|
| 274 |
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
|
| 275 |
+
|
| 276 |
+
# Decodificar parcialmente para verificar si debemos parar (solo cada 10 pasos para eficiencia)
|
| 277 |
+
if step > 10 and step % 10 == 0:
|
| 278 |
+
# Intentar decodificar tokens generados (esto es aproximado, el tokenizador real está fuera)
|
| 279 |
+
if len(generated[0]) > 10:
|
| 280 |
+
if should_stop_generation(str(generated[0].tolist()), min_response_length, max_response_length):
|
| 281 |
+
break
|
| 282 |
|
| 283 |
return generated
|
| 284 |
|
|
|
|
| 452 |
|
| 453 |
if "###" in response:
|
| 454 |
response = response.split("###")[0].strip()
|
| 455 |
+
|
| 456 |
+
# Aplicar limpieza inteligente a la respuesta
|
| 457 |
+
response = clean_response(response)
|
| 458 |
|
| 459 |
return {
|
| 460 |
"reply": response,
|
|
|
|
| 499 |
}
|
| 500 |
|
| 501 |
# ======================
|
| 502 |
+
# INTERFAZ WEB (MODERNA CON LOGO INTEGRADO)
|
| 503 |
# ======================
|
| 504 |
@app.get("/", response_class=HTMLResponse)
|
| 505 |
def chat_ui():
|
|
|
|
| 556 |
width: 32px;
|
| 557 |
height: 32px;
|
| 558 |
border-radius: 50%;
|
| 559 |
+
background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
|
| 560 |
+
background-size: cover;
|
| 561 |
+
background-position: center;
|
| 562 |
+
background-repeat: no-repeat;
|
| 563 |
+
border: 1px solid rgba(255,255,255,0.1);
|
|
|
|
| 564 |
}
|
| 565 |
.brand-text {
|
| 566 |
font-weight: 500;
|
|
|
|
| 627 |
height: 34px;
|
| 628 |
min-width: 34px;
|
| 629 |
border-radius: 50%;
|
| 630 |
+
background-image: url('https://i.postimg.cc/c4BRhSnR/8F838209-6DD9-4E1C-96BB-621EC3B78E68.png');
|
| 631 |
+
background-size: cover;
|
| 632 |
+
background-position: center;
|
| 633 |
+
background-repeat: no-repeat;
|
|
|
|
|
|
|
| 634 |
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
|
| 635 |
}
|
| 636 |
.bot-actions {
|
|
|
|
| 732 |
<body>
|
| 733 |
<header>
|
| 734 |
<div class="brand-wrapper" onclick="location.reload()">
|
| 735 |
+
<div class="brand-logo"></div>
|
| 736 |
<div class="brand-text">
|
| 737 |
MTP <span class="version-badge">v1</span>
|
| 738 |
</div>
|
|
|
|
| 740 |
</header>
|
| 741 |
<div id="chatScroll" class="chat-scroll">
|
| 742 |
<div class="msg-row bot" style="animation-delay: 0.1s;">
|
| 743 |
+
<div class="bot-avatar"></div>
|
| 744 |
<div class="msg-content-wrapper">
|
| 745 |
<div class="msg-text">
|
| 746 |
¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte hoy?
|
|
|
|
| 814 |
botRow.className = 'msg-row bot';
|
| 815 |
const avatar = document.createElement('div');
|
| 816 |
avatar.className = 'bot-avatar pulsing';
|
|
|
|
| 817 |
const wrapper = document.createElement('div');
|
| 818 |
wrapper.className = 'msg-content-wrapper';
|
| 819 |
const msgText = document.createElement('div');
|