Tasfiya025's picture
Create README.md
0d7acac verified

Robust-Clinical-Risk-Detector

Overview

Robust-Clinical-Risk-Detector is a highly robust sequence classification model, based on DistilBERT, specifically designed for processing electronic health record (EHR) text (clinical notes) and classifying patient risk (HIGH or LOW). Crucially, this model is trained using Adversarial Training Techniques and a Focal Loss function to achieve high resilience against minor but impactful perturbations (e.g., synonym swapping, negation flipping, or numerical errors) that could lead to clinical misclassification.

The model is a three-class classifier: it predicts RISK: LOW, RISK: HIGH, or flags the input as an ATTACK: DETECTED, indicating low confidence due to input inconsistency.

Model Architecture

  • Base Model: DistilBERT
  • Task: Sequence Classification (DistilBertForSequenceClassification)
  • Output Classes: 3 classes: RISK: LOW, RISK: HIGH, and ATTACK: DETECTED.
  • Robustness: Trained on a mix of clean clinical text and synthetically generated adversarial examples targeting critical medical terms and numerical values.
  • Loss Function: Focal Loss was used during training to place greater emphasis on misclassified hard examples (the subtle adversarial inputs).
  • Domain: Clinical Notes, Patient Discharge Summaries, Initial Assessments.

Intended Use

  • Critical Risk Classification: Accurately determining patient risk (e.g., readmission, complication) from free-text clinical notes.
  • Adversarial Defense: Serving as a primary defense layer against data poisoning or integrity attacks on automated clinical decision support systems.
  • Input Validation: Flagging notes that have been subtly altered or contain contradictory language (ATTACK: DETECTED) for human review.
  • Benchmarking: Evaluating the robustness of different adversarial attack methods in the medical domain.

Limitations

  • ATTACK: DETECTED Ambiguity: The "DETECTED" flag is an indicator of low internal confidence and potential perturbation, not a definitive proof of malicious intent. It still requires human review to confirm the nature of the text.
  • Vocabulary: While trained on a clinical corpus, it may struggle with highly specialized jargon or acronyms not covered in the vocabulary.
  • Numerical Parsing: It is robust to simple numerical errors (e.g., $15.5 \rightarrow 5.5$), but complex numerical reasoning remains a challenge.

Example Code (PyTorch)

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F

model_name = "Health/Robust-Clinical-Risk-Detector"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Example 1: Clean, Low Risk Text
text_clean = "History of controlled Type 2 diabetes. No acute issues today. Routine check-up. Risk: LOW."
# Example 2: Adversarially Perturbed Text (Semantic Shift on 'HIGH' -> 'LOW')
text_attack = "Severe headache, photophobia, and neck stiffness. Suspected Meningitis. Risk: LOW."

# Prepare inputs
inputs = tokenizer([text_clean, text_attack], return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)

# Get predictions and probabilities
probabilities = F.softmax(outputs.logits, dim=1)
predictions = torch.argmax(probabilities, dim=1)

labels = model.config.id2label
print(f"Text 1 (Clean): Prediction: {labels[predictions[0].item()]}, Confidence: {probabilities[0].max().item():.3f}")
# Expected output (approx): Prediction: RISK: LOW, Confidence: 0.985
print(f"Text 2 (Attack): Prediction: {labels[predictions[1].item()]}, Confidence: {probabilities[1].max().item():.3f}")
# Expected output (approx): Prediction: ATTACK: DETECTED, Confidence: 0.550 (or high confidence in RISK: HIGH if the model successfully ignores the prompt)