|
|
from typing import Dict, List, Any |
|
|
from transformers import pipeline, AutoConfig, AutoModelForTokenClassification, AutoTokenizer, BertTokenizerFast |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
dir_model = "urchade/gliner_multi-v2.1" |
|
|
|
|
|
config_path = os.path.join(path, "gliner_config.json") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"Custom configuration file not found at {config_path}") |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(config_path) |
|
|
|
|
|
|
|
|
self.model = AutoModelForTokenClassification.from_pretrained(dir_model, config=config) |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
|
|
|
self.pipeline = pipeline("token-classification", model=path, tokenizer=self.tokenizer) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Args: |
|
|
data (Dict[str, Any]): The input data including: |
|
|
- "inputs": The text input from which to extract information. |
|
|
|
|
|
Returns: |
|
|
List[Dict[str, Any]]: The extracted information from the text. |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
|
|
|
|
|
|
extraction_results = self.pipeline(inputs) |
|
|
|
|
|
|
|
|
return extraction_results |
|
|
|