Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,12 @@ import torch
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import gc
|
|
|
|
| 7 |
from fastapi import FastAPI, Request
|
| 8 |
-
from fastapi.responses import HTMLResponse
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
-
from huggingface_hub import snapshot_download
|
| 12 |
import uvicorn
|
| 13 |
import math
|
| 14 |
import torch.nn as nn
|
|
@@ -33,7 +34,7 @@ torch.set_grad_enabled(False)
|
|
| 33 |
MODEL_REPO = "TeszenAI/MTP-3"
|
| 34 |
|
| 35 |
# ======================
|
| 36 |
-
#
|
| 37 |
# ======================
|
| 38 |
class LayerNorm(nn.Module):
|
| 39 |
def __init__(self, d_model: int, eps: float = 1e-5):
|
|
@@ -41,7 +42,6 @@ class LayerNorm(nn.Module):
|
|
| 41 |
self.weight = nn.Parameter(torch.ones(d_model))
|
| 42 |
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 43 |
self.eps = eps
|
| 44 |
-
|
| 45 |
def forward(self, x):
|
| 46 |
mean = x.mean(-1, keepdim=True)
|
| 47 |
std = x.std(-1, keepdim=True)
|
|
@@ -60,7 +60,6 @@ class MultiHeadAttention(nn.Module):
|
|
| 60 |
self.w_o = nn.Linear(d_model, d_model)
|
| 61 |
self.dropout = nn.Dropout(dropout)
|
| 62 |
self.scale = math.sqrt(self.d_k)
|
| 63 |
-
|
| 64 |
def forward(self, x, mask=None):
|
| 65 |
batch_size, seq_len, _ = x.shape
|
| 66 |
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
|
|
@@ -81,7 +80,6 @@ class FeedForward(nn.Module):
|
|
| 81 |
self.linear1 = nn.Linear(d_model, d_ff)
|
| 82 |
self.linear2 = nn.Linear(d_ff, d_model)
|
| 83 |
self.dropout = nn.Dropout(dropout)
|
| 84 |
-
|
| 85 |
def forward(self, x):
|
| 86 |
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
|
| 87 |
|
|
@@ -94,7 +92,6 @@ class TransformerBlock(nn.Module):
|
|
| 94 |
self.norm2 = LayerNorm(d_model)
|
| 95 |
self.dropout1 = nn.Dropout(dropout)
|
| 96 |
self.dropout2 = nn.Dropout(dropout)
|
| 97 |
-
|
| 98 |
def forward(self, x, mask=None):
|
| 99 |
attn_output = self.attention(x, mask)
|
| 100 |
x = x + self.dropout1(attn_output)
|
|
@@ -113,22 +110,19 @@ class PositionalEncoding(nn.Module):
|
|
| 113 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 114 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 115 |
self.register_buffer('pe', pe.unsqueeze(0))
|
| 116 |
-
|
| 117 |
def forward(self, x):
|
| 118 |
return x + self.pe[:, :x.size(1), :]
|
| 119 |
|
| 120 |
class MTPModel(nn.Module):
|
| 121 |
-
def __init__(self, vocab_size: int, d_model: int =
|
| 122 |
-
n_layers: int =
|
| 123 |
super().__init__()
|
| 124 |
self.vocab_size = vocab_size
|
| 125 |
self.d_model = d_model
|
| 126 |
self.max_len = max_len
|
| 127 |
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 128 |
self.pos_encoding = PositionalEncoding(d_model, max_len)
|
| 129 |
-
self.blocks = nn.ModuleList([
|
| 130 |
-
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 131 |
-
])
|
| 132 |
self.norm = LayerNorm(d_model)
|
| 133 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 134 |
|
|
@@ -140,82 +134,34 @@ class MTPModel(nn.Module):
|
|
| 140 |
for block in self.blocks:
|
| 141 |
x = block(x, mask)
|
| 142 |
x = self.norm(x)
|
| 143 |
-
|
| 144 |
-
return logits
|
| 145 |
-
|
| 146 |
-
def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
|
| 147 |
-
"""Método de generación compatible con la interfaz"""
|
| 148 |
-
generated = input_ids
|
| 149 |
-
|
| 150 |
-
for _ in range(max_new_tokens):
|
| 151 |
-
with torch.no_grad():
|
| 152 |
-
logits = self(generated)
|
| 153 |
-
next_logits = logits[0, -1, :] / temperature
|
| 154 |
-
|
| 155 |
-
if repetition_penalty != 1.0:
|
| 156 |
-
for token_id in set(generated[0].tolist()):
|
| 157 |
-
next_logits[token_id] /= repetition_penalty
|
| 158 |
-
|
| 159 |
-
if top_k > 0:
|
| 160 |
-
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
| 161 |
-
next_logits[indices_to_remove] = float('-inf')
|
| 162 |
-
|
| 163 |
-
if top_p < 1.0:
|
| 164 |
-
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 165 |
-
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 166 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 167 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 168 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 169 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 170 |
-
next_logits[indices_to_remove] = float('-inf')
|
| 171 |
-
|
| 172 |
-
probs = F.softmax(next_logits, dim=-1)
|
| 173 |
-
next_token = torch.multinomial(probs, num_samples=1).item()
|
| 174 |
-
|
| 175 |
-
if next_token == 3:
|
| 176 |
-
break
|
| 177 |
-
|
| 178 |
-
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
|
| 179 |
-
|
| 180 |
-
return generated
|
| 181 |
|
| 182 |
# ======================
|
| 183 |
# DESCARGA Y CARGA DEL MODELO CON REINTENTOS
|
| 184 |
# ======================
|
| 185 |
def download_with_retry(repo_id, local_dir, max_retries=3):
|
| 186 |
-
"""Descarga el modelo con reintentos para evitar timeouts"""
|
| 187 |
-
|
| 188 |
for attempt in range(max_retries):
|
| 189 |
try:
|
| 190 |
print(f"📦 Intento {attempt + 1}/{max_retries} - Descargando modelo desde {repo_id}...")
|
| 191 |
-
|
| 192 |
-
# Configurar timeout más largo para descargas
|
| 193 |
repo_path = snapshot_download(
|
| 194 |
repo_id=repo_id,
|
| 195 |
repo_type="model",
|
| 196 |
local_dir=local_dir,
|
| 197 |
resume_download=True,
|
| 198 |
-
local_files_only=False
|
| 199 |
-
ignore_patterns=["*.h5", "*.ot", "*.msgpack"] # Ignorar archivos grandes innecesarios
|
| 200 |
)
|
| 201 |
-
|
| 202 |
print(f"✅ Modelo descargado exitosamente en: {repo_path}")
|
| 203 |
return repo_path
|
| 204 |
-
|
| 205 |
except Exception as e:
|
| 206 |
print(f"⚠️ Error en intento {attempt + 1}: {str(e)[:200]}")
|
| 207 |
if attempt < max_retries - 1:
|
| 208 |
-
|
| 209 |
-
print(f"🔄 Reintentando en {wait_time} segundos...")
|
| 210 |
-
time.sleep(wait_time)
|
| 211 |
else:
|
| 212 |
-
print("❌ No se pudo descargar el modelo después de múltiples intentos")
|
| 213 |
raise
|
|
|
|
| 214 |
|
| 215 |
-
# Intentar descargar el modelo
|
| 216 |
print(f"🚀 Iniciando carga del modelo desde {MODEL_REPO}...")
|
| 217 |
|
| 218 |
-
# Verificar si ya existe en caché local
|
| 219 |
if os.path.exists("mtp_repo") and os.path.exists("mtp_repo/mtp_model.pt"):
|
| 220 |
print("📁 Modelo encontrado en caché local")
|
| 221 |
repo_path = "mtp_repo"
|
|
@@ -223,10 +169,8 @@ else:
|
|
| 223 |
try:
|
| 224 |
repo_path = download_with_retry(MODEL_REPO, "mtp_repo", max_retries=3)
|
| 225 |
except Exception as e:
|
| 226 |
-
print(f"⚠️ Error
|
| 227 |
-
print("🏗️ Usando configuración por defecto...")
|
| 228 |
repo_path = "mtp_repo"
|
| 229 |
-
os.makedirs(repo_path, exist_ok=True)
|
| 230 |
|
| 231 |
# Cargar configuración
|
| 232 |
config_path = os.path.join(repo_path, "config.json")
|
|
@@ -234,35 +178,29 @@ if os.path.exists(config_path):
|
|
| 234 |
with open(config_path, "r") as f:
|
| 235 |
config = json.load(f)
|
| 236 |
else:
|
|
|
|
| 237 |
config = {
|
| 238 |
-
"vocab_size":
|
| 239 |
-
"d_model":
|
| 240 |
-
"n_heads":
|
| 241 |
-
"n_layers":
|
| 242 |
-
"d_ff":
|
| 243 |
"dropout": 0.1,
|
| 244 |
-
"max_len":
|
| 245 |
}
|
| 246 |
|
| 247 |
# Cargar tokenizador
|
| 248 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
| 249 |
if os.path.exists(tokenizer_path):
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
except Exception as e:
|
| 256 |
-
print(f"⚠️ Error cargando tokenizador: {e}")
|
| 257 |
-
VOCAB_SIZE = config.get("vocab_size", 5000)
|
| 258 |
-
sp = None
|
| 259 |
else:
|
| 260 |
-
print("
|
| 261 |
-
VOCAB_SIZE = config.get("vocab_size", 5000)
|
| 262 |
sp = None
|
| 263 |
-
|
| 264 |
-
# Actualizar vocab_size en config
|
| 265 |
-
config["vocab_size"] = VOCAB_SIZE
|
| 266 |
|
| 267 |
print(f"🧠 Inicializando modelo MTP...")
|
| 268 |
print(f" → Vocabulario: {VOCAB_SIZE}")
|
|
@@ -279,39 +217,21 @@ if os.path.exists(model_path):
|
|
| 279 |
try:
|
| 280 |
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 281 |
model.load_state_dict(state_dict)
|
| 282 |
-
print("✅ Pesos del modelo cargados")
|
| 283 |
except Exception as e:
|
| 284 |
print(f"⚠️ Error cargando pesos: {e}")
|
| 285 |
-
print(" Usando pesos aleatorios")
|
| 286 |
else:
|
| 287 |
-
print("⚠️ No se encontró mtp_model.pt
|
| 288 |
|
| 289 |
model.eval()
|
| 290 |
|
| 291 |
-
# Cuantización para CPU
|
| 292 |
-
if DEVICE == "cpu":
|
| 293 |
-
print("⚡ Optimizando para CPU...")
|
| 294 |
-
try:
|
| 295 |
-
model = torch.quantization.quantize_dynamic(
|
| 296 |
-
model,
|
| 297 |
-
{nn.Linear},
|
| 298 |
-
dtype=torch.qint8
|
| 299 |
-
)
|
| 300 |
-
print("✅ Cuantización aplicada")
|
| 301 |
-
except Exception as e:
|
| 302 |
-
print(f"⚠️ No se pudo aplicar cuantización: {e}")
|
| 303 |
-
|
| 304 |
param_count = sum(p.numel() for p in model.parameters())
|
| 305 |
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
|
| 306 |
|
| 307 |
# ======================
|
| 308 |
# API CONFIG
|
| 309 |
# ======================
|
| 310 |
-
app = FastAPI(
|
| 311 |
-
title="MTP-1.1 API",
|
| 312 |
-
description="API para modelo de lenguaje MTP-1.1",
|
| 313 |
-
version="1.1"
|
| 314 |
-
)
|
| 315 |
|
| 316 |
app.add_middleware(
|
| 317 |
CORSMiddleware,
|
|
@@ -321,58 +241,103 @@ app.add_middleware(
|
|
| 321 |
)
|
| 322 |
|
| 323 |
class PromptRequest(BaseModel):
|
| 324 |
-
text: str = Field(..., max_length=2000
|
| 325 |
-
max_tokens: int = Field(default=150, ge=10, le=300
|
| 326 |
-
temperature: float = Field(default=0.7, ge=0.1, le=2.0
|
| 327 |
-
top_k: int = Field(default=50, ge=1, le=100
|
| 328 |
-
top_p: float = Field(default=0.9, ge=0.1, le=1.0
|
| 329 |
-
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
| 334 |
|
| 335 |
# ======================
|
| 336 |
-
#
|
| 337 |
# ======================
|
| 338 |
ACTIVE_REQUESTS = 0
|
| 339 |
|
| 340 |
-
class
|
| 341 |
-
"""Wrapper para el tokenizador de SentencePiece"""
|
| 342 |
def __init__(self, sp_model):
|
| 343 |
self.sp = sp_model
|
| 344 |
-
|
| 345 |
def encode(self, text):
|
| 346 |
if self.sp is None:
|
| 347 |
-
# Tokenizador simple para fallback
|
| 348 |
return [ord(c) % 1000 for c in text[:200]]
|
| 349 |
return self.sp.encode(text)
|
| 350 |
-
|
| 351 |
def decode(self, tokens):
|
| 352 |
if self.sp is None:
|
| 353 |
return ''.join([chr(t % 128) if 32 <= t % 128 < 127 else ' ' for t in tokens])
|
| 354 |
return self.sp.decode(tokens)
|
| 355 |
-
|
| 356 |
-
def bos_id(self):
|
| 357 |
-
if self.sp is None:
|
| 358 |
-
return 2
|
| 359 |
-
return self.sp.bos_id()
|
| 360 |
-
|
| 361 |
def eos_id(self):
|
| 362 |
-
if self.sp
|
| 363 |
-
|
| 364 |
-
return self.sp.
|
| 365 |
-
|
| 366 |
def pad_id(self):
|
| 367 |
-
if self.sp
|
| 368 |
-
return 0
|
| 369 |
-
return self.sp.pad_id()
|
| 370 |
|
| 371 |
-
tokenizer_wrapper =
|
| 372 |
|
| 373 |
@app.post("/generate")
|
| 374 |
async def generate(req: PromptRequest):
|
| 375 |
-
"""Endpoint principal de generación de texto"""
|
| 376 |
global ACTIVE_REQUESTS
|
| 377 |
ACTIVE_REQUESTS += 1
|
| 378 |
|
|
@@ -389,83 +354,51 @@ async def generate(req: PromptRequest):
|
|
| 389 |
ACTIVE_REQUESTS -= 1
|
| 390 |
return {"reply": "", "tokens_generated": 0}
|
| 391 |
|
| 392 |
-
full_prompt = build_prompt(user_input)
|
| 393 |
-
tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
|
| 394 |
-
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 395 |
-
|
| 396 |
try:
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
gen_tokens = output_ids[0, len(tokens):].tolist()
|
| 408 |
-
|
| 409 |
-
safe_tokens = [
|
| 410 |
-
t for t in gen_tokens
|
| 411 |
-
if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
|
| 412 |
-
]
|
| 413 |
-
|
| 414 |
-
response = tokenizer_wrapper.decode(safe_tokens).strip()
|
| 415 |
-
|
| 416 |
-
# Limpiar la respuesta
|
| 417 |
-
if "###" in response:
|
| 418 |
-
response = response.split("###")[0].strip()
|
| 419 |
-
|
| 420 |
-
# Si la respuesta está vacía, devolver mensaje por defecto
|
| 421 |
-
if not response or len(response) < 2:
|
| 422 |
-
response = "Entendido. ¿En qué más puedo ayudarte?"
|
| 423 |
-
|
| 424 |
return {
|
| 425 |
"reply": response,
|
| 426 |
-
"tokens_generated": len(
|
| 427 |
-
"model": "MTP
|
| 428 |
}
|
| 429 |
-
|
| 430 |
except Exception as e:
|
| 431 |
-
print(f"❌ Error
|
| 432 |
-
return {
|
| 433 |
-
"reply": "Lo siento, ocurrió un error al procesar tu solicitud.",
|
| 434 |
-
"error": str(e)
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
finally:
|
| 438 |
ACTIVE_REQUESTS -= 1
|
| 439 |
if DEVICE == "cuda":
|
| 440 |
torch.cuda.empty_cache()
|
| 441 |
gc.collect()
|
| 442 |
|
| 443 |
-
# ======================
|
| 444 |
-
# ENDPOINTS DE INFORMACIÓN
|
| 445 |
-
# ======================
|
| 446 |
@app.get("/health")
|
| 447 |
def health_check():
|
| 448 |
return {
|
| 449 |
"status": "healthy",
|
| 450 |
-
"model": "MTP
|
| 451 |
"device": DEVICE,
|
| 452 |
"active_requests": ACTIVE_REQUESTS,
|
| 453 |
-
"vocab_size": VOCAB_SIZE
|
| 454 |
-
"model_loaded": os.path.exists("mtp_repo/mtp_model.pt")
|
| 455 |
}
|
| 456 |
|
| 457 |
@app.get("/info")
|
| 458 |
def model_info():
|
| 459 |
return {
|
| 460 |
-
"model_name": "MTP
|
| 461 |
-
"version": "1.
|
| 462 |
"architecture": config,
|
| 463 |
"parameters": sum(p.numel() for p in model.parameters()),
|
| 464 |
"device": DEVICE
|
| 465 |
}
|
| 466 |
|
| 467 |
# ======================
|
| 468 |
-
# INTERFAZ WEB (
|
| 469 |
# ======================
|
| 470 |
@app.get("/", response_class=HTMLResponse)
|
| 471 |
def chat_ui():
|
|
@@ -475,7 +408,7 @@ def chat_ui():
|
|
| 475 |
<head>
|
| 476 |
<meta charset="UTF-8">
|
| 477 |
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
| 478 |
-
<title>MTP
|
| 479 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 480 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 481 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
|
|
@@ -795,7 +728,7 @@ async function sendMessage(textOverride = null) {
|
|
| 795 |
const response = await fetch('/generate', {
|
| 796 |
method: 'POST',
|
| 797 |
headers: { 'Content-Type': 'application/json' },
|
| 798 |
-
body: JSON.stringify({ text: text }),
|
| 799 |
signal: abortController.signal
|
| 800 |
});
|
| 801 |
const data = await response.json();
|
|
@@ -883,7 +816,7 @@ window.onload = () => userInput.focus();
|
|
| 883 |
|
| 884 |
if __name__ == "__main__":
|
| 885 |
port = int(os.environ.get("PORT", 7860))
|
| 886 |
-
print(f"\n🚀 Iniciando servidor MTP
|
| 887 |
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
|
| 888 |
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
|
| 889 |
|
|
|
|
| 4 |
import json
|
| 5 |
import time
|
| 6 |
import gc
|
| 7 |
+
import re
|
| 8 |
from fastapi import FastAPI, Request
|
| 9 |
+
from fastapi.responses import HTMLResponse
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from pydantic import BaseModel, Field
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
import uvicorn
|
| 14 |
import math
|
| 15 |
import torch.nn as nn
|
|
|
|
| 34 |
MODEL_REPO = "TeszenAI/MTP-3"
|
| 35 |
|
| 36 |
# ======================
|
| 37 |
+
# ARQUITECTURA DEL MODELO (MISMA QUE EN colab.py)
|
| 38 |
# ======================
|
| 39 |
class LayerNorm(nn.Module):
|
| 40 |
def __init__(self, d_model: int, eps: float = 1e-5):
|
|
|
|
| 42 |
self.weight = nn.Parameter(torch.ones(d_model))
|
| 43 |
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 44 |
self.eps = eps
|
|
|
|
| 45 |
def forward(self, x):
|
| 46 |
mean = x.mean(-1, keepdim=True)
|
| 47 |
std = x.std(-1, keepdim=True)
|
|
|
|
| 60 |
self.w_o = nn.Linear(d_model, d_model)
|
| 61 |
self.dropout = nn.Dropout(dropout)
|
| 62 |
self.scale = math.sqrt(self.d_k)
|
|
|
|
| 63 |
def forward(self, x, mask=None):
|
| 64 |
batch_size, seq_len, _ = x.shape
|
| 65 |
Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
|
|
|
|
| 80 |
self.linear1 = nn.Linear(d_model, d_ff)
|
| 81 |
self.linear2 = nn.Linear(d_ff, d_model)
|
| 82 |
self.dropout = nn.Dropout(dropout)
|
|
|
|
| 83 |
def forward(self, x):
|
| 84 |
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
|
| 85 |
|
|
|
|
| 92 |
self.norm2 = LayerNorm(d_model)
|
| 93 |
self.dropout1 = nn.Dropout(dropout)
|
| 94 |
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
| 95 |
def forward(self, x, mask=None):
|
| 96 |
attn_output = self.attention(x, mask)
|
| 97 |
x = x + self.dropout1(attn_output)
|
|
|
|
| 110 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 111 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 112 |
self.register_buffer('pe', pe.unsqueeze(0))
|
|
|
|
| 113 |
def forward(self, x):
|
| 114 |
return x + self.pe[:, :x.size(1), :]
|
| 115 |
|
| 116 |
class MTPModel(nn.Module):
|
| 117 |
+
def __init__(self, vocab_size: int, d_model: int = 256, n_heads: int = 8,
|
| 118 |
+
n_layers: int = 6, d_ff: int = 1024, dropout: float = 0.1, max_len: int = 512):
|
| 119 |
super().__init__()
|
| 120 |
self.vocab_size = vocab_size
|
| 121 |
self.d_model = d_model
|
| 122 |
self.max_len = max_len
|
| 123 |
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 124 |
self.pos_encoding = PositionalEncoding(d_model, max_len)
|
| 125 |
+
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
|
|
|
|
|
|
|
| 126 |
self.norm = LayerNorm(d_model)
|
| 127 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 128 |
|
|
|
|
| 134 |
for block in self.blocks:
|
| 135 |
x = block(x, mask)
|
| 136 |
x = self.norm(x)
|
| 137 |
+
return self.lm_head(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
# ======================
|
| 140 |
# DESCARGA Y CARGA DEL MODELO CON REINTENTOS
|
| 141 |
# ======================
|
| 142 |
def download_with_retry(repo_id, local_dir, max_retries=3):
|
|
|
|
|
|
|
| 143 |
for attempt in range(max_retries):
|
| 144 |
try:
|
| 145 |
print(f"📦 Intento {attempt + 1}/{max_retries} - Descargando modelo desde {repo_id}...")
|
|
|
|
|
|
|
| 146 |
repo_path = snapshot_download(
|
| 147 |
repo_id=repo_id,
|
| 148 |
repo_type="model",
|
| 149 |
local_dir=local_dir,
|
| 150 |
resume_download=True,
|
| 151 |
+
local_files_only=False
|
|
|
|
| 152 |
)
|
|
|
|
| 153 |
print(f"✅ Modelo descargado exitosamente en: {repo_path}")
|
| 154 |
return repo_path
|
|
|
|
| 155 |
except Exception as e:
|
| 156 |
print(f"⚠️ Error en intento {attempt + 1}: {str(e)[:200]}")
|
| 157 |
if attempt < max_retries - 1:
|
| 158 |
+
time.sleep(3)
|
|
|
|
|
|
|
| 159 |
else:
|
|
|
|
| 160 |
raise
|
| 161 |
+
return local_dir
|
| 162 |
|
|
|
|
| 163 |
print(f"🚀 Iniciando carga del modelo desde {MODEL_REPO}...")
|
| 164 |
|
|
|
|
| 165 |
if os.path.exists("mtp_repo") and os.path.exists("mtp_repo/mtp_model.pt"):
|
| 166 |
print("📁 Modelo encontrado en caché local")
|
| 167 |
repo_path = "mtp_repo"
|
|
|
|
| 169 |
try:
|
| 170 |
repo_path = download_with_retry(MODEL_REPO, "mtp_repo", max_retries=3)
|
| 171 |
except Exception as e:
|
| 172 |
+
print(f"⚠️ Error: {e}")
|
|
|
|
| 173 |
repo_path = "mtp_repo"
|
|
|
|
| 174 |
|
| 175 |
# Cargar configuración
|
| 176 |
config_path = os.path.join(repo_path, "config.json")
|
|
|
|
| 178 |
with open(config_path, "r") as f:
|
| 179 |
config = json.load(f)
|
| 180 |
else:
|
| 181 |
+
# Configuración por defecto (MISMA que en colab.py)
|
| 182 |
config = {
|
| 183 |
+
"vocab_size": 2000,
|
| 184 |
+
"d_model": 256,
|
| 185 |
+
"n_heads": 8,
|
| 186 |
+
"n_layers": 6,
|
| 187 |
+
"d_ff": 1024,
|
| 188 |
"dropout": 0.1,
|
| 189 |
+
"max_len": 512
|
| 190 |
}
|
| 191 |
|
| 192 |
# Cargar tokenizador
|
| 193 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
| 194 |
if os.path.exists(tokenizer_path):
|
| 195 |
+
sp = spm.SentencePieceProcessor()
|
| 196 |
+
sp.load(tokenizer_path)
|
| 197 |
+
VOCAB_SIZE = sp.get_piece_size()
|
| 198 |
+
config["vocab_size"] = VOCAB_SIZE
|
| 199 |
+
print(f"✅ Tokenizador cargado: {VOCAB_SIZE} tokens")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
else:
|
| 201 |
+
print("❌ No se encontró tokenizador")
|
|
|
|
| 202 |
sp = None
|
| 203 |
+
VOCAB_SIZE = config.get("vocab_size", 2000)
|
|
|
|
|
|
|
| 204 |
|
| 205 |
print(f"🧠 Inicializando modelo MTP...")
|
| 206 |
print(f" → Vocabulario: {VOCAB_SIZE}")
|
|
|
|
| 217 |
try:
|
| 218 |
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 219 |
model.load_state_dict(state_dict)
|
| 220 |
+
print("✅ Pesos del modelo cargados correctamente")
|
| 221 |
except Exception as e:
|
| 222 |
print(f"⚠️ Error cargando pesos: {e}")
|
|
|
|
| 223 |
else:
|
| 224 |
+
print("⚠️ No se encontró mtp_model.pt")
|
| 225 |
|
| 226 |
model.eval()
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
param_count = sum(p.numel() for p in model.parameters())
|
| 229 |
print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
|
| 230 |
|
| 231 |
# ======================
|
| 232 |
# API CONFIG
|
| 233 |
# ======================
|
| 234 |
+
app = FastAPI(title="MTP API", description="API para modelo de lenguaje MTP", version="1.0")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
app.add_middleware(
|
| 237 |
CORSMiddleware,
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
class PromptRequest(BaseModel):
|
| 244 |
+
text: str = Field(..., max_length=2000)
|
| 245 |
+
max_tokens: int = Field(default=150, ge=10, le=300)
|
| 246 |
+
temperature: float = Field(default=0.7, ge=0.1, le=2.0)
|
| 247 |
+
top_k: int = Field(default=50, ge=1, le=100)
|
| 248 |
+
top_p: float = Field(default=0.9, ge=0.1, le=1.0)
|
| 249 |
+
repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0)
|
| 250 |
+
|
| 251 |
+
# ======================
|
| 252 |
+
# FUNCIÓN DE GENERACIÓN (IGUAL QUE EN colab.py)
|
| 253 |
+
# ======================
|
| 254 |
+
def generate_response(model, tokenizer, prompt, max_length=150, temperature=0.7, top_k=50, top_p=0.9, device='cpu'):
|
| 255 |
+
model.eval()
|
| 256 |
+
formatted_prompt = f"### Instrucción:\n{prompt}\n\n### Respuesta:\n"
|
| 257 |
+
input_ids = tokenizer.encode(formatted_prompt)
|
| 258 |
+
generated = input_ids.copy()
|
| 259 |
+
eos_id = tokenizer.eos_id()
|
| 260 |
+
|
| 261 |
+
for _ in range(max_length):
|
| 262 |
+
input_tensor = torch.tensor([generated[-model.max_len:]], dtype=torch.long).to(device)
|
| 263 |
+
with torch.no_grad():
|
| 264 |
+
logits = model(input_tensor)
|
| 265 |
+
next_logits = logits[0, -1, :] / temperature
|
| 266 |
+
|
| 267 |
+
if top_k > 0:
|
| 268 |
+
indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
|
| 269 |
+
next_logits[indices_to_remove] = float('-inf')
|
| 270 |
+
|
| 271 |
+
if top_p < 1.0:
|
| 272 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 273 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 274 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 275 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 276 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 277 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 278 |
+
next_logits[indices_to_remove] = float('-inf')
|
| 279 |
+
|
| 280 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 281 |
+
next_token = torch.multinomial(probs, 1).item()
|
| 282 |
+
|
| 283 |
+
if next_token == eos_id:
|
| 284 |
+
break
|
| 285 |
+
|
| 286 |
+
if len(generated) > 20:
|
| 287 |
+
last_tokens = generated[-10:]
|
| 288 |
+
if len(set(last_tokens)) == 1:
|
| 289 |
+
break
|
| 290 |
+
|
| 291 |
+
generated.append(next_token)
|
| 292 |
+
|
| 293 |
+
response = tokenizer.decode(generated)
|
| 294 |
+
if "### Respuesta:" in response:
|
| 295 |
+
response = response.split("### Respuesta:")[-1].strip()
|
| 296 |
+
elif "Respuesta:" in response:
|
| 297 |
+
response = response.split("Respuesta:")[-1].strip()
|
| 298 |
+
elif "[/INST]" in response:
|
| 299 |
+
response = response.split("[/INST]")[-1].strip()
|
| 300 |
+
|
| 301 |
+
# Limpiar caracteres basura
|
| 302 |
+
garbage_words = ['foompañances', 'ciudadores', 'mejtedon', 'calportedon', 'rápidodcor', 'baon', 'domol']
|
| 303 |
+
for word in garbage_words:
|
| 304 |
+
response = response.replace(word, '')
|
| 305 |
+
|
| 306 |
+
response = re.sub(r'[^\w\s\u00C0-\u00FF\u0100-\u017F.,!?¿¡()\-:;"]+', ' ', response)
|
| 307 |
+
response = re.sub(r'\s+', ' ', response).strip()
|
| 308 |
|
| 309 |
+
if len(response) < 2:
|
| 310 |
+
response = "Entendido. ¿Algo más en lo que pueda ayudarte?"
|
| 311 |
+
|
| 312 |
+
return response
|
| 313 |
|
| 314 |
# ======================
|
| 315 |
+
# ENDPOINTS
|
| 316 |
# ======================
|
| 317 |
ACTIVE_REQUESTS = 0
|
| 318 |
|
| 319 |
+
class TokenizerWrapper:
|
|
|
|
| 320 |
def __init__(self, sp_model):
|
| 321 |
self.sp = sp_model
|
|
|
|
| 322 |
def encode(self, text):
|
| 323 |
if self.sp is None:
|
|
|
|
| 324 |
return [ord(c) % 1000 for c in text[:200]]
|
| 325 |
return self.sp.encode(text)
|
|
|
|
| 326 |
def decode(self, tokens):
|
| 327 |
if self.sp is None:
|
| 328 |
return ''.join([chr(t % 128) if 32 <= t % 128 < 127 else ' ' for t in tokens])
|
| 329 |
return self.sp.decode(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def eos_id(self):
|
| 331 |
+
return self.sp.eos_id() if self.sp else 3
|
| 332 |
+
def bos_id(self):
|
| 333 |
+
return self.sp.bos_id() if self.sp else 2
|
|
|
|
| 334 |
def pad_id(self):
|
| 335 |
+
return self.sp.pad_id() if self.sp else 0
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
tokenizer_wrapper = TokenizerWrapper(sp)
|
| 338 |
|
| 339 |
@app.post("/generate")
|
| 340 |
async def generate(req: PromptRequest):
|
|
|
|
| 341 |
global ACTIVE_REQUESTS
|
| 342 |
ACTIVE_REQUESTS += 1
|
| 343 |
|
|
|
|
| 354 |
ACTIVE_REQUESTS -= 1
|
| 355 |
return {"reply": "", "tokens_generated": 0}
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
try:
|
| 358 |
+
response = generate_response(
|
| 359 |
+
model, tokenizer_wrapper, user_input,
|
| 360 |
+
max_length=dyn_max_tokens,
|
| 361 |
+
temperature=dyn_temperature,
|
| 362 |
+
top_k=req.top_k,
|
| 363 |
+
top_p=req.top_p,
|
| 364 |
+
device=DEVICE
|
| 365 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
return {
|
| 367 |
"reply": response,
|
| 368 |
+
"tokens_generated": len(response.split()),
|
| 369 |
+
"model": "MTP"
|
| 370 |
}
|
|
|
|
| 371 |
except Exception as e:
|
| 372 |
+
print(f"❌ Error: {e}")
|
| 373 |
+
return {"reply": "Lo siento, ocurrió un error.", "error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
finally:
|
| 375 |
ACTIVE_REQUESTS -= 1
|
| 376 |
if DEVICE == "cuda":
|
| 377 |
torch.cuda.empty_cache()
|
| 378 |
gc.collect()
|
| 379 |
|
|
|
|
|
|
|
|
|
|
| 380 |
@app.get("/health")
|
| 381 |
def health_check():
|
| 382 |
return {
|
| 383 |
"status": "healthy",
|
| 384 |
+
"model": "MTP",
|
| 385 |
"device": DEVICE,
|
| 386 |
"active_requests": ACTIVE_REQUESTS,
|
| 387 |
+
"vocab_size": VOCAB_SIZE
|
|
|
|
| 388 |
}
|
| 389 |
|
| 390 |
@app.get("/info")
|
| 391 |
def model_info():
|
| 392 |
return {
|
| 393 |
+
"model_name": "MTP",
|
| 394 |
+
"version": "1.0",
|
| 395 |
"architecture": config,
|
| 396 |
"parameters": sum(p.numel() for p in model.parameters()),
|
| 397 |
"device": DEVICE
|
| 398 |
}
|
| 399 |
|
| 400 |
# ======================
|
| 401 |
+
# INTERFAZ WEB COMPLETA (CON TODAS LAS FUNCIONES ORIGINALES)
|
| 402 |
# ======================
|
| 403 |
@app.get("/", response_class=HTMLResponse)
|
| 404 |
def chat_ui():
|
|
|
|
| 408 |
<head>
|
| 409 |
<meta charset="UTF-8">
|
| 410 |
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
| 411 |
+
<title>MTP 3</title>
|
| 412 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 413 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 414 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
|
|
|
|
| 728 |
const response = await fetch('/generate', {
|
| 729 |
method: 'POST',
|
| 730 |
headers: { 'Content-Type': 'application/json' },
|
| 731 |
+
body: JSON.stringify({ text: text, max_tokens: 150, temperature: 0.7 }),
|
| 732 |
signal: abortController.signal
|
| 733 |
});
|
| 734 |
const data = await response.json();
|
|
|
|
| 816 |
|
| 817 |
if __name__ == "__main__":
|
| 818 |
port = int(os.environ.get("PORT", 7860))
|
| 819 |
+
print(f"\n🚀 Iniciando servidor MTP en puerto {port}...")
|
| 820 |
print(f"🌐 Interfaz web: http://0.0.0.0:{port}")
|
| 821 |
print(f"📡 API docs: http://0.0.0.0:{port}/docs")
|
| 822 |
|