| from everything import * | |
| from bert import BertModel | |
| def get_finetuned_bert(mode: str): | |
| assert mode in ['sup', 'unsup'] | |
| bert = BertModel.from_pretrained('bert-base-uncased') | |
| if mode == 'sup': | |
| state_dict = torch.load(SUP_BERT, weights_only=True) | |
| else: | |
| state_dict = torch.load(UNSUP_BERT, weights_only=True) | |
| device = torch.device('cuda') if USE_GPU else torch.device('cpu') | |
| bert.load_state_dict(state_dict) | |
| return bert.to(device) | |