| from typing import List, Dict, Any | |
| import numpy as np | |
| from transformers import BertTokenizer, BertModel | |
| import torch | |
| import pickle | |
| def unpickle_obj(filepath): | |
| with open(filepath, 'rb') as f_in: | |
| data = pickle.load(f_in) | |
| print(f"unpickled {filepath}") | |
| return data | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.model = unpickle_obj(f"{path}/bert_lr.pkl") | |
| self.tokenizer = BertTokenizer.from_pretrained(path, local_files_only=True) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.bert = BertModel.from_pretrained(path).to(self.device) | |
| def get_embeddings(self, texts: List[str]): | |
| inputs = self.tokenizer(texts, return_tensors='pt', truncation=True, | |
| padding=True, max_length=512).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.bert(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| inputs = data.pop("inputs",data) | |
| queries = inputs['queries'] | |
| texts = inputs['texts'] | |
| queries_vec = self.get_embeddings(queries) | |
| texts_vec = self.get_embeddings(texts) | |
| diff = (np.array(texts_vec)[:, np.newaxis] - np.array(queries_vec))\ | |
| .reshape(-1, len(queries_vec[0])) | |
| return [{ | |
| "outputs": self.model.predict_proba(diff).tolist() | |
| }] |