teszenofficial commited on
Commit
0eb10b0
·
verified ·
1 Parent(s): dda31c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +347 -98
app.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
  import torch
4
  import pickle
5
  from fastapi import FastAPI
6
- from fastapi.responses import HTMLResponse, JSONResponse
7
  from pydantic import BaseModel
8
  from huggingface_hub import snapshot_download
9
  import uvicorn
@@ -17,6 +17,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  # ======================
18
  # DOWNLOAD MODEL
19
  # ======================
 
20
  repo_path = snapshot_download(
21
  repo_id=MODEL_REPO,
22
  repo_type="model",
@@ -25,12 +26,18 @@ repo_path = snapshot_download(
25
 
26
  sys.path.insert(0, repo_path)
27
 
28
- from model import MTPMiniModel
29
- from tokenizer import MTPTokenizer
 
 
 
 
 
30
 
31
  # ======================
32
  # LOAD MODEL
33
  # ======================
 
34
  with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f:
35
  model_data = pickle.load(f)
36
 
@@ -53,6 +60,7 @@ model = MTPMiniModel(
53
  model.load_state_dict(model_data["model_state_dict"])
54
  model.to(DEVICE)
55
  model.eval()
 
56
 
57
  # ======================
58
  # FASTAPI
@@ -93,7 +101,7 @@ def generate(prompt: Prompt):
93
  return {"reply": response}
94
 
95
  # ======================
96
- # CHAT WEB (HTML)
97
  # ======================
98
  @app.get("/", response_class=HTMLResponse)
99
  def chat_ui():
@@ -102,122 +110,362 @@ def chat_ui():
102
  <html lang="es">
103
  <head>
104
  <meta charset="UTF-8">
105
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
106
- <title>MTP Chat</title>
 
 
 
107
  <style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  body {
109
- margin: 0;
110
- background: #343541;
111
- font-family: system-ui, sans-serif;
112
- display: flex;
113
- justify-content: center;
114
- height: 100vh;
115
- }
116
- .chat {
117
- width: 100%;
118
- max-width: 420px;
119
- background: #40414f;
120
- display: flex;
121
- flex-direction: column;
122
  }
 
 
123
  header {
124
- background: #202123;
125
- color: white;
126
- padding: 14px;
127
- font-weight: bold;
128
- }
129
- .messages {
130
- flex: 1;
131
- padding: 14px;
132
- overflow-y: auto;
133
- }
134
- .msg {
135
- max-width: 85%;
136
- padding: 12px 14px;
137
- margin-bottom: 10px;
138
- border-radius: 10px;
139
- white-space: pre-wrap;
140
- }
141
- .user {
142
- background: #0b93f6;
143
- color: white;
144
- margin-left: auto;
145
- }
146
- .bot {
147
- background: #444654;
148
- color: white;
149
- }
150
- .input {
151
- display: flex;
152
- padding: 10px;
153
- background: #202123;
154
- }
155
- .input input {
156
- flex: 1;
157
- border: none;
158
- padding: 12px;
159
- border-radius: 8px;
160
- font-size: 16px;
161
- }
162
- .input button {
163
- background: #0b93f6;
164
- border: none;
165
- color: white;
166
- padding: 0 16px;
167
- border-radius: 8px;
168
- margin-left: 8px;
 
 
 
 
 
 
 
 
 
 
169
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  </style>
171
  </head>
172
  <body>
173
 
174
- <div class="chat">
175
- <header>🤖 MTP</header>
176
- <div id="messages" class="messages"></div>
177
- <div class="input">
178
- <input id="input" placeholder="Escribe un mensaje..." autocomplete="off">
179
- <button onclick="send()">Enviar</button>
180
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  </div>
182
 
183
  <script>
184
- const messages = document.getElementById("messages");
185
- const input = document.getElementById("input");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- function add(text, cls) {
188
- const div = document.createElement("div");
189
- div.className = "msg " + cls;
190
- div.textContent = text;
191
- messages.appendChild(div);
192
- messages.scrollTop = messages.scrollHeight;
 
 
 
 
 
 
 
 
 
 
 
 
193
  }
194
 
195
  async function send() {
196
- const text = input.value.trim();
197
- if (!text) return;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- add(text, "user");
200
- input.value = "";
201
 
202
- const typing = document.createElement("div");
203
- typing.className = "msg bot";
204
- typing.textContent = "MTP está escribiendo…";
205
- messages.appendChild(typing);
206
 
207
- const res = await fetch("/generate", {
208
- method: "POST",
209
- headers: { "Content-Type": "application/json" },
210
- body: JSON.stringify({ text })
211
- });
212
 
213
- const json = await res.json();
214
- messages.removeChild(typing);
215
- add(json.reply || "Sin respuesta", "bot");
 
 
 
 
 
 
216
  }
217
 
218
- input.addEventListener("keydown", e => {
219
- if (e.key === "Enter") send();
 
220
  });
 
 
 
 
221
  </script>
222
 
223
  </body>
@@ -228,4 +476,5 @@ input.addEventListener("keydown", e => {
228
  # ENTRYPOINT (HF)
229
  # ======================
230
  if __name__ == "__main__":
231
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
3
  import torch
4
  import pickle
5
  from fastapi import FastAPI
6
+ from fastapi.responses import HTMLResponse
7
  from pydantic import BaseModel
8
  from huggingface_hub import snapshot_download
9
  import uvicorn
 
17
  # ======================
18
  # DOWNLOAD MODEL
19
  # ======================
20
+ print(f"Descargando/Verificando modelo desde {MODEL_REPO}...")
21
  repo_path = snapshot_download(
22
  repo_id=MODEL_REPO,
23
  repo_type="model",
 
26
 
27
  sys.path.insert(0, repo_path)
28
 
29
+ try:
30
+ from model import MTPMiniModel
31
+ from tokenizer import MTPTokenizer
32
+ except ImportError:
33
+ # Fallback por si la estructura del repo requiere ajustes de path
34
+ print("Advertencia: Asegúrate de que model.py y tokenizer.py estén en la raíz del repo descargado.")
35
+ pass
36
 
37
  # ======================
38
  # LOAD MODEL
39
  # ======================
40
+ print("Cargando modelo en memoria...")
41
  with open(os.path.join(repo_path, "mtp_mini.pkl"), "rb") as f:
42
  model_data = pickle.load(f)
43
 
 
60
  model.load_state_dict(model_data["model_state_dict"])
61
  model.to(DEVICE)
62
  model.eval()
63
+ print(f"Modelo cargado en {DEVICE}")
64
 
65
  # ======================
66
  # FASTAPI
 
101
  return {"reply": response}
102
 
103
  # ======================
104
+ # CHAT WEB (HTML/CSS/JS MEJORADO)
105
  # ======================
106
  @app.get("/", response_class=HTMLResponse)
107
  def chat_ui():
 
110
  <html lang="es">
111
  <head>
112
  <meta charset="UTF-8">
113
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
114
+ <title>MTP AI</title>
115
+ <link rel="preconnect" href="https://fonts.googleapis.com">
116
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
117
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet">
118
  <style>
119
+ /* --- VARIABLES Y RESET --- */
120
+ :root {
121
+ --bg-color: #1a1b1e;
122
+ --chat-bg: #101113;
123
+ --header-bg: rgba(26, 27, 30, 0.8);
124
+ --user-msg-bg: #2563eb;
125
+ --bot-msg-bg: #2c2e33;
126
+ --text-color: #e5e7eb;
127
+ --accent-color: #3b82f6;
128
+ --input-area-bg: rgba(26, 27, 30, 0.95);
129
+ }
130
+
131
+ * { box-sizing: border-box; outline: none; }
132
+
133
  body {
134
+ margin: 0;
135
+ background-color: var(--bg-color);
136
+ font-family: 'Inter', sans-serif;
137
+ color: var(--text-color);
138
+ display: flex;
139
+ flex-direction: column;
140
+ height: 100vh; /* Fallback */
141
+ height: 100dvh; /* Dynamic Viewport Height for mobile */
142
+ overflow: hidden;
 
 
 
 
143
  }
144
+
145
+ /* --- HEADER --- */
146
  header {
147
+ background: var(--header-bg);
148
+ backdrop-filter: blur(10px);
149
+ padding: 15px 20px;
150
+ border-bottom: 1px solid #2c2e33;
151
+ display: flex;
152
+ align-items: center;
153
+ gap: 10px;
154
+ position: fixed;
155
+ top: 0;
156
+ width: 100%;
157
+ z-index: 10;
158
+ }
159
+
160
+ .status-dot {
161
+ width: 10px;
162
+ height: 10px;
163
+ background-color: #10b981;
164
+ border-radius: 50%;
165
+ box-shadow: 0 0 8px #10b981;
166
+ }
167
+
168
+ header h1 {
169
+ font-size: 1.1rem;
170
+ margin: 0;
171
+ font-weight: 600;
172
+ }
173
+
174
+ /* --- AREA DE MENSAJES --- */
175
+ .messages-container {
176
+ flex: 1;
177
+ overflow-y: auto;
178
+ padding: 80px 15px 100px 15px; /* Espacio para header e input */
179
+ display: flex;
180
+ flex-direction: column;
181
+ gap: 15px;
182
+ scroll-behavior: smooth;
183
+ }
184
+
185
+ .msg-row {
186
+ display: flex;
187
+ width: 100%;
188
+ }
189
+
190
+ .msg-row.user { justify-content: flex-end; }
191
+ .msg-row.bot { justify-content: flex-start; }
192
+
193
+ .msg-bubble {
194
+ max-width: 85%;
195
+ padding: 12px 16px;
196
+ border-radius: 18px;
197
+ font-size: 0.95rem;
198
+ line-height: 1.5;
199
+ position: relative;
200
+ word-wrap: break-word;
201
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2);
202
  }
203
+
204
+ .msg-row.user .msg-bubble {
205
+ background: var(--user-msg-bg);
206
+ color: white;
207
+ border-bottom-right-radius: 4px;
208
+ animation: popIn 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);
209
+ }
210
+
211
+ .msg-row.bot .msg-bubble {
212
+ background: var(--bot-msg-bg);
213
+ color: var(--text-color);
214
+ border-bottom-left-radius: 4px;
215
+ border: 1px solid #373a40;
216
+ animation: slideUp 0.3s ease-out;
217
+ }
218
+
219
+ /* --- EFECTO DE ESCRITURA --- */
220
+ .cursor::after {
221
+ content: '|';
222
+ animation: blink 1s step-start infinite;
223
+ color: var(--accent-color);
224
+ }
225
+
226
+ /* --- INPUT AREA --- */
227
+ .input-area {
228
+ position: fixed;
229
+ bottom: 0;
230
+ width: 100%;
231
+ background: var(--input-area-bg);
232
+ backdrop-filter: blur(10px);
233
+ padding: 15px;
234
+ border-top: 1px solid #2c2e33;
235
+ display: flex;
236
+ gap: 10px;
237
+ align-items: center;
238
+ z-index: 10;
239
+ }
240
+
241
+ .input-wrapper {
242
+ flex: 1;
243
+ background: #2c2e33;
244
+ border-radius: 24px;
245
+ padding: 2px;
246
+ display: flex;
247
+ align-items: center;
248
+ border: 1px solid transparent;
249
+ transition: border-color 0.2s;
250
+ }
251
+
252
+ .input-wrapper:focus-within {
253
+ border-color: var(--accent-color);
254
+ }
255
+
256
+ input {
257
+ flex: 1;
258
+ background: transparent;
259
+ border: none;
260
+ padding: 12px 16px;
261
+ color: white;
262
+ font-size: 1rem;
263
+ font-family: inherit;
264
+ }
265
+
266
+ button {
267
+ background: var(--accent-color);
268
+ color: white;
269
+ border: none;
270
+ width: 45px;
271
+ height: 45px;
272
+ border-radius: 50%;
273
+ cursor: pointer;
274
+ display: flex;
275
+ align-items: center;
276
+ justify-content: center;
277
+ transition: transform 0.1s, background 0.2s;
278
+ }
279
+
280
+ button:active { transform: scale(0.95); }
281
+ button:hover { background: #2563eb; }
282
+ button svg { width: 20px; height: 20px; fill: white; }
283
+ button:disabled { background: #4b5563; cursor: not-allowed; opacity: 0.7; }
284
+
285
+ /* --- ANIMACIONES --- */
286
+ @keyframes popIn {
287
+ from { opacity: 0; transform: scale(0.8); }
288
+ to { opacity: 1; transform: scale(1); }
289
+ }
290
+
291
+ @keyframes slideUp {
292
+ from { opacity: 0; transform: translateY(10px); }
293
+ to { opacity: 1; transform: translateY(0); }
294
+ }
295
+
296
+ @keyframes blink { 50% { opacity: 0; } }
297
+
298
+ /* --- INDICADOR DE CARGA (Puntos) --- */
299
+ .typing-indicator {
300
+ display: flex;
301
+ align-items: center;
302
+ gap: 4px;
303
+ padding: 5px 0;
304
+ }
305
+
306
+ .dot {
307
+ width: 6px;
308
+ height: 6px;
309
+ background: #9ca3af;
310
+ border-radius: 50%;
311
+ animation: bounce 1.4s infinite ease-in-out both;
312
+ }
313
+
314
+ .dot:nth-child(1) { animation-delay: -0.32s; }
315
+ .dot:nth-child(2) { animation-delay: -0.16s; }
316
+
317
+ @keyframes bounce {
318
+ 0%, 80%, 100% { transform: scale(0); }
319
+ 40% { transform: scale(1); }
320
+ }
321
+
322
+ /* Scrollbar styling */
323
+ ::-webkit-scrollbar { width: 6px; }
324
+ ::-webkit-scrollbar-track { background: transparent; }
325
+ ::-webkit-scrollbar-thumb { background: #4b5563; border-radius: 10px; }
326
+
327
  </style>
328
  </head>
329
  <body>
330
 
331
+ <header>
332
+ <div class="status-dot"></div>
333
+ <h1>MTP Chat</h1>
334
+ </header>
335
+
336
+ <div id="messages" class="messages-container">
337
+ <!-- Mensaje de bienvenida -->
338
+ <div class="msg-row bot">
339
+ <div class="msg-bubble">
340
+ ¡Hola! Soy MTP. ¿En qué puedo ayudarte hoy?
341
+ </div>
342
+ </div>
343
+ </div>
344
+
345
+ <div class="input-area">
346
+ <div class="input-wrapper">
347
+ <input id="input" placeholder="Escribe tu mensaje aquí..." autocomplete="off">
348
+ </div>
349
+ <button id="sendBtn" onclick="send()">
350
+ <svg viewBox="0 0 24 24"><path d="M2.01 21L23 12 2.01 3 2 10l15 2-15 2z"></path></svg>
351
+ </button>
352
  </div>
353
 
354
  <script>
355
+ const messagesEl = document.getElementById("messages");
356
+ const inputEl = document.getElementById("input");
357
+ const sendBtn = document.getElementById("sendBtn");
358
+ let isGenerating = false;
359
+
360
+ // Auto-scroll al fondo
361
+ function scrollToBottom() {
362
+ messagesEl.scrollTop = messagesEl.scrollHeight;
363
+ }
364
+
365
+ // Función para escribir texto letra por letra
366
+ function typeWriter(element, text, speed = 20) {
367
+ let i = 0;
368
+ element.classList.add('cursor'); // Añadir cursor parpadeante
369
+
370
+ function type() {
371
+ if (i < text.length) {
372
+ element.textContent += text.charAt(i);
373
+ i++;
374
+ scrollToBottom();
375
+ setTimeout(type, speed);
376
+ } else {
377
+ element.classList.remove('cursor'); // Quitar cursor al terminar
378
+ }
379
+ }
380
+ type();
381
+ }
382
 
383
+ function addMessage(text, type) {
384
+ const row = document.createElement("div");
385
+ row.className = `msg-row ${type}`;
386
+
387
+ const bubble = document.createElement("div");
388
+ bubble.className = "msg-bubble";
389
+
390
+ if (type === 'user') {
391
+ bubble.textContent = text;
392
+ } else {
393
+ // Para el bot, lo dejamos vacío inicialmente para el efecto
394
+ // o para el indicador de carga
395
+ }
396
+
397
+ row.appendChild(bubble);
398
+ messagesEl.appendChild(row);
399
+ scrollToBottom();
400
+ return bubble; // Retornamos la burbuja para manipularla
401
  }
402
 
403
  async function send() {
404
+ const text = inputEl.value.trim();
405
+ if (!text || isGenerating) return;
406
+
407
+ // 1. Mostrar mensaje de usuario
408
+ addMessage(text, "user");
409
+ inputEl.value = "";
410
+ isGenerating = true;
411
+ sendBtn.disabled = true;
412
+
413
+ // 2. Mostrar indicador de "Escribiendo..."
414
+ const botRow = document.createElement("div");
415
+ botRow.className = "msg-row bot";
416
+ const botBubble = document.createElement("div");
417
+ botBubble.className = "msg-bubble";
418
+
419
+ // HTML de los puntos animados
420
+ botBubble.innerHTML = `
421
+ <div class="typing-indicator">
422
+ <div class="dot"></div>
423
+ <div class="dot"></div>
424
+ <div class="dot"></div>
425
+ </div>
426
+ `;
427
+
428
+ botRow.appendChild(botBubble);
429
+ messagesEl.appendChild(botRow);
430
+ scrollToBottom();
431
+
432
+ try {
433
+ // 3. Petición al backend
434
+ const res = await fetch("/generate", {
435
+ method: "POST",
436
+ headers: { "Content-Type": "application/json" },
437
+ body: JSON.stringify({ text })
438
+ });
439
 
440
+ const json = await res.json();
441
+ const replyText = json.reply || "No pude generar una respuesta.";
442
 
443
+ // 4. Remover burbuja de carga
444
+ messagesEl.removeChild(botRow);
 
 
445
 
446
+ // 5. Crear burbuja real y aplicar efecto máquina de escribir
447
+ const finalBubble = addMessage("", "bot");
448
+ typeWriter(finalBubble, replyText);
 
 
449
 
450
+ } catch (e) {
451
+ messagesEl.removeChild(botRow);
452
+ const errBubble = addMessage("Error de conexión.", "bot");
453
+ errBubble.style.color = "#ef4444";
454
+ } finally {
455
+ isGenerating = false;
456
+ sendBtn.disabled = false;
457
+ inputEl.focus();
458
+ }
459
  }
460
 
461
+ // Enviar con Enter
462
+ inputEl.addEventListener("keydown", e => {
463
+ if (e.key === "Enter") send();
464
  });
465
+
466
+ // Enfoque inicial
467
+ window.onload = () => inputEl.focus();
468
+
469
  </script>
470
 
471
  </body>
 
476
  # ENTRYPOINT (HF)
477
  # ======================
478
  if __name__ == "__main__":
479
+ uvicorn.run(app, host="0.0.0.0", port=7860)
480
+