ner-english-inference / handler.py
Kevinger's picture
Enhance EndpointHandler with model optimizations and caching; add .gitignore for virtual environment
68ed19a
import os
import logging
import torch
from typing import Any, Dict, List, Union
from flair.data import Sentence
from flair.models import SequenceTagger
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str):
# Log initialization
logger.info(f"Initializing Flair endpoint handler from {path}")
# Load model with performance optimizations
model_path = os.path.join(path, "pytorch_model.bin")
# Check if CUDA is available and enable if possible
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
logger.info(f"Using device: {device}")
# Load the model with optimizations
self.tagger = SequenceTagger.load(model_path)
self.tagger.to(device)
# Enable model evaluation mode for better inference performance
self.tagger.eval()
# Cache for commonly requested inputs
self.cache = {}
self.cache_size_limit = 1000 # Adjust based on memory constraints
logger.info("Model successfully loaded and ready for inference")
def preprocess(self, text: str) -> Sentence:
# Create a sentence with optimized tokenization
return Sentence(text)
def predict_batch(self, sentences: List[Sentence]) -> None:
with torch.no_grad(): # Disable gradient calculation for inference
self.tagger.predict(sentences, label_name="predicted", mini_batch_size=32)
def postprocess(self, sentence: Sentence) -> List[Dict[str, Any]]:
entities = []
try:
for span in sentence.get_spans("predicted"):
if len(span.tokens) == 0:
continue
current_entity = {
"entity_group": span.tag,
"word": span.text,
"start": span.tokens[0].start_position,
"end": span.tokens[-1].end_position,
"score": float(span.score), # Ensure score is serializable
}
entities.append(current_entity)
except Exception as e:
logger.error(f"Error in postprocessing: {str(e)}")
return entities
def __call__(
self, data: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
# Handle both single input and batch input cases
is_batch_input = isinstance(data, list)
if not is_batch_input:
# Convert single input to batch format temporarily
data = [data]
# Extract inputs from each item in the batch
batch_inputs = []
for item in data:
text = item.pop("inputs", item) if isinstance(item, dict) else item
# Validate input
if not isinstance(text, str):
text = str(text)
# Check cache for this input
if text in self.cache:
batch_inputs.append((text, True))
else:
batch_inputs.append((text, False))
# Process non-cached inputs
sentences_to_process = []
for text, is_cached in batch_inputs:
if not is_cached:
sentences_to_process.append(self.preprocess(text))
# Batch process sentences if any need processing
if sentences_to_process:
self.predict_batch(sentences_to_process)
# Build results, including from cache
results = []
sentence_idx = 0
for text, is_cached in batch_inputs:
if is_cached:
# Get from cache
result = self.cache[text]
else:
# Process the sentence and cache result
sentence = sentences_to_process[sentence_idx]
result = self.postprocess(sentence)
# Update cache if not too large
if len(self.cache) < self.cache_size_limit:
self.cache[text] = result
sentence_idx += 1
results.append(result)
# Return single result if input was single
if not is_batch_input:
return results[0]
return results