File size: 6,490 Bytes
8ba82af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import json
import re
import numpy as np

class BioClinicalMedicalCoder(nn.Module):
    """BioClinicalModernBERT for Medical Coding"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
        self.embedding_dim = config.get("embedding_dim", 768)
        self.num_codes = config["num_codes"]
        
        # Projection layer for retrieval
        self.projection = nn.Linear(self.encoder.config.hidden_size, self.embedding_dim)
        self.dropout = nn.Dropout(0.1)
        
        # Classification head
        self.classifier = nn.Linear(self.embedding_dim, self.num_codes)
        
    def forward(self, input_ids, attention_mask, return_embeddings=False):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use CLS token
        cls_output = outputs.last_hidden_state[:, 0, :]
        
        # Project to embedding space
        embeddings = self.projection(self.dropout(cls_output))
        
        if return_embeddings:
            return embeddings
        
        # Classification logits
        logits = self.classifier(embeddings)
        return embeddings, logits

class MedicalCodingPredictor:
    """Inference wrapper for medical coding"""
    
    def __init__(self, model_path="."):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load configuration
        with open(f"{model_path}/config.json", 'r') as f:
            self.config = json.load(f)
        
        # Load mappings
        with open(f"{model_path}/code_to_idx.json", 'r') as f:
            self.code_to_idx = json.load(f)
        with open(f"{model_path}/idx_to_code.json", 'r') as f:
            self.idx_to_code = {int(k): v for k, v in json.load(f).items()}
        
        # Load descriptions
        with open(f"{model_path}/code_descriptions.json", 'r') as f:
            self.descriptions = json.load(f)
        
        # Load F1 scores
        try:
            with open(f"{model_path}/code_f1_scores.json", 'r') as f:
                self.f1_scores = json.load(f)
        except:
            self.f1_scores = {}
        
        # Initialize model
        self.model = BioClinicalMedicalCoder(self.config)
        state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
    
    def predict(self, clinical_note, threshold=0.5, max_codes=20):
        """
        Predict medical codes from clinical note
        
        Args:
            clinical_note (str): Clinical note text
            threshold (float): Confidence threshold for predictions
            max_codes (int): Maximum number of codes to return
            
        Returns:
            List of dictionaries with code, type, description, confidence, f1_score
        """
        # Tokenize input
        inputs = self.tokenizer(
            clinical_note,
            truncation=True,
            padding=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)
        
        # Get predictions
        with torch.no_grad():
            _, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
            probabilities = torch.sigmoid(logits).cpu().numpy()[0]
        
        # Get predictions above threshold
        predictions = []
        for idx, prob in enumerate(probabilities):
            if prob > threshold:
                code = self.idx_to_code[idx]
                code_info = self.descriptions.get(code, {})
                
                predictions.append({
                    "code": code,
                    "type": code_info.get("type", "Unknown"),
                    "description": code_info.get("description", f"Medical code {code}"),
                    "confidence": float(prob),
                    "f1_score": self.f1_scores.get(code, 0.0)
                })
        
        # Sort by confidence and limit results
        predictions.sort(key=lambda x: x["confidence"], reverse=True)
        return predictions[:max_codes]
    
    def predict_top_k(self, clinical_note, k=10):
        """Get top k predictions regardless of threshold"""
        # Tokenize input
        inputs = self.tokenizer(
            clinical_note,
            truncation=True,
            padding=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)
        
        # Get predictions
        with torch.no_grad():
            _, logits = self.model(inputs["input_ids"], inputs["attention_mask"])
            probabilities = torch.sigmoid(logits).cpu().numpy()[0]
        
        # Get top k predictions
        top_indices = np.argsort(probabilities)[-k:][::-1]
        
        predictions = []
        for idx in top_indices:
            code = self.idx_to_code[idx]
            prob = probabilities[idx]
            code_info = self.descriptions.get(code, {})
            
            predictions.append({
                "code": code,
                "type": code_info.get("type", "Unknown"),
                "description": code_info.get("description", f"Medical code {code}"),
                "confidence": float(prob),
                "f1_score": self.f1_scores.get(code, 0.0)
            })
        
        return predictions

# Example usage
if __name__ == "__main__":
    predictor = MedicalCodingPredictor()
    
    # Example clinical note
    clinical_note = """
    Patient presents with chest pain and shortness of breath. 
    ECG shows ST elevation in leads II, III, aVF suggesting inferior wall MI.
    Cardiac enzymes elevated. Started on dual antiplatelet therapy.
    """
    
    # Get predictions
    predictions = predictor.predict(clinical_note, threshold=0.5)
    
    for pred in predictions:
        print(f"Code: {pred['code']}")
        print(f"Type: {pred['type']}")
        print(f"Description: {pred['description']}")
        print(f"Confidence: {pred['confidence']:.3f}")
        print(f"F1 Score: {pred['f1_score']:.3f}")
        print("-" * 50)