Update app.py
Browse files
app.py
CHANGED
|
@@ -9,10 +9,17 @@ from huggingface_hub import snapshot_download
|
|
| 9 |
import uvicorn
|
| 10 |
|
| 11 |
# ======================
|
| 12 |
-
# CONFIGURACIÓN
|
| 13 |
# ======================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
MODEL_REPO = "teszenofficial/mtptz"
|
| 15 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
# ======================
|
| 18 |
# DESCARGA DEL MODELO
|
|
@@ -57,13 +64,14 @@ model = MTPMiniModel(
|
|
| 57 |
dropout=0.0
|
| 58 |
)
|
| 59 |
|
|
|
|
| 60 |
model.load_state_dict(model_data["model_state_dict"])
|
| 61 |
model.to(DEVICE)
|
| 62 |
model.eval()
|
| 63 |
-
print(f"MTP 1.1 listo en {DEVICE}")
|
| 64 |
|
| 65 |
# ======================
|
| 66 |
-
# API
|
| 67 |
# ======================
|
| 68 |
app = FastAPI(title="MTP 1.1 API")
|
| 69 |
|
|
@@ -78,6 +86,8 @@ def generate(prompt: Prompt):
|
|
| 78 |
|
| 79 |
full_prompt = f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
| 80 |
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
|
|
|
|
|
|
|
| 81 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 82 |
|
| 83 |
with torch.no_grad():
|
|
@@ -101,7 +111,7 @@ def generate(prompt: Prompt):
|
|
| 101 |
return {"reply": response}
|
| 102 |
|
| 103 |
# ======================
|
| 104 |
-
# INTERFAZ WEB (FRONTEND)
|
| 105 |
# ======================
|
| 106 |
@app.get("/", response_class=HTMLResponse)
|
| 107 |
def chat_ui():
|
|
@@ -115,9 +125,8 @@ def chat_ui():
|
|
| 115 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 116 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 117 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
|
| 118 |
-
<!-- Iconos SVG -->
|
| 119 |
<style>
|
| 120 |
-
/* --- VARIABLES --- */
|
| 121 |
:root {
|
| 122 |
--bg-color: #131314;
|
| 123 |
--surface-color: #1E1F20;
|
|
@@ -257,7 +266,7 @@ header {
|
|
| 257 |
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
|
| 258 |
}
|
| 259 |
|
| 260 |
-
/* Acciones Bot
|
| 261 |
.bot-actions {
|
| 262 |
display: flex;
|
| 263 |
gap: 10px;
|
|
@@ -406,7 +415,7 @@ header {
|
|
| 406 |
<div class="input-box">
|
| 407 |
<input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
|
| 408 |
<button id="mainBtn" onclick="handleBtnClick()">
|
| 409 |
-
<!--
|
| 410 |
</button>
|
| 411 |
</div>
|
| 412 |
<div class="disclaimer">
|
|
@@ -421,15 +430,15 @@ const mainBtn = document.getElementById('mainBtn');
|
|
| 421 |
|
| 422 |
// Variables de Estado
|
| 423 |
let isGenerating = false;
|
| 424 |
-
let abortController = null;
|
| 425 |
-
let typingTimeout = null;
|
| 426 |
let lastUserPrompt = "";
|
| 427 |
|
| 428 |
// Iconos SVG
|
| 429 |
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
|
| 430 |
const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
|
| 431 |
|
| 432 |
-
//
|
| 433 |
mainBtn.innerHTML = ICON_SEND;
|
| 434 |
|
| 435 |
// --- UTILS ---
|
|
@@ -448,9 +457,8 @@ function setBtnState(state) {
|
|
| 448 |
}
|
| 449 |
}
|
| 450 |
|
| 451 |
-
// --- CORE
|
| 452 |
|
| 453 |
-
// Manejador del Click
|
| 454 |
function handleBtnClick() {
|
| 455 |
if (isGenerating) {
|
| 456 |
stopGeneration();
|
|
@@ -460,14 +468,10 @@ function handleBtnClick() {
|
|
| 460 |
}
|
| 461 |
|
| 462 |
function stopGeneration() {
|
| 463 |
-
// 1. Cancelar fetch
|
| 464 |
if (abortController) abortController.abort();
|
| 465 |
-
|
| 466 |
-
// 2. Cancelar escritura
|
| 467 |
if (typingTimeout) clearTimeout(typingTimeout);
|
| 468 |
|
| 469 |
-
//
|
| 470 |
-
// Buscamos el cursor activo para quitarlo
|
| 471 |
const activeCursor = document.querySelector('.typing-cursor');
|
| 472 |
if (activeCursor) activeCursor.classList.remove('typing-cursor');
|
| 473 |
|
|
@@ -484,7 +488,6 @@ async function sendMessage(textOverride = null) {
|
|
| 484 |
|
| 485 |
lastUserPrompt = text;
|
| 486 |
|
| 487 |
-
// UI Updates
|
| 488 |
if (!textOverride) {
|
| 489 |
userInput.value = '';
|
| 490 |
addMessage(text, 'user');
|
|
@@ -493,13 +496,16 @@ async function sendMessage(textOverride = null) {
|
|
| 493 |
setBtnState('sending');
|
| 494 |
abortController = new AbortController();
|
| 495 |
|
| 496 |
-
//
|
| 497 |
const botRow = document.createElement('div');
|
| 498 |
botRow.className = 'msg-row bot';
|
|
|
|
| 499 |
const avatar = document.createElement('div');
|
| 500 |
avatar.className = 'bot-avatar pulsing';
|
|
|
|
| 501 |
const wrapper = document.createElement('div');
|
| 502 |
wrapper.className = 'msg-content-wrapper';
|
|
|
|
| 503 |
const msgText = document.createElement('div');
|
| 504 |
msgText.className = 'msg-text';
|
| 505 |
|
|
@@ -519,14 +525,14 @@ async function sendMessage(textOverride = null) {
|
|
| 519 |
|
| 520 |
const data = await response.json();
|
| 521 |
|
| 522 |
-
if (!isGenerating) return;
|
| 523 |
|
| 524 |
avatar.classList.remove('pulsing');
|
| 525 |
const reply = data.reply || "No entendí eso.";
|
| 526 |
|
| 527 |
await typeWriter(msgText, reply);
|
| 528 |
|
| 529 |
-
if (isGenerating) {
|
| 530 |
addActions(wrapper, reply);
|
| 531 |
setBtnState('idle');
|
| 532 |
}
|
|
@@ -560,7 +566,7 @@ function typeWriter(element, text, speed = 12) {
|
|
| 560 |
element.classList.add('typing-cursor');
|
| 561 |
|
| 562 |
function type() {
|
| 563 |
-
if (!isGenerating) {
|
| 564 |
element.classList.remove('typing-cursor');
|
| 565 |
resolve();
|
| 566 |
return;
|
|
@@ -570,9 +576,7 @@ function typeWriter(element, text, speed = 12) {
|
|
| 570 |
element.textContent += text.charAt(i);
|
| 571 |
i++;
|
| 572 |
scrollToBottom();
|
| 573 |
-
typingTimeout = setTimeout(()
|
| 574 |
-
type();
|
| 575 |
-
}, speed + Math.random() * 5);
|
| 576 |
} else {
|
| 577 |
element.classList.remove('typing-cursor');
|
| 578 |
resolve();
|
|
@@ -586,7 +590,6 @@ function addActions(wrapperElement, textToCopy) {
|
|
| 586 |
const actionsDiv = document.createElement('div');
|
| 587 |
actionsDiv.className = 'bot-actions';
|
| 588 |
|
| 589 |
-
// Copy Btn
|
| 590 |
const copyBtn = document.createElement('button');
|
| 591 |
copyBtn.className = 'action-btn';
|
| 592 |
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
|
|
@@ -594,7 +597,6 @@ function addActions(wrapperElement, textToCopy) {
|
|
| 594 |
navigator.clipboard.writeText(textToCopy);
|
| 595 |
};
|
| 596 |
|
| 597 |
-
// Regen Btn
|
| 598 |
const regenBtn = document.createElement('button');
|
| 599 |
regenBtn.className = 'action-btn';
|
| 600 |
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
|
|
@@ -610,7 +612,6 @@ function addActions(wrapperElement, textToCopy) {
|
|
| 610 |
scrollToBottom();
|
| 611 |
}
|
| 612 |
|
| 613 |
-
// Listeners
|
| 614 |
userInput.addEventListener('keydown', (e) => {
|
| 615 |
if (e.key === 'Enter') handleBtnClick();
|
| 616 |
});
|
|
|
|
| 9 |
import uvicorn
|
| 10 |
|
| 11 |
# ======================
|
| 12 |
+
# CONFIGURACIÓN DE DISPOSITIVO (GPU/CPU)
|
| 13 |
# ======================
|
| 14 |
+
# Detectar automáticamente si hay una GPU NVIDIA disponible
|
| 15 |
+
if torch.cuda.is_available():
|
| 16 |
+
DEVICE = "cuda"
|
| 17 |
+
print("✅ GPU NVIDIA detectada. Usando CUDA.")
|
| 18 |
+
else:
|
| 19 |
+
DEVICE = "cpu"
|
| 20 |
+
print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
|
| 21 |
+
|
| 22 |
MODEL_REPO = "teszenofficial/mtptz"
|
|
|
|
| 23 |
|
| 24 |
# ======================
|
| 25 |
# DESCARGA DEL MODELO
|
|
|
|
| 64 |
dropout=0.0
|
| 65 |
)
|
| 66 |
|
| 67 |
+
# Cargar pesos y mover a GPU
|
| 68 |
model.load_state_dict(model_data["model_state_dict"])
|
| 69 |
model.to(DEVICE)
|
| 70 |
model.eval()
|
| 71 |
+
print(f"🚀 MTP 1.1 listo y corriendo en: {DEVICE.upper()}")
|
| 72 |
|
| 73 |
# ======================
|
| 74 |
+
# API FASTAPI
|
| 75 |
# ======================
|
| 76 |
app = FastAPI(title="MTP 1.1 API")
|
| 77 |
|
|
|
|
| 86 |
|
| 87 |
full_prompt = f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
|
| 88 |
tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
|
| 89 |
+
|
| 90 |
+
# IMPORTANTE: Mover los inputs también a la GPU
|
| 91 |
input_ids = torch.tensor([tokens], device=DEVICE)
|
| 92 |
|
| 93 |
with torch.no_grad():
|
|
|
|
| 111 |
return {"reply": response}
|
| 112 |
|
| 113 |
# ======================
|
| 114 |
+
# INTERFAZ WEB (FRONTEND MEJORADO)
|
| 115 |
# ======================
|
| 116 |
@app.get("/", response_class=HTMLResponse)
|
| 117 |
def chat_ui():
|
|
|
|
| 125 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 126 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 127 |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
|
|
|
|
| 128 |
<style>
|
| 129 |
+
/* --- VARIABLES & THEME --- */
|
| 130 |
:root {
|
| 131 |
--bg-color: #131314;
|
| 132 |
--surface-color: #1E1F20;
|
|
|
|
| 266 |
box-shadow: 0 2px 6px rgba(0,0,0,0.2);
|
| 267 |
}
|
| 268 |
|
| 269 |
+
/* Acciones Bot */
|
| 270 |
.bot-actions {
|
| 271 |
display: flex;
|
| 272 |
gap: 10px;
|
|
|
|
| 415 |
<div class="input-box">
|
| 416 |
<input type="text" id="userInput" placeholder="Escribe un mensaje..." autocomplete="off">
|
| 417 |
<button id="mainBtn" onclick="handleBtnClick()">
|
| 418 |
+
<!-- Icono dinámico -->
|
| 419 |
</button>
|
| 420 |
</div>
|
| 421 |
<div class="disclaimer">
|
|
|
|
| 430 |
|
| 431 |
// Variables de Estado
|
| 432 |
let isGenerating = false;
|
| 433 |
+
let abortController = null;
|
| 434 |
+
let typingTimeout = null;
|
| 435 |
let lastUserPrompt = "";
|
| 436 |
|
| 437 |
// Iconos SVG
|
| 438 |
const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
|
| 439 |
const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
|
| 440 |
|
| 441 |
+
// Inicial
|
| 442 |
mainBtn.innerHTML = ICON_SEND;
|
| 443 |
|
| 444 |
// --- UTILS ---
|
|
|
|
| 457 |
}
|
| 458 |
}
|
| 459 |
|
| 460 |
+
// --- CORE ---
|
| 461 |
|
|
|
|
| 462 |
function handleBtnClick() {
|
| 463 |
if (isGenerating) {
|
| 464 |
stopGeneration();
|
|
|
|
| 468 |
}
|
| 469 |
|
| 470 |
function stopGeneration() {
|
|
|
|
| 471 |
if (abortController) abortController.abort();
|
|
|
|
|
|
|
| 472 |
if (typingTimeout) clearTimeout(typingTimeout);
|
| 473 |
|
| 474 |
+
// UI Limpieza
|
|
|
|
| 475 |
const activeCursor = document.querySelector('.typing-cursor');
|
| 476 |
if (activeCursor) activeCursor.classList.remove('typing-cursor');
|
| 477 |
|
|
|
|
| 488 |
|
| 489 |
lastUserPrompt = text;
|
| 490 |
|
|
|
|
| 491 |
if (!textOverride) {
|
| 492 |
userInput.value = '';
|
| 493 |
addMessage(text, 'user');
|
|
|
|
| 496 |
setBtnState('sending');
|
| 497 |
abortController = new AbortController();
|
| 498 |
|
| 499 |
+
// Bot Placeholder
|
| 500 |
const botRow = document.createElement('div');
|
| 501 |
botRow.className = 'msg-row bot';
|
| 502 |
+
|
| 503 |
const avatar = document.createElement('div');
|
| 504 |
avatar.className = 'bot-avatar pulsing';
|
| 505 |
+
|
| 506 |
const wrapper = document.createElement('div');
|
| 507 |
wrapper.className = 'msg-content-wrapper';
|
| 508 |
+
|
| 509 |
const msgText = document.createElement('div');
|
| 510 |
msgText.className = 'msg-text';
|
| 511 |
|
|
|
|
| 525 |
|
| 526 |
const data = await response.json();
|
| 527 |
|
| 528 |
+
if (!isGenerating) return;
|
| 529 |
|
| 530 |
avatar.classList.remove('pulsing');
|
| 531 |
const reply = data.reply || "No entendí eso.";
|
| 532 |
|
| 533 |
await typeWriter(msgText, reply);
|
| 534 |
|
| 535 |
+
if (isGenerating) {
|
| 536 |
addActions(wrapper, reply);
|
| 537 |
setBtnState('idle');
|
| 538 |
}
|
|
|
|
| 566 |
element.classList.add('typing-cursor');
|
| 567 |
|
| 568 |
function type() {
|
| 569 |
+
if (!isGenerating) {
|
| 570 |
element.classList.remove('typing-cursor');
|
| 571 |
resolve();
|
| 572 |
return;
|
|
|
|
| 576 |
element.textContent += text.charAt(i);
|
| 577 |
i++;
|
| 578 |
scrollToBottom();
|
| 579 |
+
typingTimeout = setTimeout(type, speed + Math.random() * 5);
|
|
|
|
|
|
|
| 580 |
} else {
|
| 581 |
element.classList.remove('typing-cursor');
|
| 582 |
resolve();
|
|
|
|
| 590 |
const actionsDiv = document.createElement('div');
|
| 591 |
actionsDiv.className = 'bot-actions';
|
| 592 |
|
|
|
|
| 593 |
const copyBtn = document.createElement('button');
|
| 594 |
copyBtn.className = 'action-btn';
|
| 595 |
copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
|
|
|
|
| 597 |
navigator.clipboard.writeText(textToCopy);
|
| 598 |
};
|
| 599 |
|
|
|
|
| 600 |
const regenBtn = document.createElement('button');
|
| 601 |
regenBtn.className = 'action-btn';
|
| 602 |
regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
|
|
|
|
| 612 |
scrollToBottom();
|
| 613 |
}
|
| 614 |
|
|
|
|
| 615 |
userInput.addEventListener('keydown', (e) => {
|
| 616 |
if (e.key === 'Enter') handleBtnClick();
|
| 617 |
});
|