| | import os |
| | import logging |
| | import torch |
| |
|
| | from typing import Any, Dict, List, Union |
| | from flair.data import Sentence |
| | from flair.models import SequenceTagger |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str): |
| | |
| | logger.info(f"Initializing Flair endpoint handler from {path}") |
| |
|
| | |
| | model_path = os.path.join(path, "pytorch_model.bin") |
| |
|
| | |
| | use_cuda = torch.cuda.is_available() |
| | device = torch.device("cuda" if use_cuda else "cpu") |
| | logger.info(f"Using device: {device}") |
| |
|
| | |
| | self.tagger = SequenceTagger.load(model_path) |
| | self.tagger.to(device) |
| |
|
| | |
| | self.tagger.eval() |
| |
|
| | |
| | self.cache = {} |
| | self.cache_size_limit = 1000 |
| |
|
| | logger.info("Model successfully loaded and ready for inference") |
| |
|
| | def preprocess(self, text: str) -> Sentence: |
| | |
| | return Sentence(text) |
| |
|
| | def predict_batch(self, sentences: List[Sentence]) -> None: |
| | with torch.no_grad(): |
| | self.tagger.predict(sentences, label_name="predicted", mini_batch_size=32) |
| |
|
| | def postprocess(self, sentence: Sentence) -> List[Dict[str, Any]]: |
| | entities = [] |
| |
|
| | try: |
| | for span in sentence.get_spans("predicted"): |
| | if len(span.tokens) == 0: |
| | continue |
| |
|
| | current_entity = { |
| | "entity_group": span.tag, |
| | "word": span.text, |
| | "start": span.tokens[0].start_position, |
| | "end": span.tokens[-1].end_position, |
| | "score": float(span.score), |
| | } |
| | entities.append(current_entity) |
| | except Exception as e: |
| | logger.error(f"Error in postprocessing: {str(e)}") |
| |
|
| | return entities |
| |
|
| | def __call__( |
| | self, data: Union[Dict[str, Any], List[Dict[str, Any]]] |
| | ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: |
| | |
| | is_batch_input = isinstance(data, list) |
| |
|
| | if not is_batch_input: |
| | |
| | data = [data] |
| |
|
| | |
| | batch_inputs = [] |
| | for item in data: |
| | text = item.pop("inputs", item) if isinstance(item, dict) else item |
| |
|
| | |
| | if not isinstance(text, str): |
| | text = str(text) |
| |
|
| | |
| | if text in self.cache: |
| | batch_inputs.append((text, True)) |
| | else: |
| | batch_inputs.append((text, False)) |
| |
|
| | |
| | sentences_to_process = [] |
| | for text, is_cached in batch_inputs: |
| | if not is_cached: |
| | sentences_to_process.append(self.preprocess(text)) |
| |
|
| | |
| | if sentences_to_process: |
| | self.predict_batch(sentences_to_process) |
| |
|
| | |
| | results = [] |
| | sentence_idx = 0 |
| |
|
| | for text, is_cached in batch_inputs: |
| | if is_cached: |
| | |
| | result = self.cache[text] |
| | else: |
| | |
| | sentence = sentences_to_process[sentence_idx] |
| | result = self.postprocess(sentence) |
| |
|
| | |
| | if len(self.cache) < self.cache_size_limit: |
| | self.cache[text] = result |
| |
|
| | sentence_idx += 1 |
| |
|
| | results.append(result) |
| |
|
| | |
| | if not is_batch_input: |
| | return results[0] |
| |
|
| | return results |
| |
|