| | --- |
| | license: apache-2.0 |
| | language: |
| | - en |
| | library_name: transformers |
| | tags: |
| | - medical |
| | - clinical |
| | - nli |
| | - natural-language-inference |
| | - modernbert |
| | datasets: |
| | - mednli |
| | base_model: answerdotai/ModernBERT-base |
| | metrics: |
| | - accuracy |
| | - f1 |
| | pipeline_tag: text-classification |
| | --- |
| | |
| | # ModernBERT-base fine-tuned on MedNLI |
| |
|
| | This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the MedNLI dataset for clinical natural language inference. |
| |
|
| | ## Model Description |
| |
|
| | - **Base model:** ModernBERT-base (149M parameters) |
| | - **Task:** Natural Language Inference (3-class: entailment, neutral, contradiction) |
| | - **Domain:** Clinical/Medical text |
| | - **Training data:** [MedNLI](https://physionet.org/content/mednli/1.0.0/) (11,232 training examples) |
| |
|
| | ## Performance |
| |
|
| | | Dataset | Accuracy | Macro F1 | Macro Precision | Macro Recall | |
| | |---------|----------|----------|-----------------|--------------| |
| | | **MedNLI test** | 81.4% | 81.4% | 81.5% | 81.4% | |
| | | RadNLI test (zero-shot) | 52.1% | 54.9% | 63.1% | 66.9% | |
| |
|
| | ### Comparison with base model (zero-shot) |
| |
|
| | | Model | MedNLI Acc | MedNLI F1 | RadNLI Acc | RadNLI F1 | |
| | |-------|------------|-----------|------------|-----------| |
| | | ModernBERT-base (zero-shot) | 41.2% | 38.7% | 48.5% | 47.0% | |
| | | **This model (fine-tuned)** | **81.4%** | **81.4%** | 52.1% | 54.9% | |
| |
|
| | ## Training Details |
| |
|
| | ### Hyperparameters |
| |
|
| | - **Learning rate:** 8e-5 |
| | - **Epochs:** 2 |
| | - **Batch size:** 32 |
| | - **Weight decay:** 8e-6 |
| | - **LR scheduler:** linear |
| | - **Optimizer:** AdamW (beta1=0.9, beta2=0.98, epsilon=1e-6) |
| | - **Max sequence length:** 256 |
| |
|
| | ### Hardware |
| |
|
| | - Single GPU training |
| |
|
| | ## Usage |
| |
|
| | ### Natural Language Inference |
| |
|
| | ```python |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli") |
| | tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli") |
| | |
| | premise = "The patient was diagnosed with pneumonia." |
| | hypothesis = "The patient has a respiratory infection." |
| | |
| | inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=256) |
| | outputs = model(**inputs) |
| | |
| | labels = ["entailment", "neutral", "contradiction"] |
| | prediction = labels[outputs.logits.argmax().item()] |
| | print(prediction) # entailment |
| | ``` |
| |
|
| | ### Zero-Shot Text Classification |
| |
|
| | Use entailment scores to classify clinical text without task-specific training: |
| |
|
| | ```python |
| | import torch |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli") |
| | tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli") |
| | |
| | ENTAILMENT_IDX = 0 # entailment is the first class |
| | |
| | def classify(text: str, candidate_labels: list[str]) -> dict[str, float]: |
| | """Zero-shot classification using NLI entailment scores.""" |
| | scores = {} |
| | for label in candidate_labels: |
| | hypothesis = f"This patient has {label}." |
| | inputs = tokenizer(text, hypothesis, return_tensors="pt", truncation=True, max_length=256) |
| | |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | probs = torch.softmax(logits, dim=-1) |
| | scores[label] = probs[0, ENTAILMENT_IDX].item() |
| | |
| | # Normalize scores |
| | total = sum(scores.values()) |
| | return {k: v / total for k, v in scores.items()} |
| | |
| | # Example: Binary classification for DVT/PE |
| | clinical_note = "Patient presents with acute shortness of breath and pleuritic chest pain. CT angiography reveals filling defects in the pulmonary arteries." |
| | labels = ["Deep Venous Thromboembolism or Pulmonary Embolism", "No Deep Venous Thromboembolism or Pulmonary Embolism"] |
| | |
| | result = classify(clinical_note, labels) |
| | print(result) |
| | # {'Deep Venous Thromboembolism or Pulmonary Embolism': 0.92, 'No Deep Venous Thromboembolism or Pulmonary Embolism': 0.08} |
| | ``` |
| |
|
| | ### Batch Classification |
| |
|
| | For processing multiple texts efficiently: |
| |
|
| | ```python |
| | import torch |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained("ygivenx/modernbert-base-mednli") |
| | tokenizer = AutoTokenizer.from_pretrained("ygivenx/modernbert-base-mednli") |
| | |
| | ENTAILMENT_IDX = 0 |
| | |
| | def classify_batch(texts: list[str], candidate_labels: list[str]) -> list[dict[str, float]]: |
| | """Batch zero-shot classification.""" |
| | # Build all premise-hypothesis pairs |
| | premises, hypotheses = [], [] |
| | for text in texts: |
| | for label in candidate_labels: |
| | premises.append(text) |
| | hypotheses.append(f"This patient has {label}.") |
| | |
| | inputs = tokenizer(premises, hypotheses, return_tensors="pt", truncation=True, padding=True, max_length=256) |
| | |
| | with torch.no_grad(): |
| | logits = model(**inputs).logits |
| | probs = torch.softmax(logits, dim=-1) |
| | entailment_scores = probs[:, ENTAILMENT_IDX].tolist() |
| | |
| | # Reshape and normalize |
| | results = [] |
| | n_labels = len(candidate_labels) |
| | for i in range(len(texts)): |
| | scores = dict(zip(candidate_labels, entailment_scores[i * n_labels:(i + 1) * n_labels])) |
| | total = sum(scores.values()) |
| | results.append({k: v / total for k, v in scores.items()}) |
| | |
| | return results |
| | |
| | # Example |
| | notes = [ |
| | "Patient with swelling in left leg and positive D-dimer.", |
| | "Routine checkup, no complaints, vitals normal." |
| | ] |
| | labels = ["DVT", "No DVT"] |
| | |
| | results = classify_batch(notes, labels) |
| | for note, result in zip(notes, results): |
| | print(f"{result}") |
| | ``` |
| |
|
| | ## Intended Use |
| |
|
| | This model is intended for: |
| | - Clinical NLI tasks |
| | - Medical text understanding |
| | - Research in clinical NLP |
| |
|
| | ## Limitations |
| |
|
| | - Trained only on MedNLI; limited generalization to other clinical domains (e.g., radiology) |
| | - Performance may vary on text significantly different from clinical notes |
| | - Not validated for clinical decision-making |
| |
|
| | ## Citation |
| |
|
| | If you use this model, please cite MedNLI: |
| |
|
| | ```bibtex |
| | @article{romanov2018lessons, |
| | title={Lessons from natural language inference in the clinical domain}, |
| | author={Romanov, Alexey and Shivade, Chaitanya}, |
| | journal={arXiv preprint arXiv:1808.06752}, |
| | year={2018} |
| | } |
| | ``` |
| |
|
| | And ModernBERT: |
| |
|
| | ```bibtex |
| | @article{modernbert, |
| | title={Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference}, |
| | author={Warner, Benjamin and Misra, Arka and Diab, Mona and Lotfi, Hani and others}, |
| | year={2024} |
| | } |
| | ``` |
| |
|