Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
|
@@ -7,7 +8,7 @@ import os
|
|
| 7 |
import random
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
-
# --- 1. DETERMINISMO
|
| 11 |
def set_seed(seed=42):
|
| 12 |
random.seed(seed)
|
| 13 |
np.random.seed(seed)
|
|
@@ -40,68 +41,94 @@ class MiniGPT(nn.Module):
|
|
| 40 |
self.ln = nn.LayerNorm(embed_size)
|
| 41 |
self.fc_out = nn.Linear(embed_size, v_size)
|
| 42 |
|
| 43 |
-
def forward(self, idx
|
| 44 |
B, T = idx.shape
|
| 45 |
T = min(T, block_size)
|
| 46 |
idx = idx[:, -T:]
|
| 47 |
-
|
| 48 |
tok_emb = self.token_embedding(idx)
|
| 49 |
pos = torch.arange(T, device=device)
|
| 50 |
pos_emb = self.pos_embedding(pos)[None, :, :]
|
| 51 |
x = tok_emb + pos_emb
|
| 52 |
-
|
| 53 |
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
|
| 54 |
for block in self.blocks:
|
| 55 |
x = block(x, src_mask=mask)
|
| 56 |
-
|
| 57 |
-
x = self.ln(x)
|
| 58 |
-
logits = self.fc_out(x)
|
| 59 |
-
return logits, None
|
| 60 |
|
| 61 |
-
# ---
|
| 62 |
model = MiniGPT(vocab_size).to(device)
|
| 63 |
if os.path.exists("mini_gpt.pth"):
|
| 64 |
try:
|
|
|
|
| 65 |
model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
|
| 66 |
except Exception as e:
|
| 67 |
-
print(f"Error
|
| 68 |
model.eval()
|
| 69 |
|
| 70 |
-
# --- FUNCIÓN DE RESPUESTA
|
| 71 |
def responder(mensaje, historial):
|
|
|
|
| 72 |
contexto = f"### Human: {mensaje}\n### Assistant: "
|
| 73 |
tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
|
| 74 |
ai_txt = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
with torch.no_grad():
|
| 77 |
-
for _ in range(
|
| 78 |
idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
|
| 79 |
logits, _ = model(idx)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
char = chr(next_token)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
tokens.append(next_token)
|
| 89 |
ai_txt += char
|
| 90 |
|
|
|
|
| 91 |
output = ai_txt.split("###")[0].strip()
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
return output
|
| 97 |
|
| 98 |
-
# --- INTERFAZ
|
| 99 |
-
# Eliminado el argumento 'theme' para evitar el TypeError en el servidor de HF
|
| 100 |
demo = gr.ChatInterface(
|
| 101 |
fn=responder,
|
| 102 |
-
title="Mi IA Personal (
|
| 103 |
-
description="
|
| 104 |
-
examples=["Hola", "¿Qué
|
| 105 |
)
|
| 106 |
|
| 107 |
if __name__ == "__main__":
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 8 |
import random
|
| 9 |
import numpy as np
|
| 10 |
|
| 11 |
+
# --- 1. DETERMINISMO ---
|
| 12 |
def set_seed(seed=42):
|
| 13 |
random.seed(seed)
|
| 14 |
np.random.seed(seed)
|
|
|
|
| 41 |
self.ln = nn.LayerNorm(embed_size)
|
| 42 |
self.fc_out = nn.Linear(embed_size, v_size)
|
| 43 |
|
| 44 |
+
def forward(self, idx):
|
| 45 |
B, T = idx.shape
|
| 46 |
T = min(T, block_size)
|
| 47 |
idx = idx[:, -T:]
|
|
|
|
| 48 |
tok_emb = self.token_embedding(idx)
|
| 49 |
pos = torch.arange(T, device=device)
|
| 50 |
pos_emb = self.pos_embedding(pos)[None, :, :]
|
| 51 |
x = tok_emb + pos_emb
|
|
|
|
| 52 |
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
|
| 53 |
for block in self.blocks:
|
| 54 |
x = block(x, src_mask=mask)
|
| 55 |
+
return self.fc_out(self.ln(x)), None
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
# --- CARGAR EL MODELO ---
|
| 58 |
model = MiniGPT(vocab_size).to(device)
|
| 59 |
if os.path.exists("mini_gpt.pth"):
|
| 60 |
try:
|
| 61 |
+
# Cargamos los pesos sincronizados desde tu PC
|
| 62 |
model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
|
| 63 |
except Exception as e:
|
| 64 |
+
print(f"Error cargando pesos: {e}")
|
| 65 |
model.eval()
|
| 66 |
|
| 67 |
+
# --- FUNCIÓN DE RESPUESTA CON FILTROS ANTI-BUCLE ---
|
| 68 |
def responder(mensaje, historial):
|
| 69 |
+
# Usamos el formato de tu dataset para guiar a la IA
|
| 70 |
contexto = f"### Human: {mensaje}\n### Assistant: "
|
| 71 |
tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
|
| 72 |
ai_txt = ""
|
| 73 |
+
|
| 74 |
+
# Variables de control de repetición
|
| 75 |
+
ultimo_char = ""
|
| 76 |
+
contador_repeticion = 0
|
| 77 |
|
| 78 |
with torch.no_grad():
|
| 79 |
+
for _ in range(150): # Aumentamos un poco el límite de respuesta
|
| 80 |
idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
|
| 81 |
logits, _ = model(idx)
|
| 82 |
|
| 83 |
+
# Tomamos el último paso y aplicamos Temperatura suave
|
| 84 |
+
logits = logits[:, -1, :] / 0.8
|
| 85 |
+
|
| 86 |
+
# --- PENALIZACIÓN DE REPETICIÓN ---
|
| 87 |
+
# Bajamos la probabilidad de los últimos 3 tokens usados para evitar bucles
|
| 88 |
+
for t in set(tokens[-3:]):
|
| 89 |
+
logits[0, t] -= 1.5
|
| 90 |
+
|
| 91 |
+
# Greedy Search (Elección del más probable)
|
| 92 |
+
next_token = torch.argmax(logits, dim=-1).item()
|
| 93 |
char = chr(next_token)
|
| 94 |
|
| 95 |
+
# --- DETECTOR DE ATASCOS ---
|
| 96 |
+
if char == ultimo_char:
|
| 97 |
+
contador_repeticion += 1
|
| 98 |
+
else:
|
| 99 |
+
contador_repeticion = 0
|
| 100 |
+
ultimo_char = char
|
| 101 |
+
|
| 102 |
+
# Si repite el mismo carácter (como las comillas) más de 3 veces, cortamos
|
| 103 |
+
if contador_repeticion > 3:
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
# Si detecta un salto de línea y ya ha escrito algo, finaliza
|
| 107 |
+
if char == "\n" and len(ai_txt) > 10:
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
# Si intenta auto-generarse un nuevo humano, finaliza
|
| 111 |
+
if "### Human:" in ai_txt:
|
| 112 |
+
break
|
| 113 |
|
| 114 |
tokens.append(next_token)
|
| 115 |
ai_txt += char
|
| 116 |
|
| 117 |
+
# Limpieza final de etiquetas y caracteres de control
|
| 118 |
output = ai_txt.split("###")[0].strip()
|
| 119 |
|
| 120 |
+
# Si el resultado es basura o está vacío (por el corte de seguridad)
|
| 121 |
+
if not output or len(output) < 2:
|
| 122 |
+
return "Estoy procesando la información... intenta preguntarme algo más específico."
|
| 123 |
|
| 124 |
return output
|
| 125 |
|
| 126 |
+
# --- INTERFAZ ---
|
|
|
|
| 127 |
demo = gr.ChatInterface(
|
| 128 |
fn=responder,
|
| 129 |
+
title="Mi IA Personal (Optimized)",
|
| 130 |
+
description="Modelo MiniGPT con filtros de repetición y sincronización de pesos activa.",
|
| 131 |
+
examples=["Hola", "¿Qué has aprendido?", "Cuéntame una historia"]
|
| 132 |
)
|
| 133 |
|
| 134 |
if __name__ == "__main__":
|