"""DistilBERT for sentence-pair boundary classification.""" from transformers import ( AutoTokenizer, DistilBertForSequenceClassification, PreTrainedTokenizerFast, ) from src.datasets.combined_pairs_dataset import NUM_LABELS, ID2LABEL, LABEL2ID BASE_MODEL = "distilbert-base-uncased" def load_distilbert( pretrained: str = BASE_MODEL, ) -> DistilBertForSequenceClassification: """Instantiate DistilBERT for 3-class sentence-pair classification.""" return DistilBertForSequenceClassification.from_pretrained( pretrained, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id=LABEL2ID, ) def load_distilbert_tokenizer( pretrained: str = BASE_MODEL, ) -> PreTrainedTokenizerFast: return AutoTokenizer.from_pretrained(pretrained, use_fast=True)