Deehan1866/WiC
Viewer โข Updated โข 7.47k โข 127
Cross-encoder for Word-in-Context (WiC) that combines:
[CLS ; h1 ; h2 ; h1-h2 ; h1*h2]DeBERTa encoder
โ
h1 = hidden state at <tgt> position in sentence 1
h2 = hidden state at <tgt> position in sentence 2
โ
interaction = [CLS, h1, h2, h1-h2, h1*h2] (5 ร 1024)
โ
MLP(5120 โ 512 โ 2)
โ
loss = CrossEntropy + 0.1 ร ContrastiveLoss(h1, h2)
| Split | Accuracy |
|---|---|
| Validation | 0.7555 |
| Test | 0.7257 |
| Baseline DeBERTa | 0.7279 |
| SenseBERT (SOTA) | 0.7210 |
Standard WiC cross-encoders use the [CLS] token for classification,
which must implicitly represent both sentence meaning and word sense.
This model explicitly extracts and compares the target word's contextual
representation from both sentences, directly optimising sense similarity
via contrastive loss.
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch.nn.functional as F
# Reconstruct model
tokenizer = AutoTokenizer.from_pretrained("Deehan1866/deberta-wic-contrastive")
encoder = AutoModel.from_pretrained("Deehan1866/deberta-wic-contrastive")
classifier = nn.Sequential(
nn.Linear(5 * 1024, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, 2)
)
# load classifier_head.pt separately
word = "bank"
s1 = "<tgt>bank</tgt> raised its interest rates."
s2 = "She visited her local <tgt>bank</tgt> to deposit a cheque."
enc = tokenizer(s1, s2, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
hidden = encoder(**enc).last_hidden_state
tgt_id = tokenizer.convert_tokens_to_ids("<tgt>")
ids = enc["input_ids"][0].tolist()
positions = [i for i, t in enumerate(ids) if t == tgt_id]
h1 = hidden[0, positions[0]]
h2 = hidden[0, positions[1]]
cls = hidden[0, 0]
interaction = torch.cat([cls, h1, h2, h1-h2, h1*h2]).unsqueeze(0)
logits = classifier(interaction)
pred = torch.argmax(logits).item()
print("Same sense" if pred == 1 else "Different sense")