"""DeBERTa-v3-base for sentence-pair boundary classification.""" from transformers import ( DebertaV2ForSequenceClassification, DebertaV2Tokenizer, PreTrainedTokenizerBase, ) from src.datasets.combined_pairs_dataset import NUM_LABELS, ID2LABEL, LABEL2ID BASE_MODEL = "microsoft/deberta-v3-base" def load_deberta( pretrained: str = BASE_MODEL, ) -> DebertaV2ForSequenceClassification: """Instantiate DeBERTa-v3-base for 3-class sentence-pair classification.""" return DebertaV2ForSequenceClassification.from_pretrained( pretrained, num_labels=NUM_LABELS, id2label=ID2LABEL, label2id=LABEL2ID, ) def load_deberta_tokenizer( pretrained: str = BASE_MODEL, ) -> PreTrainedTokenizerBase: return DebertaV2Tokenizer.from_pretrained(pretrained)