You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

ModernBERT-base fine-tuned on MedNLI

This model is a fine-tuned version of answerdotai/ModernBERT-base on the MedNLI dataset for clinical natural language inference.

Model Description

  • Base model: ModernBERT-base (149M parameters)
  • Task: Natural Language Inference (3-class: entailment, neutral, contradiction)
  • Domain: Clinical/Medical text
  • Training data: MedNLI (11,232 training examples)

Performance

Dataset Accuracy Macro F1 Macro Precision Macro Recall
MedNLI test 81.4% 81.4% 81.5% 81.4%
RadNLI test (zero-shot) 52.1% 54.9% 63.1% 66.9%

Comparison with base model (zero-shot)

Model MedNLI Acc MedNLI F1 RadNLI Acc RadNLI F1
ModernBERT-base (zero-shot) 41.2% 38.7% 48.5% 47.0%
This model (fine-tuned) 81.4% 81.4% 52.1% 54.9%

Training Details

Hyperparameters

  • Learning rate: 8e-5
  • Epochs: 2
  • Batch size: 32
  • Weight decay: 8e-6
  • LR scheduler: linear
  • Optimizer: AdamW (beta1=0.9, beta2=0.98, epsilon=1e-6)
  • Max sequence length: 256

Hardware

  • Single GPU training

Usage

Natural Language Inference

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

premise = "The patient was diagnosed with pneumonia."
hypothesis = "The patient has a respiratory infection."

inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=256)
outputs = model(**inputs)

labels = ["entailment", "neutral", "contradiction"]
prediction = labels[outputs.logits.argmax().item()]
print(prediction)  # entailment

Zero-Shot Text Classification

Use entailment scores to classify clinical text without task-specific training:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

ENTAILMENT_IDX = 0  # entailment is the first class

def classify(text: str, candidate_labels: list[str]) -> dict[str, float]:
    """Zero-shot classification using NLI entailment scores."""
    scores = {}
    for label in candidate_labels:
        hypothesis = f"This patient has {label}."
        inputs = tokenizer(text, hypothesis, return_tensors="pt", truncation=True, max_length=256)

        with torch.no_grad():
            logits = model(**inputs).logits
            probs = torch.softmax(logits, dim=-1)
            scores[label] = probs[0, ENTAILMENT_IDX].item()

    # Normalize scores
    total = sum(scores.values())
    return {k: v / total for k, v in scores.items()}

# Example: Binary classification for DVT/PE
clinical_note = "Patient presents with acute shortness of breath and pleuritic chest pain. CT angiography reveals filling defects in the pulmonary arteries."
labels = ["Deep Venous Thromboembolism or Pulmonary Embolism", "No Deep Venous Thromboembolism or Pulmonary Embolism"]

result = classify(clinical_note, labels)
print(result)
# {'Deep Venous Thromboembolism or Pulmonary Embolism': 0.92, 'No Deep Venous Thromboembolism or Pulmonary Embolism': 0.08}

Batch Classification

For processing multiple texts efficiently:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

ENTAILMENT_IDX = 0

def classify_batch(texts: list[str], candidate_labels: list[str]) -> list[dict[str, float]]:
    """Batch zero-shot classification."""
    # Build all premise-hypothesis pairs
    premises, hypotheses = [], []
    for text in texts:
        for label in candidate_labels:
            premises.append(text)
            hypotheses.append(f"This patient has {label}.")

    inputs = tokenizer(premises, hypotheses, return_tensors="pt", truncation=True, padding=True, max_length=256)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        entailment_scores = probs[:, ENTAILMENT_IDX].tolist()

    # Reshape and normalize
    results = []
    n_labels = len(candidate_labels)
    for i in range(len(texts)):
        scores = dict(zip(candidate_labels, entailment_scores[i * n_labels:(i + 1) * n_labels]))
        total = sum(scores.values())
        results.append({k: v / total for k, v in scores.items()})

    return results

# Example
notes = [
    "Patient with swelling in left leg and positive D-dimer.",
    "Routine checkup, no complaints, vitals normal."
]
labels = ["DVT", "No DVT"]

results = classify_batch(notes, labels)
for note, result in zip(notes, results):
    print(f"{result}")

Intended Use

This model is intended for:

  • Clinical NLI tasks
  • Medical text understanding
  • Research in clinical NLP

Limitations

  • Trained only on MedNLI; limited generalization to other clinical domains (e.g., radiology)
  • Performance may vary on text significantly different from clinical notes
  • Not validated for clinical decision-making

Citation

If you use this model, please cite MedNLI:

@article{romanov2018lessons,
  title={Lessons from natural language inference in the clinical domain},
  author={Romanov, Alexey and Shivade, Chaitanya},
  journal={arXiv preprint arXiv:1808.06752},
  year={2018}
}

And ModernBERT:

@article{modernbert,
  title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference},
  author={Warner, Benjamin and Misra, Arka and Diab, Mona and Lotfi, Hani and others},
  year={2024}
}
Downloads last month
14
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for ygivenx/modernbert-base-mednli

Finetuned
(1091)
this model

Paper for ygivenx/modernbert-base-mednli