sshan95's picture
Upload folder using huggingface_hub
8ba82af verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import json
import re
import numpy as np
class BioClinicalMedicalCoder(nn.Module):
"""BioClinicalModernBERT for Medical Coding"""
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
self.embedding_dim = config.get("embedding_dim", 768)
self.num_codes = config["num_codes"]
# Projection layer for retrieval
self.projection = nn.Linear(self.encoder.config.hidden_size, self.embedding_dim)
self.dropout = nn.Dropout(0.1)
# Classification head
self.classifier = nn.Linear(self.embedding_dim, self.num_codes)
def forward(self, input_ids, attention_mask, return_embeddings=False):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Use CLS token
cls_output = outputs.last_hidden_state[:, 0, :]
# Project to embedding space
embeddings = self.projection(self.dropout(cls_output))
if return_embeddings:
return embeddings
# Classification logits
logits = self.classifier(embeddings)
return embeddings, logits
class MedicalCodingPredictor:
"""Inference wrapper for medical coding"""
def __init__(self, model_path="."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load configuration
with open(f"{model_path}/config.json", 'r') as f:
self.config = json.load(f)
# Load mappings
with open(f"{model_path}/code_to_idx.json", 'r') as f:
self.code_to_idx = json.load(f)
with open(f"{model_path}/idx_to_code.json", 'r') as f:
self.idx_to_code = {int(k): v for k, v in json.load(f).items()}
# Load descriptions
with open(f"{model_path}/code_descriptions.json", 'r') as f:
self.descriptions = json.load(f)
# Load F1 scores
try:
with open(f"{model_path}/code_f1_scores.json", 'r') as f:
self.f1_scores = json.load(f)
except:
self.f1_scores = {}
# Initialize model
self.model = BioClinicalMedicalCoder(self.config)
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
def predict(self, clinical_note, threshold=0.5, max_codes=20):
"""
Predict medical codes from clinical note
Args:
clinical_note (str): Clinical note text
threshold (float): Confidence threshold for predictions
max_codes (int): Maximum number of codes to return
Returns:
List of dictionaries with code, type, description, confidence, f1_score
"""
# Tokenize input
inputs = self.tokenizer(
clinical_note,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Get predictions
with torch.no_grad():
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get predictions above threshold
predictions = []
for idx, prob in enumerate(probabilities):
if prob > threshold:
code = self.idx_to_code[idx]
code_info = self.descriptions.get(code, {})
predictions.append({
"code": code,
"type": code_info.get("type", "Unknown"),
"description": code_info.get("description", f"Medical code {code}"),
"confidence": float(prob),
"f1_score": self.f1_scores.get(code, 0.0)
})
# Sort by confidence and limit results
predictions.sort(key=lambda x: x["confidence"], reverse=True)
return predictions[:max_codes]
def predict_top_k(self, clinical_note, k=10):
"""Get top k predictions regardless of threshold"""
# Tokenize input
inputs = self.tokenizer(
clinical_note,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Get predictions
with torch.no_grad():
_, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get top k predictions
top_indices = np.argsort(probabilities)[-k:][::-1]
predictions = []
for idx in top_indices:
code = self.idx_to_code[idx]
prob = probabilities[idx]
code_info = self.descriptions.get(code, {})
predictions.append({
"code": code,
"type": code_info.get("type", "Unknown"),
"description": code_info.get("description", f"Medical code {code}"),
"confidence": float(prob),
"f1_score": self.f1_scores.get(code, 0.0)
})
return predictions
# Example usage
if __name__ == "__main__":
predictor = MedicalCodingPredictor()
# Example clinical note
clinical_note = """
Patient presents with chest pain and shortness of breath.
ECG shows ST elevation in leads II, III, aVF suggesting inferior wall MI.
Cardiac enzymes elevated. Started on dual antiplatelet therapy.
"""
# Get predictions
predictions = predictor.predict(clinical_note, threshold=0.5)
for pred in predictions:
print(f"Code: {pred['code']}")
print(f"Type: {pred['type']}")
print(f"Description: {pred['description']}")
print(f"Confidence: {pred['confidence']:.3f}")
print(f"F1 Score: {pred['f1_score']:.3f}")
print("-" * 50)