Leches33 commited on
Commit
d445551
·
verified ·
1 Parent(s): 06c92ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -34
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import random
8
  import numpy as np
9
 
10
- # --- 1. CONFIGURACIÓN Y DETERMINISMO ---
11
  def set_seed(seed=42):
12
  random.seed(seed)
13
  np.random.seed(seed)
@@ -16,7 +16,7 @@ def set_seed(seed=42):
16
  set_seed(42)
17
 
18
  embed_size = 256
19
- num_heads = 4
20
  num_layers = 4
21
  block_size = 256
22
  vocab_size = 256
@@ -41,67 +41,69 @@ class MiniGPT(nn.Module):
41
  B, T = idx.shape
42
  T = min(T, block_size)
43
  idx = idx[:, -T:]
44
- tok_emb = self.token_embedding(idx)
45
- pos = torch.arange(T, device=device)
46
- pos_emb = self.pos_embedding(pos)[None, :, :]
47
- x = tok_emb + pos_emb
48
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
49
- for block in self.blocks:
50
- x = block(x, src_mask=mask)
51
  return self.fc_out(self.ln(x)), None
52
 
53
- # --- 3. CARGA ---
54
  model = MiniGPT(vocab_size).to(device)
55
  if os.path.exists("mini_gpt.pth"):
56
- try:
57
- model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
58
- except:
59
- pass
60
  model.eval()
61
 
62
- # --- 4. GENERACIÓN CON FILTRO RADICAL ANTI-COMILLAS ---
63
  def responder(mensaje, historial):
 
64
  contexto = f"### Human: {mensaje}\n### Assistant: "
65
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
66
  ai_txt = ""
 
 
 
 
67
 
68
  with torch.no_grad():
69
- for _ in range(120):
70
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
71
  logits, _ = model(idx)
72
- logits = logits[:, -1, :] / 0.8 # Temperatura para dar variedad
73
 
74
- # --- BLOQUEO DE REPETICIÓN CRÍTICA ---
75
- # Si el último token fue una comilla (34) o espacio (32), bajamos su probabilidad a casi cero
76
  if len(tokens) > 0:
77
- ultimo = tokens[-1]
78
- if ultimo in [34, 32, 10]: # Comilla, Espacio, Salto de línea
79
- logits[0, ultimo] -= 100.0
80
 
81
- # En lugar de Argmax (que causa bucles), usamos Multinomial suave
 
 
 
82
  probs = F.softmax(logits, dim=-1)
83
  next_token = torch.multinomial(probs, num_samples=1).item()
84
  char = chr(next_token)
85
 
86
- # Si detecta que el modelo intenta repetir la estructura de prompt, paramos
 
 
 
 
87
  if "### Human:" in ai_txt: break
 
88
 
89
  tokens.append(next_token)
90
  ai_txt += char
91
-
92
- # Si ya tenemos una respuesta coherente y salta línea, cerramos
93
- if char == "\n" and len(ai_txt) > 20: break
94
 
95
- # Limpieza final de caracteres basura
96
- limpio = ai_txt.replace('###', '').replace('Assistant:', '').strip()
97
- # Eliminar múltiples comillas seguidas con un filtro simple de Python
98
- while '""' in limpio:
99
- limpio = limpio.replace('""', '"')
100
-
101
- return limpio if len(limpio) > 1 else "Sigo procesando el entrenamiento... ¡Pregúntame otra vez!"
 
102
 
103
  # --- 5. INTERFAZ ---
104
- demo = gr.ChatInterface(fn=responder, title="IA Personal - Filtro Anti-Bucle")
105
 
106
  if __name__ == "__main__":
107
  demo.launch()
 
7
  import random
8
  import numpy as np
9
 
10
+ # --- 1. CONFIGURACIÓN IDÉNTICA AL PC ---
11
  def set_seed(seed=42):
12
  random.seed(seed)
13
  np.random.seed(seed)
 
16
  set_seed(42)
17
 
18
  embed_size = 256
19
+ num_heads = 4 # Ajusta a 8 si en tu PC pusiste 8
20
  num_layers = 4
21
  block_size = 256
22
  vocab_size = 256
 
41
  B, T = idx.shape
42
  T = min(T, block_size)
43
  idx = idx[:, -T:]
44
+ x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device=device))[None, :, :]
 
 
 
45
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
46
+ for block in self.blocks: x = block(x, src_mask=mask)
 
47
  return self.fc_out(self.ln(x)), None
48
 
49
+ # --- 3. CARGA DE PESOS ---
50
  model = MiniGPT(vocab_size).to(device)
51
  if os.path.exists("mini_gpt.pth"):
52
+ model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
 
 
 
53
  model.eval()
54
 
55
+ # --- 4. GENERACIÓN CONTROLADA (HF + PC Fusion) ---
56
  def responder(mensaje, historial):
57
+ # Formato de prompt para guiar la estructura
58
  contexto = f"### Human: {mensaje}\n### Assistant: "
59
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
60
  ai_txt = ""
61
+
62
+ # Parámetros de "limpieza" en vivo
63
+ temp = 0.7
64
+ top_k = 40
65
 
66
  with torch.no_grad():
67
+ for _ in range(150):
68
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
69
  logits, _ = model(idx)
70
+ logits = logits[:, -1, :] / temp
71
 
72
+ # Penalización de repetición (Anti-bucle de símbolos)
 
73
  if len(tokens) > 0:
74
+ for t in set(tokens[-5:]): # Miramos los últimos 5 tokens
75
+ logits[0, t] -= 2.0
 
76
 
77
+ # Filtro Top-K (Elimina la basura de baja probabilidad)
78
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
79
+ logits[logits < v[:, [-1]]] = -float('Inf')
80
+
81
  probs = F.softmax(logits, dim=-1)
82
  next_token = torch.multinomial(probs, num_samples=1).item()
83
  char = chr(next_token)
84
 
85
+ # --- SEGURIDAD: Cortar si empieza a alucinar símbolos ---
86
+ if char in "'{}[]()=|_/\\":
87
+ # Si el modelo intenta poner símbolos raros, lo ignoramos o cortamos
88
+ continue
89
+
90
  if "### Human:" in ai_txt: break
91
+ if char == "\n" and len(ai_txt) > 30: break
92
 
93
  tokens.append(next_token)
94
  ai_txt += char
 
 
 
95
 
96
+ # Limpieza final de la respuesta
97
+ res_limpia = ai_txt.strip()
98
+
99
+ # Si la respuesta es demasiado corta o solo espacios, avisamos
100
+ if len(res_limpia) < 3:
101
+ return "El modelo está en una fase de entrenamiento inestable. Prueba con otra pregunta o espera a que baje el Loss."
102
+
103
+ return res_limpia
104
 
105
  # --- 5. INTERFAZ ---
106
+ demo = gr.ChatInterface(fn=responder, title="IA Personal - Fusion Mode")
107
 
108
  if __name__ == "__main__":
109
  demo.launch()