Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import os | |
| def load_bert(model_path): | |
| bert = BERT(model_path) | |
| bert.eval() | |
| bert.text_model.training = False | |
| for p in bert.parameters(): | |
| p.requires_grad = False | |
| return bert | |
| class BERT(nn.Module): | |
| def __init__(self, modelpath: str): | |
| super().__init__() | |
| from transformers import AutoTokenizer, AutoModel | |
| from transformers import logging | |
| logging.set_verbosity_error() | |
| # Tokenizer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(modelpath) | |
| # Text model | |
| self.text_model = AutoModel.from_pretrained(modelpath) | |
| def forward(self, texts): | |
| encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) | |
| output = self.text_model(**encoded_inputs.to(self.text_model.device)).last_hidden_state | |
| mask = encoded_inputs.attention_mask.to(dtype=bool) | |
| # output = output * mask.unsqueeze(-1) | |
| return output, mask | |