MTP-3.3.1 / app.py
teszenofficial's picture
Update app.py
e20058a verified
import os
import sys
import torch
import json
import time
import gc
import re
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
# ======================
# CONFIGURACIÓN DE DISPOSITIVO
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
print("✅ GPU NVIDIA detectada. Usando CUDA.")
else:
DEVICE = "cpu"
print("⚠️ GPU no detectada. Usando CPU.")
if DEVICE == "cpu":
torch.set_num_threads(os.cpu_count())
torch.set_grad_enabled(False)
MODEL_REPO = "TeszenAI/MTP-3.3.1"
# ======================
# LIMPIEZA DE RESPUESTAS
# ======================
def clean_response(text: str) -> str:
"""Limpia y acorta respuestas para evitar loops"""
if not text:
return ""
# Limitar longitud máxima
if len(text) > 300:
text = text[:300]
# Eliminar repeticiones
words = text.split()
cleaned = []
last = ""
repeat = 0
for w in words:
if w == last:
repeat += 1
if repeat > 2:
continue
else:
last = w
repeat = 0
cleaned.append(w)
text = " ".join(cleaned)
text = re.sub(r'\s+', ' ', text).strip()
if len(text) < 3:
return "Lo siento, no pude generar una respuesta clara."
return text
# ======================
# ARQUITECTURA DEL MODELO (OPTIMIZADA)
# ======================
class LayerNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
return self.weight * (x - x.mean(-1, keepdim=True)) / (x.std(-1, keepdim=True) + self.eps) + self.bias
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.w_o(attn_output)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.dropout1(self.attention(self.norm1(x), mask))
x = x + self.dropout2(self.feed_forward(self.norm2(x)))
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
class MTPModel(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_len = max_len
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, x, mask=None):
if mask is None:
mask = torch.tril(torch.ones(x.size(1), x.size(1))).unsqueeze(0).unsqueeze(0).to(x.device)
x = self.token_embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, mask)
return self.lm_head(self.norm(x))
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=100, temperature=0.6, top_k=40):
"""Generación RÁPIDA - optimizada para velocidad"""
generated = input_ids
eos_id = 3
for _ in range(max_new_tokens):
# Solo usar últimos tokens para velocidad
context = generated if generated.size(1) <= self.max_len else generated[:, -self.max_len:]
logits = self(context)
next_logits = logits[0, -1, :] / temperature
# Top-K sampling (más rápido que top-p)
if top_k > 0:
top_k_vals, top_k_indices = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits = torch.full_like(next_logits, float('-inf'))
next_logits[top_k_indices] = top_k_vals
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
if next_token == eos_id or next_token == 0:
break
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
return generated
# ======================
# CARGA DEL MODELO
# ======================
print(f"📦 Descargando modelo desde {MODEL_REPO}...")
repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
# Configuración
config_path = os.path.join(repo_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
else:
config = {"d_model": 256, "n_heads": 8, "n_layers": 6, "d_ff": 1024, "dropout": 0.1, "max_len": 512}
# Tokenizador
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_path)
VOCAB_SIZE = sp.get_piece_size()
config["vocab_size"] = VOCAB_SIZE
print(f"🧠 Inicializando modelo...")
model = MTPModel(**config)
model.to(DEVICE)
# Cargar pesos
model_path = os.path.join(repo_path, "mtp_model.pt")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
print("✅ Pesos cargados")
model.eval()
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ Modelo listo: {param_count:,} params ({param_count/1e6:.1f}M)")
# ======================
# API
# ======================
app = FastAPI(title="MTP API", version="3.3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000)
max_tokens: int = Field(default=100, ge=20, le=150)
temperature: float = Field(default=0.6, ge=0.3, le=1.0)
def build_prompt(user_input: str) -> str:
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
@app.post("/generate")
async def generate(req: PromptRequest):
try:
user_input = req.text.strip()
if not user_input:
return {"reply": ""}
# Parámetros fijos para velocidad y precisión
max_tokens = min(req.max_tokens, 100) # Límite para velocidad
temperature = 0.6 # Fijo para menos alucinaciones
full_prompt = build_prompt(user_input)
tokens = sp.encode(full_prompt)
# Limitar longitud del prompt
if len(tokens) > 400:
tokens = tokens[:400]
input_ids = torch.tensor([tokens], device=DEVICE)
start = time.time()
output_ids = model.generate(input_ids, max_new_tokens=max_tokens, temperature=temperature, top_k=40)
elapsed = time.time() - start
gen_tokens = output_ids[0, len(tokens):].tolist()
gen_tokens = [t for t in gen_tokens if t not in [0, 1, 2, 3]]
if gen_tokens:
response = sp.decode(gen_tokens).strip()
else:
response = ""
response = clean_response(response)
return {
"reply": response,
"tokens": len(gen_tokens),
"time": round(elapsed, 2)
}
except Exception as e:
print(f"Error: {e}")
return {"reply": "Lo siento, ocurrió un error. Intenta de nuevo."}
@app.get("/health")
def health():
return {"status": "ok", "model": "MTP-3.3.1"}
@app.get("/info")
def info():
return {"model": "MTP-3.3.1", "parameters": param_count, "device": DEVICE}
# ======================
# INTERFAZ WEB SIMPLE Y RÁPIDA
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MTP - Asistente IA</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
background: #1a1a2e;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
height: 100vh;
display: flex;
flex-direction: column;
}
.header {
background: #16213e;
padding: 15px 20px;
border-bottom: 1px solid #0f3460;
}
.header h1 { color: white; font-size: 1.2rem; }
.header p { color: #888; font-size: 0.75rem; margin-top: 4px; }
.messages {
flex: 1;
overflow-y: auto;
padding: 20px;
display: flex;
flex-direction: column;
gap: 12px;
}
.message {
max-width: 80%;
padding: 10px 15px;
border-radius: 18px;
font-size: 0.9rem;
line-height: 1.4;
animation: fadeIn 0.2s ease;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(5px); }
to { opacity: 1; transform: translateY(0); }
}
.user { background: #0f3460; color: white; align-self: flex-end; border-radius: 18px 4px 18px 18px; }
.bot { background: #16213e; color: #e0e0e0; align-self: flex-start; border-radius: 4px 18px 18px 18px; }
.input-area {
background: #16213e;
padding: 15px 20px;
border-top: 1px solid #0f3460;
display: flex;
gap: 10px;
}
input {
flex: 1;
padding: 12px 15px;
background: #0f3460;
border: none;
border-radius: 25px;
color: white;
font-size: 0.9rem;
outline: none;
}
input::placeholder { color: #888; }
button {
padding: 12px 25px;
background: #e94560;
border: none;
border-radius: 25px;
color: white;
font-weight: bold;
cursor: pointer;
transition: opacity 0.2s;
}
button:hover { opacity: 0.9; }
button:disabled { opacity: 0.5; cursor: not-allowed; }
.typing {
background: #16213e;
padding: 10px 15px;
border-radius: 18px;
align-self: flex-start;
}
.typing span {
display: inline-block;
width: 8px;
height: 8px;
background: #888;
border-radius: 50%;
margin: 0 2px;
animation: bounce 1.4s infinite;
}
.typing span:nth-child(2) { animation-delay: 0.2s; }
.typing span:nth-child(3) { animation-delay: 0.4s; }
@keyframes bounce {
0%, 60%, 100% { transform: translateY(0); }
30% { transform: translateY(-8px); }
}
.time-badge {
font-size: 0.65rem;
color: #666;
margin-top: 4px;
}
@media (max-width: 600px) {
.message { max-width: 95%; }
}
</style>
</head>
<body>
<div class="header">
<h1>🤖 MTP - Asistente IA</h1>
<p>v3.3.1 | Respuestas rápidas y precisas | Temperatura 0.6</p>
</div>
<div class="messages" id="messages">
<div class="message bot">✨ Hola, soy MTP. Pregúntame lo que quieras. Intento ser rápido y preciso.</div>
</div>
<div class="input-area">
<input type="text" id="input" placeholder="Escribe tu mensaje..." autocomplete="off">
<button id="send">Enviar</button>
</div>
<script>
const messages = document.getElementById('messages');
const input = document.getElementById('input');
const sendBtn = document.getElementById('send');
let loading = false;
function addMessage(text, isUser, time = null) {
const div = document.createElement('div');
div.className = `message ${isUser ? 'user' : 'bot'}`;
div.innerHTML = `<div>${escapeHtml(text)}</div>${time ? `<div class="time-badge">⚡ ${time}s</div>` : ''}`;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
function showTyping() {
const div = document.createElement('div');
div.className = 'typing';
div.id = 'typing';
div.innerHTML = '<span></span><span></span><span></span>';
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function hideTyping() {
const el = document.getElementById('typing');
if (el) el.remove();
}
async function sendMessage() {
const text = input.value.trim();
if (!text || loading) return;
input.value = '';
addMessage(text, true);
loading = true;
sendBtn.disabled = true;
showTyping();
try {
const res = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text, max_tokens: 100 })
});
const data = await res.json();
hideTyping();
addMessage(data.reply, false, data.time);
} catch (err) {
hideTyping();
addMessage('⚠️ Error de conexión. Intenta de nuevo.', false);
} finally {
loading = false;
sendBtn.disabled = false;
input.focus();
}
}
input.addEventListener('keypress', (e) => { if (e.key === 'Enter') sendMessage(); });
sendBtn.addEventListener('click', sendMessage);
input.focus();
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"\n🚀 Servidor MTP en http://0.0.0.0:{port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")