CRAG-dual-encoder-ade

CRAG: Causal Reasoning for Adversomics Graphs

This is the enhanced ADE-trained model in the CRAG dual-encoder family. It incorporates multiple architectural and training improvements over the base model, achieving 97.5% F1 and 99.1% AUC on drug-ADR relation extraction.

Model Description

CRAG-dual-encoder-ade builds upon the base architecture with several key improvements:

  1. BioLinkBERT backbone (pre-trained with link prediction for better relation understanding)
  2. Entity markers ([DRUG]...[/DRUG], [ADR]...[/ADR]) for explicit entity boundary signaling
  3. Hard negative mining (semantically similar but unrelated pairs)
  4. Focal loss for handling class imbalance
  5. Attention pooling instead of [CLS] token
  6. Layer-wise learning rate decay for stable fine-tuning

Architecture

┌─────────────────────────────────────────────────────────────────┐
│                    CRAG Dual-Encoder ADE                        │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   "[DRUG] aspirin [/DRUG]       "[ADR] bleeding [/ADR]          │
│    caused bleeding..."           from aspirin..."               │
│           │                            │                        │
│           ▼                            ▼                        │
│   ┌─────────────┐              ┌─────────────┐                  │
│   │ BioLinkBERT │              │ BioLinkBERT │  (separate)      │
│   │   Drug      │              │    ADR      │                  │
│   │  Encoder    │              │   Encoder   │                  │
│   └──────┬──────┘              └──────┬──────┘                  │
│          │                            │                         │
│          ▼                            ▼                         │
│   ┌─────────────┐              ┌─────────────┐                  │
│   │  Attention  │              │  Attention  │                  │
│   │   Pooling   │              │   Pooling   │                  │
│   │  (4 heads)  │              │  (4 heads)  │                  │
│   └──────┬──────┘              └──────┬──────┘                  │
│          │                            │                         │
│          ▼                            ▼                         │
│   ┌─────────────┐              ┌─────────────┐                  │
│   │ Projection  │              │ Projection  │                  │
│   │  768→256    │              │  768→256    │                  │
│   │ +LayerNorm  │              │ +LayerNorm  │                  │
│   │ +GELU+Drop  │              │ +GELU+Drop  │                  │
│   └──────┬──────┘              └──────┬──────┘                  │
│          │                            │                         │
│          └───────────┬────────────────┘                         │
│                      │                                          │
│                      ▼                                          │
│              ┌──────────────┐                                   │
│              │   Bilinear   │                                   │
│              │   Fusion     │                                   │
│              │  (256×256)   │                                   │
│              └──────┬───────┘                                   │
│                     │                                           │
│                     ▼                                           │
│              ┌──────────────┐                                   │
│              │  Classifier  │                                   │
│              │ 512→256→128→1│                                   │
│              │ +LayerNorm   │                                   │
│              └──────┬───────┘                                   │
│                     │                                           │
│                     ▼                                           │
│                 P(causal)                                       │
└─────────────────────────────────────────────────────────────────┘

Key Improvements Over Base Model

Feature Base ADE (this model)
Base Encoder PubMedBERT BioLinkBERT
Pooling [CLS] token Multi-head Attention
Entity Marking None [DRUG]/[ADR] tokens
Negative Sampling Random 50% Hard negatives
Loss Function BCE Focal Loss (γ=2.0)
LR Schedule Linear warmup Cosine + layer-wise decay
Gradient Accumulation None 4 steps (effective batch=64)

Model Specifications

  • Base Model: michiyasunaga/BioLinkBERT-base
  • Hidden Dimension: 768
  • Fusion Dimension: 256
  • Attention Heads (pooling): 4
  • Total Parameters: ~238M
  • Special Tokens: [DRUG], [/DRUG], [ADR], [/ADR]

Training Procedure

Phase 1: Contrastive Pre-training (5 epochs)

CONFIG = {
    "temperature": 0.05,           # Sharper similarity distribution
    "hard_negative_ratio": 0.5,    # 50% hard negatives
    "batch_size": 16,
    "gradient_accumulation_steps": 4,  # Effective batch = 64
}

Hard Negative Mining Strategy:

  • Same drug, different ADR (tests ADR discrimination)
  • Same ADR, different drug (tests drug discrimination)
  • Semantically similar but unrelated pairs

Phase 2: Classification Fine-tuning (8 epochs)

CONFIG = {
    "learning_rate": 2e-5,
    "warmup_ratio": 0.1,
    "layerwise_lr_decay": 0.9,     # Lower layers get 0.9× LR
    "focal_gamma": 2.0,            # Focus on hard examples
    "focal_alpha": 0.75,           # Positive class weight
    "weight_decay": 0.01,
    "max_grad_norm": 1.0,
}

Focal Loss: FL(pt)=αt(1pt)γlog(pt)FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)

Where γ=2.0 down-weights easy examples, focusing learning on hard cases.

Training Data

  • Dataset: ADE Corpus V2
  • Configuration: Ade_corpus_v2_drug_ade_relation
  • Training Examples: 13,642 (balanced positive/negative with hard mining)
  • Validation Examples: 2,047
  • Entity Marker Format: "[DRUG] aspirin [/DRUG] caused [ADR] bleeding [/ADR]"

Performance

Metrics

Metric Value
F1 Score 97.5%
ROC-AUC 99.1%
Optimal Threshold 0.55

Training Curves

Phase Epochs Final Loss Final Metric
Contrastive 5 0.021 -
Classification 8 0.008 F1: 97.5%

Comparison with CRAG Family

Model F1 AUC Improvement
CRAG-dual-encoder-base 88.3% - Baseline
CRAG-dual-encoder-ade 97.5% 99.1% +9.2% F1
CRAG-dual-encoder-mimicause 98.9% 99.8% +10.6% F1

Usage

Loading the Model

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

# Define the model architecture
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim, num_heads=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))

    def forward(self, hidden_states, attention_mask):
        batch_size = hidden_states.size(0)
        query = self.query.expand(batch_size, -1, -1)
        key_padding_mask = ~attention_mask.bool()
        pooled, _ = self.attention(query, hidden_states, hidden_states,
                                   key_padding_mask=key_padding_mask)
        return pooled.squeeze(1)

class DualEncoderADE(nn.Module):
    def __init__(self, model_name="michiyasunaga/BioLinkBERT-base"):
        super().__init__()
        self.drug_encoder = AutoModel.from_pretrained(model_name)
        self.adr_encoder = AutoModel.from_pretrained(model_name)
        self.drug_pooler = AttentionPooling(768)
        self.adr_pooler = AttentionPooling(768)
        # ... (see full architecture in training script)

# Load tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained("chrisvoncsefalvay/CRAG-dual-encoder-ade")

# Load model weights
model = DualEncoderADE()
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

Inference Example

def score_drug_adr_pair(model, tokenizer, drug_text, adr_text, drug_entity, adr_entity):
    """Score a drug-ADR pair for causal relationship."""

    # Add entity markers
    drug_context = drug_text.replace(
        drug_entity,
        f"[DRUG] {drug_entity} [/DRUG]"
    )
    adr_context = adr_text.replace(
        adr_entity,
        f"[ADR] {adr_entity} [/ADR]"
    )

    # Tokenize
    drug_inputs = tokenizer(
        drug_context,
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )
    adr_inputs = tokenizer(
        adr_context,
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )

    # Get prediction
    with torch.no_grad():
        drug_repr = model.encode_drug(**drug_inputs)
        adr_repr = model.encode_adr(**adr_inputs)
        logit = model.classify(drug_repr, adr_repr)
        prob = torch.sigmoid(logit).item()

    return prob

# Example usage
prob = score_drug_adr_pair(
    model, tokenizer,
    drug_text="Patient was started on metformin 500mg twice daily.",
    adr_text="She developed lactic acidosis requiring ICU admission.",
    drug_entity="metformin",
    adr_entity="lactic acidosis"
)
print(f"Causal probability: {prob:.3f}")
# Output: Causal probability: 0.923

Intended Uses

Primary Use Cases

  • Pharmacovigilance Systems: Automated ADR detection in medical literature
  • Drug Safety Databases: Populating causal knowledge graphs
  • Clinical Trial Analysis: Mining safety signals from trial reports
  • Regulatory Submission Review: Screening documents for ADR mentions
  • Post-Market Surveillance: Monitoring real-world drug safety

Best Practices

  1. Use entity markers [DRUG]/[ADR] for optimal performance
  2. Apply threshold of 0.55 for balanced precision/recall
  3. Validate high-confidence predictions with domain experts
  4. Consider ensemble with CRAG-dual-encoder-mimicause for critical applications

Limitations

  1. English Only: Trained on English biomedical text
  2. Explicit Mentions Required: Both drug and ADR must appear in the text
  3. Binary Classification: Does not distinguish causation types (e.g., dose-dependent)
  4. Training Data Bias: Reflects the drug/ADR distribution in ADE Corpus V2
  5. Context Window: Maximum 128 tokens per input

Ethical Considerations

  • Not for Direct Clinical Use: Predictions require expert validation
  • Bias in Coverage: Common drugs/ADRs better represented than rare ones
  • Automation Risks: Over-reliance may miss nuanced relationships
  • Transparency: Model confidence should be communicated to end users

Technical Specifications

Specification Value
Framework PyTorch
Base Model BioLinkBERT-base
Model Size ~955 MB
Vocabulary Size 30,522 + 4 special
Max Sequence Length 128 tokens
Inference Speed ~100 pairs/sec (GPU)

Citation

@misc{crag-dual-encoder-ade-2024,
  title={CRAG: Causal Reasoning for Adversomics Graphs - Enhanced Dual-Encoder with Hard Negative Mining},
  author={von Csefalvay, Chris},
  year={2024},
  publisher={Hugging Face},
  url={https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-ade}
}

References

  • Gurulingappa, H., et al. (2012). Development of a benchmark corpus to support the automatic extraction of drug-related adverse effects from medical case reports. Journal of Biomedical Informatics.
  • Yasunaga, M., et al. (2022). LinkBERT: Pretraining Language Models with Document Links. ACL.
  • Lin, T.Y., et al. (2017). Focal Loss for Dense Object Detection. ICCV.

Model Card Authors

Chris von Csefalvay (@chrisvoncsefalvay)

Model Card Contact

For questions or issues, please open a discussion on this model's repository or contact chris@chrisvoncsefalvay.com.

Downloads last month
38
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train chrisvoncsefalvay/CRAG-dual-encoder-ade

Collection including chrisvoncsefalvay/CRAG-dual-encoder-ade

Evaluation results