earlsab commited on
Commit
ebe85e4
1 Parent(s): e7225e5

add handler

Browse files
__pycache__/handler.cpython-312.pyc ADDED
Binary file (4.4 kB). View file
 
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
3
+ import torch
4
+ from handler import EndpointHandler
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # Preload all the elements you are going to need at inference.
9
+ # pseudo:
10
+ # self.model= load_model(path)
11
+
12
+
13
+ # Load the fine-tuned model
14
+ self.date_model_path = "deberta-qa-finetuned"
15
+ self.date_tokenizer = AutoTokenizer.from_pretrained(self.date_model_path)
16
+ self.date_model = AutoModelForQuestionAnswering.from_pretrained(self.date_model_path)
17
+
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.date_model.to(self.device)
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
+ """
23
+ data args:
24
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
25
+ kwargs
26
+ Return:
27
+ A :obj:`list` | `dict`: will be serialized and returned
28
+ """
29
+ start_date = self.extract_start_date(data["inputs"])
30
+ end_date = self.extract_end_date(data["inputs"])
31
+ return {"start_date": start_date, "end_date": end_date}
32
+
33
+ def extract_start_date(self, text):
34
+ question = "What is the start date?"
35
+
36
+ # Tokenize the input
37
+ inputs = self.date_tokenizer(question, text, return_tensors="pt")
38
+ if torch.cuda.is_available():
39
+ inputs = {k: v.cuda() for k, v in inputs.items()}
40
+
41
+ # Get model outputs (start and end logits)
42
+ with torch.no_grad():
43
+ outputs = self.date_model(**inputs)
44
+
45
+ # Identify the most likely start and end token positions
46
+ answer_start = torch.argmax(outputs.start_logits)
47
+ answer_end = torch.argmax(outputs.end_logits) + 1
48
+
49
+ # Convert token IDs to the answer string
50
+ answer_tokens = inputs["input_ids"][0][answer_start:answer_end]
51
+ answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True)
52
+
53
+ return answer
54
+
55
+ def extract_end_date(self, text):
56
+ question = "What is the end date?"
57
+
58
+ # Tokenize the input
59
+ inputs = self.date_tokenizer(question, text, return_tensors="pt")
60
+ if torch.cuda.is_available():
61
+ inputs = {k: v.cuda() for k, v in inputs.items()}
62
+
63
+ # Get model outputs (end and end logits)
64
+ with torch.no_grad():
65
+ outputs = self.date_model(**inputs)
66
+
67
+ # Identify the most likely start and end token positions
68
+ answer_start = torch.argmax(outputs.start_logits)
69
+ answer_end = torch.argmax(outputs.end_logits) + 1
70
+
71
+ # Convert token IDs to the answer string
72
+ answer_tokens = inputs["input_ids"][0][answer_start:answer_end]
73
+ answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True)
74
+
75
+ return answer
76
+
77
+
78
+
79
+
80
+ # init handler
81
+ my_handler = EndpointHandler(path=".")
82
+
83
+ # prepare sample payload
84
+ non_holiday_payload = {"inputs": "I am quite excited how this will turn out 08-08-2025 - 09-08-2025"}
85
+ # holiday_payload = {"inputs": "Today is a though day"}
86
+
87
+ # test the handler
88
+ non_holiday_pred=my_handler(non_holiday_payload)
89
+ # holiday_payload=my_handler(holiday_payload)
90
+
91
+ # show results
92
+ # print("non_holiday_pred", non_holiday_pred)
93
+ # print("holiday_payload", holiday_payload)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.30.2
2
+ torch==2.0.1
3
+ typing==3.7.4.3