MTP_2 / app.py
teszenofficial's picture
Update app.py
cda7b0a verified
import os
import sys
import torch
import pickle
import time
import gc
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
# ======================
# CONFIGURACIÓN DE DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
# Optimización de hilos para CPU
if DEVICE == "cpu":
torch.set_num_threads(max(1, os.cpu_count() // 2))
torch.set_grad_enabled(False)
MODEL_REPO = "TeszenAI/MTP-4"
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
print(f"📦 Descargando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtp_repo"
)
sys.path.insert(0, repo_path)
# Importar modelo y tokenizer
from model import MTPMiniModel
from tokenizer import MTPTokenizer
print("🔧 Cargando tensores y configuración...")
with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f:
model_data = pickle.load(f)
tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model"))
VOCAB_SIZE = tokenizer.sp.get_piece_size()
config = model_data["config"]
# Detectar si el modelo usa SwiGLU
use_swiglu = config["model"].get("use_swiglu", False)
print(f"🧠 Inicializando modelo MTP 4...")
print(f" → Vocabulario: {VOCAB_SIZE}")
print(f" → Dimensión: {config['model']['d_model']}")
print(f" → Capas: {config['model']['n_layers']}")
print(f" → Cabezas: {config['model']['n_heads']}")
print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}")
model = MTPMiniModel(
vocab_size=VOCAB_SIZE,
d_model=config["model"]["d_model"],
n_layers=config["model"]["n_layers"],
n_heads=config["model"]["n_heads"],
d_ff=config["model"]["d_ff"],
max_seq_len=config["model"]["max_seq_len"],
dropout=0.0,
use_swiglu=use_swiglu
)
model.load_state_dict(model_data["model_state_dict"])
model.eval()
# Cuantización para CPU
if DEVICE == "cpu":
print("⚡ Aplicando cuantización dinámica para CPU...")
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
model.to(DEVICE)
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
# ======================
# API CONFIG
# ======================
app = FastAPI(
title="MTP 4 API",
description="API para modelo de lenguaje MTP 4 con RoPE, RMSNorm y SwiGLU",
version="4.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000, description="Texto de entrada")
max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
top_k: int = Field(default=40, ge=1, le=100, description="Top-k sampling")
top_p: float = Field(default=0.92, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
repetition_penalty: float = Field(default=1.15, ge=1.0, le=2.0, description="Penalización por repetición")
min_length: int = Field(default=20, ge=5, le=100, description="Longitud mínima de respuesta")
def build_prompt(user_input: str) -> str:
"""Construye el prompt en el formato del modelo"""
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
# ======================
# ⚡ GESTIÓN DE CARGA
# ======================
ACTIVE_REQUESTS = 0
MAX_CONCURRENT_REQUESTS = 3
@app.post("/generate")
async def generate(req: PromptRequest):
"""Endpoint principal de generación de texto con control de calidad"""
global ACTIVE_REQUESTS
if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS:
return {
"reply": "El servidor está ocupado. Por favor, intenta de nuevo en unos segundos.",
"error": "too_many_requests",
"active_requests": ACTIVE_REQUESTS
}
ACTIVE_REQUESTS += 1
# Ajuste dinámico bajo carga
dyn_max_tokens = req.max_tokens
dyn_temperature = req.temperature
if ACTIVE_REQUESTS > 1:
print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
dyn_max_tokens = min(dyn_max_tokens, 120)
dyn_temperature = max(0.6, dyn_temperature * 0.95)
user_input = req.text.strip()
if not user_input:
ACTIVE_REQUESTS -= 1
return {"reply": "", "tokens_generated": 0}
full_prompt = build_prompt(user_input)
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
try:
start_time = time.time()
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=dyn_max_tokens,
temperature=dyn_temperature,
top_k=req.top_k,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
min_length=req.min_length,
eos_token_id=tokenizer.eos_id()
)
gen_tokens = output_ids[0, len(tokens):].tolist()
# Filtro de seguridad mejorado
safe_tokens = []
for t in gen_tokens:
if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id():
safe_tokens.append(t)
elif t == tokenizer.eos_id():
break
response = tokenizer.decode(safe_tokens).strip()
# Limpiar marcadores de sección
if "###" in response:
response = response.split("###")[0].strip()
# Remover repeticiones al final
if response.endswith(("...", ". . .", "…")):
response = response.rstrip(".")
generation_time = time.time() - start_time
tokens_per_second = len(safe_tokens) / generation_time if generation_time > 0 else 0
return {
"reply": response,
"tokens_generated": len(safe_tokens),
"generation_time": round(generation_time, 2),
"tokens_per_second": round(tokens_per_second, 1),
"model": "MTP-4",
"device": DEVICE
}
except Exception as e:
print(f"❌ Error durante generación: {e}")
import traceback
traceback.print_exc()
return {
"reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
"error": str(e)
}
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
# ======================
# 📡 STREAMING SSE
# ======================
@app.get("/generate_sse")
def generate_sse(
text: str,
max_tokens: int = 150,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.92,
repetition_penalty: float = 1.15
):
"""Endpoint de streaming con Server-Sent Events mejorado"""
global ACTIVE_REQUESTS
if ACTIVE_REQUESTS >= MAX_CONCURRENT_REQUESTS:
def error_stream():
yield "data:[ERROR: Servidor ocupado]\n\n"
return StreamingResponse(error_stream(), media_type="text/event-stream")
ACTIVE_REQUESTS += 1
def event_stream():
try:
full_prompt = build_prompt(text)
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
generated_tokens = []
# Ajuste dinámico
limit = min(100 if ACTIVE_REQUESTS > 1 else max_tokens, 200)
temp = max(0.6, temperature * 0.95) if ACTIVE_REQUESTS > 1 else temperature
for step in range(limit):
with torch.no_grad():
logits, _ = model(input_ids)
logits = logits[:, -1, :VOCAB_SIZE].clone()
# Aplicar repetition penalty
if repetition_penalty != 1.0:
for token_id in set(input_ids[0].tolist()):
if logits[0, token_id] < 0:
logits[0, token_id] *= repetition_penalty
else:
logits[0, token_id] /= repetition_penalty
# Temperature scaling
logits = logits / temp
# Top-k filtering
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1).item()
if next_id == tokenizer.eos_id():
break
if 0 <= next_id < VOCAB_SIZE:
generated_tokens.append(next_id)
token_text = tokenizer.decode([next_id])
# Limpiar salida
if "###" in token_text:
break
yield f"data:{token_text}\n\n"
input_ids = torch.cat(
[input_ids, torch.tensor([[next_id]], device=DEVICE)],
dim=1
)
time.sleep(0.02) # Control de velocidad
yield "data:[DONE]\n\n"
except Exception as e:
print(f"❌ Error en streaming: {e}")
yield f"data:[ERROR: {str(e)}]\n\n"
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
return StreamingResponse(event_stream(), media_type="text/event-stream")
# ======================
# 📊 ENDPOINTS DE INFORMACIÓN
# ======================
@app.get("/health")
def health_check():
"""Check del estado del servicio"""
memory_info = {}
if DEVICE == "cuda":
memory_info = {
"gpu_memory_allocated_mb": round(torch.cuda.memory_allocated() / 1024**2, 2),
"gpu_memory_reserved_mb": round(torch.cuda.memory_reserved() / 1024**2, 2)
}
return {
"status": "healthy",
"model": "MTP-4",
"device": DEVICE,
"active_requests": ACTIVE_REQUESTS,
"max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
"vocab_size": VOCAB_SIZE,
"parameters": sum(p.numel() for p in model.parameters()),
**memory_info
}
@app.get("/info")
def model_info():
"""Información detallada del modelo"""
improvements = [
"RoPE (Rotary Position Embedding)",
"RMSNorm (Root Mean Square Normalization)",
"Label Smoothing (0.1)",
"Repetition Penalty",
"Early Stopping",
"EOS Loss Weight",
"Length Control",
"Gradient Accumulation"
]
if config["model"].get("use_swiglu", False):
improvements.append("SwiGLU Activation")
return {
"model_name": "MTP-4",
"version": "4.0",
"architecture": {
"d_model": config["model"]["d_model"],
"n_layers": config["model"]["n_layers"],
"n_heads": config["model"]["n_heads"],
"d_ff": config["model"]["d_ff"],
"max_seq_len": config["model"]["max_seq_len"],
"vocab_size": VOCAB_SIZE,
"use_swiglu": config["model"].get("use_swiglu", False),
"dropout": config["model"]["dropout"]
},
"parameters": sum(p.numel() for p in model.parameters()),
"parameters_human": f"{sum(p.numel() for p in model.parameters())/1e6:.1f}M",
"device": DEVICE,
"improvements": improvements,
"training_config": {
"batch_size": config["training"]["batch_size"],
"accumulation_steps": config["training"]["accumulation_steps"],
"learning_rate": config["training"]["learning_rate"],
"weight_decay": config["training"]["weight_decay"],
"epochs": config["training"]["epochs"]
}
}
@app.get("/config")
def get_config():
"""Obtener configuración completa del modelo"""
return {
"model": config["model"],
"training": config["training"],
"data": config["data"],
"generation": config.get("generation", {})
}
# ======================
# 🎨 INTERFAZ WEB MEJORADA
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<title>MTP 4 - Chat Interface</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
<style>
:root {
--bg-color: #0a0a0b;
--surface-color: #1a1a1c;
--accent-color: #6366f1;
--text-primary: #e8e8ea;
--text-secondary: #9ca3af;
--user-bubble: #2d2d30;
--success-color: #10b981;
--warning-color: #f59e0b;
--error-color: #ef4444;
--logo-url: url('https://i.postimg.cc/yxS54PF3/IMG-3082.jpg');
}
* {
box-sizing: border-box;
outline: none;
-webkit-tap-highlight-color: transparent;
}
body {
margin: 0;
background: linear-gradient(135deg, #0a0a0b 0%, #1a1a1c 100%);
font-family: 'Inter', sans-serif;
color: var(--text-primary);
height: 100dvh;
display: flex;
flex-direction: column;
overflow: hidden;
}
header {
padding: 14px 24px;
display: flex;
align-items: center;
justify-content: space-between;
background: rgba(26, 26, 28, 0.9);
backdrop-filter: blur(16px);
position: fixed;
top: 0;
width: 100%;
z-index: 50;
border-bottom: 1px solid rgba(99, 102, 241, 0.1);
}
.brand-wrapper {
display: flex;
align-items: center;
gap: 14px;
cursor: pointer;
}
.brand-logo {
width: 36px;
height: 36px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
background-position: center;
border: 2px solid rgba(99, 102, 241, 0.3);
box-shadow: 0 0 12px rgba(99, 102, 241, 0.2);
}
.brand-text {
font-weight: 600;
font-size: 1.15rem;
display: flex;
align-items: center;
gap: 10px;
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.version-badge {
font-size: 0.75rem;
background: linear-gradient(135deg, rgba(99, 102, 241, 0.2) 0%, rgba(139, 92, 246, 0.2) 100%);
color: #a5b4fc;
padding: 3px 10px;
border-radius: 14px;
font-weight: 700;
border: 1px solid rgba(99, 102, 241, 0.3);
}
.status-indicator {
width: 10px;
height: 10px;
border-radius: 50%;
background: var(--success-color);
animation: pulse 2s infinite;
box-shadow: 0 0 8px var(--success-color);
}
@keyframes pulse {
0%, 100% { opacity: 1; transform: scale(1); }
50% { opacity: 0.7; transform: scale(0.95); }
}
.chat-scroll {
flex: 1;
overflow-y: auto;
padding: 90px 24px 50px 24px;
display: flex;
flex-direction: column;
gap: 32px;
max-width: 900px;
margin: 0 auto;
width: 100%;
scroll-behavior: smooth;
}
.msg-row {
display: flex;
gap: 18px;
width: 100%;
opacity: 0;
transform: translateY(12px);
animation: slideUpFade 0.5s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
}
.msg-row.user { justify-content: flex-end; }
.msg-row.bot { justify-content: flex-start; align-items: flex-start; }
.msg-content {
line-height: 1.65;
font-size: 1rem;
word-wrap: break-word;
max-width: 85%;
}
.user .msg-content {
background: linear-gradient(135deg, #2d2d30 0%, #3a3a3d 100%);
padding: 12px 20px;
border-radius: 20px;
border-top-right-radius: 6px;
color: #fff;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
.bot .msg-content-wrapper {
display: flex;
flex-direction: column;
gap: 10px;
width: 100%;
}
.bot .msg-text {
padding-top: 8px;
color: var(--text-primary);
white-space: pre-wrap;
}
.bot-avatar {
width: 38px;
height: 38px;
min-width: 38px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
box-shadow: 0 0 16px rgba(99, 102, 241, 0.4);
border: 2px solid rgba(99, 102, 241, 0.3);
}
.bot-actions {
display: flex;
gap: 12px;
opacity: 0;
transition: opacity 0.3s;
margin-top: 6px;
}
.action-btn {
background: rgba(99, 102, 241, 0.1);
border: 1px solid rgba(99, 102, 241, 0.2);
color: var(--text-secondary);
cursor: pointer;
padding: 6px 12px;
border-radius: 8px;
display: flex;
align-items: center;
transition: all 0.2s;
font-size: 0.85rem;
}
.action-btn:hover {
color: var(--accent-color);
background: rgba(99, 102, 241, 0.15);
border-color: rgba(99, 102, 241, 0.4);
}
.action-btn svg {
width: 16px;
height: 16px;
fill: currentColor;
margin-right: 5px;
}
.typing-cursor::after {
content: '';
display: inline-block;
width: 3px;
height: 18px;
background: var(--accent-color);
margin-left: 3px;
vertical-align: middle;
animation: blink 0.8s infinite;
}
.footer-container {
padding: 0 24px 24px 24px;
background: linear-gradient(to top, rgba(10, 10, 11, 0.95) 85%, transparent);
position: relative;
z-index: 60;
}
.input-box {
max-width: 900px;
margin: 0 auto;
background: var(--surface-color);
border-radius: 30px;
padding: 10px 12px 10px 24px;
display: flex;
align-items: center;
border: 1px solid rgba(99, 102, 241, 0.2);
transition: all 0.3s;
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.3);
}
.input-box:focus-within {
border-color: rgba(99, 102, 241, 0.6);
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.15), 0 4px 20px rgba(0, 0, 0, 0.4);
}
#userInput {
flex: 1;
background: transparent;
border: none;
color: white;
font-size: 1rem;
font-family: inherit;
padding: 10px 0;
resize: none;
max-height: 120px;
}
#mainBtn {
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
color: white;
border: none;
width: 40px;
height: 40px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
margin-left: 10px;
transition: all 0.2s;
box-shadow: 0 4px 12px rgba(99, 102, 241, 0.3);
}
#mainBtn:hover {
transform: scale(1.05);
box-shadow: 0 6px 16px rgba(99, 102, 241, 0.5);
}
#mainBtn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: scale(1);
}
.disclaimer {
text-align: center;
font-size: 0.75rem;
color: #6b7280;
margin-top: 14px;
}
.stats-badge {
font-size: 0.7rem;
color: var(--text-secondary);
margin-top: 6px;
font-family: 'Monaco', monospace;
background: rgba(99, 102, 241, 0.05);
padding: 4px 8px;
border-radius: 6px;
display: inline-block;
}
@keyframes slideUpFade {
from { opacity: 0; transform: translateY(18px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes blink {
0%, 100% { opacity: 1; }
50% { opacity: 0.3; }
}
@keyframes pulseAvatar {
0% { box-shadow: 0 0 0 0 rgba(99, 102, 241, 0.5); }
70% { box-shadow: 0 0 0 10px rgba(99, 102, 241, 0); }
100% { box-shadow: 0 0 0 0 rgba(99, 102, 241, 0); }
}
.pulsing { animation: pulseAvatar 1.5s infinite; }
::-webkit-scrollbar { width: 10px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb {
background: rgba(99, 102, 241, 0.3);
border-radius: 5px;
}
::-webkit-scrollbar-thumb:hover { background: rgba(99, 102, 241, 0.5); }
.error-message {
color: var(--error-color);
font-size: 0.9rem;
padding: 10px 14px;
background: rgba(239, 68, 68, 0.1);
border-radius: 10px;
margin-top: 10px;
border: 1px solid rgba(239, 68, 68, 0.2);
}
</style>
</head>
<body>
<header>
<div class="brand-wrapper" onclick="location.reload()">
<div class="brand-logo"></div>
<div class="brand-text">
MTP <span class="version-badge">4.0</span>
</div>
</div>
<div class="status-indicator" title="Sistema operativo"></div>
</header>
<div id="chatScroll" class="chat-scroll">
<div class="msg-row bot" style="animation-delay: 0.1s;">
<div class="bot-avatar"></div>
<div class="msg-content-wrapper">
<div class="msg-text">
¡Hola! Soy MTP 4, un modelo de lenguaje avanzado con arquitectura Transformer optimizada.
Características principales:
• RoPE - Rotary Position Embedding para mejor contexto
• RMSNorm - Normalización estable y eficiente
• SwiGLU - Función de activación mejorada
• Control inteligente de repetición y coherencia
• Generación fluida y natural
¿En qué puedo ayudarte hoy?
</div>
</div>
</div>
</div>
<div class="footer-container">
<div class="input-box">
<textarea id="userInput" placeholder="Escribe un mensaje..." rows="1" autocomplete="off"></textarea>
<button id="mainBtn" onclick="handleBtnClick()"></button>
</div>
<div class="disclaimer">
MTP 4 puede cometer errores. Considera verificar la información importante.
</div>
</div>
<script>
const chatScroll = document.getElementById('chatScroll');
const userInput = document.getElementById('userInput');
const mainBtn = document.getElementById('mainBtn');
let isGenerating = false;
let abortController = null;
let typingTimeout = null;
let lastUserPrompt = "";
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
const ICON_STOP = `<svg width="16" height="16" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
mainBtn.innerHTML = ICON_SEND;
// Auto-resize textarea
userInput.addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = Math.min(this.scrollHeight, 120) + 'px';
});
function scrollToBottom() {
chatScroll.scrollTop = chatScroll.scrollHeight;
}
function setBtnState(state) {
if (state === 'sending') {
mainBtn.innerHTML = ICON_STOP;
mainBtn.disabled = false;
isGenerating = true;
} else if (state === 'disabled') {
mainBtn.disabled = true;
isGenerating = false;
} else {
mainBtn.innerHTML = ICON_SEND;
mainBtn.disabled = false;
isGenerating = false;
abortController = null;
}
}
function handleBtnClick() {
if (isGenerating) {
stopGeneration();
} else {
sendMessage();
}
}
function stopGeneration() {
if (abortController) abortController.abort();
if (typingTimeout) clearTimeout(typingTimeout);
const activeCursor = document.querySelector('.typing-cursor');
if (activeCursor) activeCursor.classList.remove('typing-cursor');
const activeAvatar = document.querySelector('.pulsing');
if (activeAvatar) activeAvatar.classList.remove('pulsing');
setBtnState('idle');
userInput.focus();
}
async function sendMessage(textOverride = null) {
const text = textOverride || userInput.value.trim();
if (!text) return;
lastUserPrompt = text;
if (!textOverride) {
userInput.value = '';
userInput.style.height = 'auto';
addMessage(text, 'user');
}
setBtnState('sending');
abortController = new AbortController();
const botRow = document.createElement('div');
botRow.className = 'msg-row bot';
const avatar = document.createElement('div');
avatar.className = 'bot-avatar pulsing';
const wrapper = document.createElement('div');
wrapper.className = 'msg-content-wrapper';
const msgText = document.createElement('div');
msgText.className = 'msg-text';
wrapper.appendChild(msgText);
botRow.appendChild(avatar);
botRow.appendChild(wrapper);
chatScroll.appendChild(botRow);
scrollToBottom();
try {
const startTime = performance.now();
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
text: text,
max_tokens: 150,
temperature: 0.7,
top_k: 40,
top_p: 0.92,
repetition_penalty: 1.15,
min_length: 20
}),
signal: abortController.signal
});
const data = await response.json();
if (!isGenerating) return;
avatar.classList.remove('pulsing');
if (data.error) {
msgText.innerHTML = `<span style="color: var(--error-color);">Error: ${data.error}</span>`;
setBtnState('idle');
return;
}
const reply = data.reply || "No entendí eso.";
const endTime = performance.now();
const totalTime = ((endTime - startTime) / 1000).toFixed(2);
await typeWriter(msgText, reply);
if (isGenerating) {
// Agregar estadísticas
const stats = document.createElement('div');
stats.className = 'stats-badge';
stats.textContent = `${data.tokens_generated} tokens • ${data.tokens_per_second} t/s • ${totalTime}s • ${data.device}`;
wrapper.appendChild(stats);
addActions(wrapper, reply);
setBtnState('idle');
}
} catch (error) {
if (error.name === 'AbortError') {
msgText.textContent += " [Detenido]";
} else {
console.error('Error:', error);
avatar.classList.remove('pulsing');
msgText.innerHTML = `<span style="color: var(--error-color);">Error de conexión. Por favor, intenta de nuevo.</span>`;
setBtnState('idle');
}
}
}
function addMessage(text, sender) {
const row = document.createElement('div');
row.className = `msg-row ${sender}`;
const content = document.createElement('div');
content.className = 'msg-content';
content.textContent = text;
row.appendChild(content);
chatScroll.appendChild(row);
scrollToBottom();
}
function typeWriter(element, text, speed = 12) {
return new Promise(resolve => {
let i = 0;
element.classList.add('typing-cursor');
function type() {
if (!isGenerating) {
element.classList.remove('typing-cursor');
resolve();
return;
}
if (i < text.length) {
element.textContent += text.charAt(i);
i++;
scrollToBottom();
typingTimeout = setTimeout(type, speed + Math.random() * 5);
} else {
element.classList.remove('typing-cursor');
resolve();
}
}
type();
});
}
function addActions(wrapperElement, textToCopy) {
const actionsDiv = document.createElement('div');
actionsDiv.className = 'bot-actions';
const copyBtn = document.createElement('button');
copyBtn.className = 'action-btn';
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>Copiar`;
copyBtn.onclick = () => {
navigator.clipboard.writeText(textToCopy).then(() => {
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><polyline points="20 6 9 17 4 12"></polyline></svg>Copiado`;
setTimeout(() => {
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>Copiar`;
}, 2000);
});
};
const regenBtn = document.createElement('button');
regenBtn.className = 'action-btn';
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>Regenerar`;
regenBtn.onclick = () => {
sendMessage(lastUserPrompt);
};
actionsDiv.appendChild(copyBtn);
actionsDiv.appendChild(regenBtn);
wrapperElement.appendChild(actionsDiv);
requestAnimationFrame(() => actionsDiv.style.opacity = "1");
scrollToBottom();
}
userInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleBtnClick();
}
});
window.onload = () => {
userInput.focus();
// Cargar info del modelo
fetch('/info')
.then(r => r.json())
.then(data => {
console.log('MTP 4 cargado:', data);
})
.catch(e => console.error('Error cargando info:', e));
};
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"\n🚀 Iniciando servidor MTP 4...")
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
print(f"📊 Health check: http://0.0.0.0:{port}/health")
print(f"ℹ️ Model info: http://0.0.0.0:{port}/info")
print(f"\n✅ Sistema listo. Presiona Ctrl+C para detener.")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)