Leches33 commited on
Commit
55ee526
·
verified ·
1 Parent(s): b7c83b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import gradio as gr
6
+ import os
7
+
8
+ # --- MISMOS HIPERPARÁMETROS ---
9
+ embed_size = 128
10
+ num_heads = 4
11
+ num_layers = 3
12
+ block_size = 64
13
+ vocab_size = 256
14
+ device = "cpu"
15
+
16
+ # --- TU ARQUITECTURA ---
17
+ class MiniGPT(nn.Module):
18
+ def __init__(self, v_size):
19
+ super().__init__()
20
+ self.token_embedding = nn.Embedding(v_size, embed_size)
21
+ self.pos_embedding = nn.Embedding(block_size, embed_size)
22
+ self.blocks = nn.ModuleList([
23
+ nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
24
+ dim_feedforward=embed_size*4, batch_first=True,
25
+ dropout=0.1, norm_first=True)
26
+ for _ in range(num_layers)
27
+ ])
28
+ self.ln = nn.LayerNorm(embed_size)
29
+ self.fc_out = nn.Linear(embed_size, v_size)
30
+
31
+ def forward(self, idx, targets=None):
32
+ B, T = idx.shape
33
+ tok_emb = self.token_embedding(idx)
34
+ pos = torch.arange(T, device=device)
35
+ pos_emb = self.pos_embedding(pos)[None, :, :]
36
+ x = tok_emb + pos_emb
37
+ mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
38
+ for block in self.blocks: x = block(x, src_mask=mask)
39
+ x = self.ln(x)
40
+ logits = self.fc_out(x)
41
+ return logits, None
42
+
43
+ # --- CARGAR EL MODELO ---
44
+ model = MiniGPT(vocab_size).to(device)
45
+ if os.path.exists("mini_gpt.pth"):
46
+ model.load_state_dict(torch.load("mini_gpt.pth", map_location=device))
47
+ model.eval()
48
+
49
+ # --- FUNCIÓN DE RESPUESTA ---
50
+ def responder(mensaje, historial):
51
+ contexto = f"\nUsuario: {mensaje}\nIA: "
52
+ tokens = [ord(c) if ord(c) < 256 else 32 for c in contexto]
53
+ ai_txt = ""
54
+
55
+ with torch.no_grad():
56
+ for _ in range(150):
57
+ idx = torch.tensor([tokens[-block_size:]], dtype=torch.long).to(device)
58
+ logits, _ = model(idx)
59
+ probs = F.softmax(logits[:, -1, :] / 0.8, dim=-1)
60
+ next_token = torch.multinomial(probs, num_samples=1).item()
61
+ char = chr(next_token)
62
+
63
+ if char == "\n" or ai_txt.endswith("Usuario:"): break
64
+ tokens.append(next_token)
65
+ ai_txt += char
66
+
67
+ return ai_txt.replace("Usuario:", "").strip()
68
+
69
+ # --- INTERFAZ ---
70
+ demo = gr.ChatInterface(fn=responder, title="Mi IA Personal", description="Modelo MiniGPT entrenado.")
71
+
72
+ if __name__ == "__main__":
73
+ demo.l
74
+ aunch()