Spaces:
Runtime error
Runtime error
| from functools import lru_cache | |
| import torch | |
| from loguru import logger | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModel | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| list_models = [ | |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
| 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', | |
| 'sentence-transformers/all-mpnet-base-v2', | |
| 'sentence-transformers/all-MiniLM-L12-v2', | |
| 'cyclone/simcse-chinese-roberta-wwm-ext', | |
| 'bert-base-chinese', | |
| 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', | |
| 'Qwen/Qwen3-Embedding-0.6B', | |
| ] | |
| class SBert: | |
| def __init__(self, path): | |
| logger.info(f'Start loading {self.__class__} from {path} ...') | |
| self.model = SentenceTransformer(path, device=DEVICE) | |
| logger.info(f'Load {self.__class__} from {path} ...') | |
| def __call__(self, x) -> torch.Tensor: | |
| y = self.model.encode(x, convert_to_tensor=True) | |
| return y | |
| class ModelWithPooling: | |
| def __init__(self, path): | |
| logger.info(f'Start loading {self.__class__} from {path} ...') | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.model = AutoModel.from_pretrained(path) | |
| logger.info(f'Load {self.__class__} from {path} ...') | |
| def __call__(self, text: str, pooling='mean'): | |
| inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| if pooling == 'cls': | |
| o = outputs.last_hidden_state[:, 0] # [b, h] | |
| elif pooling == 'pooler': | |
| o = outputs.pooler_output # [b, h] | |
| elif pooling in ['mean', 'last-avg']: | |
| last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s] | |
| o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
| elif pooling == 'first-last-avg': | |
| first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s] | |
| last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s] | |
| first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
| last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h] | |
| avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h] | |
| o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h] | |
| else: | |
| raise Exception(f'Unknown pooling {pooling}') | |
| o = o.squeeze(0) | |
| return o | |
| def test_sbert(): | |
| m = SBert('bert-base-chinese') | |
| o = m('hello') | |
| print(o.size()) | |
| assert o.size() == (768,) | |
| def test_hf_model(): | |
| m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') | |
| o = m('hello', pooling='cls') | |
| print(o.size()) | |
| assert o.size() == (768,) | |