monajm36
commited on
Update ohca_inference.py
Browse files- src/ohca_inference.py +379 -158
src/ohca_inference.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
# OHCA Inference Module
|
| 2 |
-
# Apply pre-trained OHCA classifier to new datasets
|
| 3 |
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
|
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
| 8 |
from torch.utils.data import DataLoader, Dataset
|
| 9 |
from tqdm import tqdm
|
| 10 |
import os
|
|
|
|
| 11 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 12 |
import warnings
|
| 13 |
warnings.filterwarnings('ignore')
|
|
@@ -17,7 +18,7 @@ warnings.filterwarnings('ignore')
|
|
| 17 |
# =============================================================================
|
| 18 |
|
| 19 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
-
print(f"Inference Module - Using device: {DEVICE}")
|
| 21 |
|
| 22 |
# =============================================================================
|
| 23 |
# INFERENCE DATASET CLASS
|
|
@@ -61,12 +62,58 @@ class OHCAInferenceDataset(Dataset):
|
|
| 61 |
}
|
| 62 |
|
| 63 |
# =============================================================================
|
| 64 |
-
# MODEL LOADING FUNCTIONS
|
| 65 |
# =============================================================================
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def load_ohca_model(model_path):
|
| 68 |
"""
|
| 69 |
-
|
| 70 |
|
| 71 |
Args:
|
| 72 |
model_path: Path to saved model directory
|
|
@@ -74,7 +121,7 @@ def load_ohca_model(model_path):
|
|
| 74 |
Returns:
|
| 75 |
tuple: (model, tokenizer)
|
| 76 |
"""
|
| 77 |
-
print(f"
|
| 78 |
|
| 79 |
if not os.path.exists(model_path):
|
| 80 |
raise FileNotFoundError(f"Model not found at: {model_path}")
|
|
@@ -86,7 +133,7 @@ def load_ohca_model(model_path):
|
|
| 86 |
model = model.to(DEVICE)
|
| 87 |
model.eval()
|
| 88 |
|
| 89 |
-
print("
|
| 90 |
print(f" Device: {DEVICE}")
|
| 91 |
print(f" Model type: {type(model).__name__}")
|
| 92 |
|
|
@@ -96,26 +143,53 @@ def load_ohca_model(model_path):
|
|
| 96 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 97 |
|
| 98 |
# =============================================================================
|
| 99 |
-
# INFERENCE FUNCTIONS
|
| 100 |
# =============================================================================
|
| 101 |
|
| 102 |
-
def
|
| 103 |
-
output_path=None, probability_threshold=0.5):
|
| 104 |
"""
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
Args:
|
| 108 |
model: Pre-trained OHCA model
|
| 109 |
tokenizer: Model tokenizer
|
| 110 |
inference_df: DataFrame with columns ['hadm_id', 'clean_text']
|
|
|
|
| 111 |
batch_size: Batch size for inference
|
| 112 |
output_path: Optional path to save results CSV
|
| 113 |
-
probability_threshold: Threshold for binary predictions
|
| 114 |
|
| 115 |
Returns:
|
| 116 |
-
DataFrame: Results with probabilities and predictions
|
| 117 |
"""
|
| 118 |
-
print(f"
|
|
|
|
| 119 |
|
| 120 |
# Validate input data
|
| 121 |
required_cols = ['hadm_id', 'clean_text']
|
|
@@ -126,7 +200,7 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
|
|
| 126 |
# Remove any rows with missing data
|
| 127 |
clean_df = inference_df.dropna(subset=required_cols).copy()
|
| 128 |
if len(clean_df) < len(inference_df):
|
| 129 |
-
print(f"
|
| 130 |
|
| 131 |
# Create dataset and dataloader
|
| 132 |
inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
|
|
@@ -158,47 +232,47 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
|
|
| 158 |
'ohca_probability': all_probabilities
|
| 159 |
})
|
| 160 |
|
| 161 |
-
# Add
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
|
| 163 |
results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
|
| 164 |
results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
|
| 165 |
-
results_df['prediction_custom'] = (results_df['ohca_probability'] >= probability_threshold).astype(int)
|
| 166 |
-
|
| 167 |
-
# Add confidence categories
|
| 168 |
-
def categorize_confidence(prob):
|
| 169 |
-
if prob >= 0.9:
|
| 170 |
-
return "Very High"
|
| 171 |
-
elif prob >= 0.7:
|
| 172 |
-
return "High"
|
| 173 |
-
elif prob >= 0.3:
|
| 174 |
-
return "Medium"
|
| 175 |
-
elif prob >= 0.1:
|
| 176 |
-
return "Low"
|
| 177 |
-
else:
|
| 178 |
-
return "Very Low"
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# Sort by probability (highest first)
|
| 183 |
results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
|
| 184 |
|
| 185 |
-
# Print summary
|
| 186 |
-
print(f"\
|
| 187 |
print(f" Total cases processed: {len(results_df):,}")
|
| 188 |
print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
|
| 189 |
-
print(f"
|
| 190 |
-
print(f"
|
| 191 |
-
|
| 192 |
-
#
|
| 193 |
-
print(f"\
|
| 194 |
-
|
| 195 |
-
for
|
| 196 |
-
count = (results_df['ohca_probability'] >= threshold).sum()
|
| 197 |
pct = count / len(results_df) * 100
|
| 198 |
-
print(f"
|
| 199 |
|
| 200 |
-
# Confidence
|
| 201 |
-
print(f"\
|
| 202 |
conf_dist = results_df['confidence_category'].value_counts()
|
| 203 |
for category, count in conf_dist.items():
|
| 204 |
pct = count / len(results_df) * 100
|
|
@@ -206,98 +280,203 @@ def run_inference(model, tokenizer, inference_df, batch_size=16,
|
|
| 206 |
|
| 207 |
# Save results if path provided
|
| 208 |
if output_path:
|
|
|
|
|
|
|
| 209 |
results_df.to_csv(output_path, index=False)
|
| 210 |
-
print(f"\
|
| 211 |
|
| 212 |
return results_df
|
| 213 |
|
| 214 |
-
def
|
|
|
|
|
|
|
|
|
|
| 215 |
"""
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
Args:
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
Returns:
|
| 224 |
-
DataFrame:
|
| 225 |
"""
|
| 226 |
-
|
| 227 |
-
high_conf = high_conf.head(max_cases)
|
| 228 |
|
| 229 |
-
|
|
|
|
| 230 |
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
def
|
| 234 |
"""
|
| 235 |
-
|
|
|
|
| 236 |
|
| 237 |
Args:
|
| 238 |
-
|
| 239 |
-
|
|
|
|
| 240 |
|
| 241 |
Returns:
|
| 242 |
-
|
| 243 |
"""
|
| 244 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# Basic statistics
|
| 247 |
stats = {
|
| 248 |
'total_cases': len(results_df),
|
|
|
|
| 249 |
'mean_probability': results_df['ohca_probability'].mean(),
|
| 250 |
'std_probability': results_df['ohca_probability'].std(),
|
| 251 |
'median_probability': results_df['ohca_probability'].median(),
|
|
|
|
| 252 |
'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
|
| 253 |
-
'predicted_ohca_050': results_df
|
| 254 |
-
'predicted_ohca_070': results_df
|
| 255 |
-
'predicted_ohca_090': results_df
|
| 256 |
}
|
| 257 |
|
| 258 |
-
#
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
#
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
print(f" Total cases: {stats['total_cases']:,}")
|
|
|
|
| 264 |
print(f" Mean probability: {stats['mean_probability']:.4f}")
|
| 265 |
-
print(f"
|
| 266 |
-
print(f" High confidence (≥0.8): {stats['high_confidence_cases']:,}")
|
| 267 |
|
| 268 |
-
if stats['
|
| 269 |
-
prevalence = stats['
|
| 270 |
print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
|
| 271 |
|
| 272 |
-
#
|
| 273 |
-
print(f"\
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
print(f" • Clinical review: {stats['predicted_ohca_070']} cases ≥0.7 probability")
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
if
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
|
| 284 |
return {
|
| 285 |
'statistics': stats,
|
|
|
|
| 286 |
'confidence_distribution': conf_dist,
|
| 287 |
-
'
|
|
|
|
| 288 |
}
|
| 289 |
|
| 290 |
# =============================================================================
|
| 291 |
-
# BATCH PROCESSING
|
| 292 |
# =============================================================================
|
| 293 |
|
| 294 |
-
def
|
| 295 |
-
|
| 296 |
"""
|
| 297 |
-
Process large datasets
|
| 298 |
|
| 299 |
Args:
|
| 300 |
-
model_path: Path to trained model
|
| 301 |
data_path: Path to input CSV file
|
| 302 |
output_path: Path for output results
|
| 303 |
chunk_size: Number of rows per chunk
|
|
@@ -306,10 +485,10 @@ def process_large_dataset(model_path, data_path, output_path,
|
|
| 306 |
Returns:
|
| 307 |
str: Path to completed results file
|
| 308 |
"""
|
| 309 |
-
print(f"
|
| 310 |
|
| 311 |
-
# Load model once
|
| 312 |
-
model, tokenizer =
|
| 313 |
|
| 314 |
# Read data in chunks
|
| 315 |
chunk_results = []
|
|
@@ -317,11 +496,11 @@ def process_large_dataset(model_path, data_path, output_path,
|
|
| 317 |
|
| 318 |
for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
|
| 319 |
chunk_num += 1
|
| 320 |
-
print(f"\
|
| 321 |
|
| 322 |
-
# Run inference on chunk
|
| 323 |
-
chunk_result =
|
| 324 |
-
model, tokenizer, chunk_df,
|
| 325 |
batch_size=batch_size, output_path=None
|
| 326 |
)
|
| 327 |
|
|
@@ -330,18 +509,24 @@ def process_large_dataset(model_path, data_path, output_path,
|
|
| 330 |
# Save intermediate results
|
| 331 |
temp_path = f"{output_path}.chunk_{chunk_num}.csv"
|
| 332 |
chunk_result.to_csv(temp_path, index=False)
|
| 333 |
-
print(f"
|
| 334 |
|
| 335 |
# Combine all chunks
|
| 336 |
-
print(f"\
|
| 337 |
final_results = pd.concat(chunk_results, ignore_index=True)
|
| 338 |
|
| 339 |
# Sort by probability and save
|
| 340 |
final_results = final_results.sort_values('ohca_probability', ascending=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
final_results.to_csv(output_path, index=False)
|
| 342 |
|
| 343 |
-
print(f"
|
| 344 |
-
print(f"
|
|
|
|
| 345 |
|
| 346 |
# Clean up intermediate files
|
| 347 |
for i in range(1, chunk_num + 1):
|
|
@@ -351,60 +536,56 @@ def process_large_dataset(model_path, data_path, output_path,
|
|
| 351 |
|
| 352 |
return output_path
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
# =============================================================================
|
| 355 |
-
#
|
| 356 |
# =============================================================================
|
| 357 |
|
| 358 |
-
def quick_inference(model_path, data_path, output_path=None):
|
| 359 |
-
"""
|
| 360 |
-
Quick inference function for simple use cases
|
| 361 |
-
|
| 362 |
-
Args:
|
| 363 |
-
model_path: Path to trained model
|
| 364 |
-
data_path: Path to input CSV (or DataFrame)
|
| 365 |
-
output_path: Optional output path
|
| 366 |
-
|
| 367 |
-
Returns:
|
| 368 |
-
DataFrame: Inference results
|
| 369 |
-
"""
|
| 370 |
-
print("🚀 Quick OHCA Inference")
|
| 371 |
-
|
| 372 |
-
# Load model
|
| 373 |
-
model, tokenizer = load_ohca_model(model_path)
|
| 374 |
-
|
| 375 |
-
# Load data
|
| 376 |
-
if isinstance(data_path, str):
|
| 377 |
-
df = pd.read_csv(data_path)
|
| 378 |
-
print(f"📂 Loaded {len(df):,} cases from {data_path}")
|
| 379 |
-
else:
|
| 380 |
-
df = data_path.copy()
|
| 381 |
-
print(f"📊 Processing {len(df):,} cases from DataFrame")
|
| 382 |
-
|
| 383 |
-
# Run inference
|
| 384 |
-
results = run_inference(model, tokenizer, df, output_path=output_path)
|
| 385 |
-
|
| 386 |
-
# Quick summary
|
| 387 |
-
ohca_cases = (results['ohca_probability'] >= 0.5).sum()
|
| 388 |
-
high_conf = (results['ohca_probability'] >= 0.8).sum()
|
| 389 |
-
|
| 390 |
-
print(f"\n✅ Quick Summary:")
|
| 391 |
-
print(f" Predicted OHCA cases: {ohca_cases:,}")
|
| 392 |
-
print(f" High confidence: {high_conf:,}")
|
| 393 |
-
|
| 394 |
-
return results
|
| 395 |
-
|
| 396 |
def test_model_on_sample(model_path, sample_texts):
|
| 397 |
"""
|
| 398 |
-
Test model on
|
| 399 |
|
| 400 |
Args:
|
| 401 |
model_path: Path to trained model
|
| 402 |
sample_texts: List of text strings or dict with hadm_id: text
|
| 403 |
|
| 404 |
Returns:
|
| 405 |
-
DataFrame: Test results
|
| 406 |
"""
|
| 407 |
-
print("
|
| 408 |
|
| 409 |
# Prepare test data
|
| 410 |
if isinstance(sample_texts, dict):
|
|
@@ -418,18 +599,28 @@ def test_model_on_sample(model_path, sample_texts):
|
|
| 418 |
for i, text in enumerate(sample_texts, 1)
|
| 419 |
])
|
| 420 |
|
| 421 |
-
#
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
-
# Print results
|
| 426 |
-
print(f"\
|
| 427 |
for _, row in results.iterrows():
|
| 428 |
prob = row['ohca_probability']
|
| 429 |
-
pred = "OHCA" if
|
| 430 |
conf = row['confidence_category']
|
|
|
|
| 431 |
|
| 432 |
-
print(f" {row['hadm_id']}: {pred} (
|
| 433 |
|
| 434 |
# Show text preview
|
| 435 |
text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
|
|
@@ -438,18 +629,48 @@ def test_model_on_sample(model_path, sample_texts):
|
|
| 438 |
|
| 439 |
return results
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
# =============================================================================
|
| 442 |
# EXAMPLE USAGE
|
| 443 |
# =============================================================================
|
| 444 |
|
| 445 |
if __name__ == "__main__":
|
| 446 |
-
print("OHCA Inference Module")
|
| 447 |
-
print("="*
|
| 448 |
-
print("
|
| 449 |
-
print("
|
| 450 |
-
print("
|
| 451 |
-
print("
|
| 452 |
-
print("
|
| 453 |
-
print("
|
| 454 |
-
print(
|
| 455 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OHCA Inference Module v3.0 - Improved with Optimal Threshold Support
|
| 2 |
+
# Apply pre-trained OHCA classifier to new datasets using optimal thresholds
|
| 3 |
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
|
|
|
| 8 |
from torch.utils.data import DataLoader, Dataset
|
| 9 |
from tqdm import tqdm
|
| 10 |
import os
|
| 11 |
+
import json
|
| 12 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings('ignore')
|
|
|
|
| 18 |
# =============================================================================
|
| 19 |
|
| 20 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
print(f"Inference Module v3.0 - Using device: {DEVICE}")
|
| 22 |
|
| 23 |
# =============================================================================
|
| 24 |
# INFERENCE DATASET CLASS
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# =============================================================================
|
| 65 |
+
# IMPROVED MODEL LOADING FUNCTIONS
|
| 66 |
# =============================================================================
|
| 67 |
|
| 68 |
+
def load_ohca_model_with_metadata(model_path):
|
| 69 |
+
"""
|
| 70 |
+
Load pre-trained OHCA model, tokenizer, and metadata (including optimal threshold).
|
| 71 |
+
This addresses the data scientist's feedback about using consistent thresholds.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_path: Path to saved model directory
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
tuple: (model, tokenizer, optimal_threshold, metadata)
|
| 78 |
+
"""
|
| 79 |
+
print(f"Loading OHCA model with metadata from: {model_path}")
|
| 80 |
+
|
| 81 |
+
if not os.path.exists(model_path):
|
| 82 |
+
raise FileNotFoundError(f"Model not found at: {model_path}")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Load tokenizer and model
|
| 86 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 87 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 88 |
+
model = model.to(DEVICE)
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
# Load metadata with optimal threshold
|
| 92 |
+
metadata_path = os.path.join(model_path, 'model_metadata.json')
|
| 93 |
+
if os.path.exists(metadata_path):
|
| 94 |
+
with open(metadata_path, 'r') as f:
|
| 95 |
+
metadata = json.load(f)
|
| 96 |
+
optimal_threshold = metadata.get('optimal_threshold', 0.5)
|
| 97 |
+
print(f"Loaded optimal threshold: {optimal_threshold:.3f}")
|
| 98 |
+
print(f"Model version: {metadata.get('model_version', 'unknown')}")
|
| 99 |
+
else:
|
| 100 |
+
print("Warning: No metadata file found. Using default threshold of 0.5")
|
| 101 |
+
optimal_threshold = 0.5
|
| 102 |
+
metadata = {'optimal_threshold': 0.5, 'model_version': 'legacy'}
|
| 103 |
+
|
| 104 |
+
print("Model loaded successfully")
|
| 105 |
+
print(f" Device: {DEVICE}")
|
| 106 |
+
print(f" Model type: {type(model).__name__}")
|
| 107 |
+
print(f" Optimal threshold: {optimal_threshold:.3f}")
|
| 108 |
+
|
| 109 |
+
return model, tokenizer, optimal_threshold, metadata
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 113 |
+
|
| 114 |
def load_ohca_model(model_path):
|
| 115 |
"""
|
| 116 |
+
Backward compatibility function - loads model without metadata
|
| 117 |
|
| 118 |
Args:
|
| 119 |
model_path: Path to saved model directory
|
|
|
|
| 121 |
Returns:
|
| 122 |
tuple: (model, tokenizer)
|
| 123 |
"""
|
| 124 |
+
print(f"Loading OHCA model from: {model_path}")
|
| 125 |
|
| 126 |
if not os.path.exists(model_path):
|
| 127 |
raise FileNotFoundError(f"Model not found at: {model_path}")
|
|
|
|
| 133 |
model = model.to(DEVICE)
|
| 134 |
model.eval()
|
| 135 |
|
| 136 |
+
print("Model loaded successfully (legacy mode)")
|
| 137 |
print(f" Device: {DEVICE}")
|
| 138 |
print(f" Model type: {type(model).__name__}")
|
| 139 |
|
|
|
|
| 143 |
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 144 |
|
| 145 |
# =============================================================================
|
| 146 |
+
# IMPROVED INFERENCE FUNCTIONS
|
| 147 |
# =============================================================================
|
| 148 |
|
| 149 |
+
def categorize_confidence_with_optimal_threshold(prob, optimal_threshold):
|
|
|
|
| 150 |
"""
|
| 151 |
+
Categorize confidence levels relative to optimal threshold
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
prob: Probability score
|
| 155 |
+
optimal_threshold: Optimal threshold from training
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
tuple: (confidence_category, clinical_priority)
|
| 159 |
+
"""
|
| 160 |
+
if prob >= 0.9:
|
| 161 |
+
return "Very High", "Immediate Review"
|
| 162 |
+
elif prob >= 0.7:
|
| 163 |
+
return "High", "Priority Review"
|
| 164 |
+
elif prob >= optimal_threshold:
|
| 165 |
+
return "Medium-High", "Clinical Review"
|
| 166 |
+
elif prob >= 0.3:
|
| 167 |
+
return "Medium", "Consider Review"
|
| 168 |
+
elif prob >= 0.1:
|
| 169 |
+
return "Low", "Routine Processing"
|
| 170 |
+
else:
|
| 171 |
+
return "Very Low", "Routine Processing"
|
| 172 |
+
|
| 173 |
+
def run_inference_with_optimal_threshold(model, tokenizer, inference_df,
|
| 174 |
+
optimal_threshold=0.5, batch_size=16,
|
| 175 |
+
output_path=None):
|
| 176 |
+
"""
|
| 177 |
+
Run OHCA inference using the optimal threshold from training.
|
| 178 |
+
This addresses the data scientist's feedback about threshold consistency.
|
| 179 |
|
| 180 |
Args:
|
| 181 |
model: Pre-trained OHCA model
|
| 182 |
tokenizer: Model tokenizer
|
| 183 |
inference_df: DataFrame with columns ['hadm_id', 'clean_text']
|
| 184 |
+
optimal_threshold: Optimal threshold from validation set
|
| 185 |
batch_size: Batch size for inference
|
| 186 |
output_path: Optional path to save results CSV
|
|
|
|
| 187 |
|
| 188 |
Returns:
|
| 189 |
+
DataFrame: Results with probabilities and predictions using optimal threshold
|
| 190 |
"""
|
| 191 |
+
print(f"Running OHCA inference on {len(inference_df):,} cases...")
|
| 192 |
+
print(f"Using optimal threshold: {optimal_threshold:.3f}")
|
| 193 |
|
| 194 |
# Validate input data
|
| 195 |
required_cols = ['hadm_id', 'clean_text']
|
|
|
|
| 200 |
# Remove any rows with missing data
|
| 201 |
clean_df = inference_df.dropna(subset=required_cols).copy()
|
| 202 |
if len(clean_df) < len(inference_df):
|
| 203 |
+
print(f"Warning: Removed {len(inference_df) - len(clean_df)} rows with missing data")
|
| 204 |
|
| 205 |
# Create dataset and dataloader
|
| 206 |
inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
|
|
|
|
| 232 |
'ohca_probability': all_probabilities
|
| 233 |
})
|
| 234 |
|
| 235 |
+
# Add prediction using optimal threshold (primary prediction)
|
| 236 |
+
results_df['ohca_prediction'] = (results_df['ohca_probability'] >= optimal_threshold).astype(int)
|
| 237 |
+
results_df['optimal_threshold_used'] = optimal_threshold
|
| 238 |
+
|
| 239 |
+
# Add legacy predictions for comparison
|
| 240 |
results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
|
| 241 |
results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
|
| 242 |
results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
# Add improved confidence categories and clinical priorities
|
| 245 |
+
confidence_info = [categorize_confidence_with_optimal_threshold(prob, optimal_threshold)
|
| 246 |
+
for prob in results_df['ohca_probability']]
|
| 247 |
+
results_df['confidence_category'] = [info[0] for info in confidence_info]
|
| 248 |
+
results_df['clinical_priority'] = [info[1] for info in confidence_info]
|
| 249 |
+
|
| 250 |
+
# Add interpretation column
|
| 251 |
+
results_df['interpretation'] = results_df.apply(
|
| 252 |
+
lambda row: f"OHCA detected (p={row['ohca_probability']:.3f})"
|
| 253 |
+
if row['ohca_prediction'] == 1
|
| 254 |
+
else f"No OHCA (p={row['ohca_probability']:.3f})", axis=1
|
| 255 |
+
)
|
| 256 |
|
| 257 |
# Sort by probability (highest first)
|
| 258 |
results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
|
| 259 |
|
| 260 |
+
# Print improved summary
|
| 261 |
+
print(f"\nInference Results Summary:")
|
| 262 |
print(f" Total cases processed: {len(results_df):,}")
|
| 263 |
print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
|
| 264 |
+
print(f" OHCA detected (optimal threshold): {results_df['ohca_prediction'].sum():,}")
|
| 265 |
+
print(f" Detection rate: {results_df['ohca_prediction'].mean()*100:.2f}%")
|
| 266 |
+
|
| 267 |
+
# Clinical priority distribution
|
| 268 |
+
print(f"\nClinical Priority Distribution:")
|
| 269 |
+
priority_dist = results_df['clinical_priority'].value_counts()
|
| 270 |
+
for priority, count in priority_dist.items():
|
|
|
|
| 271 |
pct = count / len(results_df) * 100
|
| 272 |
+
print(f" {priority}: {count:,} cases ({pct:.1f}%)")
|
| 273 |
|
| 274 |
+
# Confidence distribution
|
| 275 |
+
print(f"\nConfidence Distribution:")
|
| 276 |
conf_dist = results_df['confidence_category'].value_counts()
|
| 277 |
for category, count in conf_dist.items():
|
| 278 |
pct = count / len(results_df) * 100
|
|
|
|
| 280 |
|
| 281 |
# Save results if path provided
|
| 282 |
if output_path:
|
| 283 |
+
# Add metadata to the saved file
|
| 284 |
+
results_df['inference_date'] = pd.Timestamp.now().isoformat()
|
| 285 |
results_df.to_csv(output_path, index=False)
|
| 286 |
+
print(f"\nResults saved to: {output_path}")
|
| 287 |
|
| 288 |
return results_df
|
| 289 |
|
| 290 |
+
def run_inference(model, tokenizer, inference_df, batch_size=16,
|
| 291 |
+
output_path=None, probability_threshold=0.5):
|
| 292 |
+
"""
|
| 293 |
+
Legacy inference function for backward compatibility
|
| 294 |
"""
|
| 295 |
+
print("Warning: Using legacy inference function. Consider upgrading to run_inference_with_optimal_threshold()")
|
| 296 |
+
return run_inference_with_optimal_threshold(
|
| 297 |
+
model, tokenizer, inference_df, probability_threshold, batch_size, output_path
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# =============================================================================
|
| 301 |
+
# IMPROVED CONVENIENCE FUNCTIONS
|
| 302 |
+
# =============================================================================
|
| 303 |
+
|
| 304 |
+
def quick_inference_with_optimal_threshold(model_path, data_path, output_path=None):
|
| 305 |
+
"""
|
| 306 |
+
Quick inference function that automatically uses the optimal threshold.
|
| 307 |
+
This is the recommended way to run inference with v3.0 models.
|
| 308 |
|
| 309 |
Args:
|
| 310 |
+
model_path: Path to trained model (must include metadata)
|
| 311 |
+
data_path: Path to input CSV (or DataFrame)
|
| 312 |
+
output_path: Optional output path
|
| 313 |
|
| 314 |
Returns:
|
| 315 |
+
DataFrame: Inference results using optimal threshold
|
| 316 |
"""
|
| 317 |
+
print("Quick OHCA Inference v3.0 with Optimal Threshold")
|
|
|
|
| 318 |
|
| 319 |
+
# Load model with metadata
|
| 320 |
+
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
|
| 321 |
|
| 322 |
+
# Load data
|
| 323 |
+
if isinstance(data_path, str):
|
| 324 |
+
df = pd.read_csv(data_path)
|
| 325 |
+
print(f"Loaded {len(df):,} cases from {data_path}")
|
| 326 |
+
else:
|
| 327 |
+
df = data_path.copy()
|
| 328 |
+
print(f"Processing {len(df):,} cases from DataFrame")
|
| 329 |
+
|
| 330 |
+
# Run inference with optimal threshold
|
| 331 |
+
results = run_inference_with_optimal_threshold(
|
| 332 |
+
model, tokenizer, df, optimal_threshold, output_path=output_path
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Enhanced summary
|
| 336 |
+
ohca_cases = results['ohca_prediction'].sum()
|
| 337 |
+
high_priority = (results['clinical_priority'] == 'Immediate Review').sum()
|
| 338 |
+
priority = (results['clinical_priority'] == 'Priority Review').sum()
|
| 339 |
+
|
| 340 |
+
print(f"\nEnhanced Summary:")
|
| 341 |
+
print(f" OHCA detected (optimal threshold): {ohca_cases:,}")
|
| 342 |
+
print(f" Immediate review needed: {high_priority:,}")
|
| 343 |
+
print(f" Priority review needed: {priority:,}")
|
| 344 |
+
print(f" Model version: {metadata.get('model_version', 'unknown')}")
|
| 345 |
+
print(f" Optimal threshold used: {optimal_threshold:.3f}")
|
| 346 |
+
|
| 347 |
+
return results
|
| 348 |
|
| 349 |
+
def quick_inference(model_path, data_path, output_path=None):
|
| 350 |
"""
|
| 351 |
+
Backward compatible quick inference function.
|
| 352 |
+
Automatically detects if model has metadata and uses optimal threshold if available.
|
| 353 |
|
| 354 |
Args:
|
| 355 |
+
model_path: Path to trained model
|
| 356 |
+
data_path: Path to input CSV (or DataFrame)
|
| 357 |
+
output_path: Optional output path
|
| 358 |
|
| 359 |
Returns:
|
| 360 |
+
DataFrame: Inference results
|
| 361 |
"""
|
| 362 |
+
print("Quick OHCA Inference")
|
| 363 |
+
|
| 364 |
+
# Try to load with metadata first
|
| 365 |
+
metadata_path = os.path.join(model_path, 'model_metadata.json')
|
| 366 |
+
if os.path.exists(metadata_path):
|
| 367 |
+
print("Detected v3.0 model with metadata - using optimal threshold")
|
| 368 |
+
return quick_inference_with_optimal_threshold(model_path, data_path, output_path)
|
| 369 |
+
else:
|
| 370 |
+
print("Detected legacy model - using default threshold 0.5")
|
| 371 |
+
# Load model without metadata
|
| 372 |
+
model, tokenizer = load_ohca_model(model_path)
|
| 373 |
+
|
| 374 |
+
# Load data
|
| 375 |
+
if isinstance(data_path, str):
|
| 376 |
+
df = pd.read_csv(data_path)
|
| 377 |
+
print(f"Loaded {len(df):,} cases from {data_path}")
|
| 378 |
+
else:
|
| 379 |
+
df = data_path.copy()
|
| 380 |
+
print(f"Processing {len(df):,} cases from DataFrame")
|
| 381 |
+
|
| 382 |
+
# Run inference with default threshold
|
| 383 |
+
results = run_inference_with_optimal_threshold(
|
| 384 |
+
model, tokenizer, df, optimal_threshold=0.5, output_path=output_path
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Quick summary
|
| 388 |
+
ohca_cases = results['ohca_prediction'].sum()
|
| 389 |
+
high_conf = (results['ohca_probability'] >= 0.8).sum()
|
| 390 |
+
|
| 391 |
+
print(f"\nQuick Summary:")
|
| 392 |
+
print(f" Predicted OHCA cases: {ohca_cases:,}")
|
| 393 |
+
print(f" High confidence: {high_conf:,}")
|
| 394 |
+
|
| 395 |
+
return results
|
| 396 |
+
|
| 397 |
+
def analyze_predictions_enhanced(results_df):
|
| 398 |
+
"""
|
| 399 |
+
Enhanced prediction analysis with optimal threshold insights
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
results_df: Results from inference with optimal threshold
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
dict: Enhanced analysis summary
|
| 406 |
+
"""
|
| 407 |
+
print("Analyzing prediction patterns with optimal threshold insights...")
|
| 408 |
+
|
| 409 |
+
optimal_threshold = results_df['optimal_threshold_used'].iloc[0] if 'optimal_threshold_used' in results_df.columns else 0.5
|
| 410 |
|
| 411 |
# Basic statistics
|
| 412 |
stats = {
|
| 413 |
'total_cases': len(results_df),
|
| 414 |
+
'optimal_threshold_used': optimal_threshold,
|
| 415 |
'mean_probability': results_df['ohca_probability'].mean(),
|
| 416 |
'std_probability': results_df['ohca_probability'].std(),
|
| 417 |
'median_probability': results_df['ohca_probability'].median(),
|
| 418 |
+
'ohca_detected_optimal': results_df.get('ohca_prediction', []).sum(),
|
| 419 |
'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
|
| 420 |
+
'predicted_ohca_050': results_df.get('prediction_050', []).sum(),
|
| 421 |
+
'predicted_ohca_070': results_df.get('prediction_070', []).sum(),
|
| 422 |
+
'predicted_ohca_090': results_df.get('prediction_090', []).sum(),
|
| 423 |
}
|
| 424 |
|
| 425 |
+
# Clinical priority distribution
|
| 426 |
+
if 'clinical_priority' in results_df.columns:
|
| 427 |
+
priority_dist = results_df['clinical_priority'].value_counts().to_dict()
|
| 428 |
+
else:
|
| 429 |
+
priority_dist = {}
|
| 430 |
|
| 431 |
+
# Confidence distribution
|
| 432 |
+
if 'confidence_category' in results_df.columns:
|
| 433 |
+
conf_dist = results_df['confidence_category'].value_counts().to_dict()
|
| 434 |
+
else:
|
| 435 |
+
conf_dist = {}
|
| 436 |
+
|
| 437 |
+
# Print enhanced analysis
|
| 438 |
+
print(f"\nEnhanced Prediction Analysis:")
|
| 439 |
print(f" Total cases: {stats['total_cases']:,}")
|
| 440 |
+
print(f" Optimal threshold used: {stats['optimal_threshold_used']:.3f}")
|
| 441 |
print(f" Mean probability: {stats['mean_probability']:.4f}")
|
| 442 |
+
print(f" OHCA detected (optimal): {stats['ohca_detected_optimal']:,}")
|
|
|
|
| 443 |
|
| 444 |
+
if stats['ohca_detected_optimal'] > 0:
|
| 445 |
+
prevalence = stats['ohca_detected_optimal'] / stats['total_cases'] * 100
|
| 446 |
print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
|
| 447 |
|
| 448 |
+
# Comparison with static thresholds
|
| 449 |
+
print(f"\nThreshold Comparison:")
|
| 450 |
+
print(f" Optimal threshold ({optimal_threshold:.3f}): {stats['ohca_detected_optimal']:,} cases")
|
| 451 |
+
print(f" Static threshold (0.5): {stats['predicted_ohca_050']:,} cases")
|
| 452 |
+
print(f" Static threshold (0.7): {stats['predicted_ohca_070']:,} cases")
|
|
|
|
| 453 |
|
| 454 |
+
# Clinical recommendations
|
| 455 |
+
print(f"\nClinical Recommendations:")
|
| 456 |
+
if priority_dist:
|
| 457 |
+
for priority, count in priority_dist.items():
|
| 458 |
+
if count > 0:
|
| 459 |
+
print(f" {priority}: {count:,} cases")
|
| 460 |
|
| 461 |
return {
|
| 462 |
'statistics': stats,
|
| 463 |
+
'clinical_priority_distribution': priority_dist,
|
| 464 |
'confidence_distribution': conf_dist,
|
| 465 |
+
'optimal_threshold': optimal_threshold,
|
| 466 |
+
'high_confidence_cases': results_df[results_df['ohca_probability'] >= 0.8] if len(results_df) > 0 else pd.DataFrame()
|
| 467 |
}
|
| 468 |
|
| 469 |
# =============================================================================
|
| 470 |
+
# ENHANCED BATCH PROCESSING
|
| 471 |
# =============================================================================
|
| 472 |
|
| 473 |
+
def process_large_dataset_with_optimal_threshold(model_path, data_path, output_path,
|
| 474 |
+
chunk_size=10000, batch_size=16):
|
| 475 |
"""
|
| 476 |
+
Process large datasets using optimal threshold from model metadata
|
| 477 |
|
| 478 |
Args:
|
| 479 |
+
model_path: Path to trained model with metadata
|
| 480 |
data_path: Path to input CSV file
|
| 481 |
output_path: Path for output results
|
| 482 |
chunk_size: Number of rows per chunk
|
|
|
|
| 485 |
Returns:
|
| 486 |
str: Path to completed results file
|
| 487 |
"""
|
| 488 |
+
print(f"Processing large dataset in chunks of {chunk_size:,} with optimal threshold...")
|
| 489 |
|
| 490 |
+
# Load model with metadata once
|
| 491 |
+
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
|
| 492 |
|
| 493 |
# Read data in chunks
|
| 494 |
chunk_results = []
|
|
|
|
| 496 |
|
| 497 |
for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
|
| 498 |
chunk_num += 1
|
| 499 |
+
print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
|
| 500 |
|
| 501 |
+
# Run inference on chunk with optimal threshold
|
| 502 |
+
chunk_result = run_inference_with_optimal_threshold(
|
| 503 |
+
model, tokenizer, chunk_df, optimal_threshold,
|
| 504 |
batch_size=batch_size, output_path=None
|
| 505 |
)
|
| 506 |
|
|
|
|
| 509 |
# Save intermediate results
|
| 510 |
temp_path = f"{output_path}.chunk_{chunk_num}.csv"
|
| 511 |
chunk_result.to_csv(temp_path, index=False)
|
| 512 |
+
print(f"Chunk {chunk_num} saved to: {temp_path}")
|
| 513 |
|
| 514 |
# Combine all chunks
|
| 515 |
+
print(f"\nCombining {len(chunk_results)} chunks...")
|
| 516 |
final_results = pd.concat(chunk_results, ignore_index=True)
|
| 517 |
|
| 518 |
# Sort by probability and save
|
| 519 |
final_results = final_results.sort_values('ohca_probability', ascending=False)
|
| 520 |
+
|
| 521 |
+
# Add final metadata
|
| 522 |
+
final_results['model_version'] = metadata.get('model_version', 'unknown')
|
| 523 |
+
final_results['processing_date'] = pd.Timestamp.now().isoformat()
|
| 524 |
+
|
| 525 |
final_results.to_csv(output_path, index=False)
|
| 526 |
|
| 527 |
+
print(f"Complete results saved to: {output_path}")
|
| 528 |
+
print(f"Total cases processed: {len(final_results):,}")
|
| 529 |
+
print(f"OHCA detected with optimal threshold: {final_results['ohca_prediction'].sum():,}")
|
| 530 |
|
| 531 |
# Clean up intermediate files
|
| 532 |
for i in range(1, chunk_num + 1):
|
|
|
|
| 536 |
|
| 537 |
return output_path
|
| 538 |
|
| 539 |
+
# Legacy batch processing function
|
| 540 |
+
def process_large_dataset(model_path, data_path, output_path,
|
| 541 |
+
chunk_size=10000, batch_size=16):
|
| 542 |
+
"""Legacy function for backward compatibility"""
|
| 543 |
+
metadata_path = os.path.join(model_path, 'model_metadata.json')
|
| 544 |
+
if os.path.exists(metadata_path):
|
| 545 |
+
return process_large_dataset_with_optimal_threshold(
|
| 546 |
+
model_path, data_path, output_path, chunk_size, batch_size
|
| 547 |
+
)
|
| 548 |
+
else:
|
| 549 |
+
print("Warning: Legacy model detected. Using default threshold processing.")
|
| 550 |
+
# Fall back to original implementation
|
| 551 |
+
model, tokenizer = load_ohca_model(model_path)
|
| 552 |
+
|
| 553 |
+
chunk_results = []
|
| 554 |
+
chunk_num = 0
|
| 555 |
+
|
| 556 |
+
for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
|
| 557 |
+
chunk_num += 1
|
| 558 |
+
print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
|
| 559 |
+
|
| 560 |
+
chunk_result = run_inference_with_optimal_threshold(
|
| 561 |
+
model, tokenizer, chunk_df, optimal_threshold=0.5,
|
| 562 |
+
batch_size=batch_size, output_path=None
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
chunk_results.append(chunk_result)
|
| 566 |
+
|
| 567 |
+
final_results = pd.concat(chunk_results, ignore_index=True)
|
| 568 |
+
final_results = final_results.sort_values('ohca_probability', ascending=False)
|
| 569 |
+
final_results.to_csv(output_path, index=False)
|
| 570 |
+
|
| 571 |
+
return output_path
|
| 572 |
+
|
| 573 |
# =============================================================================
|
| 574 |
+
# ENHANCED TESTING FUNCTIONS
|
| 575 |
# =============================================================================
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
def test_model_on_sample(model_path, sample_texts):
|
| 578 |
"""
|
| 579 |
+
Test model on sample texts using optimal threshold if available
|
| 580 |
|
| 581 |
Args:
|
| 582 |
model_path: Path to trained model
|
| 583 |
sample_texts: List of text strings or dict with hadm_id: text
|
| 584 |
|
| 585 |
Returns:
|
| 586 |
+
DataFrame: Test results with optimal threshold predictions
|
| 587 |
"""
|
| 588 |
+
print("Testing model on sample texts...")
|
| 589 |
|
| 590 |
# Prepare test data
|
| 591 |
if isinstance(sample_texts, dict):
|
|
|
|
| 599 |
for i, text in enumerate(sample_texts, 1)
|
| 600 |
])
|
| 601 |
|
| 602 |
+
# Try to load with metadata
|
| 603 |
+
metadata_path = os.path.join(model_path, 'model_metadata.json')
|
| 604 |
+
if os.path.exists(metadata_path):
|
| 605 |
+
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
|
| 606 |
+
results = run_inference_with_optimal_threshold(
|
| 607 |
+
model, tokenizer, test_df, optimal_threshold, output_path=None
|
| 608 |
+
)
|
| 609 |
+
else:
|
| 610 |
+
model, tokenizer = load_ohca_model(model_path)
|
| 611 |
+
results = run_inference_with_optimal_threshold(
|
| 612 |
+
model, tokenizer, test_df, optimal_threshold=0.5, output_path=None
|
| 613 |
+
)
|
| 614 |
|
| 615 |
+
# Print enhanced results
|
| 616 |
+
print(f"\nTest Results:")
|
| 617 |
for _, row in results.iterrows():
|
| 618 |
prob = row['ohca_probability']
|
| 619 |
+
pred = "OHCA" if row['ohca_prediction'] == 1 else "Non-OHCA"
|
| 620 |
conf = row['confidence_category']
|
| 621 |
+
priority = row['clinical_priority']
|
| 622 |
|
| 623 |
+
print(f" {row['hadm_id']}: {pred} (p={prob:.3f}, {conf}, {priority})")
|
| 624 |
|
| 625 |
# Show text preview
|
| 626 |
text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
|
|
|
|
| 629 |
|
| 630 |
return results
|
| 631 |
|
| 632 |
+
# =============================================================================
|
| 633 |
+
# LEGACY FUNCTIONS FOR BACKWARD COMPATIBILITY
|
| 634 |
+
# =============================================================================
|
| 635 |
+
|
| 636 |
+
def get_high_confidence_cases(results_df, threshold=0.8, max_cases=100):
|
| 637 |
+
"""Extract high-confidence OHCA predictions for manual review"""
|
| 638 |
+
high_conf = results_df[results_df['ohca_probability'] >= threshold].copy()
|
| 639 |
+
high_conf = high_conf.head(max_cases)
|
| 640 |
+
|
| 641 |
+
print(f"Found {len(high_conf)} high-confidence cases (≥{threshold})")
|
| 642 |
+
|
| 643 |
+
return high_conf
|
| 644 |
+
|
| 645 |
+
def analyze_predictions(results_df, original_df=None):
|
| 646 |
+
"""Legacy analysis function - redirects to enhanced version"""
|
| 647 |
+
return analyze_predictions_enhanced(results_df)
|
| 648 |
+
|
| 649 |
# =============================================================================
|
| 650 |
# EXAMPLE USAGE
|
| 651 |
# =============================================================================
|
| 652 |
|
| 653 |
if __name__ == "__main__":
|
| 654 |
+
print("OHCA Inference Module v3.0 - Enhanced with Optimal Threshold Support")
|
| 655 |
+
print("="*75)
|
| 656 |
+
print("Key improvements:")
|
| 657 |
+
print("✅ Automatic optimal threshold loading and usage")
|
| 658 |
+
print("✅ Enhanced confidence categories based on optimal threshold")
|
| 659 |
+
print("✅ Clinical priority recommendations")
|
| 660 |
+
print("✅ Backward compatibility with legacy models")
|
| 661 |
+
print("✅ Enhanced analysis and reporting")
|
| 662 |
+
print()
|
| 663 |
+
print("Main functions:")
|
| 664 |
+
print("• quick_inference_with_optimal_threshold() - Recommended for v3.0 models")
|
| 665 |
+
print("• load_ohca_model_with_metadata() - Load model with optimal threshold")
|
| 666 |
+
print("• run_inference_with_optimal_threshold() - Enhanced inference")
|
| 667 |
+
print("• process_large_dataset_with_optimal_threshold() - Batch processing")
|
| 668 |
+
print("• analyze_predictions_enhanced() - Enhanced prediction analysis")
|
| 669 |
+
print()
|
| 670 |
+
print("Legacy functions (maintained for compatibility):")
|
| 671 |
+
print("• quick_inference() - Auto-detects model version")
|
| 672 |
+
print("• load_ohca_model() - Basic model loading")
|
| 673 |
+
print("• run_inference() - Basic inference")
|
| 674 |
+
print()
|
| 675 |
+
print("See examples/ folder for detailed usage examples.")
|
| 676 |
+
print("="*75)
|