Archeane commited on
Commit
e18e670
·
1 Parent(s): 01e069a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -0
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+
4
+ class EndpointHandler():
5
+ def __init__(self, path=""):
6
+ model_id = 'meta-llama/Llama-2-13b-chat-hf' # "meta-llama/Llama-2-13b-chat-hf"
7
+ model_config = transformers.AutoConfig.from_pretrained(
8
+ model_id
9
+ )
10
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ trust_remote_code=True,
13
+ config=model_config,
14
+ device_map='auto'
15
+ )
16
+ self.model.eval()
17
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
18
+ model_id,
19
+ )
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
+ inputs = data.pop("input",data)
23
+ return self.embed(inputs)
24
+
25
+ def embed(self, text):
26
+ with torch.no_grad():
27
+ encoded_input = self.tokenizer(text, return_tensors="pt")
28
+ model_output = self.model(**encoded_input, output_hidden_states=True)
29
+ last_four_layers = model_output[2][-4:]
30
+ return torch.stack(last_four_layers).mean(dim=0).mean(dim=1)