teszenofficial commited on
Commit
0812507
·
verified ·
1 Parent(s): c56871f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -27
app.py CHANGED
@@ -2,8 +2,9 @@ import os
2
  import sys
3
  import torch
4
  import pickle
 
5
  from fastapi import FastAPI
6
- from fastapi.responses import HTMLResponse
7
  from pydantic import BaseModel
8
  from huggingface_hub import snapshot_download
9
  import uvicorn
@@ -11,7 +12,6 @@ import uvicorn
11
  # ======================
12
  # CONFIGURACIÓN DE DISPOSITIVO (GPU/CPU)
13
  # ======================
14
- # Detectar automáticamente si hay una GPU NVIDIA disponible
15
  if torch.cuda.is_available():
16
  DEVICE = "cuda"
17
  print("✅ GPU NVIDIA detectada. Usando CUDA.")
@@ -19,12 +19,18 @@ else:
19
  DEVICE = "cpu"
20
  print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
21
 
 
 
 
 
 
 
22
  MODEL_REPO = "teszenofficial/mtptz"
23
 
24
  # ======================
25
  # DESCARGA DEL MODELO
26
  # ======================
27
- print(f"--- SISTEMA MTP 1.1 ---")
28
  print(f"Descargando/Verificando modelo desde {MODEL_REPO}...")
29
  repo_path = snapshot_download(
30
  repo_id=MODEL_REPO,
@@ -34,12 +40,8 @@ repo_path = snapshot_download(
34
 
35
  sys.path.insert(0, repo_path)
36
 
37
- try:
38
- from model import MTPMiniModel
39
- from tokenizer import MTPTokenizer
40
- except ImportError:
41
- print("Advertencia: Verifica la estructura de archivos del modelo.")
42
- pass
43
 
44
  # ======================
45
  # CARGA DEL MODELO
@@ -52,10 +54,11 @@ tokenizer = MTPTokenizer(
52
  os.path.join(repo_path, "mtp_tokenizer.model")
53
  )
54
 
 
55
  config = model_data["config"]
56
 
57
  model = MTPMiniModel(
58
- vocab_size=model_data["vocab_size"],
59
  d_model=config["model"]["d_model"],
60
  n_layers=config["model"]["n_layers"],
61
  n_heads=config["model"]["n_heads"],
@@ -64,11 +67,22 @@ model = MTPMiniModel(
64
  dropout=0.0
65
  )
66
 
67
- # Cargar pesos y mover a GPU
68
  model.load_state_dict(model_data["model_state_dict"])
69
- model.to(DEVICE)
70
  model.eval()
71
- print(f"🚀 MTP 1.1 listo y corriendo en: {DEVICE.upper()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # ======================
74
  # API FASTAPI
@@ -78,16 +92,31 @@ app = FastAPI(title="MTP 2 API")
78
  class Prompt(BaseModel):
79
  text: str
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @app.post("/generate")
82
  def generate(prompt: Prompt):
83
  user_input = prompt.text.strip()
84
  if not user_input:
85
  return {"reply": ""}
86
 
87
- full_prompt = f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
88
  tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
89
-
90
- # IMPORTANTE: Mover los inputs también a la GPU
91
  input_ids = torch.tensor([tokens], device=DEVICE)
92
 
93
  with torch.no_grad():
@@ -101,21 +130,57 @@ def generate(prompt: Prompt):
101
 
102
  gen_tokens = output_ids[0, len(tokens):].tolist()
103
 
104
- if tokenizer.eos_id() in gen_tokens:
105
- gen_tokens = gen_tokens[:gen_tokens.index(tokenizer.eos_id())]
 
 
 
106
 
107
- response = tokenizer.decode(gen_tokens).strip()
108
  if "###" in response:
109
  response = response.split("###")[0].strip()
110
 
111
  return {"reply": response}
112
 
113
  # ======================
114
- # INTERFAZ WEB (FRONTEND MEJORADO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # ======================
116
  @app.get("/", response_class=HTMLResponse)
117
  def chat_ui():
118
  return """
 
119
  <!DOCTYPE html>
120
  <html lang="es">
121
  <head>
@@ -621,12 +686,11 @@ window.onload = () => userInput.focus();
621
  </script>
622
  </body>
623
  </html>
624
- """
625
 
626
- # ======================
627
- # ENTRYPOINT
628
- # ======================
629
  if __name__ == "__main__":
630
- uvicorn.run(app, host="0.0.0.0", port=7860)
631
-
632
-
 
 
 
2
  import sys
3
  import torch
4
  import pickle
5
+ import time
6
  from fastapi import FastAPI
7
+ from fastapi.responses import HTMLResponse, StreamingResponse
8
  from pydantic import BaseModel
9
  from huggingface_hub import snapshot_download
10
  import uvicorn
 
12
  # ======================
13
  # CONFIGURACIÓN DE DISPOSITIVO (GPU/CPU)
14
  # ======================
 
15
  if torch.cuda.is_available():
16
  DEVICE = "cuda"
17
  print("✅ GPU NVIDIA detectada. Usando CUDA.")
 
19
  DEVICE = "cpu"
20
  print("⚠️ GPU no detectada. Usando CPU (puede ser más lento).")
21
 
22
+ # ======================
23
+ # OPTIMIZACIÓN CPU
24
+ # ======================
25
+ torch.set_grad_enabled(False)
26
+ torch.set_num_threads(max(1, os.cpu_count() // 2))
27
+
28
  MODEL_REPO = "teszenofficial/mtptz"
29
 
30
  # ======================
31
  # DESCARGA DEL MODELO
32
  # ======================
33
+ print(f"--- SISTEMA MTP 2 ---")
34
  print(f"Descargando/Verificando modelo desde {MODEL_REPO}...")
35
  repo_path = snapshot_download(
36
  repo_id=MODEL_REPO,
 
40
 
41
  sys.path.insert(0, repo_path)
42
 
43
+ from model import MTPMiniModel
44
+ from tokenizer import MTPTokenizer
 
 
 
 
45
 
46
  # ======================
47
  # CARGA DEL MODELO
 
54
  os.path.join(repo_path, "mtp_tokenizer.model")
55
  )
56
 
57
+ VOCAB_SIZE = tokenizer.sp.get_piece_size()
58
  config = model_data["config"]
59
 
60
  model = MTPMiniModel(
61
+ vocab_size=VOCAB_SIZE,
62
  d_model=config["model"]["d_model"],
63
  n_layers=config["model"]["n_layers"],
64
  n_heads=config["model"]["n_heads"],
 
67
  dropout=0.0
68
  )
69
 
 
70
  model.load_state_dict(model_data["model_state_dict"])
 
71
  model.eval()
72
+
73
+ # ======================
74
+ # ⚙️ CUANTIZACIÓN CPU
75
+ # ======================
76
+ if DEVICE == "cpu":
77
+ model = torch.quantization.quantize_dynamic(
78
+ model,
79
+ {torch.nn.Linear},
80
+ dtype=torch.qint8
81
+ )
82
+ print("⚙️ Modelo cuantizado para CPU")
83
+
84
+ model.to(DEVICE)
85
+ print(f"🚀 MTP 2 listo y corriendo en: {DEVICE.upper()}")
86
 
87
  # ======================
88
  # API FASTAPI
 
92
  class Prompt(BaseModel):
93
  text: str
94
 
95
+ # ======================
96
+ # 🧠 PROMPT MEJORADO (MISMO FORMATO)
97
+ # ======================
98
+ def build_prompt(user_input: str) -> str:
99
+ return f"""Eres MTP, un modelo de lenguaje experimental.
100
+ Responde de forma clara, directa y coherente.
101
+ No inventes información.
102
+
103
+ ### Instrucción:
104
+ {user_input}
105
+
106
+ ### Respuesta:
107
+ """
108
+
109
+ # ======================
110
+ # GENERACIÓN NORMAL (IGUAL QUE ANTES)
111
+ # ======================
112
  @app.post("/generate")
113
  def generate(prompt: Prompt):
114
  user_input = prompt.text.strip()
115
  if not user_input:
116
  return {"reply": ""}
117
 
118
+ full_prompt = build_prompt(user_input)
119
  tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
 
 
120
  input_ids = torch.tensor([tokens], device=DEVICE)
121
 
122
  with torch.no_grad():
 
130
 
131
  gen_tokens = output_ids[0, len(tokens):].tolist()
132
 
133
+ # 🔒 FILTRO DE SEGURIDAD
134
+ safe_tokens = [
135
+ t for t in gen_tokens
136
+ if 0 <= t < VOCAB_SIZE and t != tokenizer.eos_id()
137
+ ]
138
 
139
+ response = tokenizer.decode(safe_tokens).strip()
140
  if "###" in response:
141
  response = response.split("###")[0].strip()
142
 
143
  return {"reply": response}
144
 
145
  # ======================
146
+ # 📡 STREAMING SSE OFICIAL
147
+ # ======================
148
+ @app.get("/generate_sse")
149
+ def generate_sse(text: str):
150
+ def event_stream():
151
+ full_prompt = build_prompt(text)
152
+ tokens = [tokenizer.bos_id()] + tokenizer.encode(full_prompt)
153
+ input_ids = torch.tensor([tokens], device=DEVICE)
154
+
155
+ for _ in range(150):
156
+ with torch.no_grad():
157
+ logits = model(input_ids)[:, -1, :VOCAB_SIZE]
158
+ probs = torch.softmax(logits / 0.7, dim=-1)
159
+ next_id = torch.argmax(probs, dim=-1).item()
160
+
161
+ if next_id == tokenizer.eos_id():
162
+ break
163
+
164
+ if 0 <= next_id < VOCAB_SIZE:
165
+ token_text = tokenizer.decode([next_id])
166
+ yield f"data:{token_text}\n\n"
167
+ input_ids = torch.cat(
168
+ [input_ids, torch.tensor([[next_id]], device=DEVICE)],
169
+ dim=1
170
+ )
171
+ time.sleep(0.015)
172
+
173
+ yield "data:[DONE]\n\n"
174
+
175
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
176
+
177
+ # ======================
178
+ # INTERFAZ WEB (TU HTML COMPLETO, SIN QUITAR NADA)
179
  # ======================
180
  @app.get("/", response_class=HTMLResponse)
181
  def chat_ui():
182
  return """
183
+
184
  <!DOCTYPE html>
185
  <html lang="es">
186
  <head>
 
686
  </script>
687
  </body>
688
  </html>
 
689
 
690
+ """
 
 
691
  if __name__ == "__main__":
692
+ uvicorn.run(
693
+ app,
694
+ host="0.0.0.0",
695
+ port=int(os.environ.get("PORT", 7860))
696
+ )