Leches33 commited on
Commit
c304f30
·
verified ·
1 Parent(s): 57c0dcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -19
app.py CHANGED
@@ -4,8 +4,20 @@ import torch.nn as nn
4
  import torch.nn.functional as F
5
  import gradio as gr
6
  import os
 
 
7
 
8
- # --- MISMOS HIPERPARÁMETROS ---
 
 
 
 
 
 
 
 
 
 
9
  embed_size = 256
10
  num_heads = 4
11
  num_layers = 4
@@ -13,7 +25,7 @@ block_size = 256
13
  vocab_size = 256
14
  device = "cpu"
15
 
16
- # --- TU ARQUITECTURA ---
17
  class MiniGPT(nn.Module):
18
  def __init__(self, v_size):
19
  super().__init__()
@@ -21,8 +33,8 @@ class MiniGPT(nn.Module):
21
  self.pos_embedding = nn.Embedding(block_size, embed_size)
22
  self.blocks = nn.ModuleList([
23
  nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
24
- dim_feedforward=embed_size*4, batch_first=True,
25
- dropout=0.1, norm_first=True)
26
  for _ in range(num_layers)
27
  ])
28
  self.ln = nn.LayerNorm(embed_size)
@@ -30,45 +42,74 @@ class MiniGPT(nn.Module):
30
 
31
  def forward(self, idx, targets=None):
32
  B, T = idx.shape
 
 
 
 
33
  tok_emb = self.token_embedding(idx)
34
  pos = torch.arange(T, device=device)
35
  pos_emb = self.pos_embedding(pos)[None, :, :]
36
  x = tok_emb + pos_emb
 
37
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
38
- for block in self.blocks: x = block(x, src_mask=mask)
 
 
39
  x = self.ln(x)
40
  logits = self.fc_out(x)
41
  return logits, None
42
 
43
- # --- CARGAR EL MODELO ---
44
  model = MiniGPT(vocab_size).to(device)
45
  if os.path.exists("mini_gpt.pth"):
 
46
  model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
47
  model.eval()
48
 
49
- # --- FUNCIÓN DE RESPUESTA ---
50
  def responder(mensaje, historial):
51
- contexto = f"\nUsuario: {mensaje}\nIA: "
 
 
52
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
53
  ai_txt = ""
54
-
55
  with torch.no_grad():
56
- for _ in range(150):
57
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
58
  logits, _ = model(idx)
59
- probs = F.softmax(logits[:, -1, :] / 0.8, dim=-1)
60
- next_token = torch.multinomial(probs, num_samples=1).item()
 
 
 
61
  char = chr(next_token)
62
 
63
- if char == "\n" or ai_txt.endswith("Usuario:"): break
 
 
 
64
  tokens.append(next_token)
65
  ai_txt += char
66
-
67
- return ai_txt.replace("Usuario:", "").strip()
68
 
69
- # --- INTERFAZ ---
70
- demo = gr.ChatInterface(fn=responder, title="Mi IA Personal", description="Modelo MiniGPT entrenado.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  if __name__ == "__main__":
73
- demo.launch()
74
- aunch()
 
4
  import torch.nn.functional as F
5
  import gradio as gr
6
  import os
7
+ import random
8
+ import numpy as np
9
 
10
+ # --- 1. DETERMINISMO TOTAL ---
11
+ # Esto asegura que HF use la misma lógica matemática que tu PC
12
+ def set_seed(seed=42):
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+
18
+ set_seed(42)
19
+
20
+ # --- CONFIGURACIÓN ---
21
  embed_size = 256
22
  num_heads = 4
23
  num_layers = 4
 
25
  vocab_size = 256
26
  device = "cpu"
27
 
28
+ # --- ARQUITECTURA ---
29
  class MiniGPT(nn.Module):
30
  def __init__(self, v_size):
31
  super().__init__()
 
33
  self.pos_embedding = nn.Embedding(block_size, embed_size)
34
  self.blocks = nn.ModuleList([
35
  nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
36
+ dim_feedforward=embed_size*4, batch_first=True,
37
+ dropout=0.1, norm_first=True)
38
  for _ in range(num_layers)
39
  ])
40
  self.ln = nn.LayerNorm(embed_size)
 
42
 
43
  def forward(self, idx, targets=None):
44
  B, T = idx.shape
45
+ # Limitamos el tamaño del bloque para evitar errores de índice
46
+ T = min(T, block_size)
47
+ idx = idx[:, -T:]
48
+
49
  tok_emb = self.token_embedding(idx)
50
  pos = torch.arange(T, device=device)
51
  pos_emb = self.pos_embedding(pos)[None, :, :]
52
  x = tok_emb + pos_emb
53
+
54
  mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
55
+ for block in self.blocks:
56
+ x = block(x, src_mask=mask)
57
+
58
  x = self.ln(x)
59
  logits = self.fc_out(x)
60
  return logits, None
61
 
62
+ # --- CARGA DEL MODELO ---
63
  model = MiniGPT(vocab_size).to(device)
64
  if os.path.exists("mini_gpt.pth"):
65
+ # Uso de weights_only=True por seguridad y compatibilidad
66
  model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
67
  model.eval()
68
 
69
+ # --- FUNCIÓN DE RESPUESTA OPTIMIZADA ---
70
  def responder(mensaje, historial):
71
+ # Formateamos el prompt exactamente como en el entrenamiento
72
+ # Usamos marcas claras para que la IA sepa dónde empezar
73
+ contexto = f"### Human: {mensaje}\n### Assistant: "
74
  tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
75
  ai_txt = ""
76
+
77
  with torch.no_grad():
78
+ for _ in range(100): # 100 caracteres es suficiente para CPU
79
  idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
80
  logits, _ = model(idx)
81
+
82
+ # --- GREEDY SEARCH (Cero azar) ---
83
+ # En lugar de multinomial, usamos argmax para que PC y HF sean gemelos
84
+ next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
85
+
86
  char = chr(next_token)
87
 
88
+ # Frenado de emergencia si empieza a repetir el prompt
89
+ if char == "\n" and len(ai_txt) > 5: break
90
+ if "### Human:" in ai_txt: break
91
+
92
  tokens.append(next_token)
93
  ai_txt += char
 
 
94
 
95
+ # --- LIMPIEZA FINAL ---
96
+ # Eliminamos cualquier residuo de las etiquetas de entrenamiento
97
+ output = ai_txt.split("###")[0].strip()
98
+
99
+ # Si la respuesta es nula o basura, damos un aviso
100
+ if not output:
101
+ return "Lo siento, todavía estoy aprendiendo de este dataset..."
102
+
103
+ return output
104
+
105
+ # --- INTERFAZ GRADIO ---
106
+ demo = gr.ChatInterface(
107
+ fn=responder,
108
+ title="Mi IA Personal (Sync Edition)",
109
+ description="Entrenando en PC -> Desplegado en HF. Sincronización de respuestas activa.",
110
+ examples=["Hola", "¿Qué tal?", "Cuéntame algo"],
111
+ theme="soft"
112
+ )
113
 
114
  if __name__ == "__main__":
115
+ demo.launch()