DeBERTa-v3-large: Target-Word Contrastive WiC

Cross-encoder for Word-in-Context (WiC) that combines:

  1. Explicit target-word interaction features: [CLS ; h1 ; h2 ; h1-h2 ; h1*h2]
  2. Supervised contrastive loss on target word representations

Architecture

DeBERTa encoder
    โ†“
h1 = hidden state at <tgt> position in sentence 1
h2 = hidden state at <tgt> position in sentence 2
    โ†“
interaction = [CLS, h1, h2, h1-h2, h1*h2]  (5 ร— 1024)
    โ†“
MLP(5120 โ†’ 512 โ†’ 2)
    โ†“
loss = CrossEntropy + 0.1 ร— ContrastiveLoss(h1, h2)

Performance

Split Accuracy
Validation 0.7555
Test 0.7257
Baseline DeBERTa 0.7279
SenseBERT (SOTA) 0.7210

Key Idea

Standard WiC cross-encoders use the [CLS] token for classification, which must implicitly represent both sentence meaning and word sense. This model explicitly extracts and compares the target word's contextual representation from both sentences, directly optimising sense similarity via contrastive loss.

Usage

import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch.nn.functional as F

# Reconstruct model
tokenizer = AutoTokenizer.from_pretrained("Deehan1866/deberta-wic-contrastive")
encoder = AutoModel.from_pretrained("Deehan1866/deberta-wic-contrastive")

classifier = nn.Sequential(
    nn.Linear(5 * 1024, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, 2)
)
# load classifier_head.pt separately

word = "bank"
s1 = "<tgt>bank</tgt> raised its interest rates."
s2 = "She visited her local <tgt>bank</tgt> to deposit a cheque."

enc = tokenizer(s1, s2, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
    hidden = encoder(**enc).last_hidden_state
    tgt_id = tokenizer.convert_tokens_to_ids("<tgt>")
    ids = enc["input_ids"][0].tolist()
    positions = [i for i, t in enumerate(ids) if t == tgt_id]
    h1 = hidden[0, positions[0]]
    h2 = hidden[0, positions[1]]
    cls = hidden[0, 0]
    interaction = torch.cat([cls, h1, h2, h1-h2, h1*h2]).unsqueeze(0)
    logits = classifier(interaction)
    pred = torch.argmax(logits).item()
print("Same sense" if pred == 1 else "Different sense")
Downloads last month
40
Safetensors
Model size
0.4B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train Deehan1866/deberta-wic-contrastive