"""BERT for sentence-pair boundary classification.""" from transformers import ( AutoTokenizer, BertForSequenceClassification, PreTrainedTokenizerFast, ) from src.datasets.combined_pairs_dataset import NUM_LABELS, ID2LABEL, LABEL2ID BASE_MODEL = "bert-base-uncased" def load_bert( pretrained: str = BASE_MODEL, ) -> BertForSequenceClassification: """Instantiate BERT for 3-class sentence-pair classification.""" return BertForSequenceClassification.from_pretrained( pretrained, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id=LABEL2ID, ) def load_bert_tokenizer( pretrained: str = BASE_MODEL, ) -> PreTrainedTokenizerFast: return AutoTokenizer.from_pretrained(pretrained, use_fast=True)