| from typing import Dict, Any |
| from transformers import pipeline |
| import holidays |
| import PIL.Image |
| import io |
|
|
| class PreTrainedPipeline(): |
| def __init__(self, model_path="PrimWong/layout_qa_hparam_tuning"): |
| |
| self.pipeline = pipeline("document-question-answering", model=model_path) |
| self.holidays = holidays.US() |
|
|
| def __call__(self, data: Dict[str, Any]) -> str: |
| """ |
| Process input data for document question answering with optional holiday checking. |
| |
| Args: |
| data (Dict[str, Any]): Input data containing a 'text' field possibly along with 'image', |
| and optionally a 'date' field. |
| |
| Returns: |
| str: The answer or processed information based on the text, or a holiday message if applicable. |
| """ |
| text = data.get("inputs") |
| date = data.get("date") |
|
|
| |
| if date and date in self.holidays: |
| return "Today is a holiday!" |
|
|
| |
| prediction = self.pipeline(question=text, image="What information do you need?") |
| return prediction["answer"] |
|
|