BK_QNADateExtractor / handler.py
earlsab
remove special characters
c1b7ccc
import re
from typing import Dict, List, Any
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
import torch
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
# Load the fine-tuned model
self.date_model_path = path + "/deberta-qa-finetuned"
self.date_tokenizer = AutoTokenizer.from_pretrained(self.date_model_path)
self.date_model = AutoModelForQuestionAnswering.from_pretrained(self.date_model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.date_model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
start_date = self.remove_special_characters(self.extract_start_date(data["inputs"]))
end_date = self.remove_special_characters(self.extract_end_date(data["inputs"]))
return {"start_date": start_date, "end_date": end_date}
def remove_special_characters(self, s):
return re.sub(r'(?<!\d)[^\w\s/]+|[^\w\s/]+(?!\d)', '', s).strip()
def extract_start_date(self, text):
question = "What is the start date?"
# Tokenize the input
inputs = self.date_tokenizer(question, text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Get model outputs (start and end logits)
with torch.no_grad():
outputs = self.date_model(**inputs)
# Identify the most likely start and end token positions
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
# Convert token IDs to the answer string
answer_tokens = inputs["input_ids"][0][answer_start:answer_end]
answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer
def extract_end_date(self, text):
question = "What is the end date?"
# Tokenize the input
inputs = self.date_tokenizer(question, text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
# Get model outputs (end and end logits)
with torch.no_grad():
outputs = self.date_model(**inputs)
# Identify the most likely start and end token positions
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
# Convert token IDs to the answer string
answer_tokens = inputs["input_ids"][0][answer_start:answer_end]
answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer
if __name__ == "__main__":
from handler import EndpointHandler
# init handler
my_handler = EndpointHandler(path=".")
# prepare sample payload
non_holiday_payload = {"inputs": "I am quite excited how this will turn out 08-08-2025 - 09-08-2025"}
# holiday_payload = {"inputs": "Today is a though day"}
# test the handler
non_holiday_pred=my_handler(non_holiday_payload)
# holiday_payload=my_handler(holiday_payload)
# show results
print(non_holiday_pred)
# print("holiday_payload", holiday_payload)