monajm36 commited on
Commit
3130918
Β·
unverified Β·
1 Parent(s): 39a2c30

Create ohca_inference.py

Browse files
Files changed (1) hide show
  1. src/ohca_inference.py +455 -0
src/ohca_inference.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OHCA Inference Module
2
+ # Apply pre-trained OHCA classifier to new datasets
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from tqdm import tqdm
10
+ import os
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ # =============================================================================
16
+ # CONFIGURATION
17
+ # =============================================================================
18
+
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Inference Module - Using device: {DEVICE}")
21
+
22
+ # =============================================================================
23
+ # INFERENCE DATASET CLASS
24
+ # =============================================================================
25
+
26
+ class OHCAInferenceDataset(Dataset):
27
+ """Dataset for OHCA inference on new data"""
28
+
29
+ def __init__(self, dataframe, tokenizer, max_length=512):
30
+ self.data = dataframe.reset_index(drop=True)
31
+ self.tokenizer = tokenizer
32
+ self.max_length = max_length
33
+
34
+ # Validate required columns
35
+ if 'hadm_id' not in self.data.columns or 'clean_text' not in self.data.columns:
36
+ raise ValueError("DataFrame must contain 'hadm_id' and 'clean_text' columns")
37
+
38
+ def __len__(self):
39
+ return len(self.data)
40
+
41
+ def __getitem__(self, idx):
42
+ row = self.data.iloc[idx]
43
+ text = str(row['clean_text'])
44
+
45
+ # Apply preprocessing consistent with training
46
+ if 'transfer' in text.lower():
47
+ text = "TRANSFERRED_PATIENT " + text
48
+
49
+ encoding = self.tokenizer(
50
+ text,
51
+ truncation=True,
52
+ padding='max_length',
53
+ max_length=self.max_length,
54
+ return_tensors='pt'
55
+ )
56
+
57
+ return {
58
+ 'input_ids': encoding['input_ids'].flatten(),
59
+ 'attention_mask': encoding['attention_mask'].flatten(),
60
+ 'hadm_id': row['hadm_id']
61
+ }
62
+
63
+ # =============================================================================
64
+ # MODEL LOADING FUNCTIONS
65
+ # =============================================================================
66
+
67
+ def load_ohca_model(model_path):
68
+ """
69
+ Load pre-trained OHCA model and tokenizer
70
+
71
+ Args:
72
+ model_path: Path to saved model directory
73
+
74
+ Returns:
75
+ tuple: (model, tokenizer)
76
+ """
77
+ print(f"πŸ“‚ Loading OHCA model from: {model_path}")
78
+
79
+ if not os.path.exists(model_path):
80
+ raise FileNotFoundError(f"Model not found at: {model_path}")
81
+
82
+ try:
83
+ # Load tokenizer and model
84
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
85
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
86
+ model = model.to(DEVICE)
87
+ model.eval()
88
+
89
+ print("βœ… Model loaded successfully")
90
+ print(f" Device: {DEVICE}")
91
+ print(f" Model type: {type(model).__name__}")
92
+
93
+ return model, tokenizer
94
+
95
+ except Exception as e:
96
+ raise RuntimeError(f"Failed to load model: {str(e)}")
97
+
98
+ # =============================================================================
99
+ # INFERENCE FUNCTIONS
100
+ # =============================================================================
101
+
102
+ def run_inference(model, tokenizer, inference_df, batch_size=16,
103
+ output_path=None, probability_threshold=0.5):
104
+ """
105
+ Run OHCA inference on new data
106
+
107
+ Args:
108
+ model: Pre-trained OHCA model
109
+ tokenizer: Model tokenizer
110
+ inference_df: DataFrame with columns ['hadm_id', 'clean_text']
111
+ batch_size: Batch size for inference
112
+ output_path: Optional path to save results CSV
113
+ probability_threshold: Threshold for binary predictions
114
+
115
+ Returns:
116
+ DataFrame: Results with probabilities and predictions
117
+ """
118
+ print(f"πŸ” Running OHCA inference on {len(inference_df):,} cases...")
119
+
120
+ # Validate input data
121
+ required_cols = ['hadm_id', 'clean_text']
122
+ missing_cols = [col for col in required_cols if col not in inference_df.columns]
123
+ if missing_cols:
124
+ raise ValueError(f"Missing required columns: {missing_cols}")
125
+
126
+ # Remove any rows with missing data
127
+ clean_df = inference_df.dropna(subset=required_cols).copy()
128
+ if len(clean_df) < len(inference_df):
129
+ print(f"⚠️ Removed {len(inference_df) - len(clean_df)} rows with missing data")
130
+
131
+ # Create dataset and dataloader
132
+ inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
133
+ inference_dataloader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False)
134
+
135
+ # Run inference
136
+ model.eval()
137
+ all_probabilities = []
138
+ all_hadm_ids = []
139
+
140
+ with torch.no_grad():
141
+ for batch in tqdm(inference_dataloader, desc="Processing batches"):
142
+ input_ids = batch['input_ids'].to(DEVICE)
143
+ attention_mask = batch['attention_mask'].to(DEVICE)
144
+ hadm_ids = batch['hadm_id']
145
+
146
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
147
+ probs = F.softmax(outputs.logits, dim=1)
148
+
149
+ # Get OHCA probabilities (class 1)
150
+ ohca_probs = probs[:, 1].cpu().numpy()
151
+
152
+ all_probabilities.extend(ohca_probs)
153
+ all_hadm_ids.extend(hadm_ids)
154
+
155
+ # Create results dataframe
156
+ results_df = pd.DataFrame({
157
+ 'hadm_id': all_hadm_ids,
158
+ 'ohca_probability': all_probabilities
159
+ })
160
+
161
+ # Add predictions with different thresholds
162
+ results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
163
+ results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
164
+ results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
165
+ results_df['prediction_custom'] = (results_df['ohca_probability'] >= probability_threshold).astype(int)
166
+
167
+ # Add confidence categories
168
+ def categorize_confidence(prob):
169
+ if prob >= 0.9:
170
+ return "Very High"
171
+ elif prob >= 0.7:
172
+ return "High"
173
+ elif prob >= 0.3:
174
+ return "Medium"
175
+ elif prob >= 0.1:
176
+ return "Low"
177
+ else:
178
+ return "Very Low"
179
+
180
+ results_df['confidence_category'] = results_df['ohca_probability'].apply(categorize_confidence)
181
+
182
+ # Sort by probability (highest first)
183
+ results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
184
+
185
+ # Print summary
186
+ print(f"\nπŸ“Š Inference Results Summary:")
187
+ print(f" Total cases processed: {len(results_df):,}")
188
+ print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
189
+ print(f" Max OHCA probability: {results_df['ohca_probability'].max():.3f}")
190
+ print(f" Min OHCA probability: {results_df['ohca_probability'].min():.3f}")
191
+
192
+ # Probability distribution
193
+ print(f"\n🎯 Probability Distribution:")
194
+ thresholds = [0.9, 0.8, 0.7, 0.6, 0.5, 0.3, 0.1]
195
+ for threshold in thresholds:
196
+ count = (results_df['ohca_probability'] >= threshold).sum()
197
+ pct = count / len(results_df) * 100
198
+ print(f" β‰₯{threshold}: {count:,} cases ({pct:.2f}%)")
199
+
200
+ # Confidence categories
201
+ print(f"\nπŸ“ˆ Confidence Distribution:")
202
+ conf_dist = results_df['confidence_category'].value_counts()
203
+ for category, count in conf_dist.items():
204
+ pct = count / len(results_df) * 100
205
+ print(f" {category}: {count:,} cases ({pct:.1f}%)")
206
+
207
+ # Save results if path provided
208
+ if output_path:
209
+ results_df.to_csv(output_path, index=False)
210
+ print(f"\nπŸ’Ύ Results saved to: {output_path}")
211
+
212
+ return results_df
213
+
214
+ def get_high_confidence_cases(results_df, threshold=0.8, max_cases=100):
215
+ """
216
+ Extract high-confidence OHCA predictions for manual review
217
+
218
+ Args:
219
+ results_df: Results from run_inference()
220
+ threshold: Minimum probability threshold
221
+ max_cases: Maximum number of cases to return
222
+
223
+ Returns:
224
+ DataFrame: High-confidence cases sorted by probability
225
+ """
226
+ high_conf = results_df[results_df['ohca_probability'] >= threshold].copy()
227
+ high_conf = high_conf.head(max_cases)
228
+
229
+ print(f"🎯 Found {len(high_conf)} high-confidence cases (β‰₯{threshold})")
230
+
231
+ return high_conf
232
+
233
+ def analyze_predictions(results_df, original_df=None):
234
+ """
235
+ Analyze prediction patterns and provide clinical insights
236
+
237
+ Args:
238
+ results_df: Results from run_inference()
239
+ original_df: Optional original dataframe to merge with results
240
+
241
+ Returns:
242
+ dict: Analysis summary
243
+ """
244
+ print("πŸ“‹ Analyzing prediction patterns...")
245
+
246
+ # Basic statistics
247
+ stats = {
248
+ 'total_cases': len(results_df),
249
+ 'mean_probability': results_df['ohca_probability'].mean(),
250
+ 'std_probability': results_df['ohca_probability'].std(),
251
+ 'median_probability': results_df['ohca_probability'].median(),
252
+ 'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
253
+ 'predicted_ohca_050': results_df['prediction_050'].sum(),
254
+ 'predicted_ohca_070': results_df['prediction_070'].sum(),
255
+ 'predicted_ohca_090': results_df['prediction_090'].sum(),
256
+ }
257
+
258
+ # Confidence distribution
259
+ conf_dist = results_df['confidence_category'].value_counts().to_dict()
260
+
261
+ # Print analysis
262
+ print(f"\nπŸ“Š Prediction Analysis:")
263
+ print(f" Total cases: {stats['total_cases']:,}")
264
+ print(f" Mean probability: {stats['mean_probability']:.4f}")
265
+ print(f" Predicted OHCA (β‰₯0.5): {stats['predicted_ohca_050']:,}")
266
+ print(f" High confidence (β‰₯0.8): {stats['high_confidence_cases']:,}")
267
+
268
+ if stats['predicted_ohca_050'] > 0:
269
+ prevalence = stats['predicted_ohca_050'] / stats['total_cases'] * 100
270
+ print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
271
+
272
+ # Clinical recommendations
273
+ print(f"\nπŸ₯ Clinical Recommendations:")
274
+ if stats['high_confidence_cases'] > 0:
275
+ print(f" β€’ Priority review: {stats['high_confidence_cases']} high-confidence cases")
276
+ if stats['predicted_ohca_070'] > 0:
277
+ print(f" β€’ Clinical review: {stats['predicted_ohca_070']} cases β‰₯0.7 probability")
278
+
279
+ uncertain_cases = ((results_df['ohca_probability'] >= 0.3) &
280
+ (results_df['ohca_probability'] < 0.7)).sum()
281
+ if uncertain_cases > 0:
282
+ print(f" β€’ Manual review suggested: {uncertain_cases} uncertain cases")
283
+
284
+ return {
285
+ 'statistics': stats,
286
+ 'confidence_distribution': conf_dist,
287
+ 'high_confidence_cases': results_df[results_df['ohca_probability'] >= 0.8]
288
+ }
289
+
290
+ # =============================================================================
291
+ # BATCH PROCESSING FUNCTIONS
292
+ # =============================================================================
293
+
294
+ def process_large_dataset(model_path, data_path, output_path,
295
+ chunk_size=10000, batch_size=16):
296
+ """
297
+ Process large datasets in chunks to avoid memory issues
298
+
299
+ Args:
300
+ model_path: Path to trained model
301
+ data_path: Path to input CSV file
302
+ output_path: Path for output results
303
+ chunk_size: Number of rows per chunk
304
+ batch_size: Batch size for inference
305
+
306
+ Returns:
307
+ str: Path to completed results file
308
+ """
309
+ print(f"πŸ”„ Processing large dataset in chunks of {chunk_size:,}...")
310
+
311
+ # Load model once
312
+ model, tokenizer = load_ohca_model(model_path)
313
+
314
+ # Read data in chunks
315
+ chunk_results = []
316
+ chunk_num = 0
317
+
318
+ for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
319
+ chunk_num += 1
320
+ print(f"\nπŸ“¦ Processing chunk {chunk_num} ({len(chunk_df):,} rows)...")
321
+
322
+ # Run inference on chunk
323
+ chunk_result = run_inference(
324
+ model, tokenizer, chunk_df,
325
+ batch_size=batch_size, output_path=None
326
+ )
327
+
328
+ chunk_results.append(chunk_result)
329
+
330
+ # Save intermediate results
331
+ temp_path = f"{output_path}.chunk_{chunk_num}.csv"
332
+ chunk_result.to_csv(temp_path, index=False)
333
+ print(f"πŸ’Ύ Chunk {chunk_num} saved to: {temp_path}")
334
+
335
+ # Combine all chunks
336
+ print(f"\nπŸ”— Combining {len(chunk_results)} chunks...")
337
+ final_results = pd.concat(chunk_results, ignore_index=True)
338
+
339
+ # Sort by probability and save
340
+ final_results = final_results.sort_values('ohca_probability', ascending=False)
341
+ final_results.to_csv(output_path, index=False)
342
+
343
+ print(f"βœ… Complete results saved to: {output_path}")
344
+ print(f"πŸ“Š Total cases processed: {len(final_results):,}")
345
+
346
+ # Clean up intermediate files
347
+ for i in range(1, chunk_num + 1):
348
+ temp_path = f"{output_path}.chunk_{i}.csv"
349
+ if os.path.exists(temp_path):
350
+ os.remove(temp_path)
351
+
352
+ return output_path
353
+
354
+ # =============================================================================
355
+ # CONVENIENCE FUNCTIONS
356
+ # =============================================================================
357
+
358
+ def quick_inference(model_path, data_path, output_path=None):
359
+ """
360
+ Quick inference function for simple use cases
361
+
362
+ Args:
363
+ model_path: Path to trained model
364
+ data_path: Path to input CSV (or DataFrame)
365
+ output_path: Optional output path
366
+
367
+ Returns:
368
+ DataFrame: Inference results
369
+ """
370
+ print("πŸš€ Quick OHCA Inference")
371
+
372
+ # Load model
373
+ model, tokenizer = load_ohca_model(model_path)
374
+
375
+ # Load data
376
+ if isinstance(data_path, str):
377
+ df = pd.read_csv(data_path)
378
+ print(f"πŸ“‚ Loaded {len(df):,} cases from {data_path}")
379
+ else:
380
+ df = data_path.copy()
381
+ print(f"πŸ“Š Processing {len(df):,} cases from DataFrame")
382
+
383
+ # Run inference
384
+ results = run_inference(model, tokenizer, df, output_path=output_path)
385
+
386
+ # Quick summary
387
+ ohca_cases = (results['ohca_probability'] >= 0.5).sum()
388
+ high_conf = (results['ohca_probability'] >= 0.8).sum()
389
+
390
+ print(f"\nβœ… Quick Summary:")
391
+ print(f" Predicted OHCA cases: {ohca_cases:,}")
392
+ print(f" High confidence: {high_conf:,}")
393
+
394
+ return results
395
+
396
+ def test_model_on_sample(model_path, sample_texts):
397
+ """
398
+ Test model on a few sample texts for quick validation
399
+
400
+ Args:
401
+ model_path: Path to trained model
402
+ sample_texts: List of text strings or dict with hadm_id: text
403
+
404
+ Returns:
405
+ DataFrame: Test results
406
+ """
407
+ print("πŸ§ͺ Testing model on sample texts...")
408
+
409
+ # Prepare test data
410
+ if isinstance(sample_texts, dict):
411
+ test_df = pd.DataFrame([
412
+ {'hadm_id': hadm_id, 'clean_text': text}
413
+ for hadm_id, text in sample_texts.items()
414
+ ])
415
+ else:
416
+ test_df = pd.DataFrame([
417
+ {'hadm_id': f'TEST_{i:03d}', 'clean_text': text}
418
+ for i, text in enumerate(sample_texts, 1)
419
+ ])
420
+
421
+ # Run inference
422
+ model, tokenizer = load_ohca_model(model_path)
423
+ results = run_inference(model, tokenizer, test_df, output_path=None)
424
+
425
+ # Print results
426
+ print(f"\nπŸ” Test Results:")
427
+ for _, row in results.iterrows():
428
+ prob = row['ohca_probability']
429
+ pred = "OHCA" if prob >= 0.5 else "Non-OHCA"
430
+ conf = row['confidence_category']
431
+
432
+ print(f" {row['hadm_id']}: {pred} (prob={prob:.3f}, {conf})")
433
+
434
+ # Show text preview
435
+ text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
436
+ print(f" Text: {text_preview[:100]}...")
437
+ print()
438
+
439
+ return results
440
+
441
+ # =============================================================================
442
+ # EXAMPLE USAGE
443
+ # =============================================================================
444
+
445
+ if __name__ == "__main__":
446
+ print("OHCA Inference Module")
447
+ print("="*25)
448
+ print("This module provides inference capabilities for pre-trained OHCA models.")
449
+ print("\nMain functions:")
450
+ print("β€’ load_ohca_model() - Load pre-trained model")
451
+ print("β€’ run_inference() - Run inference on new data")
452
+ print("β€’ quick_inference() - Simple inference function")
453
+ print("β€’ process_large_dataset() - Handle large datasets")
454
+ print("β€’ test_model_on_sample() - Test on sample texts")
455
+ print("\nSee examples/ folder for detailed usage examples.")