teszenofficial commited on
Commit
2632d84
·
verified ·
1 Parent(s): d38f4b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -65
app.py CHANGED
@@ -2,21 +2,28 @@ import os
2
  import sys
3
  import torch
4
  import pickle
 
5
  from fastapi import FastAPI
6
- from fastapi.responses import HTMLResponse
7
  from pydantic import BaseModel
8
  from huggingface_hub import snapshot_download
9
  import uvicorn
10
 
 
 
 
 
 
 
11
  # ======================
12
  # DISPOSITIVO
13
  # ======================
14
  if torch.cuda.is_available():
15
  DEVICE = "cuda"
16
- print("✅ GPU NVIDIA detectada. Usando CUDA.")
17
  else:
18
  DEVICE = "cpu"
19
- print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
20
 
21
  MODEL_REPO = "teszenofficial/mtp1"
22
 
@@ -24,8 +31,6 @@ MODEL_REPO = "teszenofficial/mtp1"
24
  # DESCARGA MODELO
25
  # ======================
26
  print("--- SISTEMA MTP 1.1 ---")
27
- print(f"Descargando/Verificando modelo desde {MODEL_REPO}...")
28
-
29
  repo_path = snapshot_download(
30
  repo_id=MODEL_REPO,
31
  repo_type="model",
@@ -38,26 +43,15 @@ from model import MTPMiniModel
38
  from tokenizer import MTPTokenizer
39
 
40
  # ======================
41
- # CARGA DEL MODELO
42
  # ======================
43
- print("Cargando modelo en memoria...")
44
-
45
- # Buscar automáticamente el .pkl
46
- pkl_file = None
47
- for f in os.listdir(repo_path):
48
- if f.endswith(".pkl"):
49
- pkl_file = f
50
- break
51
-
52
- if not pkl_file:
53
- raise FileNotFoundError("❌ No se encontró el archivo .pkl del modelo")
54
 
 
55
  with open(os.path.join(repo_path, pkl_file), "rb") as f:
56
  model_data = pickle.load(f)
57
 
58
- tokenizer = MTPTokenizer(
59
- os.path.join(repo_path, "mtp_tokenizer.model")
60
- )
61
 
62
  config = model_data["config"]
63
 
@@ -75,104 +69,153 @@ model.load_state_dict(model_data["model_state_dict"])
75
  model.to(DEVICE)
76
  model.eval()
77
 
78
- # 🔒 Forzar vocab correcto
79
  VOCAB_SIZE = tokenizer.sp.get_piece_size()
80
  model.vocab_size = VOCAB_SIZE
81
 
82
- print(f"🚀 MTP 1.1 listo y corriendo en: {DEVICE.upper()}")
83
 
84
  # ======================
85
  # FASTAPI
86
  # ======================
87
- app = FastAPI(title="MTP 1.1 API")
88
 
89
  class Prompt(BaseModel):
90
  text: str
91
 
 
 
 
92
  @app.post("/generate")
93
  def generate(prompt: Prompt):
94
  try:
95
- user_input = prompt.text.strip()
96
- if not user_input:
97
  return {"reply": ""}
98
 
99
- full_prompt = f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
100
  tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
101
-
102
  input_ids = torch.tensor([tokens], device=DEVICE)
103
 
104
  with torch.no_grad():
105
- output_ids = model.generate(
106
  input_ids,
107
- max_new_tokens=80, # CPU-safe
108
  temperature=0.7,
109
  top_k=50,
110
  top_p=0.9
111
  )
112
 
113
- gen_tokens = output_ids[0, len(tokens):].tolist()
 
 
114
 
115
- # 🔒 FILTRO CRÍTICO
116
- safe_tokens = [
117
- t for t in gen_tokens
118
- if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()
119
- ]
120
-
121
- response = tokenizer.decode(safe_tokens).strip()
122
-
123
- if "###" in response:
124
- response = response.split("###")[0].strip()
125
-
126
- return {"reply": response}
127
 
128
  except Exception as e:
129
- print("❌ ERROR EN /generate:", str(e))
130
- return {
131
- "reply": "Ocurrió un error interno al generar la respuesta."
132
- }
133
 
134
  # ======================
135
- # FRONTEND
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # ======================
137
  @app.get("/", response_class=HTMLResponse)
138
- def chat_ui():
139
- return """<!DOCTYPE html>
 
140
  <html lang="es">
141
  <head>
142
  <meta charset="UTF-8">
143
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
144
  <title>MTP 1.1</title>
145
  <style>
146
- body{margin:0;background:#131314;color:#e3e3e3;font-family:sans-serif}
147
- #chat{max-width:800px;margin:auto;padding:20px}
148
- .msg{margin:10px 0}
149
  .user{color:#8ab4f8}
150
  .bot{color:#e3e3e3}
151
- input{width:100%;padding:10px;border-radius:8px;border:none}
152
- button{margin-top:10px;padding:10px;border:none;border-radius:8px}
153
  </style>
154
  </head>
155
  <body>
 
156
  <div id="chat">
157
- <div class="msg bot">Hola, soy MTP 1.1 ¿en qué puedo ayudarte?</div>
158
  </div>
159
- <input id="inp" placeholder="Escribe aquí..." />
 
160
  <button onclick="send()">Enviar</button>
161
 
162
  <script>
163
  async function send(){
164
- const inp=document.getElementById('inp');
165
- const text=inp.value.trim();
166
  if(!text)return;
167
- inp.value="";
168
- document.getElementById('chat').innerHTML+=`<div class="msg user">${text}</div>`;
169
- const r=await fetch('/generate',{method:'POST',headers:{'Content-Type':'application/json'},body:JSON.stringify({text})});
170
- const j=await r.json();
171
- document.getElementById('chat').innerHTML+=`<div class="msg bot">${j.reply}</div>`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  }
173
  </script>
 
174
  </body>
175
- </html>"""
 
176
 
177
  # ======================
178
  # ENTRYPOINT
 
2
  import sys
3
  import torch
4
  import pickle
5
+ import time
6
  from fastapi import FastAPI
7
+ from fastapi.responses import HTMLResponse, StreamingResponse
8
  from pydantic import BaseModel
9
  from huggingface_hub import snapshot_download
10
  import uvicorn
11
 
12
+ # ======================
13
+ # OPTIMIZACIÓN CPU
14
+ # ======================
15
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
16
+ torch.set_grad_enabled(False)
17
+
18
  # ======================
19
  # DISPOSITIVO
20
  # ======================
21
  if torch.cuda.is_available():
22
  DEVICE = "cuda"
23
+ print("✅ GPU detectada. Usando CUDA.")
24
  else:
25
  DEVICE = "cpu"
26
+ print("⚠️ GPU no detectada. Usando CPU.")
27
 
28
  MODEL_REPO = "teszenofficial/mtp1"
29
 
 
31
  # DESCARGA MODELO
32
  # ======================
33
  print("--- SISTEMA MTP 1.1 ---")
 
 
34
  repo_path = snapshot_download(
35
  repo_id=MODEL_REPO,
36
  repo_type="model",
 
43
  from tokenizer import MTPTokenizer
44
 
45
  # ======================
46
+ # CARGA MODELO
47
  # ======================
48
+ print("Cargando modelo...")
 
 
 
 
 
 
 
 
 
 
49
 
50
+ pkl_file = next(f for f in os.listdir(repo_path) if f.endswith(".pkl"))
51
  with open(os.path.join(repo_path, pkl_file), "rb") as f:
52
  model_data = pickle.load(f)
53
 
54
+ tokenizer = MTPTokenizer(os.path.join(repo_path, "mtp_tokenizer.model"))
 
 
55
 
56
  config = model_data["config"]
57
 
 
69
  model.to(DEVICE)
70
  model.eval()
71
 
 
72
  VOCAB_SIZE = tokenizer.sp.get_piece_size()
73
  model.vocab_size = VOCAB_SIZE
74
 
75
+ print(f"🚀 MTP 1.1 listo en {DEVICE.upper()}")
76
 
77
  # ======================
78
  # FASTAPI
79
  # ======================
80
+ app = FastAPI(title="MTP 1.1")
81
 
82
  class Prompt(BaseModel):
83
  text: str
84
 
85
+ # ======================
86
+ # GENERACIÓN NORMAL (NO STREAM)
87
+ # ======================
88
  @app.post("/generate")
89
  def generate(prompt: Prompt):
90
  try:
91
+ text = prompt.text.strip()
92
+ if not text:
93
  return {"reply": ""}
94
 
95
+ full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n"
96
  tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
 
97
  input_ids = torch.tensor([tokens], device=DEVICE)
98
 
99
  with torch.no_grad():
100
+ output = model.generate(
101
  input_ids,
102
+ max_new_tokens=80,
103
  temperature=0.7,
104
  top_k=50,
105
  top_p=0.9
106
  )
107
 
108
+ gen = output[0, len(tokens):].tolist()
109
+ safe = [t for t in gen if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()]
110
+ reply = tokenizer.decode(safe).strip()
111
 
112
+ return {"reply": reply}
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  except Exception as e:
115
+ print("❌ ERROR:", e)
116
+ return {"reply": "Error interno."}
 
 
117
 
118
  # ======================
119
+ # GENERACIÓN STREAMING (TIPO CHATGPT)
120
+ # ======================
121
+ @app.post("/generate_stream")
122
+ def generate_stream(prompt: Prompt):
123
+ def stream():
124
+ try:
125
+ text = prompt.text.strip()
126
+ full_prompt = f"### Instrucción:\n{text}\n\n### Respuesta:\n"
127
+
128
+ tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
129
+ input_ids = torch.tensor([tokens], device=DEVICE)
130
+
131
+ for _ in range(80):
132
+ with torch.no_grad():
133
+ logits = model(input_ids)[:, -1, :]
134
+ logits = logits[:, :VOCAB_SIZE]
135
+ probs = torch.softmax(logits / 0.7, dim=-1)
136
+ next_id = torch.argmax(probs, dim=-1).item()
137
+
138
+ if next_id == tokenizer.eos_id():
139
+ break
140
+
141
+ if 0 <= next_id < VOCAB_SIZE:
142
+ token_text = tokenizer.decode([next_id])
143
+ yield token_text
144
+ input_ids = torch.cat(
145
+ [input_ids, torch.tensor([[next_id]], device=DEVICE)],
146
+ dim=1
147
+ )
148
+ time.sleep(0.015)
149
+
150
+ except Exception as e:
151
+ print("❌ STREAM ERROR:", e)
152
+ yield "\n[error]"
153
+
154
+ return StreamingResponse(stream(), media_type="text/plain")
155
+
156
+ # ======================
157
+ # FRONTEND HTML COMPLETO
158
  # ======================
159
  @app.get("/", response_class=HTMLResponse)
160
+ def ui():
161
+ return """
162
+ <!DOCTYPE html>
163
  <html lang="es">
164
  <head>
165
  <meta charset="UTF-8">
166
+ <meta name="viewport" content="width=device-width,initial-scale=1">
167
  <title>MTP 1.1</title>
168
  <style>
169
+ body{margin:0;background:#131314;color:#e3e3e3;font-family:Inter,system-ui}
170
+ #chat{max-width:900px;margin:auto;padding:20px}
171
+ .msg{margin:12px 0;white-space:pre-wrap}
172
  .user{color:#8ab4f8}
173
  .bot{color:#e3e3e3}
174
+ input{width:100%;padding:12px;border-radius:10px;border:none;background:#1e1f20;color:white}
175
+ button{margin-top:10px;padding:10px;border-radius:10px;border:none;background:#4a9eff;color:black;font-weight:bold}
176
  </style>
177
  </head>
178
  <body>
179
+
180
  <div id="chat">
181
+ <div class="msg bot">Hola, soy MTP 1.1.</div>
182
  </div>
183
+
184
+ <input id="inp" placeholder="Escribe algo…" />
185
  <button onclick="send()">Enviar</button>
186
 
187
  <script>
188
  async function send(){
189
+ const input=document.getElementById('inp');
190
+ const text=input.value.trim();
191
  if(!text)return;
192
+ input.value="";
193
+ const chat=document.getElementById('chat');
194
+ chat.innerHTML+=`<div class="msg user">${text}</div>`;
195
+ const bot=document.createElement('div');
196
+ bot.className="msg bot";
197
+ chat.appendChild(bot);
198
+
199
+ const res=await fetch('/generate_stream',{
200
+ method:'POST',
201
+ headers:{'Content-Type':'application/json'},
202
+ body:JSON.stringify({text})
203
+ });
204
+
205
+ const reader=res.body.getReader();
206
+ const decoder=new TextDecoder();
207
+ while(true){
208
+ const {value,done}=await reader.read();
209
+ if(done)break;
210
+ bot.textContent+=decoder.decode(value);
211
+ window.scrollTo(0,document.body.scrollHeight);
212
+ }
213
  }
214
  </script>
215
+
216
  </body>
217
+ </html>
218
+ """
219
 
220
  # ======================
221
  # ENTRYPOINT