teszenofficial commited on
Commit
2d65976
·
verified ·
1 Parent(s): 9c2275b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -42
app.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
  import torch
4
  import json
5
  import gc
6
- import time
7
  from fastapi import FastAPI
8
  from fastapi.responses import HTMLResponse
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -17,11 +17,11 @@ import sentencepiece as spm
17
 
18
  if torch.cuda.is_available():
19
  DEVICE = "cuda"
20
- print("✅ GPU")
21
  torch.backends.cudnn.benchmark = True
22
  else:
23
  DEVICE = "cpu"
24
- print("⚠️ CPU")
25
  torch.set_num_threads(4)
26
 
27
  torch.set_grad_enabled(False)
@@ -114,7 +114,9 @@ class MTPModel(nn.Module):
114
  self.max_len = max_len
115
  self.token_embedding = nn.Embedding(vocab_size, d_model)
116
  self.pos_encoding = PositionalEncoding(d_model, max_len)
117
- self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
 
 
118
  self.norm = LayerNorm(d_model)
119
  self.lm_head = nn.Linear(d_model, vocab_size)
120
 
@@ -129,7 +131,7 @@ class MTPModel(nn.Module):
129
  return self.lm_head(x)
130
 
131
  @torch.inference_mode()
132
- def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=40):
133
  generated = input_ids
134
  for _ in range(max_new_tokens):
135
  logits = self(generated)
@@ -143,7 +145,7 @@ class MTPModel(nn.Module):
143
  if next_token == 3 or next_token == 0 or next_token == 1:
144
  break
145
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
146
- if len(generated[0]) > 180:
147
  break
148
  return generated
149
 
@@ -151,25 +153,52 @@ print("📦 Descargando modelo...")
151
  repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
152
 
153
  config_path = os.path.join(repo_path, "config.json")
154
- with open(config_path, "r") as f:
155
- config = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
 
 
 
 
158
  sp = spm.SentencePieceProcessor()
159
  sp.load(tokenizer_path)
160
- config["vocab_size"] = sp.get_piece_size()
 
161
 
162
- print(f" Vocab: {config['vocab_size']}")
163
- print(f" Dim: {config['d_model']}, Layers: {config['n_layers']}")
 
 
 
164
 
165
  model = MTPModel(**config)
166
  model.to(DEVICE)
167
 
168
  model_path = os.path.join(repo_path, "mtp_model.pt")
169
- state_dict = torch.load(model_path, map_location=DEVICE)
170
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
171
  model.eval()
172
- print(f"✅ Modelo: {sum(p.numel() for p in model.parameters()):,} params")
173
 
174
  app = FastAPI()
175
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@@ -180,24 +209,29 @@ class PromptRequest(BaseModel):
180
  def build_prompt(user_input):
181
  return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
182
 
 
 
 
 
 
 
 
183
  @app.post("/generate")
184
  async def generate(req: PromptRequest):
185
- start = time.time()
186
  user_input = req.text.strip()
187
-
188
  if not user_input:
189
- return {"reply": "Escribe un mensaje", "time": 0}
190
 
191
  prompt = build_prompt(user_input)
192
  tokens = sp.encode(prompt)
193
 
194
- if len(tokens) > 800:
195
- tokens = tokens[-800:]
196
 
197
  input_ids = torch.tensor([tokens], device=DEVICE)
198
 
199
  try:
200
- output_ids = model.generate(input_ids, max_new_tokens=80, temperature=0.75, top_k=40)
201
 
202
  gen_tokens = output_ids[0, len(tokens):].tolist()
203
 
@@ -208,26 +242,15 @@ async def generate(req: PromptRequest):
208
  clean_tokens.append(t)
209
 
210
  response = sp.decode(clean_tokens).strip() if clean_tokens else ""
 
211
 
212
- markers = ["### Respuesta:", "Respuesta:", "[/INST]", "Asistente:", "Usuario:"]
213
- for marker in markers:
214
- if marker in response:
215
- response = response.split(marker)[-1]
216
-
217
- response = response.replace('<unk>', '').replace('<pad>', '').replace('<s>', '').replace('</s>', '')
218
- response = ' '.join(response.split())
219
-
220
- if not response or len(response) < 2:
221
- response = "Entiendo tu pregunta. ¿Podrías darme más detalles?"
222
 
223
- elapsed = time.time() - start
224
- print(f"✅ {user_input[:25]}... -> {elapsed:.1f}s ({len(clean_tokens)} tokens)")
225
-
226
- return {"reply": response[:350], "time": elapsed}
227
 
228
  except Exception as e:
229
  print(f"❌ Error: {e}")
230
- return {"reply": "Error, intenta de nuevo", "time": 0}
231
 
232
  @app.get("/health")
233
  def health():
@@ -360,11 +383,11 @@ body {
360
  <body>
361
  <div class="header">
362
  <h1><span class="dot"></span> MTP Assistant</h1>
363
- <p>Modelo Transformer | Generación 100% por IA</p>
364
  </div>
365
  <div class="chat" id="chat">
366
  <div class="message bot">
367
- <div class="message-content">¡Hola! Soy MTP. ¿En qué puedo ayudarte?</div>
368
  </div>
369
  </div>
370
  <div class="input-area">
@@ -415,8 +438,6 @@ async function send() {
415
  sendBtn.disabled = true;
416
  addTyping();
417
 
418
- const startTime = Date.now();
419
-
420
  try {
421
  const res = await fetch('/generate', {
422
  method: 'POST',
@@ -424,13 +445,11 @@ async function send() {
424
  body: JSON.stringify({ text: text })
425
  });
426
  const data = await res.json();
427
- const elapsed = ((Date.now() - startTime) / 1000).toFixed(1);
428
  removeTyping();
429
  addMessage(data.reply || "No pude generar respuesta.", false);
430
- console.log(`Respuesta en ${elapsed}s`);
431
  } catch (err) {
432
  removeTyping();
433
- addMessage("Error de conexión. Intenta de nuevo.", false);
434
  } finally {
435
  loading = false;
436
  sendBtn.disabled = false;
 
3
  import torch
4
  import json
5
  import gc
6
+ import re
7
  from fastapi import FastAPI
8
  from fastapi.responses import HTMLResponse
9
  from fastapi.middleware.cors import CORSMiddleware
 
17
 
18
  if torch.cuda.is_available():
19
  DEVICE = "cuda"
20
+ print("✅ GPU detectada")
21
  torch.backends.cudnn.benchmark = True
22
  else:
23
  DEVICE = "cpu"
24
+ print("⚠️ CPU mode")
25
  torch.set_num_threads(4)
26
 
27
  torch.set_grad_enabled(False)
 
114
  self.max_len = max_len
115
  self.token_embedding = nn.Embedding(vocab_size, d_model)
116
  self.pos_encoding = PositionalEncoding(d_model, max_len)
117
+ self.blocks = nn.ModuleList([
118
+ TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
119
+ ])
120
  self.norm = LayerNorm(d_model)
121
  self.lm_head = nn.Linear(d_model, vocab_size)
122
 
 
131
  return self.lm_head(x)
132
 
133
  @torch.inference_mode()
134
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.7, top_k=50):
135
  generated = input_ids
136
  for _ in range(max_new_tokens):
137
  logits = self(generated)
 
145
  if next_token == 3 or next_token == 0 or next_token == 1:
146
  break
147
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
148
+ if len(generated[0]) > 200:
149
  break
150
  return generated
151
 
 
153
  repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
154
 
155
  config_path = os.path.join(repo_path, "config.json")
156
+ if os.path.exists(config_path):
157
+ with open(config_path, "r") as f:
158
+ config = json.load(f)
159
+ print(f"✅ Configuración cargada: d_model={config.get('d_model', 512)}, layers={config.get('n_layers', 8)}")
160
+ else:
161
+ print("⚠️ Usando configuración por defecto (igual que colab.py)")
162
+ config = {
163
+ "vocab_size": 8000,
164
+ "d_model": 512,
165
+ "n_heads": 8,
166
+ "n_layers": 8,
167
+ "d_ff": 2048,
168
+ "dropout": 0.1,
169
+ "max_len": 1024
170
+ }
171
 
172
  tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
173
+ if not os.path.exists(tokenizer_path):
174
+ print(f"❌ Tokenizador no encontrado")
175
+ sys.exit(1)
176
+
177
  sp = spm.SentencePieceProcessor()
178
  sp.load(tokenizer_path)
179
+ VOCAB_SIZE = sp.get_piece_size()
180
+ config["vocab_size"] = VOCAB_SIZE
181
 
182
+ print(f"🧠 Inicializando modelo MTP...")
183
+ print(f" Vocabulario: {VOCAB_SIZE}")
184
+ print(f" → Dimensión: {config['d_model']}")
185
+ print(f" → Capas: {config['n_layers']}")
186
+ print(f" → Heads: {config['n_heads']}")
187
 
188
  model = MTPModel(**config)
189
  model.to(DEVICE)
190
 
191
  model_path = os.path.join(repo_path, "mtp_model.pt")
192
+ if os.path.exists(model_path):
193
+ state_dict = torch.load(model_path, map_location=DEVICE)
194
+ model.load_state_dict(state_dict, strict=False)
195
+ print("✅ Pesos cargados correctamente")
196
+ else:
197
+ print(f"❌ Modelo no encontrado")
198
+ sys.exit(1)
199
+
200
  model.eval()
201
+ print(f"✅ Modelo listo: {sum(p.numel() for p in model.parameters()):,} params")
202
 
203
  app = FastAPI()
204
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
209
  def build_prompt(user_input):
210
  return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
211
 
212
+ def clean_response(text):
213
+ if not text:
214
+ return ""
215
+ text = re.sub(r'<unk>|<pad>|<s>|</s>', '', text)
216
+ text = re.sub(r'\s+', ' ', text).strip()
217
+ return text
218
+
219
  @app.post("/generate")
220
  async def generate(req: PromptRequest):
 
221
  user_input = req.text.strip()
 
222
  if not user_input:
223
+ return {"reply": ""}
224
 
225
  prompt = build_prompt(user_input)
226
  tokens = sp.encode(prompt)
227
 
228
+ if len(tokens) > 900:
229
+ tokens = tokens[-900:]
230
 
231
  input_ids = torch.tensor([tokens], device=DEVICE)
232
 
233
  try:
234
+ output_ids = model.generate(input_ids, max_new_tokens=120, temperature=0.7, top_k=50)
235
 
236
  gen_tokens = output_ids[0, len(tokens):].tolist()
237
 
 
242
  clean_tokens.append(t)
243
 
244
  response = sp.decode(clean_tokens).strip() if clean_tokens else ""
245
+ response = clean_response(response)
246
 
247
+ print(f"📝 {user_input[:40]} -> {len(clean_tokens)} tokens")
 
 
 
 
 
 
 
 
 
248
 
249
+ return {"reply": response[:500]}
 
 
 
250
 
251
  except Exception as e:
252
  print(f"❌ Error: {e}")
253
+ return {"reply": ""}
254
 
255
  @app.get("/health")
256
  def health():
 
383
  <body>
384
  <div class="header">
385
  <h1><span class="dot"></span> MTP Assistant</h1>
386
+ <p>Modelo Transformer 512-dim / 8-capas</p>
387
  </div>
388
  <div class="chat" id="chat">
389
  <div class="message bot">
390
+ <div class="message-content">Hola, soy MTP. ¿En qué puedo ayudarte?</div>
391
  </div>
392
  </div>
393
  <div class="input-area">
 
438
  sendBtn.disabled = true;
439
  addTyping();
440
 
 
 
441
  try {
442
  const res = await fetch('/generate', {
443
  method: 'POST',
 
445
  body: JSON.stringify({ text: text })
446
  });
447
  const data = await res.json();
 
448
  removeTyping();
449
  addMessage(data.reply || "No pude generar respuesta.", false);
 
450
  } catch (err) {
451
  removeTyping();
452
+ addMessage("Error de conexión.", false);
453
  } finally {
454
  loading = false;
455
  sendBtn.disabled = false;