File size: 1,121 Bytes
e18e670 7fe961b e18e670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import transformers
from typing import Dict, Any, List
class EndpointHandler():
def __init__(self, path=""):
model_id = 'meta-llama/Llama-2-13b-chat-hf' # "meta-llama/Llama-2-13b-chat-hf"
model_config = transformers.AutoConfig.from_pretrained(
model_id
)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=model_config,
device_map='auto'
)
self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id,
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("input",data)
return self.embed(inputs)
def embed(self, text):
with torch.no_grad():
encoded_input = self.tokenizer(text, return_tensors="pt")
model_output = self.model(**encoded_input, output_hidden_states=True)
last_four_layers = model_output[2][-4:]
return torch.stack(last_four_layers).mean(dim=0).mean(dim=1) |