File size: 4,303 Bytes
55c59f1
68ed19a
 
55c59f1
68ed19a
55c59f1
 
4f43fc0
68ed19a
 
 
 
4f43fc0
 
55c59f1
68ed19a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55c59f1
68ed19a
 
 
 
 
 
 
 
 
55c59f1
68ed19a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f43fc0
55c59f1
68ed19a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import logging
import torch

from typing import Any, Dict, List, Union
from flair.data import Sentence
from flair.models import SequenceTagger

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    def __init__(self, path: str):
        # Log initialization
        logger.info(f"Initializing Flair endpoint handler from {path}")

        # Load model with performance optimizations
        model_path = os.path.join(path, "pytorch_model.bin")

        # Check if CUDA is available and enable if possible
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")
        logger.info(f"Using device: {device}")

        # Load the model with optimizations
        self.tagger = SequenceTagger.load(model_path)
        self.tagger.to(device)

        # Enable model evaluation mode for better inference performance
        self.tagger.eval()

        # Cache for commonly requested inputs
        self.cache = {}
        self.cache_size_limit = 1000  # Adjust based on memory constraints

        logger.info("Model successfully loaded and ready for inference")

    def preprocess(self, text: str) -> Sentence:
        # Create a sentence with optimized tokenization
        return Sentence(text)

    def predict_batch(self, sentences: List[Sentence]) -> None:
        with torch.no_grad():  # Disable gradient calculation for inference
            self.tagger.predict(sentences, label_name="predicted", mini_batch_size=32)

    def postprocess(self, sentence: Sentence) -> List[Dict[str, Any]]:
        entities = []

        try:
            for span in sentence.get_spans("predicted"):
                if len(span.tokens) == 0:
                    continue

                current_entity = {
                    "entity_group": span.tag,
                    "word": span.text,
                    "start": span.tokens[0].start_position,
                    "end": span.tokens[-1].end_position,
                    "score": float(span.score),  # Ensure score is serializable
                }
                entities.append(current_entity)
        except Exception as e:
            logger.error(f"Error in postprocessing: {str(e)}")

        return entities

    def __call__(
        self, data: Union[Dict[str, Any], List[Dict[str, Any]]]
    ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
        # Handle both single input and batch input cases
        is_batch_input = isinstance(data, list)

        if not is_batch_input:
            # Convert single input to batch format temporarily
            data = [data]

        # Extract inputs from each item in the batch
        batch_inputs = []
        for item in data:
            text = item.pop("inputs", item) if isinstance(item, dict) else item

            # Validate input
            if not isinstance(text, str):
                text = str(text)

            # Check cache for this input
            if text in self.cache:
                batch_inputs.append((text, True))
            else:
                batch_inputs.append((text, False))

        # Process non-cached inputs
        sentences_to_process = []
        for text, is_cached in batch_inputs:
            if not is_cached:
                sentences_to_process.append(self.preprocess(text))

        # Batch process sentences if any need processing
        if sentences_to_process:
            self.predict_batch(sentences_to_process)

        # Build results, including from cache
        results = []
        sentence_idx = 0

        for text, is_cached in batch_inputs:
            if is_cached:
                # Get from cache
                result = self.cache[text]
            else:
                # Process the sentence and cache result
                sentence = sentences_to_process[sentence_idx]
                result = self.postprocess(sentence)

                # Update cache if not too large
                if len(self.cache) < self.cache_size_limit:
                    self.cache[text] = result

                sentence_idx += 1

            results.append(result)

        # Return single result if input was single
        if not is_batch_input:
            return results[0]

        return results