Token Classification
GLiNER
PyTorch
multilingual
bert
Gliner_haystack / handler.py
Rejebc's picture
Upload 7 files
4e97a81 verified
raw
history blame
1.73 kB
from typing import Dict, List, Any
from transformers import pipeline, AutoConfig, AutoModelForTokenClassification, AutoTokenizer, BertTokenizerFast
import os
class EndpointHandler():
def __init__(self, path=""):
dir_model = "urchade/gliner_multi-v2.1"
config_path = os.path.join(path, "gliner_config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Custom configuration file not found at {config_path}")
# Load the custom configuration
config = AutoConfig.from_pretrained(config_path)
# Load the model using the custom configuration
self.model = AutoModelForTokenClassification.from_pretrained(dir_model, config=config)
# Initialize the pipeline with the model and tokenizer
# Use a pre-trained tokenizer compatible with your model
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# Use a pipeline appropriate for your task. Here we use "token-classification" for NER (Named Entity Recognition).
self.pipeline = pipeline("token-classification", model=path, tokenizer=self.tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (Dict[str, Any]): The input data including:
- "inputs": The text input from which to extract information.
Returns:
List[Dict[str, Any]]: The extracted information from the text.
"""
# Get inputs
inputs = data.get("inputs", "")
# Run the pipeline for text extraction
extraction_results = self.pipeline(inputs)
# Process and return the results as needed
return extraction_results