File size: 4,303 Bytes
55c59f1 68ed19a 55c59f1 68ed19a 55c59f1 4f43fc0 68ed19a 4f43fc0 55c59f1 68ed19a 55c59f1 68ed19a 55c59f1 68ed19a 4f43fc0 55c59f1 68ed19a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import logging
import torch
from typing import Any, Dict, List, Union
from flair.data import Sentence
from flair.models import SequenceTagger
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str):
# Log initialization
logger.info(f"Initializing Flair endpoint handler from {path}")
# Load model with performance optimizations
model_path = os.path.join(path, "pytorch_model.bin")
# Check if CUDA is available and enable if possible
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
logger.info(f"Using device: {device}")
# Load the model with optimizations
self.tagger = SequenceTagger.load(model_path)
self.tagger.to(device)
# Enable model evaluation mode for better inference performance
self.tagger.eval()
# Cache for commonly requested inputs
self.cache = {}
self.cache_size_limit = 1000 # Adjust based on memory constraints
logger.info("Model successfully loaded and ready for inference")
def preprocess(self, text: str) -> Sentence:
# Create a sentence with optimized tokenization
return Sentence(text)
def predict_batch(self, sentences: List[Sentence]) -> None:
with torch.no_grad(): # Disable gradient calculation for inference
self.tagger.predict(sentences, label_name="predicted", mini_batch_size=32)
def postprocess(self, sentence: Sentence) -> List[Dict[str, Any]]:
entities = []
try:
for span in sentence.get_spans("predicted"):
if len(span.tokens) == 0:
continue
current_entity = {
"entity_group": span.tag,
"word": span.text,
"start": span.tokens[0].start_position,
"end": span.tokens[-1].end_position,
"score": float(span.score), # Ensure score is serializable
}
entities.append(current_entity)
except Exception as e:
logger.error(f"Error in postprocessing: {str(e)}")
return entities
def __call__(
self, data: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
# Handle both single input and batch input cases
is_batch_input = isinstance(data, list)
if not is_batch_input:
# Convert single input to batch format temporarily
data = [data]
# Extract inputs from each item in the batch
batch_inputs = []
for item in data:
text = item.pop("inputs", item) if isinstance(item, dict) else item
# Validate input
if not isinstance(text, str):
text = str(text)
# Check cache for this input
if text in self.cache:
batch_inputs.append((text, True))
else:
batch_inputs.append((text, False))
# Process non-cached inputs
sentences_to_process = []
for text, is_cached in batch_inputs:
if not is_cached:
sentences_to_process.append(self.preprocess(text))
# Batch process sentences if any need processing
if sentences_to_process:
self.predict_batch(sentences_to_process)
# Build results, including from cache
results = []
sentence_idx = 0
for text, is_cached in batch_inputs:
if is_cached:
# Get from cache
result = self.cache[text]
else:
# Process the sentence and cache result
sentence = sentences_to_process[sentence_idx]
result = self.postprocess(sentence)
# Update cache if not too large
if len(self.cache) < self.cache_size_limit:
self.cache[text] = result
sentence_idx += 1
results.append(result)
# Return single result if input was single
if not is_batch_input:
return results[0]
return results
|