File size: 971 Bytes
62a40cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import AutoModel, AutoTokenizer
import torch

class EndpointHandler:
    def __init__(self, model_dir):
        self.model = None
        self.tokenizer = None
        self.model_dir = model_dir
        
    def __call__(self, data):
        # Initialize model if not already initialized
        if self.model is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
            self.model = AutoModel.from_pretrained(self.model_dir)
            
        # Process input data
        inputs = data.get("inputs")
        if isinstance(inputs, str):
            inputs = [inputs]
            
        # Tokenize and get model outputs
        encoded_input = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**encoded_input)
            
        # Return the last hidden state
        return {"outputs": outputs.last_hidden_state.tolist()}