MarkChenX commited on
Commit
573941c
·
verified ·
1 Parent(s): 31b59c5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -8
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from model import GPT, GPTConfig
3
 
4
  class EndpointHandler:
@@ -13,7 +14,7 @@ class EndpointHandler:
13
  checkpoint_path = f"{path}/ckpt.pt"
14
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
15
 
16
- # Check if it's a full training checkpoint or a clean state_dict
17
  if isinstance(checkpoint, dict) and "model" in checkpoint:
18
  state_dict = checkpoint["model"]
19
  else:
@@ -26,33 +27,71 @@ class EndpointHandler:
26
  new_key = key[len(prefix):] if key.startswith(prefix) else key
27
  cleaned_state_dict[new_key] = val
28
 
29
- # Load state dict with non-strict to see any mismatches
30
  missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False)
31
  if missing:
32
  print("Warning: missing keys in state_dict:", missing)
33
  if unexpected:
34
  print("Warning: unexpected keys in state_dict:", unexpected)
35
 
 
36
  self.model.eval()
 
 
 
37
  print("Model loaded and ready.")
38
 
39
  def __call__(self, data):
40
  """
41
- data: {"inputs": {"input_ids": [[int, int, ...]]}}
42
- Returns: {"generated_ids": [[...]]}
 
 
 
 
43
  """
44
  try:
45
- input_ids = data.get("inputs", {}).get("input_ids")
46
- if not input_ids:
47
- return {"error": "Missing 'input_ids' in inputs"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
49
  input_tensor = torch.tensor(input_ids).long()
50
 
 
51
  with torch.no_grad():
52
  output_tensor = self.model.generate(input_tensor, max_new_tokens=32)
53
  output_ids = output_tensor.tolist()
54
 
55
- return {"generated_ids": output_ids}
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  except Exception as e:
58
  return {"error": str(e)}
 
1
  import torch
2
+ import tiktoken
3
  from model import GPT, GPTConfig
4
 
5
  class EndpointHandler:
 
14
  checkpoint_path = f"{path}/ckpt.pt"
15
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
16
 
17
+ # Extract state_dict if wrapped
18
  if isinstance(checkpoint, dict) and "model" in checkpoint:
19
  state_dict = checkpoint["model"]
20
  else:
 
27
  new_key = key[len(prefix):] if key.startswith(prefix) else key
28
  cleaned_state_dict[new_key] = val
29
 
30
+ # Load state dict non-strict to inspect mismatches
31
  missing, unexpected = self.model.load_state_dict(cleaned_state_dict, strict=False)
32
  if missing:
33
  print("Warning: missing keys in state_dict:", missing)
34
  if unexpected:
35
  print("Warning: unexpected keys in state_dict:", unexpected)
36
 
37
+ # Ready model
38
  self.model.eval()
39
+ # Initialize tokenizer for text inputs
40
+ self.tokenizer = tiktoken.get_encoding("gpt2")
41
+
42
  print("Model loaded and ready.")
43
 
44
  def __call__(self, data):
45
  """
46
+ Accept either:
47
+ - A raw prompt string (data is str)
48
+ - A dict: {"inputs": "prompt text"}
49
+ - A dict: {"inputs": {"input_ids": [[...]]}}
50
+ Returns:
51
+ {"generated_ids": [[...]], optional "generated_text": str}
52
  """
53
  try:
54
+ # Determine input format
55
+ if isinstance(data, str):
56
+ text = data
57
+ elif isinstance(data, dict):
58
+ inputs = data.get("inputs")
59
+ if isinstance(inputs, str):
60
+ text = inputs
61
+ elif isinstance(inputs, dict) and "input_ids" in inputs:
62
+ input_ids = inputs["input_ids"]
63
+ else:
64
+ return {"error": "Invalid 'inputs'; expected string or dict with 'input_ids'"}
65
+ else:
66
+ return {"error": "Invalid request format"}
67
+
68
+ # If text prompt given, tokenize
69
+ if 'text' in locals():
70
+ # encode text into token IDs
71
+ tokens = self.tokenizer.encode(text)
72
+ input_ids = [tokens]
73
 
74
+ # Convert to tensor
75
  input_tensor = torch.tensor(input_ids).long()
76
 
77
+ # Generate
78
  with torch.no_grad():
79
  output_tensor = self.model.generate(input_tensor, max_new_tokens=32)
80
  output_ids = output_tensor.tolist()
81
 
82
+ # Build response
83
+ result = {"generated_ids": output_ids}
84
+ if 'text' in locals():
85
+ # Decode the first sequence
86
+ generated_tokens = output_ids[0]
87
+ try:
88
+ generated_text = self.tokenizer.decode(generated_tokens)
89
+ except Exception:
90
+ generated_text = None
91
+ if generated_text is not None:
92
+ result["generated_text"] = generated_text
93
+
94
+ return result
95
 
96
  except Exception as e:
97
  return {"error": str(e)}