import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel import json import re import numpy as np class BioClinicalMedicalCoder(nn.Module): """BioClinicalModernBERT for Medical Coding""" def __init__(self, config): super().__init__() self.config = config self.encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext") self.embedding_dim = config.get("embedding_dim", 768) self.num_codes = config["num_codes"] # Projection layer for retrieval self.projection = nn.Linear(self.encoder.config.hidden_size, self.embedding_dim) self.dropout = nn.Dropout(0.1) # Classification head self.classifier = nn.Linear(self.embedding_dim, self.num_codes) def forward(self, input_ids, attention_mask, return_embeddings=False): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) # Use CLS token cls_output = outputs.last_hidden_state[:, 0, :] # Project to embedding space embeddings = self.projection(self.dropout(cls_output)) if return_embeddings: return embeddings # Classification logits logits = self.classifier(embeddings) return embeddings, logits class MedicalCodingPredictor: """Inference wrapper for medical coding""" def __init__(self, model_path="."): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load configuration with open(f"{model_path}/config.json", 'r') as f: self.config = json.load(f) # Load mappings with open(f"{model_path}/code_to_idx.json", 'r') as f: self.code_to_idx = json.load(f) with open(f"{model_path}/idx_to_code.json", 'r') as f: self.idx_to_code = {int(k): v for k, v in json.load(f).items()} # Load descriptions with open(f"{model_path}/code_descriptions.json", 'r') as f: self.descriptions = json.load(f) # Load F1 scores try: with open(f"{model_path}/code_f1_scores.json", 'r') as f: self.f1_scores = json.load(f) except: self.f1_scores = {} # Initialize model self.model = BioClinicalMedicalCoder(self.config) state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext") def predict(self, clinical_note, threshold=0.5, max_codes=20): """ Predict medical codes from clinical note Args: clinical_note (str): Clinical note text threshold (float): Confidence threshold for predictions max_codes (int): Maximum number of codes to return Returns: List of dictionaries with code, type, description, confidence, f1_score """ # Tokenize input inputs = self.tokenizer( clinical_note, truncation=True, padding=True, max_length=512, return_tensors="pt" ).to(self.device) # Get predictions with torch.no_grad(): _, logits = self.model(inputs["input_ids"], inputs["attention_mask"]) probabilities = torch.sigmoid(logits).cpu().numpy()[0] # Get predictions above threshold predictions = [] for idx, prob in enumerate(probabilities): if prob > threshold: code = self.idx_to_code[idx] code_info = self.descriptions.get(code, {}) predictions.append({ "code": code, "type": code_info.get("type", "Unknown"), "description": code_info.get("description", f"Medical code {code}"), "confidence": float(prob), "f1_score": self.f1_scores.get(code, 0.0) }) # Sort by confidence and limit results predictions.sort(key=lambda x: x["confidence"], reverse=True) return predictions[:max_codes] def predict_top_k(self, clinical_note, k=10): """Get top k predictions regardless of threshold""" # Tokenize input inputs = self.tokenizer( clinical_note, truncation=True, padding=True, max_length=512, return_tensors="pt" ).to(self.device) # Get predictions with torch.no_grad(): _, logits = self.model(inputs["input_ids"], inputs["attention_mask"]) probabilities = torch.sigmoid(logits).cpu().numpy()[0] # Get top k predictions top_indices = np.argsort(probabilities)[-k:][::-1] predictions = [] for idx in top_indices: code = self.idx_to_code[idx] prob = probabilities[idx] code_info = self.descriptions.get(code, {}) predictions.append({ "code": code, "type": code_info.get("type", "Unknown"), "description": code_info.get("description", f"Medical code {code}"), "confidence": float(prob), "f1_score": self.f1_scores.get(code, 0.0) }) return predictions # Example usage if __name__ == "__main__": predictor = MedicalCodingPredictor() # Example clinical note clinical_note = """ Patient presents with chest pain and shortness of breath. ECG shows ST elevation in leads II, III, aVF suggesting inferior wall MI. Cardiac enzymes elevated. Started on dual antiplatelet therapy. """ # Get predictions predictions = predictor.predict(clinical_note, threshold=0.5) for pred in predictions: print(f"Code: {pred['code']}") print(f"Type: {pred['type']}") print(f"Description: {pred['description']}") print(f"Confidence: {pred['confidence']:.3f}") print(f"F1 Score: {pred['f1_score']:.3f}") print("-" * 50)