File size: 1,121 Bytes
e18e670
 
7fe961b
e18e670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import transformers
from typing import Dict, Any, List

class EndpointHandler():
    def __init__(self, path=""):
        model_id = 'meta-llama/Llama-2-13b-chat-hf' # "meta-llama/Llama-2-13b-chat-hf"
        model_config = transformers.AutoConfig.from_pretrained(
            model_id
        )
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            config=model_config,
            device_map='auto'
        )
        self.model.eval()
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_id,
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.pop("input",data)
        return self.embed(inputs)
   
    def embed(self, text):
        with torch.no_grad():
            encoded_input = self.tokenizer(text, return_tensors="pt")
            model_output = self.model(**encoded_input, output_hidden_states=True)
            last_four_layers = model_output[2][-4:]
            return torch.stack(last_four_layers).mean(dim=0).mean(dim=1)