AI / app.py
Leches33's picture
Update app.py
d445551 verified
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import os
import random
import numpy as np
# --- 1. CONFIGURACIÓN IDÉNTICA AL PC ---
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
set_seed(42)
embed_size = 256
num_heads = 4 # Ajusta a 8 si en tu PC pusiste 8
num_layers = 4
block_size = 256
vocab_size = 256
device = "cpu"
# --- 2. ARQUITECTURA ---
class MiniGPT(nn.Module):
def __init__(self, v_size):
super().__init__()
self.token_embedding = nn.Embedding(v_size, embed_size)
self.pos_embedding = nn.Embedding(block_size, embed_size)
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
dim_feedforward=embed_size*4, batch_first=True,
dropout=0.1, norm_first=True)
for _ in range(num_layers)
])
self.ln = nn.LayerNorm(embed_size)
self.fc_out = nn.Linear(embed_size, v_size)
def forward(self, idx):
B, T = idx.shape
T = min(T, block_size)
idx = idx[:, -T:]
x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device=device))[None, :, :]
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
for block in self.blocks: x = block(x, src_mask=mask)
return self.fc_out(self.ln(x)), None
# --- 3. CARGA DE PESOS ---
model = MiniGPT(vocab_size).to(device)
if os.path.exists("mini_gpt.pth"):
model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
model.eval()
# --- 4. GENERACIÓN CONTROLADA (HF + PC Fusion) ---
def responder(mensaje, historial):
# Formato de prompt para guiar la estructura
contexto = f"### Human: {mensaje}\n### Assistant: "
tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
ai_txt = ""
# Parámetros de "limpieza" en vivo
temp = 0.7
top_k = 40
with torch.no_grad():
for _ in range(150):
idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
logits, _ = model(idx)
logits = logits[:, -1, :] / temp
# Penalización de repetición (Anti-bucle de símbolos)
if len(tokens) > 0:
for t in set(tokens[-5:]): # Miramos los últimos 5 tokens
logits[0, t] -= 2.0
# Filtro Top-K (Elimina la basura de baja probabilidad)
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
char = chr(next_token)
# --- SEGURIDAD: Cortar si empieza a alucinar símbolos ---
if char in "'{}[]()=|_/\\":
# Si el modelo intenta poner símbolos raros, lo ignoramos o cortamos
continue
if "### Human:" in ai_txt: break
if char == "\n" and len(ai_txt) > 30: break
tokens.append(next_token)
ai_txt += char
# Limpieza final de la respuesta
res_limpia = ai_txt.strip()
# Si la respuesta es demasiado corta o solo espacios, avisamos
if len(res_limpia) < 3:
return "El modelo está en una fase de entrenamiento inestable. Prueba con otra pregunta o espera a que baje el Loss."
return res_limpia
# --- 5. INTERFAZ ---
demo = gr.ChatInterface(fn=responder, title="IA Personal - Fusion Mode")
if __name__ == "__main__":
demo.launch()