Leches33 commited on
Commit
253e926
verified
1 Parent(s): 9590e6d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -13
handler.py CHANGED
@@ -1,9 +1,10 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import os
5
 
6
- # REPETIMOS TU ARQUITECTURA AQU脥 (Exactamente igual que en tu script)
7
  embed_size = 256
8
  num_heads = 8
9
  num_layers = 4
@@ -23,33 +24,32 @@ class MiniGPT(nn.Module):
23
  self.fc_out = nn.Linear(embed_size, v_size)
24
 
25
  def forward(self, idx):
26
- B, T = idx.shape
27
  x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device="cpu"))[None, :, :]
28
  mask = torch.triu(torch.ones(T, T, device="cpu"), diagonal=1).bool()
29
  for block in self.blocks: x = block(x, src_mask=mask)
30
- logits = self.fc_out(self.ln(x))
31
- return logits
32
 
33
  class EndpointHandler:
34
  def __init__(self, path=""):
35
- # Cargar el modelo
36
  self.model = MiniGPT()
37
- checkpoint = os.path.join(path, "pytorch_model.bin")
38
- self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
 
39
  self.model.eval()
40
 
41
- def __call__(self, data):
42
- # Procesar la entrada
43
  inputs = data.get("inputs", "")
44
- tokens = [ord(c) if ord(c) < 256 else 32 for c in inputs]
45
 
46
- # Generar (versi贸n simplificada de tu funci贸n generate)
47
  res = ""
48
- for _ in range(50): # Generamos 50 caracteres
 
49
  idx = torch.tensor([tokens[-block_size:]])
50
  with torch.no_grad():
51
  logits = self.model(idx)
52
- logits = logits[:, -1, :] / 0.7 # temp fija 0.7
53
  probs = F.softmax(logits, dim=-1)
54
  nxt = torch.multinomial(probs, 1).item()
55
  if nxt == ord('\n'): break
 
1
+ from typing import Dict, List, Any
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import os
6
 
7
+ # Configuraci贸n id茅ntica a tu script
8
  embed_size = 256
9
  num_heads = 8
10
  num_layers = 4
 
24
  self.fc_out = nn.Linear(embed_size, v_size)
25
 
26
  def forward(self, idx):
27
+ T = idx.shape[1]
28
  x = self.token_embedding(idx) + self.pos_embedding(torch.arange(T, device="cpu"))[None, :, :]
29
  mask = torch.triu(torch.ones(T, T, device="cpu"), diagonal=1).bool()
30
  for block in self.blocks: x = block(x, src_mask=mask)
31
+ return self.fc_out(self.ln(x))
 
32
 
33
  class EndpointHandler:
34
  def __init__(self, path=""):
 
35
  self.model = MiniGPT()
36
+ # Buscamos el archivo de pesos
37
+ checkpoint_path = os.path.join(path, "pytorch_model.bin")
38
+ self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
39
  self.model.eval()
40
 
41
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
42
  inputs = data.get("inputs", "")
43
+ if not inputs: return [{"generated_text": ""}]
44
 
45
+ tokens = [ord(c) if ord(c) < 256 else 32 for c in inputs]
46
  res = ""
47
+
48
+ for _ in range(30): # Generamos 30 caracteres para probar
49
  idx = torch.tensor([tokens[-block_size:]])
50
  with torch.no_grad():
51
  logits = self.model(idx)
52
+ logits = logits[:, -1, :] / 0.7
53
  probs = F.softmax(logits, dim=-1)
54
  nxt = torch.multinomial(probs, 1).item()
55
  if nxt == ord('\n'): break