MTP / app.py
teszenofficial's picture
Upload 2 files
19a3b5c verified
import os
import sys
import torch
import pickle
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from huggingface_hub import snapshot_download
import uvicorn
# ======================
# CONFIGURACIÓN DE DISPOSITIVO (GPU/CPU)
# ======================
# Detectar automáticamente si hay una GPU NVIDIA disponible
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
MODEL_REPO = "teszenofficial/mtptz"
# ======================
# DESCARGA DEL MODELO
# ======================
print(f"--- SISTEMA MTP 1.1 ---")
print(f"Descargando/Verificando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir="mtptz_repo"
)
sys.path.insert(0, repo_path)
try:
from model import MTPMiniModel
from tokenizer import MTPTokenizer
except ImportError:
print("Advertencia: Verifica la estructura de archivos del modelo.")
pass
# ======================
# CARGA DEL MODELO
# ======================
print("Cargando modelo en memoria...")
with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f:
model_data = pickle.load(f)
tokenizer = MTPTokenizer(
os.path.join(repo_path, "mtp_tokenizer.model")
)
config = model_data["config"]
model = MTPMiniModel(
vocab_size=model_data["vocab_size"],
d_model=config["model"]["d_model"],
n_layers=config["model"]["n_layers"],
n_heads=config["model"]["n_heads"],
d_ff=config["model"]["d_ff"],
max_seq_len=config["model"]["max_seq_len"],
dropout=0.0
)
# Cargar pesos y mover a GPU
model.load_state_dict(model_data["model_state_dict"])
model.to(DEVICE)
model.eval()
print(f"🚀 MTP 1.1 listo y corriendo en: {DEVICE.upper()}")
# ======================
# API FASTAPI
# ======================
app = FastAPI(title="MTP 1.1 API")
class Prompt(BaseModel):
text: str
@app.post("/generate")
def generate(prompt: Prompt):
user_input = prompt.text.strip()
if not user_input:
return {"reply": ""}
full_prompt = f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
# IMPORTANTE: Mover los inputs también a la GPU
input_ids = torch.tensor([tokens], device=DEVICE)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=150,
temperature=0.7,
top_k=50,
top_p=0.9
)
gen_tokens = output_ids[0, len(tokens):].tolist()
if tokenizer.eos_id() in gen_tokens:
gen_tokens = gen_tokens[:gen_tokens.index(tokenizer.eos_id())]
response = tokenizer.decode(gen_tokens).strip()
if "###" in response:
response = response.split("###")[0].strip()
return {"reply": response}
# ======================
# INTERFAZ WEB (FRONTEND MEJORADO)
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<title>MTP 1.1</title>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
<style>
/* --- VARIABLES & THEME --- */
:root {
--bg-color: #131314;
--surface-color: #1E1F20;
--accent-color: #4a9eff;
--text-primary: #e3e3e3;
--text-secondary: #9aa0a6;
--user-bubble: #282a2c;
--bot-actions-color: #c4c7c5;
--logo-url: url('https://i.postimg.cc/yxS54PF3/IMG-3082.jpg');
}
* { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
body {
margin: 0;
background-color: var(--bg-color);
font-family: 'Inter', sans-serif;
color: var(--text-primary);
height: 100dvh;
display: flex;
flex-direction: column;
overflow: hidden;
}
/* --- HEADER --- */
header {
padding: 12px 20px;
display: flex;
align-items: center;
justify-content: space-between;
background: rgba(19, 19, 20, 0.85);
backdrop-filter: blur(12px);
position: fixed;
top: 0;
width: 100%;
z-index: 50;
border-bottom: 1px solid rgba(255,255,255,0.05);
}
.brand-wrapper {
display: flex;
align-items: center;
gap: 12px;
cursor: pointer;
}
.brand-logo {
width: 32px;
height: 32px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
background-position: center;
border: 1px solid rgba(255,255,255,0.1);
}
.brand-text {
font-weight: 500;
font-size: 1.05rem;
display: flex;
align-items: center;
gap: 8px;
}
.version-badge {
font-size: 0.75rem;
background: rgba(74, 158, 255, 0.15);
color: #8ab4f8;
padding: 2px 8px;
border-radius: 12px;
font-weight: 600;
}
/* --- CHAT AREA --- */
.chat-scroll {
flex: 1;
overflow-y: auto;
padding: 80px 20px 40px 20px;
display: flex;
flex-direction: column;
gap: 30px;
max-width: 850px;
margin: 0 auto;
width: 100%;
scroll-behavior: smooth;
}
/* Filas de Mensaje */
.msg-row {
display: flex;
gap: 16px;
width: 100%;
opacity: 0;
transform: translateY(10px);
animation: slideUpFade 0.4s cubic-bezier(0.2, 0.8, 0.2, 1) forwards;
}
.msg-row.user { justify-content: flex-end; }
.msg-row.bot { justify-content: flex-start; align-items: flex-start; }
/* Contenido */
.msg-content {
line-height: 1.6;
font-size: 1rem;
word-wrap: break-word;
max-width: 85%;
}
.user .msg-content {
background-color: var(--user-bubble);
padding: 10px 18px;
border-radius: 18px;
border-top-right-radius: 4px;
color: #fff;
}
.bot .msg-content-wrapper {
display: flex;
flex-direction: column;
gap: 8px;
width: 100%;
}
.bot .msg-text {
padding-top: 6px;
color: var(--text-primary);
}
/* Avatar Bot */
.bot-avatar {
width: 34px;
height: 34px;
min-width: 34px;
border-radius: 50%;
background-image: var(--logo-url);
background-size: cover;
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
}
/* Acciones Bot */
.bot-actions {
display: flex;
gap: 10px;
opacity: 0;
transition: opacity 0.3s;
margin-top: 5px;
}
.action-btn {
background: transparent;
border: none;
color: var(--text-secondary);
cursor: pointer;
padding: 4px;
border-radius: 4px;
display: flex;
align-items: center;
transition: color 0.2s, background 0.2s;
}
.action-btn:hover {
color: var(--text-primary);
background: rgba(255,255,255,0.08);
}
.action-btn svg { width: 16px; height: 16px; fill: currentColor; }
/* Efecto Escritura (BOLITA AZUL) */
.typing-cursor::after {
content: '';
display: inline-block;
width: 10px;
height: 10px;
background: var(--accent-color);
border-radius: 50%;
margin-left: 5px;
vertical-align: middle;
animation: blink 1s infinite;
}
/* --- FOOTER & INPUT --- */
.footer-container {
padding: 0 20px 20px 20px;
background: linear-gradient(to top, var(--bg-color) 85%, transparent);
position: relative;
z-index: 60;
}
.input-box {
max-width: 850px;
margin: 0 auto;
background: var(--surface-color);
border-radius: 28px;
padding: 8px 10px 8px 20px;
display: flex;
align-items: center;
border: 1px solid rgba(255,255,255,0.1);
transition: border-color 0.2s, box-shadow 0.2s;
}
.input-box:focus-within {
border-color: rgba(74, 158, 255, 0.5);
box-shadow: 0 0 0 2px rgba(74, 158, 255, 0.1);
}
#userInput {
flex: 1;
background: transparent;
border: none;
color: white;
font-size: 1rem;
font-family: inherit;
padding: 10px 0;
}
#mainBtn {
background: white;
color: black;
border: none;
width: 36px;
height: 36px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
margin-left: 8px;
transition: transform 0.2s;
}
#mainBtn:hover { transform: scale(1.05); }
.disclaimer {
text-align: center;
font-size: 0.75rem;
color: #666;
margin-top: 12px;
}
/* --- ANIMACIONES --- */
@keyframes slideUpFade {
from { opacity: 0; transform: translateY(15px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
@keyframes pulseAvatar {
0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
}
.pulsing { animation: pulseAvatar 1.5s infinite; }
::-webkit-scrollbar { width: 8px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
</style>
</head>
<body>
<header>
<div class="brand-wrapper" onclick="location.reload()">
<div class="brand-logo"></div>
<div class="brand-text">
MTP <span class="version-badge">1.1</span>
</div>
</div>
</header>
<div id="chatScroll" class="chat-scroll">
<!-- Bienvenida -->
<div class="msg-row bot" style="animation-delay: 0.1s;">
<div class="bot-avatar"></div>
<div class="msg-content-wrapper">
<div class="msg-text">
¡Hola! Soy MTP 1.1. ¿En qué puedo ayudarte hoy?
</div>
</div>
</div>
</div>
<div class="footer-container">
<div class="input-box">
<input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
<button id="mainBtn" onclick="handleBtnClick()">
<!-- Icono dinámico -->
</button>
</div>
<div class="disclaimer">
MTP puede cometer errores. Considera verificar la información importante.
</div>
</div>
<script>
const chatScroll = document.getElementById('chatScroll');
const userInput = document.getElementById('userInput');
const mainBtn = document.getElementById('mainBtn');
// Variables de Estado
let isGenerating = false;
let abortController = null;
let typingTimeout = null;
let lastUserPrompt = "";
// Iconos SVG
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
// Inicial
mainBtn.innerHTML = ICON_SEND;
// --- UTILS ---
function scrollToBottom() {
chatScroll.scrollTop = chatScroll.scrollHeight;
}
function setBtnState(state) {
if (state === 'sending') {
mainBtn.innerHTML = ICON_STOP;
isGenerating = true;
} else {
mainBtn.innerHTML = ICON_SEND;
isGenerating = false;
abortController = null;
}
}
// --- CORE ---
function handleBtnClick() {
if (isGenerating) {
stopGeneration();
} else {
sendMessage();
}
}
function stopGeneration() {
if (abortController) abortController.abort();
if (typingTimeout) clearTimeout(typingTimeout);
// UI Limpieza
const activeCursor = document.querySelector('.typing-cursor');
if (activeCursor) activeCursor.classList.remove('typing-cursor');
const activeAvatar = document.querySelector('.pulsing');
if (activeAvatar) activeAvatar.classList.remove('pulsing');
setBtnState('idle');
userInput.focus();
}
async function sendMessage(textOverride = null) {
const text = textOverride || userInput.value.trim();
if (!text) return;
lastUserPrompt = text;
if (!textOverride) {
userInput.value = '';
addMessage(text, 'user');
}
setBtnState('sending');
abortController = new AbortController();
// Bot Placeholder
const botRow = document.createElement('div');
botRow.className = 'msg-row bot';
const avatar = document.createElement('div');
avatar.className = 'bot-avatar pulsing';
const wrapper = document.createElement('div');
wrapper.className = 'msg-content-wrapper';
const msgText = document.createElement('div');
msgText.className = 'msg-text';
wrapper.appendChild(msgText);
botRow.appendChild(avatar);
botRow.appendChild(wrapper);
chatScroll.appendChild(botRow);
scrollToBottom();
try {
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text }),
signal: abortController.signal
});
const data = await response.json();
if (!isGenerating) return;
avatar.classList.remove('pulsing');
const reply = data.reply || "No entendí eso.";
await typeWriter(msgText, reply);
if (isGenerating) {
addActions(wrapper, reply);
setBtnState('idle');
}
} catch (error) {
if (error.name === 'AbortError') {
msgText.textContent += " [Detenido]";
} else {
avatar.classList.remove('pulsing');
msgText.textContent = "Error de conexión.";
msgText.style.color = "#ff8b8b";
setBtnState('idle');
}
}
}
function addMessage(text, sender) {
const row = document.createElement('div');
row.className = `msg-row ${sender}`;
const content = document.createElement('div');
content.className = 'msg-content';
content.textContent = text;
row.appendChild(content);
chatScroll.appendChild(row);
scrollToBottom();
}
function typeWriter(element, text, speed = 12) {
return new Promise(resolve => {
let i = 0;
element.classList.add('typing-cursor');
function type() {
if (!isGenerating) {
element.classList.remove('typing-cursor');
resolve();
return;
}
if (i < text.length) {
element.textContent += text.charAt(i);
i++;
scrollToBottom();
typingTimeout = setTimeout(type, speed + Math.random() * 5);
} else {
element.classList.remove('typing-cursor');
resolve();
}
}
type();
});
}
function addActions(wrapperElement, textToCopy) {
const actionsDiv = document.createElement('div');
actionsDiv.className = 'bot-actions';
const copyBtn = document.createElement('button');
copyBtn.className = 'action-btn';
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
copyBtn.onclick = () => {
navigator.clipboard.writeText(textToCopy);
};
const regenBtn = document.createElement('button');
regenBtn.className = 'action-btn';
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
regenBtn.onclick = () => {
sendMessage(lastUserPrompt);
};
actionsDiv.appendChild(copyBtn);
actionsDiv.appendChild(regenBtn);
wrapperElement.appendChild(actionsDiv);
requestAnimationFrame(() => actionsDiv.style.opacity = "1");
scrollToBottom();
}
userInput.addEventListener('keydown', (e) => {
if (e.key === 'Enter') handleBtnClick();
});
window.onload = () => userInput.focus();
</script>
</body>
</html>
"""
# ======================
# ENTRYPOINT
# ======================
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)