Leches33 commited on
Commit
b45ba12
verified
1 Parent(s): 0207db0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -56
handler.py CHANGED
@@ -1,59 +1,7 @@
1
- from typing import Dict, List, Any
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import os
6
-
7
- # Configuraci贸n id茅ntica a tu script
8
- embed_size = 256
9
- num_heads = 8
10
- num_layers = 4
11
- block_size = 256
12
-
13
- class MiniGPT(nn.Module):
14
- def __init__(self, v_size=256):
15
- super().__init__()
16
- self.token_embedding = nn.Embedding(v_size, embed_size)
17
- self.pos_embedding = nn.Embedding(block_size, embed_size)
18
- self.blocks = nn.ModuleList([
19
- nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads,
20
- dim_feedforward=embed_size*4, batch_first=True,
21
- dropout=0.1, norm_first=True) for _ in range(num_layers)
22
- ])
23
- self.ln = nn.LayerNorm(embed_size)
24
- self.fc_out = nn.Linear(embed_size, v_size)
25
-
26
- def forward(self, idx):
27
- T = idx.shape[1]
28
- x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device="cpu"))[None, :, :]
29
- mask = torch.triu(torch.ones(T, T, device="cpu"), diagonal=1).bool()
30
- for block in self.blocks: x = block(x, src_mask=mask)
31
- return self.fc_out(self.ln(x))
32
-
33
  class EndpointHandler:
34
  def __init__(self, path=""):
35
- self.model = MiniGPT()
36
- # Buscamos el archivo de pesos
37
- checkpoint_path = os.path.join(path, "pytorch_model.bin")
38
- self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
39
- self.model.eval()
40
 
41
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
42
- inputs = data.get("inputs", "")
43
- if not inputs: return [{"generated_text": ""}]
44
-
45
- tokens = [ord(c) if ord(c) < 256 else 32 for c in inputs]
46
- res = ""
47
-
48
- for _ in range(30): # Generamos 30 caracteres para probar
49
- idx = torch.tensor([tokens[-block_size:]])
50
- with torch.no_grad():
51
- logits = self.model(idx)
52
- logits = logits[:, -1, :] / 0.7
53
- probs = F.softmax(logits, dim=-1)
54
- nxt = torch.multinomial(probs, 1).item()
55
- if nxt == ord('\n'): break
56
- tokens.append(nxt)
57
- res += chr(nxt)
58
-
59
- return [{"generated_text": res}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class EndpointHandler:
2
  def __init__(self, path=""):
3
+ # No cargamos nada para probar
4
+ pass
 
 
 
5
 
6
+ def __call__(self, data):
7
+ return [{"generated_text": "API FUNCIONANDO"}]