Khriis commited on
Commit
0d62938
·
1 Parent(s): d37eeb3

Fixed handler names

Browse files
Files changed (1) hide show
  1. handler.py +30 -7
handler.py CHANGED
@@ -1,15 +1,23 @@
1
  import torch
2
  import importlib.util
3
  import sys
4
- import pathlib
 
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
- class InferenceHandler:
8
- def __init__(self):
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
  # Import custom model definition from local file
12
- model_path = "cross_scorer_model.py"
 
 
 
 
 
 
 
13
  spec = importlib.util.spec_from_file_location("cross_scorer_model", model_path)
14
  mod = importlib.util.module_from_spec(spec)
15
  sys.modules["cross_scorer_model"] = mod
@@ -20,7 +28,12 @@ class InferenceHandler:
20
  self.model = mod.CrossScorerCrossEncoder(encoder).to(self.device)
21
 
22
  # Load weights
23
- weights_path = "reflection_scorer_weight.pt"
 
 
 
 
 
24
  state = torch.load(weights_path, map_location=self.device)
25
  sd = state.get("model_state_dict", state)
26
  self.model.load_state_dict(sd, strict=False)
@@ -30,14 +43,24 @@ class InferenceHandler:
30
  # Initialize tokenizer
31
  self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
32
 
33
- def handle(self, inputs: list) -> list:
 
 
 
 
 
 
 
 
 
 
 
34
  results = []
35
  for item in inputs:
36
  prompt = item.get("prompt")
37
  response = item.get("response")
38
 
39
  if not prompt or not response:
40
- # Handle missing keys gracefully, though instructions imply strict format
41
  results.append({"error": "Missing prompt or response"})
42
  continue
43
 
 
1
  import torch
2
  import importlib.util
3
  import sys
4
+ import os
5
+ from typing import Dict, List, Any
6
  from transformers import AutoModel, AutoTokenizer
7
 
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Import custom model definition from local file
13
+ # The file is expected to be in the same directory as handler.py, which 'path' points to
14
+ model_filename = "cross_scorer_model.py"
15
+ model_path = os.path.join(path, model_filename)
16
+
17
+ # Fallback if path is empty or "." and file is in CWD
18
+ if not os.path.exists(model_path):
19
+ model_path = model_filename
20
+
21
  spec = importlib.util.spec_from_file_location("cross_scorer_model", model_path)
22
  mod = importlib.util.module_from_spec(spec)
23
  sys.modules["cross_scorer_model"] = mod
 
28
  self.model = mod.CrossScorerCrossEncoder(encoder).to(self.device)
29
 
30
  # Load weights
31
+ weights_filename = "reflection_scorer_weight.pt"
32
+ weights_path = os.path.join(path, weights_filename)
33
+
34
+ if not os.path.exists(weights_path):
35
+ weights_path = weights_filename
36
+
37
  state = torch.load(weights_path, map_location=self.device)
38
  sd = state.get("model_state_dict", state)
39
  self.model.load_state_dict(sd, strict=False)
 
43
  # Initialize tokenizer
44
  self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
45
 
46
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
47
+ """
48
+ data args:
49
+ inputs (:obj: `list` | `dict`): The inputs to the model.
50
+ """
51
+ # get inputs
52
+ inputs = data.pop("inputs", data)
53
+
54
+ # If inputs is a dict (single item), wrap in list to reuse logic, or handle list
55
+ if isinstance(inputs, dict):
56
+ inputs = [inputs]
57
+
58
  results = []
59
  for item in inputs:
60
  prompt = item.get("prompt")
61
  response = item.get("response")
62
 
63
  if not prompt or not response:
 
64
  results.append({"error": "Missing prompt or response"})
65
  continue
66