ygivenx's picture
Add zero-shot classification examples
7b0ef24 verified
metadata
license: apache-2.0
language:
  - en
library_name: transformers
tags:
  - medical
  - clinical
  - nli
  - natural-language-inference
  - modernbert
datasets:
  - mednli
base_model: answerdotai/ModernBERT-base
metrics:
  - accuracy
  - f1
pipeline_tag: text-classification

ModernBERT-base fine-tuned on MedNLI

This model is a fine-tuned version of answerdotai/ModernBERT-base on the MedNLI dataset for clinical natural language inference.

Model Description

  • Base model: ModernBERT-base (149M parameters)
  • Task: Natural Language Inference (3-class: entailment, neutral, contradiction)
  • Domain: Clinical/Medical text
  • Training data: MedNLI (11,232 training examples)

Performance

Dataset Accuracy Macro F1 Macro Precision Macro Recall
MedNLI test 81.4% 81.4% 81.5% 81.4%
RadNLI test (zero-shot) 52.1% 54.9% 63.1% 66.9%

Comparison with base model (zero-shot)

Model MedNLI Acc MedNLI F1 RadNLI Acc RadNLI F1
ModernBERT-base (zero-shot) 41.2% 38.7% 48.5% 47.0%
This model (fine-tuned) 81.4% 81.4% 52.1% 54.9%

Training Details

Hyperparameters

  • Learning rate: 8e-5
  • Epochs: 2
  • Batch size: 32
  • Weight decay: 8e-6
  • LR scheduler: linear
  • Optimizer: AdamW (beta1=0.9, beta2=0.98, epsilon=1e-6)
  • Max sequence length: 256

Hardware

  • Single GPU training

Usage

Natural Language Inference

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

premise = "The patient was diagnosed with pneumonia."
hypothesis = "The patient has a respiratory infection."

inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=256)
outputs = model(**inputs)

labels = ["entailment", "neutral", "contradiction"]
prediction = labels[outputs.logits.argmax().item()]
print(prediction)  # entailment

Zero-Shot Text Classification

Use entailment scores to classify clinical text without task-specific training:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

ENTAILMENT_IDX = 0  # entailment is the first class

def classify(text: str, candidate_labels: list[str]) -> dict[str, float]:
    """Zero-shot classification using NLI entailment scores."""
    scores = {}
    for label in candidate_labels:
        hypothesis = f"This patient has {label}."
        inputs = tokenizer(text, hypothesis, return_tensors="pt", truncation=True, max_length=256)

        with torch.no_grad():
            logits = model(**inputs).logits
            probs = torch.softmax(logits, dim=-1)
            scores[label] = probs[0, ENTAILMENT_IDX].item()

    # Normalize scores
    total = sum(scores.values())
    return {k: v / total for k, v in scores.items()}

# Example: Binary classification for DVT/PE
clinical_note = "Patient presents with acute shortness of breath and pleuritic chest pain. CT angiography reveals filling defects in the pulmonary arteries."
labels = ["Deep Venous Thromboembolism or Pulmonary Embolism", "No Deep Venous Thromboembolism or Pulmonary Embolism"]

result = classify(clinical_note, labels)
print(result)
# {'Deep Venous Thromboembolism or Pulmonary Embolism': 0.92, 'No Deep Venous Thromboembolism or Pulmonary Embolism': 0.08}

Batch Classification

For processing multiple texts efficiently:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli")
tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli")

ENTAILMENT_IDX = 0

def classify_batch(texts: list[str], candidate_labels: list[str]) -> list[dict[str, float]]:
    """Batch zero-shot classification."""
    # Build all premise-hypothesis pairs
    premises, hypotheses = [], []
    for text in texts:
        for label in candidate_labels:
            premises.append(text)
            hypotheses.append(f"This patient has {label}.")

    inputs = tokenizer(premises, hypotheses, return_tensors="pt", truncation=True, padding=True, max_length=256)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        entailment_scores = probs[:, ENTAILMENT_IDX].tolist()

    # Reshape and normalize
    results = []
    n_labels = len(candidate_labels)
    for i in range(len(texts)):
        scores = dict(zip(candidate_labels, entailment_scores[i * n_labels:(i + 1) * n_labels]))
        total = sum(scores.values())
        results.append({k: v / total for k, v in scores.items()})

    return results

# Example
notes = [
    "Patient with swelling in left leg and positive D-dimer.",
    "Routine checkup, no complaints, vitals normal."
]
labels = ["DVT", "No DVT"]

results = classify_batch(notes, labels)
for note, result in zip(notes, results):
    print(f"{result}")

Intended Use

This model is intended for:

  • Clinical NLI tasks
  • Medical text understanding
  • Research in clinical NLP

Limitations

  • Trained only on MedNLI; limited generalization to other clinical domains (e.g., radiology)
  • Performance may vary on text significantly different from clinical notes
  • Not validated for clinical decision-making

Citation

If you use this model, please cite MedNLI:

@article{romanov2018lessons,
  title={Lessons from natural language inference in the clinical domain},
  author={Romanov, Alexey and Shivade, Chaitanya},
  journal={arXiv preprint arXiv:1808.06752},
  year={2018}
}

And ModernBERT:

@article{modernbert,
  title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference},
  author={Warner, Benjamin and Misra, Arka and Diab, Mona and Lotfi, Hani and others},
  year={2024}
}