# -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr import os import random import numpy as np # --- 1. CONFIGURACIÓN IDÉNTICA AL PC --- def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) set_seed(42) embed_size = 256 num_heads = 4 # Ajusta a 8 si en tu PC pusiste 8 num_layers = 4 block_size = 256 vocab_size = 256 device = "cpu" # --- 2. ARQUITECTURA --- class MiniGPT(nn.Module): def __init__(self, v_size): super().__init__() self.token_embedding = nn.Embedding(v_size, embed_size) self.pos_embedding = nn.Embedding(block_size, embed_size) self.blocks = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=embed_size*4, batch_first=True, dropout=0.1, norm_first=True) for _ in range(num_layers) ]) self.ln = nn.LayerNorm(embed_size) self.fc_out = nn.Linear(embed_size, v_size) def forward(self, idx): B, T = idx.shape T = min(T, block_size) idx = idx[:, -T:] x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device=device))[None, :, :] mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() for block in self.blocks: x = block(x, src_mask=mask) return self.fc_out(self.ln(x)), None # --- 3. CARGA DE PESOS --- model = MiniGPT(vocab_size).to(device) if os.path.exists("mini_gpt.pth"): model.load_state_dict(torch.load("mini_gpt.pth", map_location=device)) model.eval() # --- 4. GENERACIÓN CONTROLADA (HF + PC Fusion) --- def responder(mensaje, historial): # Formato de prompt para guiar la estructura contexto = f"### Human: {mensaje}\n### Assistant: " tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto] ai_txt = "" # Parámetros de "limpieza" en vivo temp = 0.7 top_k = 40 with torch.no_grad(): for _ in range(150): idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device) logits, _ = model(idx) logits = logits[:, -1, :] / temp # Penalización de repetición (Anti-bucle de símbolos) if len(tokens) > 0: for t in set(tokens[-5:]): # Miramos los últimos 5 tokens logits[0, t] -= 2.0 # Filtro Top-K (Elimina la basura de baja probabilidad) v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() char = chr(next_token) # --- SEGURIDAD: Cortar si empieza a alucinar símbolos --- if char in "'{}[]()=|_/\\": # Si el modelo intenta poner símbolos raros, lo ignoramos o cortamos continue if "### Human:" in ai_txt: break if char == "\n" and len(ai_txt) > 30: break tokens.append(next_token) ai_txt += char # Limpieza final de la respuesta res_limpia = ai_txt.strip() # Si la respuesta es demasiado corta o solo espacios, avisamos if len(res_limpia) < 3: return "El modelo está en una fase de entrenamiento inestable. Prueba con otra pregunta o espera a que baje el Loss." return res_limpia # --- 5. INTERFAZ --- demo = gr.ChatInterface(fn=responder, title="IA Personal - Fusion Mode") if __name__ == "__main__": demo.launch()