| """DeBERTa-v3-base for sentence-pair boundary classification.""" |
|
|
| from transformers import ( |
| DebertaV2ForSequenceClassification, |
| DebertaV2Tokenizer, |
| PreTrainedTokenizerBase, |
| ) |
|
|
| from src.datasets.combined_pairs_dataset import NUM_LABELS, ID2LABEL, LABEL2ID |
|
|
| BASE_MODEL = "microsoft/deberta-v3-base" |
|
|
|
|
| def load_deberta( |
| pretrained: str = BASE_MODEL, |
| ) -> DebertaV2ForSequenceClassification: |
| """Instantiate DeBERTa-v3-base for 3-class sentence-pair classification.""" |
| return DebertaV2ForSequenceClassification.from_pretrained( |
| pretrained, |
| num_labels=NUM_LABELS, |
| id2label=ID2LABEL, |
| label2id=LABEL2ID, |
| ) |
|
|
|
|
| def load_deberta_tokenizer( |
| pretrained: str = BASE_MODEL, |
| ) -> PreTrainedTokenizerBase: |
| return DebertaV2Tokenizer.from_pretrained(pretrained) |
|
|