import os import logging import torch from typing import Any, Dict, List, Union from flair.data import Sentence from flair.models import SequenceTagger # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str): # Log initialization logger.info(f"Initializing Flair endpoint handler from {path}") # Load model with performance optimizations model_path = os.path.join(path, "pytorch_model.bin") # Check if CUDA is available and enable if possible use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") logger.info(f"Using device: {device}") # Load the model with optimizations self.tagger = SequenceTagger.load(model_path) self.tagger.to(device) # Enable model evaluation mode for better inference performance self.tagger.eval() # Cache for commonly requested inputs self.cache = {} self.cache_size_limit = 1000 # Adjust based on memory constraints logger.info("Model successfully loaded and ready for inference") def preprocess(self, text: str) -> Sentence: # Create a sentence with optimized tokenization return Sentence(text) def predict_batch(self, sentences: List[Sentence]) -> None: with torch.no_grad(): # Disable gradient calculation for inference self.tagger.predict(sentences, label_name="predicted", mini_batch_size=32) def postprocess(self, sentence: Sentence) -> List[Dict[str, Any]]: entities = [] try: for span in sentence.get_spans("predicted"): if len(span.tokens) == 0: continue current_entity = { "entity_group": span.tag, "word": span.text, "start": span.tokens[0].start_position, "end": span.tokens[-1].end_position, "score": float(span.score), # Ensure score is serializable } entities.append(current_entity) except Exception as e: logger.error(f"Error in postprocessing: {str(e)}") return entities def __call__( self, data: Union[Dict[str, Any], List[Dict[str, Any]]] ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: # Handle both single input and batch input cases is_batch_input = isinstance(data, list) if not is_batch_input: # Convert single input to batch format temporarily data = [data] # Extract inputs from each item in the batch batch_inputs = [] for item in data: text = item.pop("inputs", item) if isinstance(item, dict) else item # Validate input if not isinstance(text, str): text = str(text) # Check cache for this input if text in self.cache: batch_inputs.append((text, True)) else: batch_inputs.append((text, False)) # Process non-cached inputs sentences_to_process = [] for text, is_cached in batch_inputs: if not is_cached: sentences_to_process.append(self.preprocess(text)) # Batch process sentences if any need processing if sentences_to_process: self.predict_batch(sentences_to_process) # Build results, including from cache results = [] sentence_idx = 0 for text, is_cached in batch_inputs: if is_cached: # Get from cache result = self.cache[text] else: # Process the sentence and cache result sentence = sentences_to_process[sentence_idx] result = self.postprocess(sentence) # Update cache if not too large if len(self.cache) < self.cache_size_limit: self.cache[text] = result sentence_idx += 1 results.append(result) # Return single result if input was single if not is_batch_input: return results[0] return results