izzelbas commited on
Commit
5777bbb
·
verified ·
1 Parent(s): 663e929

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +9 -14
handler.py CHANGED
@@ -3,7 +3,6 @@ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
- # Load tokenizer and model
7
  self.tokenizer = AutoTokenizer.from_pretrained(path)
8
  self.model = AutoModelForQuestionAnswering.from_pretrained(path)
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -11,10 +10,8 @@ class EndpointHandler:
11
  self.model.eval()
12
 
13
  def get_top1_answer(self, question, context, max_answer_len=30):
14
- # Tokenize input
15
  inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512).to(self.device)
16
 
17
- # Inference
18
  with torch.no_grad():
19
  outputs = self.model(**inputs)
20
 
@@ -34,19 +31,17 @@ class EndpointHandler:
34
 
35
  return best_span, best_score
36
 
37
- def preprocess(self, inputs):
38
- # Expecting {"inputs": {"question": "...", "context": "..."}}
39
- payload = inputs.get("inputs", {})
40
- question = payload.get("question", "")
41
- context = payload.get("context", "")
42
- return question, context
 
 
43
 
44
- def predict(self, inputs):
45
- question, context = self.preprocess(inputs)
46
  answer, score = self.get_top1_answer(question, context)
47
  return {"answer": answer, "score": score}
48
 
49
- def postprocess(self, outputs):
50
- return outputs
51
-
52
  handler = EndpointHandler()
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
 
6
  self.tokenizer = AutoTokenizer.from_pretrained(path)
7
  self.model = AutoModelForQuestionAnswering.from_pretrained(path)
8
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
10
  self.model.eval()
11
 
12
  def get_top1_answer(self, question, context, max_answer_len=30):
 
13
  inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512).to(self.device)
14
 
 
15
  with torch.no_grad():
16
  outputs = self.model(**inputs)
17
 
 
31
 
32
  return best_span, best_score
33
 
34
+ def __call__(self, data):
35
+ # Hugging Face sends data with "inputs" key
36
+ inputs = data.get("inputs", {})
37
+ question = inputs.get("question")
38
+ context = inputs.get("context")
39
+
40
+ if not question or not context:
41
+ return {"error": "Both 'question' and 'context' must be provided."}
42
 
 
 
43
  answer, score = self.get_top1_answer(question, context)
44
  return {"answer": answer, "score": score}
45
 
46
+ # Must be callable
 
 
47
  handler = EndpointHandler()