Lessons from Natural Language Inference in the Clinical Domain
Paper
• 1808.06752 • Published
This model is a fine-tuned version of answerdotai/ModernBERT-base on the MedNLI dataset for clinical natural language inference.
| Dataset | Accuracy | Macro F1 | Macro Precision | Macro Recall |
|---|---|---|---|---|
| MedNLI test | 81.4% | 81.4% | 81.5% | 81.4% |
| RadNLI test (zero-shot) | 52.1% | 54.9% | 63.1% | 66.9% |
| Model | MedNLI Acc | MedNLI F1 | RadNLI Acc | RadNLI F1 |
|---|---|---|---|---|
| ModernBERT-base (zero-shot) | 41.2% | 38.7% | 48.5% | 47.0% |
| This model (fine-tuned) | 81.4% | 81.4% | 52.1% | 54.9% |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")
premise = "The patient was diagnosed with pneumonia."
hypothesis = "The patient has a respiratory infection."
inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=256)
outputs = model(**inputs)
labels = ["entailment", "neutral", "contradiction"]
prediction = labels[outputs.logits.argmax().item()]
print(prediction) # entailment
Use entailment scores to classify clinical text without task-specific training:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")
ENTAILMENT_IDX = 0 # entailment is the first class
def classify(text: str, candidate_labels: list[str]) -> dict[str, float]:
"""Zero-shot classification using NLI entailment scores."""
scores = {}
for label in candidate_labels:
hypothesis = f"This patient has {label}."
inputs = tokenizer(text, hypothesis, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
scores[label] = probs[0, ENTAILMENT_IDX].item()
# Normalize scores
total = sum(scores.values())
return {k: v / total for k, v in scores.items()}
# Example: Binary classification for DVT/PE
clinical_note = "Patient presents with acute shortness of breath and pleuritic chest pain. CT angiography reveals filling defects in the pulmonary arteries."
labels = ["Deep Venous Thromboembolism or Pulmonary Embolism", "No Deep Venous Thromboembolism or Pulmonary Embolism"]
result = classify(clinical_note, labels)
print(result)
# {'Deep Venous Thromboembolism or Pulmonary Embolism': 0.92, 'No Deep Venous Thromboembolism or Pulmonary Embolism': 0.08}
For processing multiple texts efficiently:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")
ENTAILMENT_IDX = 0
def classify_batch(texts: list[str], candidate_labels: list[str]) -> list[dict[str, float]]:
"""Batch zero-shot classification."""
# Build all premise-hypothesis pairs
premises, hypotheses = [], []
for text in texts:
for label in candidate_labels:
premises.append(text)
hypotheses.append(f"This patient has {label}.")
inputs = tokenizer(premises, hypotheses, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
entailment_scores = probs[:, ENTAILMENT_IDX].tolist()
# Reshape and normalize
results = []
n_labels = len(candidate_labels)
for i in range(len(texts)):
scores = dict(zip(candidate_labels, entailment_scores[i * n_labels:(i + 1) * n_labels]))
total = sum(scores.values())
results.append({k: v / total for k, v in scores.items()})
return results
# Example
notes = [
"Patient with swelling in left leg and positive D-dimer.",
"Routine checkup, no complaints, vitals normal."
]
labels = ["DVT", "No DVT"]
results = classify_batch(notes, labels)
for note, result in zip(notes, results):
print(f"{result}")
This model is intended for:
If you use this model, please cite MedNLI:
@article{romanov2018lessons,
title={Lessons from natural language inference in the clinical domain},
author={Romanov, Alexey and Shivade, Chaitanya},
journal={arXiv preprint arXiv:1808.06752},
year={2018}
}
And ModernBERT:
@article{modernbert,
title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference},
author={Warner, Benjamin and Misra, Arka and Diab, Mona and Lotfi, Hani and others},
year={2024}
}
Base model
answerdotai/ModernBERT-base