Leches33 commited on
Commit
4aa47aa
·
verified ·
1 Parent(s): 80288b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -23
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # -*- coding: utf-8 -*-
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -7,7 +8,7 @@ import os
7
  import random
8
  import numpy as np
9
 
10
- # --- 1. DETERMINISMO TOTAL ---
11
  def set_seed(seed=42):
12
  random.seed(seed)
13
  np.random.seed(seed)
@@ -40,68 +41,94 @@ class MiniGPT(nn.Module):
40
  self.ln = nn.LayerNorm(embed_size)
41
  self.fc_out = nn.Linear(embed_size, v_size)
42
 
43
- def forward(self, idx, targets=None):
44
  B, T = idx.shape
45
  T = min(T, block_size)
46
  idx = idx[:, -T:]
47
-
48
  tok_emb = self.token_embedding(idx)
49
  pos = torch.arange(T, device=device)
50
  pos_emb = self.pos_embedding(pos)[None, :, :]
51
  x = tok_emb + pos_emb
52
-
53
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
54
  for block in self.blocks:
55
  x = block(x, src_mask=mask)
56
-
57
- x = self.ln(x)
58
- logits = self.fc_out(x)
59
- return logits, None
60
 
61
- # --- CARGA DEL MODELO ---
62
  model = MiniGPT(vocab_size).to(device)
63
  if os.path.exists("mini_gpt.pth"):
64
  try:
 
65
  model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
66
  except Exception as e:
67
- print(f"Error al cargar pesos: {e}")
68
  model.eval()
69
 
70
- # --- FUNCIÓN DE RESPUESTA OPTIMIZADA ---
71
  def responder(mensaje, historial):
 
72
  contexto = f"### Human: {mensaje}\n### Assistant: "
73
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
74
  ai_txt = ""
 
 
 
 
75
 
76
  with torch.no_grad():
77
- for _ in range(100):
78
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
79
  logits, _ = model(idx)
80
 
81
- # Greedy Search (Determinista)
82
- next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
 
 
 
 
 
 
 
 
83
  char = chr(next_token)
84
 
85
- if char == "\n" and len(ai_txt) > 5: break
86
- if "### Human:" in ai_txt: break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  tokens.append(next_token)
89
  ai_txt += char
90
 
 
91
  output = ai_txt.split("###")[0].strip()
92
 
93
- if not output:
94
- return "Lo siento, todavía estoy aprendiendo de este dataset..."
 
95
 
96
  return output
97
 
98
- # --- INTERFAZ GRADIO ---
99
- # Eliminado el argumento 'theme' para evitar el TypeError en el servidor de HF
100
  demo = gr.ChatInterface(
101
  fn=responder,
102
- title="Mi IA Personal (Sync Edition)",
103
- description="Entrenando en PC -> Desplegado en HF. Sincronización de respuestas activa.",
104
- examples=["Hola", "¿Qué tal?", "Cuéntame algo"]
105
  )
106
 
107
  if __name__ == "__main__":
 
1
  # -*- coding: utf-8 -*-
2
+ # -*- coding: utf-8 -*-
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
8
  import random
9
  import numpy as np
10
 
11
+ # --- 1. DETERMINISMO ---
12
  def set_seed(seed=42):
13
  random.seed(seed)
14
  np.random.seed(seed)
 
41
  self.ln = nn.LayerNorm(embed_size)
42
  self.fc_out = nn.Linear(embed_size, v_size)
43
 
44
+ def forward(self, idx):
45
  B, T = idx.shape
46
  T = min(T, block_size)
47
  idx = idx[:, -T:]
 
48
  tok_emb = self.token_embedding(idx)
49
  pos = torch.arange(T, device=device)
50
  pos_emb = self.pos_embedding(pos)[None, :, :]
51
  x = tok_emb + pos_emb
 
52
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
53
  for block in self.blocks:
54
  x = block(x, src_mask=mask)
55
+ return self.fc_out(self.ln(x)), None
 
 
 
56
 
57
+ # --- CARGAR EL MODELO ---
58
  model = MiniGPT(vocab_size).to(device)
59
  if os.path.exists("mini_gpt.pth"):
60
  try:
61
+ # Cargamos los pesos sincronizados desde tu PC
62
  model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
63
  except Exception as e:
64
+ print(f"Error cargando pesos: {e}")
65
  model.eval()
66
 
67
+ # --- FUNCIÓN DE RESPUESTA CON FILTROS ANTI-BUCLE ---
68
  def responder(mensaje, historial):
69
+ # Usamos el formato de tu dataset para guiar a la IA
70
  contexto = f"### Human: {mensaje}\n### Assistant: "
71
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
72
  ai_txt = ""
73
+
74
+ # Variables de control de repetición
75
+ ultimo_char = ""
76
+ contador_repeticion = 0
77
 
78
  with torch.no_grad():
79
+ for _ in range(150): # Aumentamos un poco el límite de respuesta
80
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
81
  logits, _ = model(idx)
82
 
83
+ # Tomamos el último paso y aplicamos Temperatura suave
84
+ logits = logits[:, -1, :] / 0.8
85
+
86
+ # --- PENALIZACIÓN DE REPETICIÓN ---
87
+ # Bajamos la probabilidad de los últimos 3 tokens usados para evitar bucles
88
+ for t in set(tokens[-3:]):
89
+ logits[0, t] -= 1.5
90
+
91
+ # Greedy Search (Elección del más probable)
92
+ next_token = torch.argmax(logits, dim=-1).item()
93
  char = chr(next_token)
94
 
95
+ # --- DETECTOR DE ATASCOS ---
96
+ if char == ultimo_char:
97
+ contador_repeticion += 1
98
+ else:
99
+ contador_repeticion = 0
100
+ ultimo_char = char
101
+
102
+ # Si repite el mismo carácter (como las comillas) más de 3 veces, cortamos
103
+ if contador_repeticion > 3:
104
+ break
105
+
106
+ # Si detecta un salto de línea y ya ha escrito algo, finaliza
107
+ if char == "\n" and len(ai_txt) > 10:
108
+ break
109
+
110
+ # Si intenta auto-generarse un nuevo humano, finaliza
111
+ if "### Human:" in ai_txt:
112
+ break
113
 
114
  tokens.append(next_token)
115
  ai_txt += char
116
 
117
+ # Limpieza final de etiquetas y caracteres de control
118
  output = ai_txt.split("###")[0].strip()
119
 
120
+ # Si el resultado es basura o está vacío (por el corte de seguridad)
121
+ if not output or len(output) < 2:
122
+ return "Estoy procesando la información... intenta preguntarme algo más específico."
123
 
124
  return output
125
 
126
+ # --- INTERFAZ ---
 
127
  demo = gr.ChatInterface(
128
  fn=responder,
129
+ title="Mi IA Personal (Optimized)",
130
+ description="Modelo MiniGPT con filtros de repetición y sincronización de pesos activa.",
131
+ examples=["Hola", "¿Qué has aprendido?", "Cuéntame una historia"]
132
  )
133
 
134
  if __name__ == "__main__":