izzelbas commited on
Commit
608c050
·
verified ·
1 Parent(s): d8b54ee

Updates handler.py function names to match requirements

Browse files
Files changed (1) hide show
  1. handler.py +27 -48
handler.py CHANGED
@@ -1,59 +1,38 @@
1
- import torch
2
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
- from typing import Dict, Any
4
 
5
- class QuestionAnsweringHandler:
6
- def __init__(self):
7
- self.model = None
8
- self.tokenizer = None
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
10
 
11
- def initialize(self, ctx):
12
- model_dir = ctx.system_properties.get("model_dir")
13
- self.model = AutoModelForQuestionAnswering.from_pretrained(model_dir).to(self.device)
14
- self.model.eval()
15
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
16
-
17
- def preprocess(self, data: Any) -> Dict[str, str]:
18
- # Expect JSON with {"question": ..., "context": ...}
19
- question = data[0]["body"].get("question", "")
20
- context = data[0]["body"].get("context", "")
21
- return {"question": question, "context": context}
22
-
23
- def inference(self, inputs: Dict[str, str]) -> Dict[str, str]:
24
- question = inputs["question"]
25
- context = inputs["context"]
26
-
27
- encoded = self.tokenizer(
28
- question,
29
- context,
30
- return_tensors="pt",
31
- max_length=512,
32
- truncation=True
33
- ).to(self.device)
34
 
35
- with torch.no_grad():
36
- outputs = self.model(**encoded)
37
 
38
- start_logits = outputs.start_logits[0]
39
- end_logits = outputs.end_logits[0]
 
 
40
 
41
- max_answer_len = 30
42
- input_ids = encoded["input_ids"][0]
43
 
44
- # Score spans and find the best one
45
- best_score = float("-inf")
46
- best_span = ""
47
 
48
- for start in range(len(start_logits)):
49
- for end in range(start, min(start + max_answer_len, len(end_logits))):
50
- score = start_logits[start] + end_logits[end]
51
- if score > best_score:
52
- best_score = score
53
- span_ids = input_ids[start:end + 1]
54
- best_span = self.tokenizer.decode(span_ids, skip_special_tokens=True)
55
 
56
- return {"best_span": best_span.strip()}
57
 
58
- def postprocess(self, output: Dict[str, str]) -> [Dict[str, str]]:
59
- return [output]
 
 
1
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
2
+ import torch
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, model_path=""):
 
 
6
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
8
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_path).to(self.device)
9
 
10
+ def __call__(self, data):
11
+ """
12
+ data: dict containing 'inputs' with 'question' and 'context' keys
13
+ """
14
+ inputs = data.get("inputs", {})
15
+ question = inputs.get("question")
16
+ context = inputs.get("context")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ if not question or not context:
19
+ return {"error": "Missing question or context"}
20
 
21
+ encoded = self.tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512).to(self.device)
22
+
23
+ with torch.no_grad():
24
+ output = self.model(**encoded)
25
 
26
+ start_scores = output.start_logits[0]
27
+ end_scores = output.end_logits[0]
28
 
29
+ # Get best span
30
+ start_idx = torch.argmax(start_scores)
31
+ end_idx = torch.argmax(end_scores)
32
 
33
+ if end_idx < start_idx:
34
+ return {"answer": ""}
 
 
 
 
 
35
 
36
+ answer = self.tokenizer.decode(encoded["input_ids"][0][start_idx:end_idx + 1], skip_special_tokens=True)
37
 
38
+ return {"answer": answer}