Leches33 commited on
Commit
4eb713f
verified
1 Parent(s): 9a527aa

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +59 -0
handler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+
6
+ # REPETIMOS TU ARQUITECTURA AQU脥 (Exactamente igual que en tu script)
7
+ embed_size = 256
8
+ num_heads = 8
9
+ num_layers = 4
10
+ block_size = 256
11
+
12
+ class MiniGPT(nn.Module):
13
+ def __init__(self, v_size=256):
14
+ super().__init__()
15
+ self.token_embedding = nn.Embedding(v_size, embed_size)
16
+ self.pos_embedding = nn.Embedding(block_size, embed_size)
17
+ self.blocks = nn.ModuleList([
18
+ nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
19
+ dim_feedforward=embed_size*4, batch_first=True,
20
+ dropout=0.1, norm_first=True) for _ in range(num_layers)
21
+ ])
22
+ self.ln = nn.LayerNorm(embed_size)
23
+ self.fc_out = nn.Linear(embed_size, v_size)
24
+
25
+ def forward(self, idx):
26
+ B, T = idx.shape
27
+ x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device="cpu"))[None, :, :]
28
+ mask = torch.triu(torch.ones(T, T, device="cpu"), diagonal=1).bool()
29
+ for block in self.blocks: x = block(x, src_mask=mask)
30
+ logits = self.fc_out(self.ln(x))
31
+ return logits
32
+
33
+ class EndpointHandler:
34
+ def __init__(self, path=""):
35
+ # Cargar el modelo
36
+ self.model = MiniGPT()
37
+ checkpoint = os.path.join(path, "mini_gpt.pth")
38
+ self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
39
+ self.model.eval()
40
+
41
+ def __call__(self, data):
42
+ # Procesar la entrada
43
+ inputs = data.get("inputs", "")
44
+ tokens = [ord(c) if ord(c) < 256 else 32 for c in inputs]
45
+
46
+ # Generar (versi贸n simplificada de tu funci贸n generate)
47
+ res = ""
48
+ for _ in range(50): # Generamos 50 caracteres
49
+ idx = torch.tensor([tokens[-block_size:]])
50
+ with torch.no_grad():
51
+ logits = self.model(idx)
52
+ logits = logits[:, -1, :] / 0.7 # temp fija 0.7
53
+ probs = F.softmax(logits, dim=-1)
54
+ nxt = torch.multinomial(probs, 1).item()
55
+ if nxt == ord('\n'): break
56
+ tokens.append(nxt)
57
+ res += chr(nxt)
58
+
59
+ return [{"generated_text": res}]