| import torch | |
| import transformers | |
| from typing import Dict, Any, List | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| model_id = 'meta-llama/Llama-2-13b-chat-hf' # "meta-llama/Llama-2-13b-chat-hf" | |
| model_config = transformers.AutoConfig.from_pretrained( | |
| model_id | |
| ) | |
| self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| config=model_config, | |
| device_map='auto' | |
| ) | |
| self.model.eval() | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| model_id, | |
| ) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| inputs = data.pop("input",data) | |
| return self.embed(inputs) | |
| def embed(self, text): | |
| with torch.no_grad(): | |
| encoded_input = self.tokenizer(text, return_tensors="pt") | |
| model_output = self.model(**encoded_input, output_hidden_states=True) | |
| last_four_layers = model_output[2][-4:] | |
| return torch.stack(last_four_layers).mean(dim=0).mean(dim=1) |