File size: 6,490 Bytes
8ba82af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import json
import re
import numpy as np
class BioClinicalMedicalCoder(nn.Module):
"""BioClinicalModernBERT for Medical Coding"""
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
self.embedding_dim = config.get("embedding_dim", 768)
self.num_codes = config["num_codes"]
# Projection layer for retrieval
self.projection = nn.Linear(self.encoder.config.hidden_size, self.embedding_dim)
self.dropout = nn.Dropout(0.1)
# Classification head
self.classifier = nn.Linear(self.embedding_dim, self.num_codes)
def forward(self, input_ids, attention_mask, return_embeddings=False):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Use CLS token
cls_output = outputs.last_hidden_state[:, 0, :]
# Project to embedding space
embeddings = self.projection(self.dropout(cls_output))
if return_embeddings:
return embeddings
# Classification logits
logits = self.classifier(embeddings)
return embeddings, logits
class MedicalCodingPredictor:
"""Inference wrapper for medical coding"""
def __init__(self, model_path="."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load configuration
with open(f"{model_path}/config.json", 'r') as f:
self.config = json.load(f)
# Load mappings
with open(f"{model_path}/code_to_idx.json", 'r') as f:
self.code_to_idx = json.load(f)
with open(f"{model_path}/idx_to_code.json", 'r') as f:
self.idx_to_code = {int(k): v for k, v in json.load(f).items()}
# Load descriptions
with open(f"{model_path}/code_descriptions.json", 'r') as f:
self.descriptions = json.load(f)
# Load F1 scores
try:
with open(f"{model_path}/code_f1_scores.json", 'r') as f:
self.f1_scores = json.load(f)
except:
self.f1_scores = {}
# Initialize model
self.model = BioClinicalMedicalCoder(self.config)
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
def predict(self, clinical_note, threshold=0.5, max_codes=20):
"""
Predict medical codes from clinical note
Args:
clinical_note (str): Clinical note text
threshold (float): Confidence threshold for predictions
max_codes (int): Maximum number of codes to return
Returns:
List of dictionaries with code, type, description, confidence, f1_score
"""
# Tokenize input
inputs = self.tokenizer(
clinical_note,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Get predictions
with torch.no_grad():
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get predictions above threshold
predictions = []
for idx, prob in enumerate(probabilities):
if prob > threshold:
code = self.idx_to_code[idx]
code_info = self.descriptions.get(code, {})
predictions.append({
"code": code,
"type": code_info.get("type", "Unknown"),
"description": code_info.get("description", f"Medical code {code}"),
"confidence": float(prob),
"f1_score": self.f1_scores.get(code, 0.0)
})
# Sort by confidence and limit results
predictions.sort(key=lambda x: x["confidence"], reverse=True)
return predictions[:max_codes]
def predict_top_k(self, clinical_note, k=10):
"""Get top k predictions regardless of threshold"""
# Tokenize input
inputs = self.tokenizer(
clinical_note,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Get predictions
with torch.no_grad():
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get top k predictions
top_indices = np.argsort(probabilities)[-k:][::-1]
predictions = []
for idx in top_indices:
code = self.idx_to_code[idx]
prob = probabilities[idx]
code_info = self.descriptions.get(code, {})
predictions.append({
"code": code,
"type": code_info.get("type", "Unknown"),
"description": code_info.get("description", f"Medical code {code}"),
"confidence": float(prob),
"f1_score": self.f1_scores.get(code, 0.0)
})
return predictions
# Example usage
if __name__ == "__main__":
predictor = MedicalCodingPredictor()
# Example clinical note
clinical_note = """
Patient presents with chest pain and shortness of breath.
ECG shows ST elevation in leads II, III, aVF suggesting inferior wall MI.
Cardiac enzymes elevated. Started on dual antiplatelet therapy.
"""
# Get predictions
predictions = predictor.predict(clinical_note, threshold=0.5)
for pred in predictions:
print(f"Code: {pred['code']}")
print(f"Type: {pred['type']}")
print(f"Description: {pred['description']}")
print(f"Confidence: {pred['confidence']:.3f}")
print(f"F1 Score: {pred['f1_score']:.3f}")
print("-" * 50)
|