MTP-2.5 / app.py
teszenofficial's picture
Update app.py
0968217 verified
import os
import sys
import torch
import json
import time
import gc
import re
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
# ======================
# CONFIGURACIÓN DE DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
if DEVICE == "cpu":
torch.set_num_threads(max(1, os.cpu_count() // 2))
torch.set_grad_enabled(False)
MODEL_REPO = "TeszenAI/MTP-2.5"
# ======================
# DEFINIR ARQUITECTURA DEL MODELO (MTP-1.1)
# ======================
class LayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.weight * (x - mean) / (std + self.eps) + self.bias
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.w_o(attn_output)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output = self.attention(x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
ff_output = self.feed_forward(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
class MTPModel(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4,
n_layers: int = 4, d_ff: int = 512, dropout: float = 0.1, max_len: int = 256):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_len = max_len
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
if mask is None:
mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
x = self.token_embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)
logits = self.lm_head(x)
return logits
def generate(self, input_ids, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1, eos_token_id=3):
"""Método de generación mejorado con parada limpia"""
generated = input_ids
eos_detected = False
for _ in range(max_new_tokens):
with torch.no_grad():
logits = self(generated)
next_logits = logits[0, -1, :] / temperature
# Repetition penalty
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
next_logits[token_id] /= repetition_penalty
# Top-k
if top_k > 0:
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
next_logits[indices_to_remove] = float('-inf')
# Top-p
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
# Detener en EOS o tokens sospechosos
if next_token == eos_token_id:
eos_detected = True
break
# Detener si detectamos repetición excesiva del mismo token
if len(generated[0]) > 10:
last_tokens = generated[0][-10:].tolist()
if len(set(last_tokens)) == 1:
break
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
return generated, eos_detected
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
print(f"📦 Descargando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtp_repo"
)
# Cargar configuración
config_path = os.path.join(repo_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
config = {
"vocab_size": 5000,
"d_model": 128,
"n_heads": 4,
"n_layers": 4,
"d_ff": 512,
"dropout": 0.1,
"max_len": 256
}
# Cargar tokenizador
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_path)
VOCAB_SIZE = sp.get_piece_size()
EOS_TOKEN_ID = sp.eos_id()
BOS_TOKEN_ID = sp.bos_id()
# Actualizar vocab_size en config
config["vocab_size"] = VOCAB_SIZE
print(f"🧠 Inicializando modelo MTP-1.1...")
print(f" → Vocabulario: {VOCAB_SIZE}")
print(f" → EOS token ID: {EOS_TOKEN_ID}")
print(f" → BOS token ID: {BOS_TOKEN_ID}")
print(f" → Dimensión: {config['d_model']}")
print(f" → Capas: {config['n_layers']}")
print(f" → Heads: {config['n_heads']}")
model = MTPModel(**config)
model.to(DEVICE)
# Cargar pesos del modelo
model_path = os.path.join(repo_path, "mtp_model.pt")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
print("✅ Pesos del modelo cargados")
else:
print("⚠️ No se encontró mtp_model.pt, usando pesos aleatorios")
model.eval()
# Cuantización para CPU
if DEVICE == "cpu":
print("⚡ Aplicando cuantización dinámica para CPU...")
model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
# ======================
# FUNCIONES DE LIMPIEZA DE RESPUESTAS
# ======================
def clean_response(text: str, original_prompt: str = None) -> str:
"""Limpia la respuesta generada eliminando basura y repeticiones"""
if not text:
return "Lo siento, no pude generar una respuesta."
# Eliminar el prompt original si aparece al inicio
if original_prompt:
prompt_clean = original_prompt.strip().lower()
text_lower = text.lower()
if text_lower.startswith(prompt_clean):
text = text[len(original_prompt):].strip()
elif prompt_clean in text_lower[:50]:
# Buscar después del prompt
idx = text_lower.find(prompt_clean)
if idx != -1:
text = text[idx + len(original_prompt):].strip()
# Eliminar partes que contienen "###"
if "###" in text:
text = text.split("###")[0].strip()
# Eliminar repeticiones absurdas (patrones como "xxx" repetido)
words = text.split()
if len(words) > 10:
unique_words = []
last_word = None
repeat_count = 0
for w in words:
if w == last_word:
repeat_count += 1
if repeat_count > 2:
continue
else:
repeat_count = 0
unique_words.append(w)
last_word = w
text = " ".join(unique_words)
# Eliminar fragmentos que parecen basura (patrones sin sentido)
garbage_patterns = [
r'[a-z]{20,}', # Palabras muy largas sin sentido
r'\d{5,}', # Números muy largos
r'[^\w\s\.\,\!\?\-áéíóúüñ]{10,}', # Caracteres extraños repetidos
]
for pattern in garbage_patterns:
text = re.sub(pattern, '', text)
# Limpiar espacios múltiples
text = re.sub(r'\s+', ' ', text).strip()
# Capitalizar primera letra
if text and len(text) > 0:
text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
# Si la respuesta es demasiado corta o vacía, dar mensaje por defecto
if len(text) < 3:
return "Entendido. ¿Algo más en lo que pueda ayudarte?"
return text
# ======================
# API CONFIG
# ======================
app = FastAPI(
title="MTP-1.1 API",
description="API para modelo de lenguaje MTP-1.1",
version="1.1"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000, description="Texto de entrada")
max_tokens: int = Field(default=100, ge=10, le=200, description="Tokens máximos a generar")
temperature: float = Field(default=0.7, ge=0.1, le=1.5, description="Temperatura de muestreo")
top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
def build_prompt(user_input: str) -> str:
"""Construye el prompt en el formato del modelo"""
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
# ======================
# GESTIÓN DE CARGA
# ======================
ACTIVE_REQUESTS = 0
class MTPTokenizer:
"""Wrapper para el tokenizador de SentencePiece"""
def __init__(self, sp_model):
self.sp = sp_model
def encode(self, text):
return self.sp.encode(text)
def decode(self, tokens):
return self.sp.decode(tokens)
def bos_id(self):
return self.sp.bos_id()
def eos_id(self):
return self.sp.eos_id()
tokenizer_wrapper = MTPTokenizer(sp)
@app.post("/generate")
async def generate(req: PromptRequest):
"""Endpoint principal de generación de texto"""
global ACTIVE_REQUESTS
ACTIVE_REQUESTS += 1
try:
user_input = req.text.strip()
if not user_input:
return {"reply": "", "tokens_generated": 0}
# Construir prompt
full_prompt = build_prompt(user_input)
tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
# Parámetros dinámicos según carga
dyn_max_tokens = req.max_tokens
dyn_temperature = req.temperature
if ACTIVE_REQUESTS > 2:
dyn_max_tokens = min(dyn_max_tokens, 80)
dyn_temperature = max(0.5, dyn_temperature * 0.9)
# Generar
with torch.no_grad():
output_ids, eos_detected = model.generate(
input_ids,
max_new_tokens=dyn_max_tokens,
temperature=dyn_temperature,
top_k=req.top_k,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
eos_token_id=tokenizer_wrapper.eos_id()
)
# Extraer solo los tokens generados (excluyendo el prompt)
gen_tokens = output_ids[0, len(tokens):].tolist()
# Filtrar tokens inválidos
safe_tokens = [
t for t in gen_tokens
if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
]
# Decodificar
raw_response = tokenizer_wrapper.decode(safe_tokens).strip()
# Limpiar respuesta
clean_reply = clean_response(raw_response, user_input)
# Si EOS no fue detectado y la respuesta parece incompleta, truncar
if not eos_detected and len(clean_reply) > 200:
# Buscar un punto final para truncar
last_period = clean_reply.rfind('.')
if last_period > 100:
clean_reply = clean_reply[:last_period + 1]
# Eliminar frases sin sentido comunes
nonsense_phrases = [
"foompañances", "ciudadores", "mejtedon", "calportedon",
"rápidodcor", "rápidodarse", "miel", "baon", "domol"
]
for phrase in nonsense_phrases:
clean_reply = clean_reply.replace(phrase, "")
# Limpiar espacios dobles nuevamente
clean_reply = re.sub(r'\s+', ' ', clean_reply).strip()
# Si la respuesta sigue siendo muy larga y no tiene puntos, cortar
if len(clean_reply) > 300 and '.' not in clean_reply[-50:]:
clean_reply = clean_reply[:250] + "..."
return {
"reply": clean_reply,
"tokens_generated": len(safe_tokens),
"model": "MTP-1.1",
"eos_detected": eos_detected
}
except Exception as e:
print(f"❌ Error durante generación: {e}")
return {
"reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
"error": str(e)
}
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
# ======================
# ENDPOINTS DE INFORMACIÓN
# ======================
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model": "MTP-1.1",
"device": DEVICE,
"active_requests": ACTIVE_REQUESTS,
"vocab_size": VOCAB_SIZE
}
@app.get("/info")
def model_info():
return {
"model_name": "MTP-1.1",
"version": "1.1",
"architecture": config,
"parameters": sum(p.numel() for p in model.parameters()),
"device": DEVICE
}
# ======================
# INTERFAZ WEB (MODERNA)
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<title>MTP 1.1 - Chat IA</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
<style>
:root {
--bg-color: #131314;
--surface-color: #1E1F20;
--accent-color: #4a9eff;
--text-primary: #e3e3e3;
--text-secondary: #9aa0a6;
--user-bubble: #282a2c;
}
* { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
body {
margin: 0;
background-color: var(--bg-color);
font-family: 'Inter', sans-serif;
color: var(--text-primary);
height: 100dvh;
display: flex;
flex-direction: column;
overflow: hidden;
}
header {
padding: 12px 20px;
display: flex;
align-items: center;
justify-content: space-between;
background: rgba(19, 19, 20, 0.85);
backdrop-filter: blur(12px);
position: fixed;
top: 0;
width: 100%;
z-index: 50;
border-bottom: 1px solid rgba(255,255,255,0.05);
}
.brand-wrapper {
display: flex;
align-items: center;
gap: 12px;
cursor: pointer;
}
.brand-logo {
width: 32px;
height: 32px;
border-radius: 50%;
background: linear-gradient(135deg, #4a9eff, #8a6eff);
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
font-size: 14px;
color: white;
}
.brand-text {
font-weight: 500;
font-size: 1.05rem;
display: flex;
align-items: center;
gap: 8px;
}
.version-badge {
font-size: 0.75rem;
background: rgba(74, 158, 255, 0.15);
color: #8ab4f8;
padding: 2px 8px;
border-radius: 12px;
font-weight: 600;
}
.status-badge {
font-size: 0.7rem;
background: rgba(76, 175, 80, 0.15);
color: #4caf50;
padding: 2px 8px;
border-radius: 12px;
font-weight: 500;
display: flex;
align-items: center;
gap: 6px;
}
.status-badge .dot {
width: 8px;
height: 8px;
background: #4caf50;
border-radius: 50%;
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; transform: scale(1); }
50% { opacity: 0.5; transform: scale(0.8); }
}
.chat-scroll {
flex: 1;
overflow-y: auto;
padding: 80px 20px 40px 20px;
display: flex;
flex-direction: column;
gap: 30px;
max-width: 850px;
margin: 0 auto;
width: 100%;
scroll-behavior: smooth;
}
.msg-row {
display: flex;
gap: 16px;
width: 100%;
opacity: 0;
transform: translateY(10px);
animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
}
.msg-row.user { justify-content: flex-end; }
.msg-row.bot { justify-content: flex-start; align-items: flex-start; }
.msg-content {
line-height: 1.6;
font-size: 1rem;
word-wrap: break-word;
max-width: 85%;
}
.user .msg-content {
background-color: var(--user-bubble);
padding: 10px 18px;
border-radius: 18px;
border-top-right-radius: 4px;
color: #fff;
}
.bot .msg-content-wrapper {
display: flex;
flex-direction: column;
gap: 8px;
width: 100%;
}
.bot .msg-text {
padding-top: 6px;
color: var(--text-primary);
}
.bot-avatar {
width: 34px;
height: 34px;
min-width: 34px;
border-radius: 50%;
background: linear-gradient(135deg, #4a9eff, #8a6eff);
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
font-size: 14px;
color: white;
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
}
.bot-actions {
display: flex;
gap: 10px;
opacity: 0;
transition: opacity 0.3s;
margin-top: 5px;
}
.action-btn {
background: transparent;
border: none;
color: var(--text-secondary);
cursor: pointer;
padding: 4px;
border-radius: 4px;
display: flex;
align-items: center;
transition: color 0.2s, background 0.2s;
}
.action-btn:hover {
color: var(--text-primary);
background: rgba(255,255,255,0.08);
}
.action-btn svg { width: 16px; height: 16px; fill: currentColor; }
.typing-cursor::after {
content: '';
display: inline-block;
width: 10px;
height: 10px;
background: var(--accent-color);
border-radius: 50%;
margin-left: 5px;
vertical-align: middle;
animation: blink 1s infinite;
}
.footer-container {
padding: 0 20px 20px 20px;
background: linear-gradient(to top, var(--bg-color) 85%, transparent);
position: relative;
z-index: 60;
}
.input-box {
max-width: 850px;
margin: 0 auto;
background: var(--surface-color);
border-radius: 28px;
padding: 8px 10px 8px 20px;
display: flex;
align-items: center;
border: 1px solid rgba(255,255,255,0.1);
transition: border-color 0.2s, box-shadow 0.2s;
}
.input-box:focus-within {
border-color: rgba(74, 158, 255, 0.5);
box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
}
#userInput {
flex: 1;
background: transparent;
border: none;
color: white;
font-size: 1rem;
font-family: inherit;
padding: 10px 0;
}
#userInput::placeholder {
color: var(--text-secondary);
}
#mainBtn {
background: white;
color: black;
border: none;
width: 36px;
height: 36px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
margin-left: 8px;
transition: transform 0.2s;
}
#mainBtn:hover { transform: scale(1.05); }
.disclaimer {
text-align: center;
font-size: 0.75rem;
color: #666;
margin-top: 12px;
}
@keyframes slideUpFade {
from { opacity: 0; transform: translateY(15px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
::-webkit-scrollbar { width: 8px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
</style>
</head>
<body>
<header>
<div class="brand-wrapper" onclick="location.reload()">
<div class="brand-logo">M</div>
<div class="brand-text">
MTP <span class="version-badge">2.5</span>
</div>
</div>
<div class="status-badge">
<span class="dot"></span>
<span id="statusText">Conectado</span>
</div>
</header>
<div id="chatScroll" class="chat-scroll">
<div class="msg-row bot" style="animation-delay: 0.1s;">
<div class="bot-avatar">M</div>
<div class="msg-content-wrapper">
<div class="msg-text">
¡Hola! Soy MTP 2.5 ¿En qué puedo ayudarte hoy?
</div>
</div>
</div>
</div>
<div class="footer-container">
<div class="input-box">
<input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
<button id="mainBtn" onclick="handleBtnClick()"></button>
</div>
<div class="disclaimer">
MTP puede cometer errores. Considera verificar la información importante.
</div>
</div>
<script>
const chatScroll = document.getElementById('chatScroll');
const userInput = document.getElementById('userInput');
const mainBtn = document.getElementById('mainBtn');
const statusText = document.getElementById('statusText');
let isGenerating = false;
let abortController = null;
let typingTimeout = null;
let lastUserPrompt = "";
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
mainBtn.innerHTML = ICON_SEND;
function scrollToBottom() {
chatScroll.scrollTop = chatScroll.scrollHeight;
}
function setBtnState(state) {
if (state === 'sending') {
mainBtn.innerHTML = ICON_STOP;
isGenerating = true;
statusText.textContent = "Pensando...";
} else {
mainBtn.innerHTML = ICON_SEND;
isGenerating = false;
abortController = null;
statusText.textContent = "Conectado";
}
}
function handleBtnClick() {
if (isGenerating) {
stopGeneration();
} else {
sendMessage();
}
}
function stopGeneration() {
if (abortController) abortController.abort();
if (typingTimeout) clearTimeout(typingTimeout);
const activeCursor = document.querySelector('.typing-cursor');
if (activeCursor) activeCursor.classList.remove('typing-cursor');
setBtnState('idle');
userInput.focus();
}
async function sendMessage(textOverride = null) {
const text = textOverride || userInput.value.trim();
if (!text) return;
lastUserPrompt = text;
if (!textOverride) {
userInput.value = '';
addMessage(text, 'user');
}
setBtnState('sending');
abortController = new AbortController();
const botRow = document.createElement('div');
botRow.className = 'msg-row bot';
const avatar = document.createElement('div');
avatar.className = 'bot-avatar';
avatar.textContent = 'M';
const wrapper = document.createElement('div');
wrapper.className = 'msg-content-wrapper';
const msgText = document.createElement('div');
msgText.className = 'msg-text';
wrapper.appendChild(msgText);
botRow.appendChild(avatar);
botRow.appendChild(wrapper);
chatScroll.appendChild(botRow);
scrollToBottom();
try {
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text }),
signal: abortController.signal
});
const data = await response.json();
if (!isGenerating) return;
const reply = data.reply || "Lo siento, no pude procesar tu solicitud.";
await typeWriter(msgText, reply);
if (isGenerating) {
addActions(wrapper, reply);
setBtnState('idle');
}
} catch (error) {
if (error.name === 'AbortError') {
msgText.textContent += " [Detenido]";
} else {
msgText.textContent = "Error de conexión. Intenta de nuevo.";
msgText.style.color = "#ff8b8b";
}
setBtnState('idle');
}
}
function addMessage(text, sender) {
const row = document.createElement('div');
row.className = `msg-row ${sender}`;
const content = document.createElement('div');
content.className = 'msg-content';
content.textContent = text;
row.appendChild(content);
chatScroll.appendChild(row);
scrollToBottom();
}
function typeWriter(element, text, speed = 12) {
return new Promise(resolve => {
let i = 0;
element.textContent = '';
element.classList.add('typing-cursor');
function type() {
if (!isGenerating) {
element.classList.remove('typing-cursor');
resolve();
return;
}
if (i < text.length) {
element.textContent += text.charAt(i);
i++;
scrollToBottom();
typingTimeout = setTimeout(type, speed + Math.random() * 5);
} else {
element.classList.remove('typing-cursor');
resolve();
}
}
type();
});
}
function addActions(wrapperElement, textToCopy) {
const actionsDiv = document.createElement('div');
actionsDiv.className = 'bot-actions';
const copyBtn = document.createElement('button');
copyBtn.className = 'action-btn';
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
copyBtn.onclick = () => {
navigator.clipboard.writeText(textToCopy);
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M20 6L9 17l-5-5"></path></svg>`;
setTimeout(() => {
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
}, 1500);
};
const regenBtn = document.createElement('button');
regenBtn.className = 'action-btn';
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
regenBtn.onclick = () => {
sendMessage(lastUserPrompt);
};
actionsDiv.appendChild(copyBtn);
actionsDiv.appendChild(regenBtn);
wrapperElement.appendChild(actionsDiv);
requestAnimationFrame(() => actionsDiv.style.opacity = "1");
scrollToBottom();
}
userInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleBtnClick();
}
});
window.onload = () => userInput.focus();
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"\n🚀 Iniciando servidor MTP-1.1 en puerto {port}...")
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)