| from functools import lru_cache |
|
|
| import torch |
| from loguru import logger |
| from sentence_transformers import SentenceTransformer |
| from transformers import AutoTokenizer, AutoModel |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| list_models = [ |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', |
| 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', |
| 'sentence-transformers/all-mpnet-base-v2', |
| 'sentence-transformers/all-MiniLM-L12-v2', |
| 'cyclone/simcse-chinese-roberta-wwm-ext', |
| 'bert-base-chinese', |
| 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', |
| ] |
|
|
|
|
| class SBert: |
| def __init__(self, path): |
| logger.info(f'Start loading {self.__class__} from {path} ...') |
| self.model = SentenceTransformer(path, device=DEVICE) |
| logger.info(f'Load {self.__class__} from {path} ...') |
|
|
| @lru_cache(maxsize=10000) |
| def __call__(self, x) -> torch.Tensor: |
| y = self.model.encode(x, convert_to_tensor=True) |
| return y |
|
|
|
|
| class ModelWithPooling: |
| def __init__(self, path): |
| logger.info(f'Start loading {self.__class__} from {path} ...') |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.model = AutoModel.from_pretrained(path) |
| logger.info(f'Load {self.__class__} from {path} ...') |
|
|
| @lru_cache(maxsize=100) |
| @torch.no_grad() |
| def __call__(self, text: str, pooling='mean'): |
| inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
| outputs = self.model(**inputs, output_hidden_states=True) |
|
|
| if pooling == 'cls': |
| o = outputs.last_hidden_state[:, 0] |
|
|
| elif pooling == 'pooler': |
| o = outputs.pooler_output |
|
|
| elif pooling in ['mean', 'last-avg']: |
| last = outputs.last_hidden_state.transpose(1, 2) |
| o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
|
|
| elif pooling == 'first-last-avg': |
| first = outputs.hidden_states[1].transpose(1, 2) |
| last = outputs.hidden_states[-1].transpose(1, 2) |
| first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) |
| last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
| avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) |
| o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) |
|
|
| else: |
| raise Exception(f'Unknown pooling {pooling}') |
|
|
| o = o.squeeze(0) |
| return o |
|
|
|
|
| def test_sbert(): |
| m = SBert('bert-base-chinese') |
| o = m('hello') |
| print(o.size()) |
| assert o.size() == (768,) |
|
|
|
|
| def test_hf_model(): |
| m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') |
| o = m('hello', pooling='cls') |
| print(o.size()) |
| assert o.size() == (768,) |
|
|