Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import sys
|
|
| 3 |
import torch
|
| 4 |
import json
|
| 5 |
import gc
|
| 6 |
-
import
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from fastapi.responses import HTMLResponse
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -17,11 +17,11 @@ import sentencepiece as spm
|
|
| 17 |
|
| 18 |
if torch.cuda.is_available():
|
| 19 |
DEVICE = "cuda"
|
| 20 |
-
print("✅ GPU")
|
| 21 |
torch.backends.cudnn.benchmark = True
|
| 22 |
else:
|
| 23 |
DEVICE = "cpu"
|
| 24 |
-
print("⚠️ CPU")
|
| 25 |
torch.set_num_threads(4)
|
| 26 |
|
| 27 |
torch.set_grad_enabled(False)
|
|
@@ -114,7 +114,9 @@ class MTPModel(nn.Module):
|
|
| 114 |
self.max_len = max_len
|
| 115 |
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 116 |
self.pos_encoding = PositionalEncoding(d_model, max_len)
|
| 117 |
-
self.blocks = nn.ModuleList([
|
|
|
|
|
|
|
| 118 |
self.norm = LayerNorm(d_model)
|
| 119 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 120 |
|
|
@@ -129,7 +131,7 @@ class MTPModel(nn.Module):
|
|
| 129 |
return self.lm_head(x)
|
| 130 |
|
| 131 |
@torch.inference_mode()
|
| 132 |
-
def generate(self, input_ids, max_new_tokens=
|
| 133 |
generated = input_ids
|
| 134 |
for _ in range(max_new_tokens):
|
| 135 |
logits = self(generated)
|
|
@@ -143,7 +145,7 @@ class MTPModel(nn.Module):
|
|
| 143 |
if next_token == 3 or next_token == 0 or next_token == 1:
|
| 144 |
break
|
| 145 |
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
|
| 146 |
-
if len(generated[0]) >
|
| 147 |
break
|
| 148 |
return generated
|
| 149 |
|
|
@@ -151,25 +153,52 @@ print("📦 Descargando modelo...")
|
|
| 151 |
repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
|
| 152 |
|
| 153 |
config_path = os.path.join(repo_path, "config.json")
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
sp = spm.SentencePieceProcessor()
|
| 159 |
sp.load(tokenizer_path)
|
| 160 |
-
|
|
|
|
| 161 |
|
| 162 |
-
print(f"
|
| 163 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
model = MTPModel(**config)
|
| 166 |
model.to(DEVICE)
|
| 167 |
|
| 168 |
model_path = os.path.join(repo_path, "mtp_model.pt")
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
model.eval()
|
| 172 |
-
print(f"✅ Modelo: {sum(p.numel() for p in model.parameters()):,} params")
|
| 173 |
|
| 174 |
app = FastAPI()
|
| 175 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
@@ -180,24 +209,29 @@ class PromptRequest(BaseModel):
|
|
| 180 |
def build_prompt(user_input):
|
| 181 |
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
@app.post("/generate")
|
| 184 |
async def generate(req: PromptRequest):
|
| 185 |
-
start = time.time()
|
| 186 |
user_input = req.text.strip()
|
| 187 |
-
|
| 188 |
if not user_input:
|
| 189 |
-
return {"reply": "
|
| 190 |
|
| 191 |
prompt = build_prompt(user_input)
|
| 192 |
tokens = sp.encode(prompt)
|
| 193 |
|
| 194 |
-
if len(tokens) >
|
| 195 |
-
tokens = tokens[-
|
| 196 |
|
| 197 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 198 |
|
| 199 |
try:
|
| 200 |
-
output_ids = model.generate(input_ids, max_new_tokens=
|
| 201 |
|
| 202 |
gen_tokens = output_ids[0, len(tokens):].tolist()
|
| 203 |
|
|
@@ -208,26 +242,15 @@ async def generate(req: PromptRequest):
|
|
| 208 |
clean_tokens.append(t)
|
| 209 |
|
| 210 |
response = sp.decode(clean_tokens).strip() if clean_tokens else ""
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
for marker in markers:
|
| 214 |
-
if marker in response:
|
| 215 |
-
response = response.split(marker)[-1]
|
| 216 |
-
|
| 217 |
-
response = response.replace('<unk>', '').replace('<pad>', '').replace('<s>', '').replace('</s>', '')
|
| 218 |
-
response = ' '.join(response.split())
|
| 219 |
-
|
| 220 |
-
if not response or len(response) < 2:
|
| 221 |
-
response = "Entiendo tu pregunta. ¿Podrías darme más detalles?"
|
| 222 |
|
| 223 |
-
|
| 224 |
-
print(f"✅ {user_input[:25]}... -> {elapsed:.1f}s ({len(clean_tokens)} tokens)")
|
| 225 |
-
|
| 226 |
-
return {"reply": response[:350], "time": elapsed}
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
print(f"❌ Error: {e}")
|
| 230 |
-
return {"reply": "
|
| 231 |
|
| 232 |
@app.get("/health")
|
| 233 |
def health():
|
|
@@ -360,11 +383,11 @@ body {
|
|
| 360 |
<body>
|
| 361 |
<div class="header">
|
| 362 |
<h1><span class="dot"></span> MTP Assistant</h1>
|
| 363 |
-
<p>Modelo Transformer
|
| 364 |
</div>
|
| 365 |
<div class="chat" id="chat">
|
| 366 |
<div class="message bot">
|
| 367 |
-
<div class="message-content">
|
| 368 |
</div>
|
| 369 |
</div>
|
| 370 |
<div class="input-area">
|
|
@@ -415,8 +438,6 @@ async function send() {
|
|
| 415 |
sendBtn.disabled = true;
|
| 416 |
addTyping();
|
| 417 |
|
| 418 |
-
const startTime = Date.now();
|
| 419 |
-
|
| 420 |
try {
|
| 421 |
const res = await fetch('/generate', {
|
| 422 |
method: 'POST',
|
|
@@ -424,13 +445,11 @@ async function send() {
|
|
| 424 |
body: JSON.stringify({ text: text })
|
| 425 |
});
|
| 426 |
const data = await res.json();
|
| 427 |
-
const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
|
| 428 |
removeTyping();
|
| 429 |
addMessage(data.reply || "No pude generar respuesta.", false);
|
| 430 |
-
console.log(`Respuesta en ${elapsed}s`);
|
| 431 |
} catch (err) {
|
| 432 |
removeTyping();
|
| 433 |
-
addMessage("Error de conexión.
|
| 434 |
} finally {
|
| 435 |
loading = false;
|
| 436 |
sendBtn.disabled = false;
|
|
|
|
| 3 |
import torch
|
| 4 |
import json
|
| 5 |
import gc
|
| 6 |
+
import re
|
| 7 |
from fastapi import FastAPI
|
| 8 |
from fastapi.responses import HTMLResponse
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 17 |
|
| 18 |
if torch.cuda.is_available():
|
| 19 |
DEVICE = "cuda"
|
| 20 |
+
print("✅ GPU detectada")
|
| 21 |
torch.backends.cudnn.benchmark = True
|
| 22 |
else:
|
| 23 |
DEVICE = "cpu"
|
| 24 |
+
print("⚠️ CPU mode")
|
| 25 |
torch.set_num_threads(4)
|
| 26 |
|
| 27 |
torch.set_grad_enabled(False)
|
|
|
|
| 114 |
self.max_len = max_len
|
| 115 |
self.token_embedding = nn.Embedding(vocab_size, d_model)
|
| 116 |
self.pos_encoding = PositionalEncoding(d_model, max_len)
|
| 117 |
+
self.blocks = nn.ModuleList([
|
| 118 |
+
TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 119 |
+
])
|
| 120 |
self.norm = LayerNorm(d_model)
|
| 121 |
self.lm_head = nn.Linear(d_model, vocab_size)
|
| 122 |
|
|
|
|
| 131 |
return self.lm_head(x)
|
| 132 |
|
| 133 |
@torch.inference_mode()
|
| 134 |
+
def generate(self, input_ids, max_new_tokens=150, temperature=0.7, top_k=50):
|
| 135 |
generated = input_ids
|
| 136 |
for _ in range(max_new_tokens):
|
| 137 |
logits = self(generated)
|
|
|
|
| 145 |
if next_token == 3 or next_token == 0 or next_token == 1:
|
| 146 |
break
|
| 147 |
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
|
| 148 |
+
if len(generated[0]) > 200:
|
| 149 |
break
|
| 150 |
return generated
|
| 151 |
|
|
|
|
| 153 |
repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
|
| 154 |
|
| 155 |
config_path = os.path.join(repo_path, "config.json")
|
| 156 |
+
if os.path.exists(config_path):
|
| 157 |
+
with open(config_path, "r") as f:
|
| 158 |
+
config = json.load(f)
|
| 159 |
+
print(f"✅ Configuración cargada: d_model={config.get('d_model', 512)}, layers={config.get('n_layers', 8)}")
|
| 160 |
+
else:
|
| 161 |
+
print("⚠️ Usando configuración por defecto (igual que colab.py)")
|
| 162 |
+
config = {
|
| 163 |
+
"vocab_size": 8000,
|
| 164 |
+
"d_model": 512,
|
| 165 |
+
"n_heads": 8,
|
| 166 |
+
"n_layers": 8,
|
| 167 |
+
"d_ff": 2048,
|
| 168 |
+
"dropout": 0.1,
|
| 169 |
+
"max_len": 1024
|
| 170 |
+
}
|
| 171 |
|
| 172 |
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
|
| 173 |
+
if not os.path.exists(tokenizer_path):
|
| 174 |
+
print(f"❌ Tokenizador no encontrado")
|
| 175 |
+
sys.exit(1)
|
| 176 |
+
|
| 177 |
sp = spm.SentencePieceProcessor()
|
| 178 |
sp.load(tokenizer_path)
|
| 179 |
+
VOCAB_SIZE = sp.get_piece_size()
|
| 180 |
+
config["vocab_size"] = VOCAB_SIZE
|
| 181 |
|
| 182 |
+
print(f"🧠 Inicializando modelo MTP...")
|
| 183 |
+
print(f" → Vocabulario: {VOCAB_SIZE}")
|
| 184 |
+
print(f" → Dimensión: {config['d_model']}")
|
| 185 |
+
print(f" → Capas: {config['n_layers']}")
|
| 186 |
+
print(f" → Heads: {config['n_heads']}")
|
| 187 |
|
| 188 |
model = MTPModel(**config)
|
| 189 |
model.to(DEVICE)
|
| 190 |
|
| 191 |
model_path = os.path.join(repo_path, "mtp_model.pt")
|
| 192 |
+
if os.path.exists(model_path):
|
| 193 |
+
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 194 |
+
model.load_state_dict(state_dict, strict=False)
|
| 195 |
+
print("✅ Pesos cargados correctamente")
|
| 196 |
+
else:
|
| 197 |
+
print(f"❌ Modelo no encontrado")
|
| 198 |
+
sys.exit(1)
|
| 199 |
+
|
| 200 |
model.eval()
|
| 201 |
+
print(f"✅ Modelo listo: {sum(p.numel() for p in model.parameters()):,} params")
|
| 202 |
|
| 203 |
app = FastAPI()
|
| 204 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
|
|
| 209 |
def build_prompt(user_input):
|
| 210 |
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
| 211 |
|
| 212 |
+
def clean_response(text):
|
| 213 |
+
if not text:
|
| 214 |
+
return ""
|
| 215 |
+
text = re.sub(r'<unk>|<pad>|<s>|</s>', '', text)
|
| 216 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 217 |
+
return text
|
| 218 |
+
|
| 219 |
@app.post("/generate")
|
| 220 |
async def generate(req: PromptRequest):
|
|
|
|
| 221 |
user_input = req.text.strip()
|
|
|
|
| 222 |
if not user_input:
|
| 223 |
+
return {"reply": ""}
|
| 224 |
|
| 225 |
prompt = build_prompt(user_input)
|
| 226 |
tokens = sp.encode(prompt)
|
| 227 |
|
| 228 |
+
if len(tokens) > 900:
|
| 229 |
+
tokens = tokens[-900:]
|
| 230 |
|
| 231 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 232 |
|
| 233 |
try:
|
| 234 |
+
output_ids = model.generate(input_ids, max_new_tokens=120, temperature=0.7, top_k=50)
|
| 235 |
|
| 236 |
gen_tokens = output_ids[0, len(tokens):].tolist()
|
| 237 |
|
|
|
|
| 242 |
clean_tokens.append(t)
|
| 243 |
|
| 244 |
response = sp.decode(clean_tokens).strip() if clean_tokens else ""
|
| 245 |
+
response = clean_response(response)
|
| 246 |
|
| 247 |
+
print(f"📝 {user_input[:40]} -> {len(clean_tokens)} tokens")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
+
return {"reply": response[:500]}
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
except Exception as e:
|
| 252 |
print(f"❌ Error: {e}")
|
| 253 |
+
return {"reply": ""}
|
| 254 |
|
| 255 |
@app.get("/health")
|
| 256 |
def health():
|
|
|
|
| 383 |
<body>
|
| 384 |
<div class="header">
|
| 385 |
<h1><span class="dot"></span> MTP Assistant</h1>
|
| 386 |
+
<p>Modelo Transformer 512-dim / 8-capas</p>
|
| 387 |
</div>
|
| 388 |
<div class="chat" id="chat">
|
| 389 |
<div class="message bot">
|
| 390 |
+
<div class="message-content">Hola, soy MTP. ¿En qué puedo ayudarte?</div>
|
| 391 |
</div>
|
| 392 |
</div>
|
| 393 |
<div class="input-area">
|
|
|
|
| 438 |
sendBtn.disabled = true;
|
| 439 |
addTyping();
|
| 440 |
|
|
|
|
|
|
|
| 441 |
try {
|
| 442 |
const res = await fetch('/generate', {
|
| 443 |
method: 'POST',
|
|
|
|
| 445 |
body: JSON.stringify({ text: text })
|
| 446 |
});
|
| 447 |
const data = await res.json();
|
|
|
|
| 448 |
removeTyping();
|
| 449 |
addMessage(data.reply || "No pude generar respuesta.", false);
|
|
|
|
| 450 |
} catch (err) {
|
| 451 |
removeTyping();
|
| 452 |
+
addMessage("Error de conexión.", false);
|
| 453 |
} finally {
|
| 454 |
loading = false;
|
| 455 |
sendBtn.disabled = false;
|