|
|
import re |
|
|
from typing import Dict, List, Any |
|
|
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline |
|
|
import torch |
|
|
|
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.date_model_path = path + "/deberta-qa-finetuned" |
|
|
self.date_tokenizer = AutoTokenizer.from_pretrained(self.date_model_path) |
|
|
self.date_model = AutoModelForQuestionAnswering.from_pretrained(self.date_model_path) |
|
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.date_model.to(self.device) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
data args: |
|
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
|
kwargs |
|
|
Return: |
|
|
A :obj:`list` | `dict`: will be serialized and returned |
|
|
""" |
|
|
start_date = self.remove_special_characters(self.extract_start_date(data["inputs"])) |
|
|
end_date = self.remove_special_characters(self.extract_end_date(data["inputs"])) |
|
|
return {"start_date": start_date, "end_date": end_date} |
|
|
|
|
|
def remove_special_characters(self, s): |
|
|
return re.sub(r'(?<!\d)[^\w\s/]+|[^\w\s/]+(?!\d)', '', s).strip() |
|
|
|
|
|
|
|
|
def extract_start_date(self, text): |
|
|
question = "What is the start date?" |
|
|
|
|
|
|
|
|
inputs = self.date_tokenizer(question, text, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.date_model(**inputs) |
|
|
|
|
|
|
|
|
answer_start = torch.argmax(outputs.start_logits) |
|
|
answer_end = torch.argmax(outputs.end_logits) + 1 |
|
|
|
|
|
|
|
|
answer_tokens = inputs["input_ids"][0][answer_start:answer_end] |
|
|
answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True) |
|
|
|
|
|
return answer |
|
|
|
|
|
def extract_end_date(self, text): |
|
|
question = "What is the end date?" |
|
|
|
|
|
|
|
|
inputs = self.date_tokenizer(question, text, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.date_model(**inputs) |
|
|
|
|
|
|
|
|
answer_start = torch.argmax(outputs.start_logits) |
|
|
answer_end = torch.argmax(outputs.end_logits) + 1 |
|
|
|
|
|
|
|
|
answer_tokens = inputs["input_ids"][0][answer_start:answer_end] |
|
|
answer = self.date_tokenizer.decode(answer_tokens, skip_special_tokens=True) |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from handler import EndpointHandler |
|
|
|
|
|
|
|
|
my_handler = EndpointHandler(path=".") |
|
|
|
|
|
|
|
|
non_holiday_payload = {"inputs": "I am quite excited how this will turn out 08-08-2025 - 09-08-2025"} |
|
|
|
|
|
|
|
|
|
|
|
non_holiday_pred=my_handler(non_holiday_payload) |
|
|
|
|
|
|
|
|
|
|
|
print(non_holiday_pred) |
|
|
|