| |
|
|
| from typing import Dict, Any |
|
|
| from peft import PeftModel |
| from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig |
| import transformers |
|
|
|
|
| import torch |
| from torch import Tensor |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from transformers.models.mistral.modeling_mistral import MistralAttention |
| from ExtractableMistralAttention import forward |
|
|
| MistralAttention.forward = forward |
|
|
| class EndpointHandler(): |
| def __init__(self, model_dir=''): |
| self.instruction = 'Given a web search query, retrieve relevant passages that answer the query:\n' |
| self.max_length = 4096 |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| self.tokenizer.pad_token = '[PAD]' |
| self.tokenizer.padding_side = 'left' |
|
|
| bnb_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16) |
|
|
| self.model = AutoModel.from_pretrained( |
| model_dir, |
| quantization_config=bnb_config, |
| device_map="auto", |
| trust_remote_code=True, |
| attn_implementation="eager", |
| ) |
| |
| self.model.eval() |
|
|
|
|
| def last_token_pool(self, last_hidden_states: Tensor, |
| attention_mask: Tensor) -> Tensor: |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| return last_hidden_states[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_states.shape[0] |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
|
| def tokenize(self, text, request_type): |
| if request_type == 'query': |
| text = self.instruction + text |
| return self.tokenizer(text + self.tokenizer.eos_token, max_length=self.max_length, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
|
| def extract_attn_vec(self, model): |
| return self.model._modules['layers'][-1].self_attn.attn_vec |
|
|
|
|
| def embed(self, text, request_type): |
| tokens = self.tokenize(text, request_type) |
| with torch.no_grad(): |
| output = self.model(tokens['input_ids'], tokens['attention_mask']).last_hidden_state.detach() |
| embedding = self.last_token_pool(output, tokens['attention_mask']) |
| embedding = F.normalize(embedding, p=2, dim=1) |
|
|
| attn_vec = self.extract_attn_vec(self.model) |
| attn_vec = self.last_token_pool(attn_vec, tokens['attention_mask']) |
| attn_vec = F.normalize(attn_vec, p=2, dim=1) |
| del output, tokens |
| torch.cuda.empty_cache() |
| return embedding, attn_vec |
|
|
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| data args: |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
| inputs = data.pop("inputs", data) |
| id = inputs.pop("id", inputs) |
| text = inputs.pop("text", inputs) |
| request_type = inputs.pop("type", inputs) |
|
|
|
|
| embeddings, attn_vec = self.embed(text, request_type) |
| embeddings = embeddings[0].tolist() |
| attn_vec = attn_vec[0].tolist() |
|
|
| if request_type == 'query': |
| return {"id": id, "embedding": embeddings, "attention_vec": attn_vec} |
| |
| elif request_type == 'document': |
| return {"id": id, "embedding": embeddings} |