File size: 3,733 Bytes
55ee526
 
 
 
 
 
c304f30
 
55ee526
d445551
c304f30
 
 
 
 
 
 
44872d5
d445551
44872d5
 
55ee526
 
 
06c92ca
55ee526
 
 
 
 
 
 
c304f30
 
55ee526
 
 
 
 
4aa47aa
55ee526
c304f30
 
d445551
55ee526
d445551
4aa47aa
55ee526
d445551
55ee526
 
d445551
55ee526
 
d445551
55ee526
d445551
c304f30
55ee526
 
d445551
 
 
 
c304f30
55ee526
d445551
55ee526
 
d445551
c304f30
d445551
06c92ca
d445551
 
06c92ca
d445551
 
 
 
06c92ca
 
55ee526
06c92ca
d445551
 
 
 
 
06c92ca
d445551
c304f30
55ee526
 
 
d445551
 
 
 
 
 
 
 
c304f30
06c92ca
d445551
55ee526
 
c304f30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import os
import random
import numpy as np

# --- 1. CONFIGURACIÓN IDÉNTICA AL PC ---
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seed(42)

embed_size = 256
num_heads = 4 # Ajusta a 8 si en tu PC pusiste 8
num_layers = 4
block_size = 256
vocab_size = 256
device = "cpu"

# --- 2. ARQUITECTURA ---
class MiniGPT(nn.Module):
    def __init__(self, v_size):
        super().__init__()
        self.token_embedding = nn.Embedding(v_size, embed_size)
        self.pos_embedding = nn.Embedding(block_size, embed_size)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, 
                                        dim_feedforward=embed_size*4, batch_first=True, 
                                        dropout=0.1, norm_first=True) 
            for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(embed_size)
        self.fc_out = nn.Linear(embed_size, v_size)

    def forward(self, idx):
        B, T = idx.shape
        T = min(T, block_size)
        idx = idx[:, -T:]
        x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device=device))[None, :, :]
        mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
        for block in self.blocks: x = block(x, src_mask=mask)
        return self.fc_out(self.ln(x)), None

# --- 3. CARGA DE PESOS ---
model = MiniGPT(vocab_size).to(device)
if os.path.exists("mini_gpt.pth"):
    model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
model.eval()

# --- 4. GENERACIÓN CONTROLADA (HF + PC Fusion) ---
def responder(mensaje, historial):
    # Formato de prompt para guiar la estructura
    contexto = f"### Human: {mensaje}\n### Assistant: "
    tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
    ai_txt = ""
    
    # Parámetros de "limpieza" en vivo
    temp = 0.7  
    top_k = 40  

    with torch.no_grad():
        for _ in range(150):
            idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
            logits, _ = model(idx)
            logits = logits[:, -1, :] / temp
            
            # Penalización de repetición (Anti-bucle de símbolos)
            if len(tokens) > 0:
                for t in set(tokens[-5:]): # Miramos los últimos 5 tokens
                    logits[0, t] -= 2.0

            # Filtro Top-K (Elimina la basura de baja probabilidad)
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            char = chr(next_token)

            # --- SEGURIDAD: Cortar si empieza a alucinar símbolos ---
            if char in "'{}[]()=|_/\\": 
                # Si el modelo intenta poner símbolos raros, lo ignoramos o cortamos
                continue 
            
            if "### Human:" in ai_txt: break
            if char == "\n" and len(ai_txt) > 30: break
            
            tokens.append(next_token)
            ai_txt += char

    # Limpieza final de la respuesta
    res_limpia = ai_txt.strip()
    
    # Si la respuesta es demasiado corta o solo espacios, avisamos
    if len(res_limpia) < 3:
        return "El modelo está en una fase de entrenamiento inestable. Prueba con otra pregunta o espera a que baje el Loss."

    return res_limpia

# --- 5. INTERFAZ ---
demo = gr.ChatInterface(fn=responder, title="IA Personal - Fusion Mode")

if __name__ == "__main__":
    demo.launch()