CRAG: Causal Reasoning for Adversomics Graphs
Collection
SOTA dual-encoder models for drug-ADR relation extraction.
•
3 items
•
Updated
•
1
CRAG: Causal Reasoning for Adversomics Graphs
This is the enhanced ADE-trained model in the CRAG dual-encoder family. It incorporates multiple architectural and training improvements over the base model, achieving 97.5% F1 and 99.1% AUC on drug-ADR relation extraction.
CRAG-dual-encoder-ade builds upon the base architecture with several key improvements:
┌─────────────────────────────────────────────────────────────────┐
│ CRAG Dual-Encoder ADE │
├─────────────────────────────────────────────────────────────────┤
│ │
│ "[DRUG] aspirin [/DRUG] "[ADR] bleeding [/ADR] │
│ caused bleeding..." from aspirin..." │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ BioLinkBERT │ │ BioLinkBERT │ (separate) │
│ │ Drug │ │ ADR │ │
│ │ Encoder │ │ Encoder │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Attention │ │ Attention │ │
│ │ Pooling │ │ Pooling │ │
│ │ (4 heads) │ │ (4 heads) │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Projection │ │ Projection │ │
│ │ 768→256 │ │ 768→256 │ │
│ │ +LayerNorm │ │ +LayerNorm │ │
│ │ +GELU+Drop │ │ +GELU+Drop │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ └───────────┬────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Bilinear │ │
│ │ Fusion │ │
│ │ (256×256) │ │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Classifier │ │
│ │ 512→256→128→1│ │
│ │ +LayerNorm │ │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ P(causal) │
└─────────────────────────────────────────────────────────────────┘
| Feature | Base | ADE (this model) |
|---|---|---|
| Base Encoder | PubMedBERT | BioLinkBERT |
| Pooling | [CLS] token | Multi-head Attention |
| Entity Marking | None | [DRUG]/[ADR] tokens |
| Negative Sampling | Random | 50% Hard negatives |
| Loss Function | BCE | Focal Loss (γ=2.0) |
| LR Schedule | Linear warmup | Cosine + layer-wise decay |
| Gradient Accumulation | None | 4 steps (effective batch=64) |
michiyasunaga/BioLinkBERT-base[DRUG], [/DRUG], [ADR], [/ADR]CONFIG = {
"temperature": 0.05, # Sharper similarity distribution
"hard_negative_ratio": 0.5, # 50% hard negatives
"batch_size": 16,
"gradient_accumulation_steps": 4, # Effective batch = 64
}
Hard Negative Mining Strategy:
CONFIG = {
"learning_rate": 2e-5,
"warmup_ratio": 0.1,
"layerwise_lr_decay": 0.9, # Lower layers get 0.9× LR
"focal_gamma": 2.0, # Focus on hard examples
"focal_alpha": 0.75, # Positive class weight
"weight_decay": 0.01,
"max_grad_norm": 1.0,
}
Focal Loss:
Where γ=2.0 down-weights easy examples, focusing learning on hard cases.
Ade_corpus_v2_drug_ade_relation"[DRUG] aspirin [/DRUG] caused [ADR] bleeding [/ADR]"| Metric | Value |
|---|---|
| F1 Score | 97.5% |
| ROC-AUC | 99.1% |
| Optimal Threshold | 0.55 |
| Phase | Epochs | Final Loss | Final Metric |
|---|---|---|---|
| Contrastive | 5 | 0.021 | - |
| Classification | 8 | 0.008 | F1: 97.5% |
| Model | F1 | AUC | Improvement |
|---|---|---|---|
| CRAG-dual-encoder-base | 88.3% | - | Baseline |
| CRAG-dual-encoder-ade | 97.5% | 99.1% | +9.2% F1 |
| CRAG-dual-encoder-mimicause | 98.9% | 99.8% | +10.6% F1 |
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
# Define the model architecture
class AttentionPooling(nn.Module):
def __init__(self, hidden_dim, num_heads=4):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
def forward(self, hidden_states, attention_mask):
batch_size = hidden_states.size(0)
query = self.query.expand(batch_size, -1, -1)
key_padding_mask = ~attention_mask.bool()
pooled, _ = self.attention(query, hidden_states, hidden_states,
key_padding_mask=key_padding_mask)
return pooled.squeeze(1)
class DualEncoderADE(nn.Module):
def __init__(self, model_name="michiyasunaga/BioLinkBERT-base"):
super().__init__()
self.drug_encoder = AutoModel.from_pretrained(model_name)
self.adr_encoder = AutoModel.from_pretrained(model_name)
self.drug_pooler = AttentionPooling(768)
self.adr_pooler = AttentionPooling(768)
# ... (see full architecture in training script)
# Load tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained("chrisvoncsefalvay/CRAG-dual-encoder-ade")
# Load model weights
model = DualEncoderADE()
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
def score_drug_adr_pair(model, tokenizer, drug_text, adr_text, drug_entity, adr_entity):
"""Score a drug-ADR pair for causal relationship."""
# Add entity markers
drug_context = drug_text.replace(
drug_entity,
f"[DRUG] {drug_entity} [/DRUG]"
)
adr_context = adr_text.replace(
adr_entity,
f"[ADR] {adr_entity} [/ADR]"
)
# Tokenize
drug_inputs = tokenizer(
drug_context,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
adr_inputs = tokenizer(
adr_context,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
# Get prediction
with torch.no_grad():
drug_repr = model.encode_drug(**drug_inputs)
adr_repr = model.encode_adr(**adr_inputs)
logit = model.classify(drug_repr, adr_repr)
prob = torch.sigmoid(logit).item()
return prob
# Example usage
prob = score_drug_adr_pair(
model, tokenizer,
drug_text="Patient was started on metformin 500mg twice daily.",
adr_text="She developed lactic acidosis requiring ICU admission.",
drug_entity="metformin",
adr_entity="lactic acidosis"
)
print(f"Causal probability: {prob:.3f}")
# Output: Causal probability: 0.923
[DRUG]/[ADR] for optimal performance| Specification | Value |
|---|---|
| Framework | PyTorch |
| Base Model | BioLinkBERT-base |
| Model Size | ~955 MB |
| Vocabulary Size | 30,522 + 4 special |
| Max Sequence Length | 128 tokens |
| Inference Speed | ~100 pairs/sec (GPU) |
@misc{crag-dual-encoder-ade-2024,
title={CRAG: Causal Reasoning for Adversomics Graphs - Enhanced Dual-Encoder with Hard Negative Mining},
author={von Csefalvay, Chris},
year={2024},
publisher={Hugging Face},
url={https://huggingface.co/chrisvoncsefalvay/CRAG-dual-encoder-ade}
}
Chris von Csefalvay (@chrisvoncsefalvay)
For questions or issues, please open a discussion on this model's repository or contact chris@chrisvoncsefalvay.com.