Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,84 +35,147 @@ torch.set_grad_enabled(False)
|
|
| 35 |
MODEL_REPO = "TeszenAI/MTP-3.1.1"
|
| 36 |
|
| 37 |
# ======================
|
| 38 |
-
# FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD
|
| 39 |
# ======================
|
| 40 |
|
| 41 |
-
def
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
return text
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
return truncated
|
| 57 |
|
| 58 |
-
|
| 59 |
-
if len(text) > 80:
|
| 60 |
-
return text[:80] + "..."
|
| 61 |
-
return text
|
| 62 |
|
| 63 |
def clean_response(text: str, user_input: str = "") -> str:
|
| 64 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 65 |
if not text:
|
| 66 |
return ""
|
| 67 |
|
| 68 |
-
# Eliminar repeticiones excesivas
|
| 69 |
words = text.split()
|
| 70 |
cleaned_words = []
|
| 71 |
last_word = ""
|
| 72 |
repeat_count = 0
|
|
|
|
| 73 |
|
| 74 |
for word in words:
|
|
|
|
| 75 |
if word == last_word:
|
| 76 |
repeat_count += 1
|
| 77 |
-
if repeat_count >
|
| 78 |
continue
|
| 79 |
else:
|
| 80 |
last_word = word
|
| 81 |
repeat_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
cleaned_words.append(word)
|
| 83 |
|
| 84 |
text = " ".join(cleaned_words)
|
| 85 |
|
| 86 |
-
# Eliminar caracteres raros
|
| 87 |
-
text = re.sub(r'(.)\1{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
else:
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
text = first_sentence
|
| 102 |
-
elif len(text) > 60:
|
| 103 |
-
text = text[:60]
|
| 104 |
-
|
| 105 |
-
# Si la respuesta es muy corta o vacía
|
| 106 |
-
if len(text.strip()) < 5:
|
| 107 |
-
if is_greeting:
|
| 108 |
-
return "¡Hola! ¿En qué puedo ayudarte?"
|
| 109 |
-
return "Lo siento, no pude generar una respuesta clara. ¿Podrías reformular tu pregunta?"
|
| 110 |
-
|
| 111 |
-
# Eliminar espacios múltiples
|
| 112 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
return text
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# ======================
|
| 117 |
# DEFINIR ARQUITECTURA DEL MODELO (MTP)
|
| 118 |
# ======================
|
|
@@ -224,8 +287,12 @@ class MTPModel(nn.Module):
|
|
| 224 |
logits = self.lm_head(x)
|
| 225 |
return logits
|
| 226 |
|
| 227 |
-
def generate(self, input_ids, max_new_tokens=150, temperature=0.
|
| 228 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
generated = input_ids
|
| 230 |
|
| 231 |
for step in range(max_new_tokens):
|
|
@@ -233,14 +300,17 @@ class MTPModel(nn.Module):
|
|
| 233 |
logits = self(generated)
|
| 234 |
next_logits = logits[0, -1, :] / temperature
|
| 235 |
|
|
|
|
| 236 |
if repetition_penalty != 1.0:
|
| 237 |
for token_id in set(generated[0].tolist()):
|
| 238 |
next_logits[token_id] /= repetition_penalty
|
| 239 |
|
|
|
|
| 240 |
if top_k > 0:
|
| 241 |
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
| 242 |
next_logits[indices_to_remove] = float('-inf')
|
| 243 |
|
|
|
|
| 244 |
if top_p < 1.0:
|
| 245 |
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 246 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
@@ -328,8 +398,8 @@ print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)
|
|
| 328 |
# ======================
|
| 329 |
app = FastAPI(
|
| 330 |
title="MTP API",
|
| 331 |
-
description="API para modelo de lenguaje MTP",
|
| 332 |
-
version="1.
|
| 333 |
)
|
| 334 |
|
| 335 |
app.add_middleware(
|
|
@@ -341,15 +411,11 @@ app.add_middleware(
|
|
| 341 |
|
| 342 |
class PromptRequest(BaseModel):
|
| 343 |
text: str = Field(..., max_length=2000, description="Texto de entrada")
|
| 344 |
-
max_tokens: int = Field(default=
|
| 345 |
-
temperature: float = Field(default=0.
|
| 346 |
-
top_k: int = Field(default=
|
| 347 |
-
top_p: float = Field(default=0.
|
| 348 |
-
repetition_penalty: float = Field(default=1.
|
| 349 |
-
|
| 350 |
-
def build_prompt(user_input: str) -> str:
|
| 351 |
-
"""Construye el prompt en el formato del modelo"""
|
| 352 |
-
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
| 353 |
|
| 354 |
# ======================
|
| 355 |
# GESTIÓN DE CARGA
|
|
@@ -380,7 +446,7 @@ tokenizer_wrapper = MTPTokenizer(sp)
|
|
| 380 |
|
| 381 |
@app.post("/generate")
|
| 382 |
async def generate(req: PromptRequest):
|
| 383 |
-
"""Endpoint principal de generación de texto"""
|
| 384 |
global ACTIVE_REQUESTS
|
| 385 |
ACTIVE_REQUESTS += 1
|
| 386 |
|
|
@@ -389,15 +455,41 @@ async def generate(req: PromptRequest):
|
|
| 389 |
ACTIVE_REQUESTS -= 1
|
| 390 |
return {"reply": "", "tokens_generated": 0}
|
| 391 |
|
| 392 |
-
# Detectar
|
| 393 |
-
|
| 394 |
|
| 395 |
-
#
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
tokens = tokenizer_wrapper.encode(full_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
try:
|
| 403 |
with torch.no_grad():
|
|
@@ -420,30 +512,28 @@ async def generate(req: PromptRequest):
|
|
| 420 |
else:
|
| 421 |
response = ""
|
| 422 |
|
| 423 |
-
# Limpiar respuesta
|
| 424 |
response = clean_response(response, user_input)
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
# Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
|
| 427 |
-
if len(response) <
|
| 428 |
-
|
| 429 |
-
response = "¡Hola! ¿En qué puedo ayudarte?"
|
| 430 |
-
else:
|
| 431 |
-
response = "Lo siento, no pude generar una respuesta. ¿Podrías reformular tu pregunta?"
|
| 432 |
|
| 433 |
return {
|
| 434 |
"reply": response,
|
| 435 |
"tokens_generated": len(safe_tokens),
|
| 436 |
-
"model": "MTP"
|
|
|
|
| 437 |
}
|
| 438 |
|
| 439 |
except Exception as e:
|
| 440 |
print(f"❌ Error durante generación: {e}")
|
| 441 |
-
if is_greeting:
|
| 442 |
-
fallback = "¡Hola! ¿En qué puedo ayudarte?"
|
| 443 |
-
else:
|
| 444 |
-
fallback = "Lo siento, ocurrió un error al procesar tu solicitud."
|
| 445 |
return {
|
| 446 |
-
"reply":
|
| 447 |
"error": str(e)
|
| 448 |
}
|
| 449 |
|
|
@@ -463,21 +553,23 @@ def health_check():
|
|
| 463 |
"model": "MTP",
|
| 464 |
"device": DEVICE,
|
| 465 |
"active_requests": ACTIVE_REQUESTS,
|
| 466 |
-
"vocab_size": VOCAB_SIZE
|
|
|
|
| 467 |
}
|
| 468 |
|
| 469 |
@app.get("/info")
|
| 470 |
def model_info():
|
| 471 |
return {
|
| 472 |
"model_name": "MTP",
|
| 473 |
-
"version": "1.
|
| 474 |
"architecture": config,
|
| 475 |
"parameters": sum(p.numel() for p in model.parameters()),
|
| 476 |
-
"device": DEVICE
|
|
|
|
| 477 |
}
|
| 478 |
|
| 479 |
# ======================
|
| 480 |
-
# INTERFAZ WEB
|
| 481 |
# ======================
|
| 482 |
@app.get("/", response_class=HTMLResponse)
|
| 483 |
def chat_ui():
|
|
@@ -487,7 +579,7 @@ def chat_ui():
|
|
| 487 |
<head>
|
| 488 |
<meta charset="UTF-8">
|
| 489 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 490 |
-
<title>MTP - Asistente IA</title>
|
| 491 |
<style>
|
| 492 |
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 493 |
body {
|
|
@@ -507,6 +599,11 @@ body {
|
|
| 507 |
font-size: 1.2rem;
|
| 508 |
font-weight: 500;
|
| 509 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
.chat-messages {
|
| 511 |
flex: 1;
|
| 512 |
overflow-y: auto;
|
|
@@ -597,15 +694,37 @@ body {
|
|
| 597 |
0%, 80%, 100% { transform: scale(0); }
|
| 598 |
40% { transform: scale(1); }
|
| 599 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
</style>
|
| 601 |
</head>
|
| 602 |
<body>
|
| 603 |
<div class="chat-header">
|
| 604 |
<h1>🤖 MTP - Asistente IA</h1>
|
|
|
|
| 605 |
</div>
|
| 606 |
<div class="chat-messages" id="chatMessages">
|
| 607 |
<div class="message bot">
|
| 608 |
-
<div class="message-content">¡Hola! Soy MTP, tu asistente de IA. ¿En qué puedo ayudarte
|
| 609 |
</div>
|
| 610 |
</div>
|
| 611 |
<div class="chat-input-container">
|
|
@@ -613,6 +732,12 @@ body {
|
|
| 613 |
<input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
|
| 614 |
<button id="sendBtn">Enviar</button>
|
| 615 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
</div>
|
| 617 |
<script>
|
| 618 |
const chatMessages = document.getElementById('chatMessages');
|
|
@@ -623,12 +748,18 @@ let isLoading = false;
|
|
| 623 |
function addMessage(text, isUser) {
|
| 624 |
const div = document.createElement('div');
|
| 625 |
div.className = `message ${isUser ? 'user' : 'bot'}`;
|
| 626 |
-
div.innerHTML = `<div class="message-content">${text}</div>`;
|
| 627 |
chatMessages.appendChild(div);
|
| 628 |
chatMessages.scrollTop = chatMessages.scrollHeight;
|
| 629 |
return div;
|
| 630 |
}
|
| 631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
function addTypingIndicator() {
|
| 633 |
const div = document.createElement('div');
|
| 634 |
div.className = 'message bot';
|
|
@@ -643,12 +774,12 @@ function removeTypingIndicator() {
|
|
| 643 |
if (indicator) indicator.remove();
|
| 644 |
}
|
| 645 |
|
| 646 |
-
async function sendMessage() {
|
| 647 |
-
const
|
| 648 |
-
if (!
|
| 649 |
|
| 650 |
-
messageInput.value = '';
|
| 651 |
-
addMessage(
|
| 652 |
isLoading = true;
|
| 653 |
sendBtn.disabled = true;
|
| 654 |
addTypingIndicator();
|
|
@@ -657,14 +788,14 @@ async function sendMessage() {
|
|
| 657 |
const response = await fetch('/generate', {
|
| 658 |
method: 'POST',
|
| 659 |
headers: { 'Content-Type': 'application/json' },
|
| 660 |
-
body: JSON.stringify({ text:
|
| 661 |
});
|
| 662 |
const data = await response.json();
|
| 663 |
removeTypingIndicator();
|
| 664 |
addMessage(data.reply, false);
|
| 665 |
} catch (error) {
|
| 666 |
removeTypingIndicator();
|
| 667 |
-
addMessage('Error de conexión.
|
| 668 |
} finally {
|
| 669 |
isLoading = false;
|
| 670 |
sendBtn.disabled = false;
|
|
@@ -675,7 +806,15 @@ async function sendMessage() {
|
|
| 675 |
messageInput.addEventListener('keypress', (e) => {
|
| 676 |
if (e.key === 'Enter') sendMessage();
|
| 677 |
});
|
| 678 |
-
sendBtn.addEventListener('click', sendMessage);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
messageInput.focus();
|
| 680 |
</script>
|
| 681 |
</body>
|
|
@@ -687,6 +826,11 @@ if __name__ == "__main__":
|
|
| 687 |
print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
|
| 688 |
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
|
| 689 |
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
uvicorn.run(
|
| 692 |
app,
|
|
|
|
| 35 |
MODEL_REPO = "TeszenAI/MTP-3.1.1"
|
| 36 |
|
| 37 |
# ======================
|
| 38 |
+
# FUNCIONES DE LIMPIEZA Y CONTROL DE CALIDAD (MEJORADAS)
|
| 39 |
# ======================
|
| 40 |
|
| 41 |
+
def is_greeting(user_input: str) -> bool:
|
| 42 |
+
"""Detecta si el mensaje es un saludo simple"""
|
| 43 |
+
greetings = ["hola", "hola!", "hola.", "buenas", "saludos", "hola?", "buen día", "buenas tardes", "buenas noches"]
|
| 44 |
+
return user_input.lower().strip() in greetings
|
| 45 |
+
|
| 46 |
+
def truncate_response(text: str, max_length: int = 300) -> str:
|
| 47 |
+
"""Trunca respuesta de forma limpia en oraciones completas"""
|
| 48 |
+
if not text or len(text) <= max_length:
|
| 49 |
return text
|
| 50 |
|
| 51 |
+
# Intentar truncar en el último punto dentro del límite
|
| 52 |
+
last_period = text[:max_length].rfind('.')
|
| 53 |
+
if last_period > max_length // 2:
|
| 54 |
+
return text[:last_period + 1].strip()
|
| 55 |
|
| 56 |
+
# Si no hay punto, truncar en espacio
|
| 57 |
+
last_space = text[:max_length].rfind(' ')
|
| 58 |
+
if last_space > max_length // 2:
|
| 59 |
+
return text[:last_space].strip() + "..."
|
|
|
|
| 60 |
|
| 61 |
+
return text[:max_length].strip() + "..."
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def clean_response(text: str, user_input: str = "") -> str:
|
| 64 |
+
"""
|
| 65 |
+
Limpia la respuesta del modelo de forma más agresiva
|
| 66 |
+
Elimina alucinaciones, repeticiones y caracteres raros
|
| 67 |
+
"""
|
| 68 |
if not text:
|
| 69 |
return ""
|
| 70 |
|
| 71 |
+
# Eliminar repeticiones excesivas de palabras (más agresivo)
|
| 72 |
words = text.split()
|
| 73 |
cleaned_words = []
|
| 74 |
last_word = ""
|
| 75 |
repeat_count = 0
|
| 76 |
+
last_two_words = []
|
| 77 |
|
| 78 |
for word in words:
|
| 79 |
+
# Detectar repeticiones inmediatas
|
| 80 |
if word == last_word:
|
| 81 |
repeat_count += 1
|
| 82 |
+
if repeat_count > 1: # Más estricto: solo permitir 1 repetición
|
| 83 |
continue
|
| 84 |
else:
|
| 85 |
last_word = word
|
| 86 |
repeat_count = 0
|
| 87 |
+
|
| 88 |
+
# Detectar patrones repetitivos de 2-3 palabras
|
| 89 |
+
last_two_words.append(word)
|
| 90 |
+
if len(last_two_words) > 3:
|
| 91 |
+
last_two_words.pop(0)
|
| 92 |
+
|
| 93 |
+
if len(last_two_words) >= 2:
|
| 94 |
+
# Si las últimas 2-3 palabras ya aparecieron antes, omitir
|
| 95 |
+
if len(cleaned_words) > len(last_two_words) * 2:
|
| 96 |
+
pattern = ' '.join(last_two_words)
|
| 97 |
+
text_so_far = ' '.join(cleaned_words[-len(last_two_words)*2:])
|
| 98 |
+
if pattern in text_so_far:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
cleaned_words.append(word)
|
| 102 |
|
| 103 |
text = " ".join(cleaned_words)
|
| 104 |
|
| 105 |
+
# Eliminar caracteres raros y patrones no deseados
|
| 106 |
+
text = re.sub(r'(.)\1{5,}', r'\1\1', text) # Caracteres repetidos más de 5 veces
|
| 107 |
+
text = re.sub(r'[^\w\s\.\,\!\?\-\'\¡\¿áéíóúñÑ]', '', text) # Caracteres especiales no deseados
|
| 108 |
+
|
| 109 |
+
# Eliminar URLs y menciones extrañas
|
| 110 |
+
text = re.sub(r'https?://\S+|www\.\S+', '[enlace]', text)
|
| 111 |
+
text = re.sub(r'@\w+', '', text)
|
| 112 |
+
|
| 113 |
+
# Eliminar secuencias numéricas largas
|
| 114 |
+
text = re.sub(r'\b\d{10,}\b', '', text)
|
| 115 |
+
|
| 116 |
+
# Para saludos, respuesta corta y precisa
|
| 117 |
+
if is_greeting(user_input):
|
| 118 |
+
# Respuesta de saludo estándar sin inventar
|
| 119 |
+
return "¡Hola! ¿En qué puedo ayudarte?"
|
| 120 |
+
|
| 121 |
+
# Si la respuesta es muy larga, truncar
|
| 122 |
+
if len(text) > 500:
|
| 123 |
+
text = truncate_response(text, 400)
|
| 124 |
+
|
| 125 |
+
# Eliminar frases de "auto-referencia" comunes que indican alucinación
|
| 126 |
+
hallucination_patterns = [
|
| 127 |
+
r'(?i)como modelo de lenguaje (?:IA|inteligencia artificial|AI)',
|
| 128 |
+
r'(?i)soy una (?:IA|inteligencia artificial)',
|
| 129 |
+
r'(?i)no tengo (?:emociones|sentimientos|conciencia)',
|
| 130 |
+
r'(?i)disculpa las molestias',
|
| 131 |
+
r'(?i)lo siento(?:,)? (?:no puedo|no sé|no entiendo)',
|
| 132 |
+
r'(?i)como (?:asistente|IA) virtual',
|
| 133 |
+
r'(?i)basado en mi (?:entrenamiento|conocimiento)',
|
| 134 |
+
r'(?i)no (?:tengo|poseo) (?:acceso|información)',
|
| 135 |
+
]
|
| 136 |
|
| 137 |
+
for pattern in hallucination_patterns:
|
| 138 |
+
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
|
| 139 |
|
| 140 |
+
# Si después de limpiar la respuesta es muy corta o vacía
|
| 141 |
+
if len(text.strip()) < 10:
|
| 142 |
+
# Respuesta por defecto según el contexto
|
| 143 |
+
if any(q in user_input.lower() for q in ['cómo estás', 'como estas', 'que tal']):
|
| 144 |
+
return "Estoy bien, gracias por preguntar. ¿En qué puedo ayudarte?"
|
| 145 |
+
elif any(q in user_input.lower() for q in ['quién eres', 'quien eres', 'que eres']):
|
| 146 |
+
return "Soy un asistente de IA. ¿En qué puedo ayudarte?"
|
| 147 |
else:
|
| 148 |
+
return "No pude procesar tu solicitud correctamente. ¿Podrías reformular tu pregunta?"
|
| 149 |
+
|
| 150 |
+
# Eliminar espacios múltiples y limpiar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 152 |
|
| 153 |
+
# Asegurar que termine con puntuación
|
| 154 |
+
if text and text[-1] not in '.!?':
|
| 155 |
+
text += '.'
|
| 156 |
+
|
| 157 |
return text
|
| 158 |
|
| 159 |
+
def format_prompt(user_input: str) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Formato de prompt más estructurado para reducir alucinaciones
|
| 162 |
+
"""
|
| 163 |
+
# Detectar tipo de pregunta para mejor contexto
|
| 164 |
+
user_lower = user_input.lower()
|
| 165 |
+
|
| 166 |
+
if is_greeting(user_input):
|
| 167 |
+
return "### Instrucción:\nSaluda cortésmente.\n\n### Respuesta:\n¡Hola! ¿En qué puedo ayudarte?"
|
| 168 |
+
|
| 169 |
+
# Prompt estructurado para respuestas más precisas
|
| 170 |
+
prompt = f"""### Instrucción:
|
| 171 |
+
Responde la siguiente pregunta de forma CONCISA y DIRECTA. No inventes información. Si no sabes la respuesta, dilo claramente.
|
| 172 |
+
|
| 173 |
+
Pregunta: {user_input}
|
| 174 |
+
|
| 175 |
+
### Respuesta:"""
|
| 176 |
+
|
| 177 |
+
return prompt
|
| 178 |
+
|
| 179 |
# ======================
|
| 180 |
# DEFINIR ARQUITECTURA DEL MODELO (MTP)
|
| 181 |
# ======================
|
|
|
|
| 287 |
logits = self.lm_head(x)
|
| 288 |
return logits
|
| 289 |
|
| 290 |
+
def generate(self, input_ids, max_new_tokens=150, temperature=0.5, top_k=30, top_p=0.85, repetition_penalty=1.2):
|
| 291 |
+
"""
|
| 292 |
+
Genera texto con parámetros más conservadores para respuestas precisas
|
| 293 |
+
temperature más baja = menos creatividad
|
| 294 |
+
repetition_penalty más alto = menos repeticiones
|
| 295 |
+
"""
|
| 296 |
generated = input_ids
|
| 297 |
|
| 298 |
for step in range(max_new_tokens):
|
|
|
|
| 300 |
logits = self(generated)
|
| 301 |
next_logits = logits[0, -1, :] / temperature
|
| 302 |
|
| 303 |
+
# Penalización por repetición más agresiva
|
| 304 |
if repetition_penalty != 1.0:
|
| 305 |
for token_id in set(generated[0].tolist()):
|
| 306 |
next_logits[token_id] /= repetition_penalty
|
| 307 |
|
| 308 |
+
# Top-k más restrictivo
|
| 309 |
if top_k > 0:
|
| 310 |
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
| 311 |
next_logits[indices_to_remove] = float('-inf')
|
| 312 |
|
| 313 |
+
# Top-p más restrictivo
|
| 314 |
if top_p < 1.0:
|
| 315 |
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 316 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
| 398 |
# ======================
|
| 399 |
app = FastAPI(
|
| 400 |
title="MTP API",
|
| 401 |
+
description="API para modelo de lenguaje MTP - Respuestas precisas y concisas",
|
| 402 |
+
version="1.1"
|
| 403 |
)
|
| 404 |
|
| 405 |
app.add_middleware(
|
|
|
|
| 411 |
|
| 412 |
class PromptRequest(BaseModel):
|
| 413 |
text: str = Field(..., max_length=2000, description="Texto de entrada")
|
| 414 |
+
max_tokens: int = Field(default=80, ge=10, le=150, description="Tokens máximos a generar")
|
| 415 |
+
temperature: float = Field(default=0.5, ge=0.1, le=1.5, description="Temperatura de muestreo (menor = más preciso)")
|
| 416 |
+
top_k: int = Field(default=30, ge=1, le=80, description="Top-k sampling")
|
| 417 |
+
top_p: float = Field(default=0.85, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
|
| 418 |
+
repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0, description="Penalización por repetición")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
# ======================
|
| 421 |
# GESTIÓN DE CARGA
|
|
|
|
| 446 |
|
| 447 |
@app.post("/generate")
|
| 448 |
async def generate(req: PromptRequest):
|
| 449 |
+
"""Endpoint principal de generación de texto - Versión mejorada para respuestas precisas"""
|
| 450 |
global ACTIVE_REQUESTS
|
| 451 |
ACTIVE_REQUESTS += 1
|
| 452 |
|
|
|
|
| 455 |
ACTIVE_REQUESTS -= 1
|
| 456 |
return {"reply": "", "tokens_generated": 0}
|
| 457 |
|
| 458 |
+
# Detectar tipo de mensaje
|
| 459 |
+
is_greeting_msg = is_greeting(user_input)
|
| 460 |
|
| 461 |
+
# Para saludos, respuesta directa sin generar
|
| 462 |
+
if is_greeting_msg:
|
| 463 |
+
ACTIVE_REQUESTS -= 1
|
| 464 |
+
return {
|
| 465 |
+
"reply": "¡Hola! ¿En qué puedo ayudarte?",
|
| 466 |
+
"tokens_generated": 0,
|
| 467 |
+
"model": "MTP",
|
| 468 |
+
"mode": "direct"
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
# Para preguntas muy cortas o confusas, pedir aclaración
|
| 472 |
+
if len(user_input) < 5 and not is_greeting_msg:
|
| 473 |
+
ACTIVE_REQUESTS -= 1
|
| 474 |
+
return {
|
| 475 |
+
"reply": "¿Podrías ser más específico? No entendí tu pregunta.",
|
| 476 |
+
"tokens_generated": 0,
|
| 477 |
+
"model": "MTP",
|
| 478 |
+
"mode": "clarify"
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
# Construir prompt estructurado
|
| 482 |
+
full_prompt = format_prompt(user_input)
|
| 483 |
tokens = tokenizer_wrapper.encode(full_prompt)
|
| 484 |
+
|
| 485 |
+
# Limitar longitud de entrada para evitar contextos muy largos
|
| 486 |
+
if len(tokens) > 256:
|
| 487 |
+
tokens = tokens[:256]
|
| 488 |
+
|
| 489 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 490 |
+
|
| 491 |
+
# Ajustar max_tokens según la pregunta
|
| 492 |
+
max_tokens = min(req.max_tokens, 100) # Limitar a 100 tokens máximo para respuestas concisas
|
| 493 |
|
| 494 |
try:
|
| 495 |
with torch.no_grad():
|
|
|
|
| 512 |
else:
|
| 513 |
response = ""
|
| 514 |
|
| 515 |
+
# Limpiar respuesta (elimina alucinaciones y repeticiones)
|
| 516 |
response = clean_response(response, user_input)
|
| 517 |
|
| 518 |
+
# Verificar si la respuesta es demasiado larga o no tiene sentido
|
| 519 |
+
if len(response) > 400:
|
| 520 |
+
response = truncate_response(response, 350)
|
| 521 |
+
|
| 522 |
# Si la respuesta sigue vacía o es muy corta, usar respuesta por defecto
|
| 523 |
+
if len(response) < 10:
|
| 524 |
+
response = "Lo siento, no pude generar una respuesta precisa. ¿Podrías reformular tu pregunta?"
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
return {
|
| 527 |
"reply": response,
|
| 528 |
"tokens_generated": len(safe_tokens),
|
| 529 |
+
"model": "MTP",
|
| 530 |
+
"mode": "generated"
|
| 531 |
}
|
| 532 |
|
| 533 |
except Exception as e:
|
| 534 |
print(f"❌ Error durante generación: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
return {
|
| 536 |
+
"reply": "Ocurrió un error al procesar tu solicitud. Por favor, intenta de nuevo.",
|
| 537 |
"error": str(e)
|
| 538 |
}
|
| 539 |
|
|
|
|
| 553 |
"model": "MTP",
|
| 554 |
"device": DEVICE,
|
| 555 |
"active_requests": ACTIVE_REQUESTS,
|
| 556 |
+
"vocab_size": VOCAB_SIZE,
|
| 557 |
+
"mode": "precise"
|
| 558 |
}
|
| 559 |
|
| 560 |
@app.get("/info")
|
| 561 |
def model_info():
|
| 562 |
return {
|
| 563 |
"model_name": "MTP",
|
| 564 |
+
"version": "1.1",
|
| 565 |
"architecture": config,
|
| 566 |
"parameters": sum(p.numel() for p in model.parameters()),
|
| 567 |
+
"device": DEVICE,
|
| 568 |
+
"description": "Modelo optimizado para respuestas precisas y concisas"
|
| 569 |
}
|
| 570 |
|
| 571 |
# ======================
|
| 572 |
+
# INTERFAZ WEB MEJORADA
|
| 573 |
# ======================
|
| 574 |
@app.get("/", response_class=HTMLResponse)
|
| 575 |
def chat_ui():
|
|
|
|
| 579 |
<head>
|
| 580 |
<meta charset="UTF-8">
|
| 581 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 582 |
+
<title>MTP - Asistente IA Preciso</title>
|
| 583 |
<style>
|
| 584 |
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 585 |
body {
|
|
|
|
| 599 |
font-size: 1.2rem;
|
| 600 |
font-weight: 500;
|
| 601 |
}
|
| 602 |
+
.chat-header p {
|
| 603 |
+
color: #888;
|
| 604 |
+
font-size: 0.75rem;
|
| 605 |
+
margin-top: 4px;
|
| 606 |
+
}
|
| 607 |
.chat-messages {
|
| 608 |
flex: 1;
|
| 609 |
overflow-y: auto;
|
|
|
|
| 694 |
0%, 80%, 100% { transform: scale(0); }
|
| 695 |
40% { transform: scale(1); }
|
| 696 |
}
|
| 697 |
+
.suggestion-buttons {
|
| 698 |
+
display: flex;
|
| 699 |
+
gap: 8px;
|
| 700 |
+
margin-top: 10px;
|
| 701 |
+
flex-wrap: wrap;
|
| 702 |
+
justify-content: center;
|
| 703 |
+
}
|
| 704 |
+
.suggestion {
|
| 705 |
+
background: #2a2b2e;
|
| 706 |
+
border: none;
|
| 707 |
+
border-radius: 20px;
|
| 708 |
+
padding: 6px 14px;
|
| 709 |
+
color: #aaa;
|
| 710 |
+
font-size: 0.8rem;
|
| 711 |
+
cursor: pointer;
|
| 712 |
+
transition: all 0.2s;
|
| 713 |
+
}
|
| 714 |
+
.suggestion:hover {
|
| 715 |
+
background: #3a3b3e;
|
| 716 |
+
color: white;
|
| 717 |
+
}
|
| 718 |
</style>
|
| 719 |
</head>
|
| 720 |
<body>
|
| 721 |
<div class="chat-header">
|
| 722 |
<h1>🤖 MTP - Asistente IA</h1>
|
| 723 |
+
<p>Respuestas precisas y concisas | Modo conservador</p>
|
| 724 |
</div>
|
| 725 |
<div class="chat-messages" id="chatMessages">
|
| 726 |
<div class="message bot">
|
| 727 |
+
<div class="message-content">¡Hola! Soy MTP, tu asistente de IA. Haré lo posible por darte respuestas precisas y concisas. ¿En qué puedo ayudarte?</div>
|
| 728 |
</div>
|
| 729 |
</div>
|
| 730 |
<div class="chat-input-container">
|
|
|
|
| 732 |
<input type="text" id="messageInput" placeholder="Escribe tu mensaje..." autocomplete="off">
|
| 733 |
<button id="sendBtn">Enviar</button>
|
| 734 |
</div>
|
| 735 |
+
<div class="suggestion-buttons">
|
| 736 |
+
<button class="suggestion" data-text="¿Qué es la inteligencia artificial?">🤖 ¿Qué es la IA?</button>
|
| 737 |
+
<button class="suggestion" data-text="¿Cómo estás?">😊 ¿Cómo estás?</button>
|
| 738 |
+
<button class="suggestion" data-text="¿Quién eres?">👋 ¿Quién eres?</button>
|
| 739 |
+
<button class="suggestion" data-text="Hola">👋 Hola</button>
|
| 740 |
+
</div>
|
| 741 |
</div>
|
| 742 |
<script>
|
| 743 |
const chatMessages = document.getElementById('chatMessages');
|
|
|
|
| 748 |
function addMessage(text, isUser) {
|
| 749 |
const div = document.createElement('div');
|
| 750 |
div.className = `message ${isUser ? 'user' : 'bot'}`;
|
| 751 |
+
div.innerHTML = `<div class="message-content">${escapeHtml(text)}</div>`;
|
| 752 |
chatMessages.appendChild(div);
|
| 753 |
chatMessages.scrollTop = chatMessages.scrollHeight;
|
| 754 |
return div;
|
| 755 |
}
|
| 756 |
|
| 757 |
+
function escapeHtml(text) {
|
| 758 |
+
const div = document.createElement('div');
|
| 759 |
+
div.textContent = text;
|
| 760 |
+
return div.innerHTML;
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
function addTypingIndicator() {
|
| 764 |
const div = document.createElement('div');
|
| 765 |
div.className = 'message bot';
|
|
|
|
| 774 |
if (indicator) indicator.remove();
|
| 775 |
}
|
| 776 |
|
| 777 |
+
async function sendMessage(text = null) {
|
| 778 |
+
const messageText = text || messageInput.value.trim();
|
| 779 |
+
if (!messageText || isLoading) return;
|
| 780 |
|
| 781 |
+
if (!text) messageInput.value = '';
|
| 782 |
+
addMessage(messageText, true);
|
| 783 |
isLoading = true;
|
| 784 |
sendBtn.disabled = true;
|
| 785 |
addTypingIndicator();
|
|
|
|
| 788 |
const response = await fetch('/generate', {
|
| 789 |
method: 'POST',
|
| 790 |
headers: { 'Content-Type': 'application/json' },
|
| 791 |
+
body: JSON.stringify({ text: messageText })
|
| 792 |
});
|
| 793 |
const data = await response.json();
|
| 794 |
removeTypingIndicator();
|
| 795 |
addMessage(data.reply, false);
|
| 796 |
} catch (error) {
|
| 797 |
removeTypingIndicator();
|
| 798 |
+
addMessage('Error de conexión. Por favor, intenta de nuevo.', false);
|
| 799 |
} finally {
|
| 800 |
isLoading = false;
|
| 801 |
sendBtn.disabled = false;
|
|
|
|
| 806 |
messageInput.addEventListener('keypress', (e) => {
|
| 807 |
if (e.key === 'Enter') sendMessage();
|
| 808 |
});
|
| 809 |
+
sendBtn.addEventListener('click', () => sendMessage());
|
| 810 |
+
|
| 811 |
+
// Sugerencias
|
| 812 |
+
document.querySelectorAll('.suggestion').forEach(btn => {
|
| 813 |
+
btn.addEventListener('click', () => {
|
| 814 |
+
sendMessage(btn.dataset.text);
|
| 815 |
+
});
|
| 816 |
+
});
|
| 817 |
+
|
| 818 |
messageInput.focus();
|
| 819 |
</script>
|
| 820 |
</body>
|
|
|
|
| 826 |
print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
|
| 827 |
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
|
| 828 |
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
|
| 829 |
+
print(f"\n⚙️ Configuración para respuestas PRECISAS:")
|
| 830 |
+
print(f" • Temperature: 0.5 (menos creatividad)")
|
| 831 |
+
print(f" • Top-k: 30 (muestreo más restrictivo)")
|
| 832 |
+
print(f" • Top-p: 0.85")
|
| 833 |
+
print(f" • Repetition penalty: 1.2 (reduce repeticiones)")
|
| 834 |
|
| 835 |
uvicorn.run(
|
| 836 |
app,
|