MTP-3 / app.py
teszenofficial's picture
Create app.py
1044bb0 verified
import os
import sys
import torch
import pickle
import time
import gc
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
# ======================
# CONFIGURACIÓN DE DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
# Optimización de hilos para CPU
if DEVICE == "cpu":
torch.set_num_threads(max(1, os.cpu_count() // 2))
torch.set_grad_enabled(False)
MODEL_REPO = "TeszenAI/mtp-3.1"
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
print(f"📦 Descargando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtptz_repo"
)
sys.path.insert(0, repo_path)
# Importar modelo mejorado compatible
from model import MTPMiniModel
from tokenizer import MTPTokenizer
print("🔧 Cargando tensores y configuración...")
with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f:
model_data = pickle.load(f)
tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model"))
VOCAB_SIZE = tokenizer.sp.get_piece_size()
config = model_data["config"]
# Detectar si el modelo usa SwiGLU
use_swiglu = config["model"].get("use_swiglu", False)
print(f"🧠 Inicializando modelo...")
print(f" → Vocabulario: {VOCAB_SIZE}")
print(f" → Dimensión: {config['model']['d_model']}")
print(f" → Capas: {config['model']['n_layers']}")
print(f" → SwiGLU: {'✓' if use_swiglu else '✗'}")
model = MTPMiniModel(
vocab_size=VOCAB_SIZE,
d_model=config["model"]["d_model"],
n_layers=config["model"]["n_layers"],
n_heads=config["model"]["n_heads"],
d_ff=config["model"]["d_ff"],
max_seq_len=config["model"]["max_seq_len"],
dropout=0.0,
use_swiglu=use_swiglu # NUEVO: soporte para SwiGLU
)
model.load_state_dict(model_data["model_state_dict"])
model.eval()
# Cuantización para CPU
if DEVICE == "cpu":
print("⚡ Aplicando cuantización dinámica para CPU...")
model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
model.to(DEVICE)
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
# ======================
# API CONFIG
# ======================
app = FastAPI(
title="MTP-3 API",
description="API para modelo de lenguaje MTP-3 mejorado",
version="3.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000, description="Texto de entrada")
max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
def build_prompt(user_input: str) -> str:
"""Construye el prompt en el formato del modelo"""
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
# ======================
# ⚡ GESTIÓN DE CARGA
# ======================
ACTIVE_REQUESTS = 0
@app.post("/generate")
async def generate(req: PromptRequest):
"""Endpoint principal de generación de texto"""
global ACTIVE_REQUESTS
ACTIVE_REQUESTS += 1
# Ajuste dinámico bajo carga
dyn_max_tokens = req.max_tokens
dyn_temperature = req.temperature
if ACTIVE_REQUESTS > 2:
print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
dyn_max_tokens = min(dyn_max_tokens, 120)
dyn_temperature = max(0.5, dyn_temperature * 0.9)
user_input = req.text.strip()
if not user_input:
ACTIVE_REQUESTS -= 1
return {"reply": "", "tokens_generated": 0}
full_prompt = build_prompt(user_input)
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
try:
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=dyn_max_tokens,
temperature=dyn_temperature,
top_k=req.top_k,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty
)
gen_tokens = output_ids[0, len(tokens):].tolist()
# Filtro de seguridad
safe_tokens = [
t for t in gen_tokens
if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()
]
response = tokenizer.decode(safe_tokens).strip()
# Limpiar marcadores de sección
if "###" in response:
response = response.split("###")[0].strip()
return {
"reply": response,
"tokens_generated": len(safe_tokens),
"model": "MTP-3"
}
except Exception as e:
print(f"❌ Error durante generación: {e}")
return {
"reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
"error": str(e)
}
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
# ======================
# 📡 STREAMING SSE
# ======================
@app.get("/generate_sse")
def generate_sse(
text: str,
max_tokens: int = 150,
temperature: float = 0.7
):
"""Endpoint de streaming con Server-Sent Events"""
global ACTIVE_REQUESTS
ACTIVE_REQUESTS += 1
def event_stream():
try:
full_prompt = build_prompt(text)
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
# Ajuste dinámico
limit = 100 if ACTIVE_REQUESTS > 2 else max_tokens
temp = max(0.5, temperature * 0.9) if ACTIVE_REQUESTS > 2 else temperature
for step in range(limit):
with torch.no_grad():
logits, _ = model(input_ids)
logits = logits[:, -1, :VOCAB_SIZE]
# Sampling con temperatura
probs = torch.softmax(logits / temp, dim=-1)
next_id = torch.multinomial(probs, num_samples=1).item()
if next_id == tokenizer.eos_id():
break
if 0 <= next_id < VOCAB_SIZE:
token_text = tokenizer.decode([next_id])
# Limpiar salida
if "###" in token_text:
break
yield f"data:{token_text}\n\n"
input_ids = torch.cat(
[input_ids, torch.tensor([[next_id]], device=DEVICE)],
dim=1
)
time.sleep(0.01)
yield "data:[DONE]\n\n"
except Exception as e:
yield f"data:[ERROR: {str(e)}]\n\n"
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
return StreamingResponse(event_stream(), media_type="text/event-stream")
# ======================
# 📊 ENDPOINTS DE INFORMACIÓN
# ======================
@app.get("/health")
def health_check():
"""Check del estado del servicio"""
return {
"status": "healthy",
"model": "MTP-3",
"device": DEVICE,
"active_requests": ACTIVE_REQUESTS,
"vocab_size": VOCAB_SIZE,
"parameters": sum(p.numel() for p in model.parameters())
}
@app.get("/info")
def model_info():
"""Información detallada del modelo"""
return {
"model_name": "MTP-3",
"version": "3.0",
"architecture": {
"d_model": config["model"]["d_model"],
"n_layers": config["model"]["n_layers"],
"n_heads": config["model"]["n_heads"],
"d_ff": config["model"]["d_ff"],
"max_seq_len": config["model"]["max_seq_len"],
"vocab_size": VOCAB_SIZE,
"use_swiglu": config["model"].get("use_swiglu", False)
},
"parameters": sum(p.numel() for p in model.parameters()),
"device": DEVICE,
"improvements": [
"RoPE (Rotary Position Embedding)",
"RMSNorm",
"Label Smoothing",
"Repetition Penalty",
"SwiGLU (opcional)" if config["model"].get("use_swiglu") else None
]
}
# ======================
# 🎨 INTERFAZ WEB
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<title>MTP 3</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
<style>
:root {
--bg-color: #131314;
--surface-color: #1E1F20;
--accent-color: #4a9eff;
--text-primary: #e3e3e3;
--text-secondary: #9aa0a6;
--user-bubble: #282a2c;
--bot-actions-color: #c4c7c5;
--logo-url: url('https://i.postimg.cc/yxS54PF3/IMG-3082.jpg');
}
* { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
body {
margin: 0;
background-color: var(--bg-color);
font-family: 'Inter', sans-serif;
color: var(--text-primary);
height: 100dvh;
display: flex;
flex-direction: column;
overflow: hidden;
}
header {
padding: 12px 20px;
display: flex;
align-items: center;
justify-content: space-between;
background: rgba(19, 19, 20, 0.85);
backdrop-filter: blur(12px);
position: fixed;
top: 0;
width: 100%;
z-index: 50;
border-bottom: 1px solid rgba(255,255,255,0.05);
}
.brand-wrapper {
display: flex;
align-items: center;
gap: 12px;
cursor: pointer;
}
.brand-logo {
width: 32px;
height: 32px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
background-position: center;
border: 1px solid rgba(255,255,255,0.1);
}
.brand-text {
font-weight: 500;
font-size: 1.05rem;
display: flex;
align-items: center;
gap: 8px;
}
.version-badge {
font-size: 0.75rem;
background: rgba(74, 158, 255, 0.15);
color: #8ab4f8;
padding: 2px 8px;
border-radius: 12px;
font-weight: 600;
}
.chat-scroll {
flex: 1;
overflow-y: auto;
padding: 80px 20px 40px 20px;
display: flex;
flex-direction: column;
gap: 30px;
max-width: 850px;
margin: 0 auto;
width: 100%;
scroll-behavior: smooth;
}
.msg-row {
display: flex;
gap: 16px;
width: 100%;
opacity: 0;
transform: translateY(10px);
animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
}
.msg-row.user { justify-content: flex-end; }
.msg-row.bot { justify-content: flex-start; align-items: flex-start; }
.msg-content {
line-height: 1.6;
font-size: 1rem;
word-wrap: break-word;
max-width: 85%;
}
.user .msg-content {
background-color: var(--user-bubble);
padding: 10px 18px;
border-radius: 18px;
border-top-right-radius: 4px;
color: #fff;
}
.bot .msg-content-wrapper {
display: flex;
flex-direction: column;
gap: 8px;
width: 100%;
}
.bot .msg-text {
padding-top: 6px;
color: var(--text-primary);
}
.bot-avatar {
width: 34px;
height: 34px;
min-width: 34px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
}
.bot-actions {
display: flex;
gap: 10px;
opacity: 0;
transition: opacity 0.3s;
margin-top: 5px;
}
.action-btn {
background: transparent;
border: none;
color: var(--text-secondary);
cursor: pointer;
padding: 4px;
border-radius: 4px;
display: flex;
align-items: center;
transition: color 0.2s, background 0.2s;
}
.action-btn:hover {
color: var(--text-primary);
background: rgba(255,255,255,0.08);
}
.action-btn svg { width: 16px; height: 16px; fill: currentColor; }
.typing-cursor::after {
content: '';
display: inline-block;
width: 10px;
height: 10px;
background: var(--accent-color);
border-radius: 50%;
margin-left: 5px;
vertical-align: middle;
animation: blink 1s infinite;
}
.footer-container {
padding: 0 20px 20px 20px;
background: linear-gradient(to top, var(--bg-color) 85%, transparent);
position: relative;
z-index: 60;
}
.input-box {
max-width: 850px;
margin: 0 auto;
background: var(--surface-color);
border-radius: 28px;
padding: 8px 10px 8px 20px;
display: flex;
align-items: center;
border: 1px solid rgba(255,255,255,0.1);
transition: border-color 0.2s, box-shadow 0.2s;
}
.input-box:focus-within {
border-color: rgba(74, 158, 255, 0.5);
box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
}
#userInput {
flex: 1;
background: transparent;
border: none;
color: white;
font-size: 1rem;
font-family: inherit;
padding: 10px 0;
}
#mainBtn {
background: white;
color: black;
border: none;
width: 36px;
height: 36px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
margin-left: 8px;
transition: transform 0.2s;
}
#mainBtn:hover { transform: scale(1.05); }
.disclaimer {
text-align: center;
font-size: 0.75rem;
color: #666;
margin-top: 12px;
}
@keyframes slideUpFade {
from { opacity: 0; transform: translateY(15px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
@keyframes pulseAvatar {
0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
}
.pulsing { animation: pulseAvatar 1.5s infinite; }
::-webkit-scrollbar { width: 8px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
</style>
</head>
<body>
<header>
<div class="brand-wrapper" onclick="location.reload()">
<div class="brand-logo"></div>
<div class="brand-text">
MTP <span class="version-badge">3</span>
</div>
</div>
</header>
<div id="chatScroll" class="chat-scroll">
<div class="msg-row bot" style="animation-delay: 0.1s;">
<div class="bot-avatar"></div>
<div class="msg-content-wrapper">
<div class="msg-text">
¡Hola! Soy MTP 3. ¿En qué puedo ayudarte hoy?
</div>
</div>
</div>
</div>
<div class="footer-container">
<div class="input-box">
<input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
<button id="mainBtn" onclick="handleBtnClick()"></button>
</div>
<div class="disclaimer">
MTP puede cometer errores. Considera verificar la información importante.
</div>
</div>
<script>
const chatScroll = document.getElementById('chatScroll');
const userInput = document.getElementById('userInput');
const mainBtn = document.getElementById('mainBtn');
let isGenerating = false;
let abortController = null;
let typingTimeout = null;
let lastUserPrompt = "";
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
mainBtn.innerHTML = ICON_SEND;
function scrollToBottom() {
chatScroll.scrollTop = chatScroll.scrollHeight;
}
function setBtnState(state) {
if (state === 'sending') {
mainBtn.innerHTML = ICON_STOP;
isGenerating = true;
} else {
mainBtn.innerHTML = ICON_SEND;
isGenerating = false;
abortController = null;
}
}
function handleBtnClick() {
if (isGenerating) {
stopGeneration();
} else {
sendMessage();
}
}
function stopGeneration() {
if (abortController) abortController.abort();
if (typingTimeout) clearTimeout(typingTimeout);
const activeCursor = document.querySelector('.typing-cursor');
if (activeCursor) activeCursor.classList.remove('typing-cursor');
const activeAvatar = document.querySelector('.pulsing');
if (activeAvatar) activeAvatar.classList.remove('pulsing');
setBtnState('idle');
userInput.focus();
}
async function sendMessage(textOverride = null) {
const text = textOverride || userInput.value.trim();
if (!text) return;
lastUserPrompt = text;
if (!textOverride) {
userInput.value = '';
addMessage(text, 'user');
}
setBtnState('sending');
abortController = new AbortController();
const botRow = document.createElement('div');
botRow.className = 'msg-row bot';
const avatar = document.createElement('div');
avatar.className = 'bot-avatar pulsing';
const wrapper = document.createElement('div');
wrapper.className = 'msg-content-wrapper';
const msgText = document.createElement('div');
msgText.className = 'msg-text';
wrapper.appendChild(msgText);
botRow.appendChild(avatar);
botRow.appendChild(wrapper);
chatScroll.appendChild(botRow);
scrollToBottom();
try {
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text }),
signal: abortController.signal
});
const data = await response.json();
if (!isGenerating) return;
avatar.classList.remove('pulsing');
const reply = data.reply || "No entendí eso.";
await typeWriter(msgText, reply);
if (isGenerating) {
addActions(wrapper, reply);
setBtnState('idle');
}
} catch (error) {
if (error.name === 'AbortError') {
msgText.textContent += " [Detenido]";
} else {
avatar.classList.remove('pulsing');
msgText.textContent = "Error de conexión.";
msgText.style.color = "#ff8b8b";
setBtnState('idle');
}
}
}
function addMessage(text, sender) {
const row = document.createElement('div');
row.className = `msg-row ${sender}`;
const content = document.createElement('div');
content.className = 'msg-content';
content.textContent = text;
row.appendChild(content);
chatScroll.appendChild(row);
scrollToBottom();
}
function typeWriter(element, text, speed = 12) {
return new Promise(resolve => {
let i = 0;
element.classList.add('typing-cursor');
function type() {
if (!isGenerating) {
element.classList.remove('typing-cursor');
resolve();
return;
}
if (i < text.length) {
element.textContent += text.charAt(i);
i++;
scrollToBottom();
typingTimeout = setTimeout(type, speed + Math.random() * 5);
} else {
element.classList.remove('typing-cursor');
resolve();
}
}
type();
});
}
function addActions(wrapperElement, textToCopy) {
const actionsDiv = document.createElement('div');
actionsDiv.className = 'bot-actions';
const copyBtn = document.createElement('button');
copyBtn.className = 'action-btn';
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
copyBtn.onclick = () => {
navigator.clipboard.writeText(textToCopy);
};
const regenBtn = document.createElement('button');
regenBtn.className = 'action-btn';
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
regenBtn.onclick = () => {
sendMessage(lastUserPrompt);
};
actionsDiv.appendChild(copyBtn);
actionsDiv.appendChild(regenBtn);
wrapperElement.appendChild(actionsDiv);
requestAnimationFrame(() => actionsDiv.style.opacity = "1");
scrollToBottom();
}
userInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter') handleBtnClick();
});
window.onload = () => userInput.focus();
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"\n🚀 Iniciando servidor en puerto {port}...")
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)