|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import json |
|
|
import re |
|
|
import numpy as np |
|
|
|
|
|
class BioClinicalMedicalCoder(nn.Module): |
|
|
"""BioClinicalModernBERT for Medical Coding""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext") |
|
|
self.embedding_dim = config.get("embedding_dim", 768) |
|
|
self.num_codes = config["num_codes"] |
|
|
|
|
|
|
|
|
self.projection = nn.Linear(self.encoder.config.hidden_size, self.embedding_dim) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(self.embedding_dim, self.num_codes) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, return_embeddings=False): |
|
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
cls_output = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
embeddings = self.projection(self.dropout(cls_output)) |
|
|
|
|
|
if return_embeddings: |
|
|
return embeddings |
|
|
|
|
|
|
|
|
logits = self.classifier(embeddings) |
|
|
return embeddings, logits |
|
|
|
|
|
class MedicalCodingPredictor: |
|
|
"""Inference wrapper for medical coding""" |
|
|
|
|
|
def __init__(self, model_path="."): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
with open(f"{model_path}/config.json", 'r') as f: |
|
|
self.config = json.load(f) |
|
|
|
|
|
|
|
|
with open(f"{model_path}/code_to_idx.json", 'r') as f: |
|
|
self.code_to_idx = json.load(f) |
|
|
with open(f"{model_path}/idx_to_code.json", 'r') as f: |
|
|
self.idx_to_code = {int(k): v for k, v in json.load(f).items()} |
|
|
|
|
|
|
|
|
with open(f"{model_path}/code_descriptions.json", 'r') as f: |
|
|
self.descriptions = json.load(f) |
|
|
|
|
|
|
|
|
try: |
|
|
with open(f"{model_path}/code_f1_scores.json", 'r') as f: |
|
|
self.f1_scores = json.load(f) |
|
|
except: |
|
|
self.f1_scores = {} |
|
|
|
|
|
|
|
|
self.model = BioClinicalMedicalCoder(self.config) |
|
|
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device) |
|
|
self.model.load_state_dict(state_dict) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext") |
|
|
|
|
|
def predict(self, clinical_note, threshold=0.5, max_codes=20): |
|
|
""" |
|
|
Predict medical codes from clinical note |
|
|
|
|
|
Args: |
|
|
clinical_note (str): Clinical note text |
|
|
threshold (float): Confidence threshold for predictions |
|
|
max_codes (int): Maximum number of codes to return |
|
|
|
|
|
Returns: |
|
|
List of dictionaries with code, type, description, confidence, f1_score |
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
clinical_note, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"]) |
|
|
probabilities = torch.sigmoid(logits).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for idx, prob in enumerate(probabilities): |
|
|
if prob > threshold: |
|
|
code = self.idx_to_code[idx] |
|
|
code_info = self.descriptions.get(code, {}) |
|
|
|
|
|
predictions.append({ |
|
|
"code": code, |
|
|
"type": code_info.get("type", "Unknown"), |
|
|
"description": code_info.get("description", f"Medical code {code}"), |
|
|
"confidence": float(prob), |
|
|
"f1_score": self.f1_scores.get(code, 0.0) |
|
|
}) |
|
|
|
|
|
|
|
|
predictions.sort(key=lambda x: x["confidence"], reverse=True) |
|
|
return predictions[:max_codes] |
|
|
|
|
|
def predict_top_k(self, clinical_note, k=10): |
|
|
"""Get top k predictions regardless of threshold""" |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
clinical_note, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"]) |
|
|
probabilities = torch.sigmoid(logits).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
top_indices = np.argsort(probabilities)[-k:][::-1] |
|
|
|
|
|
predictions = [] |
|
|
for idx in top_indices: |
|
|
code = self.idx_to_code[idx] |
|
|
prob = probabilities[idx] |
|
|
code_info = self.descriptions.get(code, {}) |
|
|
|
|
|
predictions.append({ |
|
|
"code": code, |
|
|
"type": code_info.get("type", "Unknown"), |
|
|
"description": code_info.get("description", f"Medical code {code}"), |
|
|
"confidence": float(prob), |
|
|
"f1_score": self.f1_scores.get(code, 0.0) |
|
|
}) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
predictor = MedicalCodingPredictor() |
|
|
|
|
|
|
|
|
clinical_note = """ |
|
|
Patient presents with chest pain and shortness of breath. |
|
|
ECG shows ST elevation in leads II, III, aVF suggesting inferior wall MI. |
|
|
Cardiac enzymes elevated. Started on dual antiplatelet therapy. |
|
|
""" |
|
|
|
|
|
|
|
|
predictions = predictor.predict(clinical_note, threshold=0.5) |
|
|
|
|
|
for pred in predictions: |
|
|
print(f"Code: {pred['code']}") |
|
|
print(f"Type: {pred['type']}") |
|
|
print(f"Description: {pred['description']}") |
|
|
print(f"Confidence: {pred['confidence']:.3f}") |
|
|
print(f"F1 Score: {pred['f1_score']:.3f}") |
|
|
print("-" * 50) |
|
|
|