monajm36 commited on
Commit
a5114c2
·
unverified ·
1 Parent(s): e6890e6

Update ohca_inference.py

Browse files
Files changed (1) hide show
  1. src/ohca_inference.py +379 -158
src/ohca_inference.py CHANGED
@@ -1,5 +1,5 @@
1
- # OHCA Inference Module
2
- # Apply pre-trained OHCA classifier to new datasets
3
 
4
  import pandas as pd
5
  import numpy as np
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  from torch.utils.data import DataLoader, Dataset
9
  from tqdm import tqdm
10
  import os
 
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
  import warnings
13
  warnings.filterwarnings('ignore')
@@ -17,7 +18,7 @@ warnings.filterwarnings('ignore')
17
  # =============================================================================
18
 
19
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- print(f"Inference Module - Using device: {DEVICE}")
21
 
22
  # =============================================================================
23
  # INFERENCE DATASET CLASS
@@ -61,12 +62,58 @@ class OHCAInferenceDataset(Dataset):
61
  }
62
 
63
  # =============================================================================
64
- # MODEL LOADING FUNCTIONS
65
  # =============================================================================
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def load_ohca_model(model_path):
68
  """
69
- Load pre-trained OHCA model and tokenizer
70
 
71
  Args:
72
  model_path: Path to saved model directory
@@ -74,7 +121,7 @@ def load_ohca_model(model_path):
74
  Returns:
75
  tuple: (model, tokenizer)
76
  """
77
- print(f"📂 Loading OHCA model from: {model_path}")
78
 
79
  if not os.path.exists(model_path):
80
  raise FileNotFoundError(f"Model not found at: {model_path}")
@@ -86,7 +133,7 @@ def load_ohca_model(model_path):
86
  model = model.to(DEVICE)
87
  model.eval()
88
 
89
- print("Model loaded successfully")
90
  print(f" Device: {DEVICE}")
91
  print(f" Model type: {type(model).__name__}")
92
 
@@ -96,26 +143,53 @@ def load_ohca_model(model_path):
96
  raise RuntimeError(f"Failed to load model: {str(e)}")
97
 
98
  # =============================================================================
99
- # INFERENCE FUNCTIONS
100
  # =============================================================================
101
 
102
- def run_inference(model, tokenizer, inference_df, batch_size=16,
103
- output_path=None, probability_threshold=0.5):
104
  """
105
- Run OHCA inference on new data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  Args:
108
  model: Pre-trained OHCA model
109
  tokenizer: Model tokenizer
110
  inference_df: DataFrame with columns ['hadm_id', 'clean_text']
 
111
  batch_size: Batch size for inference
112
  output_path: Optional path to save results CSV
113
- probability_threshold: Threshold for binary predictions
114
 
115
  Returns:
116
- DataFrame: Results with probabilities and predictions
117
  """
118
- print(f"🔍 Running OHCA inference on {len(inference_df):,} cases...")
 
119
 
120
  # Validate input data
121
  required_cols = ['hadm_id', 'clean_text']
@@ -126,7 +200,7 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
126
  # Remove any rows with missing data
127
  clean_df = inference_df.dropna(subset=required_cols).copy()
128
  if len(clean_df) < len(inference_df):
129
- print(f"⚠️ Removed {len(inference_df) - len(clean_df)} rows with missing data")
130
 
131
  # Create dataset and dataloader
132
  inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
@@ -158,47 +232,47 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
158
  'ohca_probability': all_probabilities
159
  })
160
 
161
- # Add predictions with different thresholds
 
 
 
 
162
  results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
163
  results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
164
  results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
165
- results_df['prediction_custom'] = (results_df['ohca_probability'] >= probability_threshold).astype(int)
166
-
167
- # Add confidence categories
168
- def categorize_confidence(prob):
169
- if prob >= 0.9:
170
- return "Very High"
171
- elif prob >= 0.7:
172
- return "High"
173
- elif prob >= 0.3:
174
- return "Medium"
175
- elif prob >= 0.1:
176
- return "Low"
177
- else:
178
- return "Very Low"
179
 
180
- results_df['confidence_category'] = results_df['ohca_probability'].apply(categorize_confidence)
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # Sort by probability (highest first)
183
  results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
184
 
185
- # Print summary
186
- print(f"\n📊 Inference Results Summary:")
187
  print(f" Total cases processed: {len(results_df):,}")
188
  print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
189
- print(f" Max OHCA probability: {results_df['ohca_probability'].max():.3f}")
190
- print(f" Min OHCA probability: {results_df['ohca_probability'].min():.3f}")
191
-
192
- # Probability distribution
193
- print(f"\n🎯 Probability Distribution:")
194
- thresholds = [0.9, 0.8, 0.7, 0.6, 0.5, 0.3, 0.1]
195
- for threshold in thresholds:
196
- count = (results_df['ohca_probability'] >= threshold).sum()
197
  pct = count / len(results_df) * 100
198
- print(f" {threshold}: {count:,} cases ({pct:.2f}%)")
199
 
200
- # Confidence categories
201
- print(f"\n📈 Confidence Distribution:")
202
  conf_dist = results_df['confidence_category'].value_counts()
203
  for category, count in conf_dist.items():
204
  pct = count / len(results_df) * 100
@@ -206,98 +280,203 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
206
 
207
  # Save results if path provided
208
  if output_path:
 
 
209
  results_df.to_csv(output_path, index=False)
210
- print(f"\n💾 Results saved to: {output_path}")
211
 
212
  return results_df
213
 
214
- def get_high_confidence_cases(results_df, threshold=0.8, max_cases=100):
 
 
 
215
  """
216
- Extract high-confidence OHCA predictions for manual review
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  Args:
219
- results_df: Results from run_inference()
220
- threshold: Minimum probability threshold
221
- max_cases: Maximum number of cases to return
222
 
223
  Returns:
224
- DataFrame: High-confidence cases sorted by probability
225
  """
226
- high_conf = results_df[results_df['ohca_probability'] >= threshold].copy()
227
- high_conf = high_conf.head(max_cases)
228
 
229
- print(f"🎯 Found {len(high_conf)} high-confidence cases (≥{threshold})")
 
230
 
231
- return high_conf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- def analyze_predictions(results_df, original_df=None):
234
  """
235
- Analyze prediction patterns and provide clinical insights
 
236
 
237
  Args:
238
- results_df: Results from run_inference()
239
- original_df: Optional original dataframe to merge with results
 
240
 
241
  Returns:
242
- dict: Analysis summary
243
  """
244
- print("📋 Analyzing prediction patterns...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  # Basic statistics
247
  stats = {
248
  'total_cases': len(results_df),
 
249
  'mean_probability': results_df['ohca_probability'].mean(),
250
  'std_probability': results_df['ohca_probability'].std(),
251
  'median_probability': results_df['ohca_probability'].median(),
 
252
  'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
253
- 'predicted_ohca_050': results_df['prediction_050'].sum(),
254
- 'predicted_ohca_070': results_df['prediction_070'].sum(),
255
- 'predicted_ohca_090': results_df['prediction_090'].sum(),
256
  }
257
 
258
- # Confidence distribution
259
- conf_dist = results_df['confidence_category'].value_counts().to_dict()
 
 
 
260
 
261
- # Print analysis
262
- print(f"\n📊 Prediction Analysis:")
 
 
 
 
 
 
263
  print(f" Total cases: {stats['total_cases']:,}")
 
264
  print(f" Mean probability: {stats['mean_probability']:.4f}")
265
- print(f" Predicted OHCA (≥0.5): {stats['predicted_ohca_050']:,}")
266
- print(f" High confidence (≥0.8): {stats['high_confidence_cases']:,}")
267
 
268
- if stats['predicted_ohca_050'] > 0:
269
- prevalence = stats['predicted_ohca_050'] / stats['total_cases'] * 100
270
  print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
271
 
272
- # Clinical recommendations
273
- print(f"\n🏥 Clinical Recommendations:")
274
- if stats['high_confidence_cases'] > 0:
275
- print(f" Priority review: {stats['high_confidence_cases']} high-confidence cases")
276
- if stats['predicted_ohca_070'] > 0:
277
- print(f" • Clinical review: {stats['predicted_ohca_070']} cases ≥0.7 probability")
278
 
279
- uncertain_cases = ((results_df['ohca_probability'] >= 0.3) &
280
- (results_df['ohca_probability'] < 0.7)).sum()
281
- if uncertain_cases > 0:
282
- print(f" • Manual review suggested: {uncertain_cases} uncertain cases")
 
 
283
 
284
  return {
285
  'statistics': stats,
 
286
  'confidence_distribution': conf_dist,
287
- 'high_confidence_cases': results_df[results_df['ohca_probability'] >= 0.8]
 
288
  }
289
 
290
  # =============================================================================
291
- # BATCH PROCESSING FUNCTIONS
292
  # =============================================================================
293
 
294
- def process_large_dataset(model_path, data_path, output_path,
295
- chunk_size=10000, batch_size=16):
296
  """
297
- Process large datasets in chunks to avoid memory issues
298
 
299
  Args:
300
- model_path: Path to trained model
301
  data_path: Path to input CSV file
302
  output_path: Path for output results
303
  chunk_size: Number of rows per chunk
@@ -306,10 +485,10 @@ def process_large_dataset(model_path, data_path, output_path,
306
  Returns:
307
  str: Path to completed results file
308
  """
309
- print(f"🔄 Processing large dataset in chunks of {chunk_size:,}...")
310
 
311
- # Load model once
312
- model, tokenizer = load_ohca_model(model_path)
313
 
314
  # Read data in chunks
315
  chunk_results = []
@@ -317,11 +496,11 @@ def process_large_dataset(model_path, data_path, output_path,
317
 
318
  for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
319
  chunk_num += 1
320
- print(f"\n📦 Processing chunk {chunk_num} ({len(chunk_df):,} rows)...")
321
 
322
- # Run inference on chunk
323
- chunk_result = run_inference(
324
- model, tokenizer, chunk_df,
325
  batch_size=batch_size, output_path=None
326
  )
327
 
@@ -330,18 +509,24 @@ def process_large_dataset(model_path, data_path, output_path,
330
  # Save intermediate results
331
  temp_path = f"{output_path}.chunk_{chunk_num}.csv"
332
  chunk_result.to_csv(temp_path, index=False)
333
- print(f"💾 Chunk {chunk_num} saved to: {temp_path}")
334
 
335
  # Combine all chunks
336
- print(f"\n🔗 Combining {len(chunk_results)} chunks...")
337
  final_results = pd.concat(chunk_results, ignore_index=True)
338
 
339
  # Sort by probability and save
340
  final_results = final_results.sort_values('ohca_probability', ascending=False)
 
 
 
 
 
341
  final_results.to_csv(output_path, index=False)
342
 
343
- print(f"Complete results saved to: {output_path}")
344
- print(f"📊 Total cases processed: {len(final_results):,}")
 
345
 
346
  # Clean up intermediate files
347
  for i in range(1, chunk_num + 1):
@@ -351,60 +536,56 @@ def process_large_dataset(model_path, data_path, output_path,
351
 
352
  return output_path
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  # =============================================================================
355
- # CONVENIENCE FUNCTIONS
356
  # =============================================================================
357
 
358
- def quick_inference(model_path, data_path, output_path=None):
359
- """
360
- Quick inference function for simple use cases
361
-
362
- Args:
363
- model_path: Path to trained model
364
- data_path: Path to input CSV (or DataFrame)
365
- output_path: Optional output path
366
-
367
- Returns:
368
- DataFrame: Inference results
369
- """
370
- print("🚀 Quick OHCA Inference")
371
-
372
- # Load model
373
- model, tokenizer = load_ohca_model(model_path)
374
-
375
- # Load data
376
- if isinstance(data_path, str):
377
- df = pd.read_csv(data_path)
378
- print(f"📂 Loaded {len(df):,} cases from {data_path}")
379
- else:
380
- df = data_path.copy()
381
- print(f"📊 Processing {len(df):,} cases from DataFrame")
382
-
383
- # Run inference
384
- results = run_inference(model, tokenizer, df, output_path=output_path)
385
-
386
- # Quick summary
387
- ohca_cases = (results['ohca_probability'] >= 0.5).sum()
388
- high_conf = (results['ohca_probability'] >= 0.8).sum()
389
-
390
- print(f"\n✅ Quick Summary:")
391
- print(f" Predicted OHCA cases: {ohca_cases:,}")
392
- print(f" High confidence: {high_conf:,}")
393
-
394
- return results
395
-
396
  def test_model_on_sample(model_path, sample_texts):
397
  """
398
- Test model on a few sample texts for quick validation
399
 
400
  Args:
401
  model_path: Path to trained model
402
  sample_texts: List of text strings or dict with hadm_id: text
403
 
404
  Returns:
405
- DataFrame: Test results
406
  """
407
- print("🧪 Testing model on sample texts...")
408
 
409
  # Prepare test data
410
  if isinstance(sample_texts, dict):
@@ -418,18 +599,28 @@ def test_model_on_sample(model_path, sample_texts):
418
  for i, text in enumerate(sample_texts, 1)
419
  ])
420
 
421
- # Run inference
422
- model, tokenizer = load_ohca_model(model_path)
423
- results = run_inference(model, tokenizer, test_df, output_path=None)
 
 
 
 
 
 
 
 
 
424
 
425
- # Print results
426
- print(f"\n🔍 Test Results:")
427
  for _, row in results.iterrows():
428
  prob = row['ohca_probability']
429
- pred = "OHCA" if prob >= 0.5 else "Non-OHCA"
430
  conf = row['confidence_category']
 
431
 
432
- print(f" {row['hadm_id']}: {pred} (prob={prob:.3f}, {conf})")
433
 
434
  # Show text preview
435
  text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
@@ -438,18 +629,48 @@ def test_model_on_sample(model_path, sample_texts):
438
 
439
  return results
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  # =============================================================================
442
  # EXAMPLE USAGE
443
  # =============================================================================
444
 
445
  if __name__ == "__main__":
446
- print("OHCA Inference Module")
447
- print("="*25)
448
- print("This module provides inference capabilities for pre-trained OHCA models.")
449
- print("\nMain functions:")
450
- print(" load_ohca_model() - Load pre-trained model")
451
- print(" run_inference() - Run inference on new data")
452
- print(" quick_inference() - Simple inference function")
453
- print(" process_large_dataset() - Handle large datasets")
454
- print("• test_model_on_sample() - Test on sample texts")
455
- print("\nSee examples/ folder for detailed usage examples.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OHCA Inference Module v3.0 - Improved with Optimal Threshold Support
2
+ # Apply pre-trained OHCA classifier to new datasets using optimal thresholds
3
 
4
  import pandas as pd
5
  import numpy as np
 
8
  from torch.utils.data import DataLoader, Dataset
9
  from tqdm import tqdm
10
  import os
11
+ import json
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
  import warnings
14
  warnings.filterwarnings('ignore')
 
18
  # =============================================================================
19
 
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ print(f"Inference Module v3.0 - Using device: {DEVICE}")
22
 
23
  # =============================================================================
24
  # INFERENCE DATASET CLASS
 
62
  }
63
 
64
  # =============================================================================
65
+ # IMPROVED MODEL LOADING FUNCTIONS
66
  # =============================================================================
67
 
68
+ def load_ohca_model_with_metadata(model_path):
69
+ """
70
+ Load pre-trained OHCA model, tokenizer, and metadata (including optimal threshold).
71
+ This addresses the data scientist's feedback about using consistent thresholds.
72
+
73
+ Args:
74
+ model_path: Path to saved model directory
75
+
76
+ Returns:
77
+ tuple: (model, tokenizer, optimal_threshold, metadata)
78
+ """
79
+ print(f"Loading OHCA model with metadata from: {model_path}")
80
+
81
+ if not os.path.exists(model_path):
82
+ raise FileNotFoundError(f"Model not found at: {model_path}")
83
+
84
+ try:
85
+ # Load tokenizer and model
86
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
87
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
88
+ model = model.to(DEVICE)
89
+ model.eval()
90
+
91
+ # Load metadata with optimal threshold
92
+ metadata_path = os.path.join(model_path, 'model_metadata.json')
93
+ if os.path.exists(metadata_path):
94
+ with open(metadata_path, 'r') as f:
95
+ metadata = json.load(f)
96
+ optimal_threshold = metadata.get('optimal_threshold', 0.5)
97
+ print(f"Loaded optimal threshold: {optimal_threshold:.3f}")
98
+ print(f"Model version: {metadata.get('model_version', 'unknown')}")
99
+ else:
100
+ print("Warning: No metadata file found. Using default threshold of 0.5")
101
+ optimal_threshold = 0.5
102
+ metadata = {'optimal_threshold': 0.5, 'model_version': 'legacy'}
103
+
104
+ print("Model loaded successfully")
105
+ print(f" Device: {DEVICE}")
106
+ print(f" Model type: {type(model).__name__}")
107
+ print(f" Optimal threshold: {optimal_threshold:.3f}")
108
+
109
+ return model, tokenizer, optimal_threshold, metadata
110
+
111
+ except Exception as e:
112
+ raise RuntimeError(f"Failed to load model: {str(e)}")
113
+
114
  def load_ohca_model(model_path):
115
  """
116
+ Backward compatibility function - loads model without metadata
117
 
118
  Args:
119
  model_path: Path to saved model directory
 
121
  Returns:
122
  tuple: (model, tokenizer)
123
  """
124
+ print(f"Loading OHCA model from: {model_path}")
125
 
126
  if not os.path.exists(model_path):
127
  raise FileNotFoundError(f"Model not found at: {model_path}")
 
133
  model = model.to(DEVICE)
134
  model.eval()
135
 
136
+ print("Model loaded successfully (legacy mode)")
137
  print(f" Device: {DEVICE}")
138
  print(f" Model type: {type(model).__name__}")
139
 
 
143
  raise RuntimeError(f"Failed to load model: {str(e)}")
144
 
145
  # =============================================================================
146
+ # IMPROVED INFERENCE FUNCTIONS
147
  # =============================================================================
148
 
149
+ def categorize_confidence_with_optimal_threshold(prob, optimal_threshold):
 
150
  """
151
+ Categorize confidence levels relative to optimal threshold
152
+
153
+ Args:
154
+ prob: Probability score
155
+ optimal_threshold: Optimal threshold from training
156
+
157
+ Returns:
158
+ tuple: (confidence_category, clinical_priority)
159
+ """
160
+ if prob >= 0.9:
161
+ return "Very High", "Immediate Review"
162
+ elif prob >= 0.7:
163
+ return "High", "Priority Review"
164
+ elif prob >= optimal_threshold:
165
+ return "Medium-High", "Clinical Review"
166
+ elif prob >= 0.3:
167
+ return "Medium", "Consider Review"
168
+ elif prob >= 0.1:
169
+ return "Low", "Routine Processing"
170
+ else:
171
+ return "Very Low", "Routine Processing"
172
+
173
+ def run_inference_with_optimal_threshold(model, tokenizer, inference_df,
174
+ optimal_threshold=0.5, batch_size=16,
175
+ output_path=None):
176
+ """
177
+ Run OHCA inference using the optimal threshold from training.
178
+ This addresses the data scientist's feedback about threshold consistency.
179
 
180
  Args:
181
  model: Pre-trained OHCA model
182
  tokenizer: Model tokenizer
183
  inference_df: DataFrame with columns ['hadm_id', 'clean_text']
184
+ optimal_threshold: Optimal threshold from validation set
185
  batch_size: Batch size for inference
186
  output_path: Optional path to save results CSV
 
187
 
188
  Returns:
189
+ DataFrame: Results with probabilities and predictions using optimal threshold
190
  """
191
+ print(f"Running OHCA inference on {len(inference_df):,} cases...")
192
+ print(f"Using optimal threshold: {optimal_threshold:.3f}")
193
 
194
  # Validate input data
195
  required_cols = ['hadm_id', 'clean_text']
 
200
  # Remove any rows with missing data
201
  clean_df = inference_df.dropna(subset=required_cols).copy()
202
  if len(clean_df) < len(inference_df):
203
+ print(f"Warning: Removed {len(inference_df) - len(clean_df)} rows with missing data")
204
 
205
  # Create dataset and dataloader
206
  inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
 
232
  'ohca_probability': all_probabilities
233
  })
234
 
235
+ # Add prediction using optimal threshold (primary prediction)
236
+ results_df['ohca_prediction'] = (results_df['ohca_probability'] >= optimal_threshold).astype(int)
237
+ results_df['optimal_threshold_used'] = optimal_threshold
238
+
239
+ # Add legacy predictions for comparison
240
  results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
241
  results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
242
  results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
+ # Add improved confidence categories and clinical priorities
245
+ confidence_info = [categorize_confidence_with_optimal_threshold(prob, optimal_threshold)
246
+ for prob in results_df['ohca_probability']]
247
+ results_df['confidence_category'] = [info[0] for info in confidence_info]
248
+ results_df['clinical_priority'] = [info[1] for info in confidence_info]
249
+
250
+ # Add interpretation column
251
+ results_df['interpretation'] = results_df.apply(
252
+ lambda row: f"OHCA detected (p={row['ohca_probability']:.3f})"
253
+ if row['ohca_prediction'] == 1
254
+ else f"No OHCA (p={row['ohca_probability']:.3f})", axis=1
255
+ )
256
 
257
  # Sort by probability (highest first)
258
  results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
259
 
260
+ # Print improved summary
261
+ print(f"\nInference Results Summary:")
262
  print(f" Total cases processed: {len(results_df):,}")
263
  print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
264
+ print(f" OHCA detected (optimal threshold): {results_df['ohca_prediction'].sum():,}")
265
+ print(f" Detection rate: {results_df['ohca_prediction'].mean()*100:.2f}%")
266
+
267
+ # Clinical priority distribution
268
+ print(f"\nClinical Priority Distribution:")
269
+ priority_dist = results_df['clinical_priority'].value_counts()
270
+ for priority, count in priority_dist.items():
 
271
  pct = count / len(results_df) * 100
272
+ print(f" {priority}: {count:,} cases ({pct:.1f}%)")
273
 
274
+ # Confidence distribution
275
+ print(f"\nConfidence Distribution:")
276
  conf_dist = results_df['confidence_category'].value_counts()
277
  for category, count in conf_dist.items():
278
  pct = count / len(results_df) * 100
 
280
 
281
  # Save results if path provided
282
  if output_path:
283
+ # Add metadata to the saved file
284
+ results_df['inference_date'] = pd.Timestamp.now().isoformat()
285
  results_df.to_csv(output_path, index=False)
286
+ print(f"\nResults saved to: {output_path}")
287
 
288
  return results_df
289
 
290
+ def run_inference(model, tokenizer, inference_df, batch_size=16,
291
+ output_path=None, probability_threshold=0.5):
292
+ """
293
+ Legacy inference function for backward compatibility
294
  """
295
+ print("Warning: Using legacy inference function. Consider upgrading to run_inference_with_optimal_threshold()")
296
+ return run_inference_with_optimal_threshold(
297
+ model, tokenizer, inference_df, probability_threshold, batch_size, output_path
298
+ )
299
+
300
+ # =============================================================================
301
+ # IMPROVED CONVENIENCE FUNCTIONS
302
+ # =============================================================================
303
+
304
+ def quick_inference_with_optimal_threshold(model_path, data_path, output_path=None):
305
+ """
306
+ Quick inference function that automatically uses the optimal threshold.
307
+ This is the recommended way to run inference with v3.0 models.
308
 
309
  Args:
310
+ model_path: Path to trained model (must include metadata)
311
+ data_path: Path to input CSV (or DataFrame)
312
+ output_path: Optional output path
313
 
314
  Returns:
315
+ DataFrame: Inference results using optimal threshold
316
  """
317
+ print("Quick OHCA Inference v3.0 with Optimal Threshold")
 
318
 
319
+ # Load model with metadata
320
+ model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
321
 
322
+ # Load data
323
+ if isinstance(data_path, str):
324
+ df = pd.read_csv(data_path)
325
+ print(f"Loaded {len(df):,} cases from {data_path}")
326
+ else:
327
+ df = data_path.copy()
328
+ print(f"Processing {len(df):,} cases from DataFrame")
329
+
330
+ # Run inference with optimal threshold
331
+ results = run_inference_with_optimal_threshold(
332
+ model, tokenizer, df, optimal_threshold, output_path=output_path
333
+ )
334
+
335
+ # Enhanced summary
336
+ ohca_cases = results['ohca_prediction'].sum()
337
+ high_priority = (results['clinical_priority'] == 'Immediate Review').sum()
338
+ priority = (results['clinical_priority'] == 'Priority Review').sum()
339
+
340
+ print(f"\nEnhanced Summary:")
341
+ print(f" OHCA detected (optimal threshold): {ohca_cases:,}")
342
+ print(f" Immediate review needed: {high_priority:,}")
343
+ print(f" Priority review needed: {priority:,}")
344
+ print(f" Model version: {metadata.get('model_version', 'unknown')}")
345
+ print(f" Optimal threshold used: {optimal_threshold:.3f}")
346
+
347
+ return results
348
 
349
+ def quick_inference(model_path, data_path, output_path=None):
350
  """
351
+ Backward compatible quick inference function.
352
+ Automatically detects if model has metadata and uses optimal threshold if available.
353
 
354
  Args:
355
+ model_path: Path to trained model
356
+ data_path: Path to input CSV (or DataFrame)
357
+ output_path: Optional output path
358
 
359
  Returns:
360
+ DataFrame: Inference results
361
  """
362
+ print("Quick OHCA Inference")
363
+
364
+ # Try to load with metadata first
365
+ metadata_path = os.path.join(model_path, 'model_metadata.json')
366
+ if os.path.exists(metadata_path):
367
+ print("Detected v3.0 model with metadata - using optimal threshold")
368
+ return quick_inference_with_optimal_threshold(model_path, data_path, output_path)
369
+ else:
370
+ print("Detected legacy model - using default threshold 0.5")
371
+ # Load model without metadata
372
+ model, tokenizer = load_ohca_model(model_path)
373
+
374
+ # Load data
375
+ if isinstance(data_path, str):
376
+ df = pd.read_csv(data_path)
377
+ print(f"Loaded {len(df):,} cases from {data_path}")
378
+ else:
379
+ df = data_path.copy()
380
+ print(f"Processing {len(df):,} cases from DataFrame")
381
+
382
+ # Run inference with default threshold
383
+ results = run_inference_with_optimal_threshold(
384
+ model, tokenizer, df, optimal_threshold=0.5, output_path=output_path
385
+ )
386
+
387
+ # Quick summary
388
+ ohca_cases = results['ohca_prediction'].sum()
389
+ high_conf = (results['ohca_probability'] >= 0.8).sum()
390
+
391
+ print(f"\nQuick Summary:")
392
+ print(f" Predicted OHCA cases: {ohca_cases:,}")
393
+ print(f" High confidence: {high_conf:,}")
394
+
395
+ return results
396
+
397
+ def analyze_predictions_enhanced(results_df):
398
+ """
399
+ Enhanced prediction analysis with optimal threshold insights
400
+
401
+ Args:
402
+ results_df: Results from inference with optimal threshold
403
+
404
+ Returns:
405
+ dict: Enhanced analysis summary
406
+ """
407
+ print("Analyzing prediction patterns with optimal threshold insights...")
408
+
409
+ optimal_threshold = results_df['optimal_threshold_used'].iloc[0] if 'optimal_threshold_used' in results_df.columns else 0.5
410
 
411
  # Basic statistics
412
  stats = {
413
  'total_cases': len(results_df),
414
+ 'optimal_threshold_used': optimal_threshold,
415
  'mean_probability': results_df['ohca_probability'].mean(),
416
  'std_probability': results_df['ohca_probability'].std(),
417
  'median_probability': results_df['ohca_probability'].median(),
418
+ 'ohca_detected_optimal': results_df.get('ohca_prediction', []).sum(),
419
  'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
420
+ 'predicted_ohca_050': results_df.get('prediction_050', []).sum(),
421
+ 'predicted_ohca_070': results_df.get('prediction_070', []).sum(),
422
+ 'predicted_ohca_090': results_df.get('prediction_090', []).sum(),
423
  }
424
 
425
+ # Clinical priority distribution
426
+ if 'clinical_priority' in results_df.columns:
427
+ priority_dist = results_df['clinical_priority'].value_counts().to_dict()
428
+ else:
429
+ priority_dist = {}
430
 
431
+ # Confidence distribution
432
+ if 'confidence_category' in results_df.columns:
433
+ conf_dist = results_df['confidence_category'].value_counts().to_dict()
434
+ else:
435
+ conf_dist = {}
436
+
437
+ # Print enhanced analysis
438
+ print(f"\nEnhanced Prediction Analysis:")
439
  print(f" Total cases: {stats['total_cases']:,}")
440
+ print(f" Optimal threshold used: {stats['optimal_threshold_used']:.3f}")
441
  print(f" Mean probability: {stats['mean_probability']:.4f}")
442
+ print(f" OHCA detected (optimal): {stats['ohca_detected_optimal']:,}")
 
443
 
444
+ if stats['ohca_detected_optimal'] > 0:
445
+ prevalence = stats['ohca_detected_optimal'] / stats['total_cases'] * 100
446
  print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
447
 
448
+ # Comparison with static thresholds
449
+ print(f"\nThreshold Comparison:")
450
+ print(f" Optimal threshold ({optimal_threshold:.3f}): {stats['ohca_detected_optimal']:,} cases")
451
+ print(f" Static threshold (0.5): {stats['predicted_ohca_050']:,} cases")
452
+ print(f" Static threshold (0.7): {stats['predicted_ohca_070']:,} cases")
 
453
 
454
+ # Clinical recommendations
455
+ print(f"\nClinical Recommendations:")
456
+ if priority_dist:
457
+ for priority, count in priority_dist.items():
458
+ if count > 0:
459
+ print(f" {priority}: {count:,} cases")
460
 
461
  return {
462
  'statistics': stats,
463
+ 'clinical_priority_distribution': priority_dist,
464
  'confidence_distribution': conf_dist,
465
+ 'optimal_threshold': optimal_threshold,
466
+ 'high_confidence_cases': results_df[results_df['ohca_probability'] >= 0.8] if len(results_df) > 0 else pd.DataFrame()
467
  }
468
 
469
  # =============================================================================
470
+ # ENHANCED BATCH PROCESSING
471
  # =============================================================================
472
 
473
+ def process_large_dataset_with_optimal_threshold(model_path, data_path, output_path,
474
+ chunk_size=10000, batch_size=16):
475
  """
476
+ Process large datasets using optimal threshold from model metadata
477
 
478
  Args:
479
+ model_path: Path to trained model with metadata
480
  data_path: Path to input CSV file
481
  output_path: Path for output results
482
  chunk_size: Number of rows per chunk
 
485
  Returns:
486
  str: Path to completed results file
487
  """
488
+ print(f"Processing large dataset in chunks of {chunk_size:,} with optimal threshold...")
489
 
490
+ # Load model with metadata once
491
+ model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
492
 
493
  # Read data in chunks
494
  chunk_results = []
 
496
 
497
  for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
498
  chunk_num += 1
499
+ print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
500
 
501
+ # Run inference on chunk with optimal threshold
502
+ chunk_result = run_inference_with_optimal_threshold(
503
+ model, tokenizer, chunk_df, optimal_threshold,
504
  batch_size=batch_size, output_path=None
505
  )
506
 
 
509
  # Save intermediate results
510
  temp_path = f"{output_path}.chunk_{chunk_num}.csv"
511
  chunk_result.to_csv(temp_path, index=False)
512
+ print(f"Chunk {chunk_num} saved to: {temp_path}")
513
 
514
  # Combine all chunks
515
+ print(f"\nCombining {len(chunk_results)} chunks...")
516
  final_results = pd.concat(chunk_results, ignore_index=True)
517
 
518
  # Sort by probability and save
519
  final_results = final_results.sort_values('ohca_probability', ascending=False)
520
+
521
+ # Add final metadata
522
+ final_results['model_version'] = metadata.get('model_version', 'unknown')
523
+ final_results['processing_date'] = pd.Timestamp.now().isoformat()
524
+
525
  final_results.to_csv(output_path, index=False)
526
 
527
+ print(f"Complete results saved to: {output_path}")
528
+ print(f"Total cases processed: {len(final_results):,}")
529
+ print(f"OHCA detected with optimal threshold: {final_results['ohca_prediction'].sum():,}")
530
 
531
  # Clean up intermediate files
532
  for i in range(1, chunk_num + 1):
 
536
 
537
  return output_path
538
 
539
+ # Legacy batch processing function
540
+ def process_large_dataset(model_path, data_path, output_path,
541
+ chunk_size=10000, batch_size=16):
542
+ """Legacy function for backward compatibility"""
543
+ metadata_path = os.path.join(model_path, 'model_metadata.json')
544
+ if os.path.exists(metadata_path):
545
+ return process_large_dataset_with_optimal_threshold(
546
+ model_path, data_path, output_path, chunk_size, batch_size
547
+ )
548
+ else:
549
+ print("Warning: Legacy model detected. Using default threshold processing.")
550
+ # Fall back to original implementation
551
+ model, tokenizer = load_ohca_model(model_path)
552
+
553
+ chunk_results = []
554
+ chunk_num = 0
555
+
556
+ for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
557
+ chunk_num += 1
558
+ print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
559
+
560
+ chunk_result = run_inference_with_optimal_threshold(
561
+ model, tokenizer, chunk_df, optimal_threshold=0.5,
562
+ batch_size=batch_size, output_path=None
563
+ )
564
+
565
+ chunk_results.append(chunk_result)
566
+
567
+ final_results = pd.concat(chunk_results, ignore_index=True)
568
+ final_results = final_results.sort_values('ohca_probability', ascending=False)
569
+ final_results.to_csv(output_path, index=False)
570
+
571
+ return output_path
572
+
573
  # =============================================================================
574
+ # ENHANCED TESTING FUNCTIONS
575
  # =============================================================================
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def test_model_on_sample(model_path, sample_texts):
578
  """
579
+ Test model on sample texts using optimal threshold if available
580
 
581
  Args:
582
  model_path: Path to trained model
583
  sample_texts: List of text strings or dict with hadm_id: text
584
 
585
  Returns:
586
+ DataFrame: Test results with optimal threshold predictions
587
  """
588
+ print("Testing model on sample texts...")
589
 
590
  # Prepare test data
591
  if isinstance(sample_texts, dict):
 
599
  for i, text in enumerate(sample_texts, 1)
600
  ])
601
 
602
+ # Try to load with metadata
603
+ metadata_path = os.path.join(model_path, 'model_metadata.json')
604
+ if os.path.exists(metadata_path):
605
+ model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
606
+ results = run_inference_with_optimal_threshold(
607
+ model, tokenizer, test_df, optimal_threshold, output_path=None
608
+ )
609
+ else:
610
+ model, tokenizer = load_ohca_model(model_path)
611
+ results = run_inference_with_optimal_threshold(
612
+ model, tokenizer, test_df, optimal_threshold=0.5, output_path=None
613
+ )
614
 
615
+ # Print enhanced results
616
+ print(f"\nTest Results:")
617
  for _, row in results.iterrows():
618
  prob = row['ohca_probability']
619
+ pred = "OHCA" if row['ohca_prediction'] == 1 else "Non-OHCA"
620
  conf = row['confidence_category']
621
+ priority = row['clinical_priority']
622
 
623
+ print(f" {row['hadm_id']}: {pred} (p={prob:.3f}, {conf}, {priority})")
624
 
625
  # Show text preview
626
  text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
 
629
 
630
  return results
631
 
632
+ # =============================================================================
633
+ # LEGACY FUNCTIONS FOR BACKWARD COMPATIBILITY
634
+ # =============================================================================
635
+
636
+ def get_high_confidence_cases(results_df, threshold=0.8, max_cases=100):
637
+ """Extract high-confidence OHCA predictions for manual review"""
638
+ high_conf = results_df[results_df['ohca_probability'] >= threshold].copy()
639
+ high_conf = high_conf.head(max_cases)
640
+
641
+ print(f"Found {len(high_conf)} high-confidence cases (≥{threshold})")
642
+
643
+ return high_conf
644
+
645
+ def analyze_predictions(results_df, original_df=None):
646
+ """Legacy analysis function - redirects to enhanced version"""
647
+ return analyze_predictions_enhanced(results_df)
648
+
649
  # =============================================================================
650
  # EXAMPLE USAGE
651
  # =============================================================================
652
 
653
  if __name__ == "__main__":
654
+ print("OHCA Inference Module v3.0 - Enhanced with Optimal Threshold Support")
655
+ print("="*75)
656
+ print("Key improvements:")
657
+ print(" Automatic optimal threshold loading and usage")
658
+ print(" Enhanced confidence categories based on optimal threshold")
659
+ print(" Clinical priority recommendations")
660
+ print(" Backward compatibility with legacy models")
661
+ print(" Enhanced analysis and reporting")
662
+ print()
663
+ print("Main functions:")
664
+ print("• quick_inference_with_optimal_threshold() - Recommended for v3.0 models")
665
+ print("• load_ohca_model_with_metadata() - Load model with optimal threshold")
666
+ print("• run_inference_with_optimal_threshold() - Enhanced inference")
667
+ print("• process_large_dataset_with_optimal_threshold() - Batch processing")
668
+ print("• analyze_predictions_enhanced() - Enhanced prediction analysis")
669
+ print()
670
+ print("Legacy functions (maintained for compatibility):")
671
+ print("• quick_inference() - Auto-detects model version")
672
+ print("• load_ohca_model() - Basic model loading")
673
+ print("• run_inference() - Basic inference")
674
+ print()
675
+ print("See examples/ folder for detailed usage examples.")
676
+ print("="*75)