ohca-classifier-v3 / src /ohca_inference.py
monajm36
Update ohca_inference.py
a5114c2 unverified
# OHCA Inference Module v3.0 - Improved with Optimal Threshold Support
# Apply pre-trained OHCA classifier to new datasets using optimal thresholds
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import os
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import warnings
warnings.filterwarnings('ignore')
# =============================================================================
# CONFIGURATION
# =============================================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Inference Module v3.0 - Using device: {DEVICE}")
# =============================================================================
# INFERENCE DATASET CLASS
# =============================================================================
class OHCAInferenceDataset(Dataset):
"""Dataset for OHCA inference on new data"""
def __init__(self, dataframe, tokenizer, max_length=512):
self.data = dataframe.reset_index(drop=True)
self.tokenizer = tokenizer
self.max_length = max_length
# Validate required columns
if 'hadm_id' not in self.data.columns or 'clean_text' not in self.data.columns:
raise ValueError("DataFrame must contain 'hadm_id' and 'clean_text' columns")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
text = str(row['clean_text'])
# Apply preprocessing consistent with training
if 'transfer' in text.lower():
text = "TRANSFERRED_PATIENT " + text
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'hadm_id': row['hadm_id']
}
# =============================================================================
# IMPROVED MODEL LOADING FUNCTIONS
# =============================================================================
def load_ohca_model_with_metadata(model_path):
"""
Load pre-trained OHCA model, tokenizer, and metadata (including optimal threshold).
This addresses the data scientist's feedback about using consistent thresholds.
Args:
model_path: Path to saved model directory
Returns:
tuple: (model, tokenizer, optimal_threshold, metadata)
"""
print(f"Loading OHCA model with metadata from: {model_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model = model.to(DEVICE)
model.eval()
# Load metadata with optimal threshold
metadata_path = os.path.join(model_path, 'model_metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
optimal_threshold = metadata.get('optimal_threshold', 0.5)
print(f"Loaded optimal threshold: {optimal_threshold:.3f}")
print(f"Model version: {metadata.get('model_version', 'unknown')}")
else:
print("Warning: No metadata file found. Using default threshold of 0.5")
optimal_threshold = 0.5
metadata = {'optimal_threshold': 0.5, 'model_version': 'legacy'}
print("Model loaded successfully")
print(f" Device: {DEVICE}")
print(f" Model type: {type(model).__name__}")
print(f" Optimal threshold: {optimal_threshold:.3f}")
return model, tokenizer, optimal_threshold, metadata
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def load_ohca_model(model_path):
"""
Backward compatibility function - loads model without metadata
Args:
model_path: Path to saved model directory
Returns:
tuple: (model, tokenizer)
"""
print(f"Loading OHCA model from: {model_path}")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model = model.to(DEVICE)
model.eval()
print("Model loaded successfully (legacy mode)")
print(f" Device: {DEVICE}")
print(f" Model type: {type(model).__name__}")
return model, tokenizer
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# =============================================================================
# IMPROVED INFERENCE FUNCTIONS
# =============================================================================
def categorize_confidence_with_optimal_threshold(prob, optimal_threshold):
"""
Categorize confidence levels relative to optimal threshold
Args:
prob: Probability score
optimal_threshold: Optimal threshold from training
Returns:
tuple: (confidence_category, clinical_priority)
"""
if prob >= 0.9:
return "Very High", "Immediate Review"
elif prob >= 0.7:
return "High", "Priority Review"
elif prob >= optimal_threshold:
return "Medium-High", "Clinical Review"
elif prob >= 0.3:
return "Medium", "Consider Review"
elif prob >= 0.1:
return "Low", "Routine Processing"
else:
return "Very Low", "Routine Processing"
def run_inference_with_optimal_threshold(model, tokenizer, inference_df,
optimal_threshold=0.5, batch_size=16,
output_path=None):
"""
Run OHCA inference using the optimal threshold from training.
This addresses the data scientist's feedback about threshold consistency.
Args:
model: Pre-trained OHCA model
tokenizer: Model tokenizer
inference_df: DataFrame with columns ['hadm_id', 'clean_text']
optimal_threshold: Optimal threshold from validation set
batch_size: Batch size for inference
output_path: Optional path to save results CSV
Returns:
DataFrame: Results with probabilities and predictions using optimal threshold
"""
print(f"Running OHCA inference on {len(inference_df):,} cases...")
print(f"Using optimal threshold: {optimal_threshold:.3f}")
# Validate input data
required_cols = ['hadm_id', 'clean_text']
missing_cols = [col for col in required_cols if col not in inference_df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Remove any rows with missing data
clean_df = inference_df.dropna(subset=required_cols).copy()
if len(clean_df) < len(inference_df):
print(f"Warning: Removed {len(inference_df) - len(clean_df)} rows with missing data")
# Create dataset and dataloader
inference_dataset = OHCAInferenceDataset(clean_df, tokenizer)
inference_dataloader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False)
# Run inference
model.eval()
all_probabilities = []
all_hadm_ids = []
with torch.no_grad():
for batch in tqdm(inference_dataloader, desc="Processing batches"):
input_ids = batch['input_ids'].to(DEVICE)
attention_mask = batch['attention_mask'].to(DEVICE)
hadm_ids = batch['hadm_id']
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probs = F.softmax(outputs.logits, dim=1)
# Get OHCA probabilities (class 1)
ohca_probs = probs[:, 1].cpu().numpy()
all_probabilities.extend(ohca_probs)
all_hadm_ids.extend(hadm_ids)
# Create results dataframe
results_df = pd.DataFrame({
'hadm_id': all_hadm_ids,
'ohca_probability': all_probabilities
})
# Add prediction using optimal threshold (primary prediction)
results_df['ohca_prediction'] = (results_df['ohca_probability'] >= optimal_threshold).astype(int)
results_df['optimal_threshold_used'] = optimal_threshold
# Add legacy predictions for comparison
results_df['prediction_050'] = (results_df['ohca_probability'] >= 0.5).astype(int)
results_df['prediction_070'] = (results_df['ohca_probability'] >= 0.7).astype(int)
results_df['prediction_090'] = (results_df['ohca_probability'] >= 0.9).astype(int)
# Add improved confidence categories and clinical priorities
confidence_info = [categorize_confidence_with_optimal_threshold(prob, optimal_threshold)
for prob in results_df['ohca_probability']]
results_df['confidence_category'] = [info[0] for info in confidence_info]
results_df['clinical_priority'] = [info[1] for info in confidence_info]
# Add interpretation column
results_df['interpretation'] = results_df.apply(
lambda row: f"OHCA detected (p={row['ohca_probability']:.3f})"
if row['ohca_prediction'] == 1
else f"No OHCA (p={row['ohca_probability']:.3f})", axis=1
)
# Sort by probability (highest first)
results_df = results_df.sort_values('ohca_probability', ascending=False).reset_index(drop=True)
# Print improved summary
print(f"\nInference Results Summary:")
print(f" Total cases processed: {len(results_df):,}")
print(f" Mean OHCA probability: {results_df['ohca_probability'].mean():.4f}")
print(f" OHCA detected (optimal threshold): {results_df['ohca_prediction'].sum():,}")
print(f" Detection rate: {results_df['ohca_prediction'].mean()*100:.2f}%")
# Clinical priority distribution
print(f"\nClinical Priority Distribution:")
priority_dist = results_df['clinical_priority'].value_counts()
for priority, count in priority_dist.items():
pct = count / len(results_df) * 100
print(f" {priority}: {count:,} cases ({pct:.1f}%)")
# Confidence distribution
print(f"\nConfidence Distribution:")
conf_dist = results_df['confidence_category'].value_counts()
for category, count in conf_dist.items():
pct = count / len(results_df) * 100
print(f" {category}: {count:,} cases ({pct:.1f}%)")
# Save results if path provided
if output_path:
# Add metadata to the saved file
results_df['inference_date'] = pd.Timestamp.now().isoformat()
results_df.to_csv(output_path, index=False)
print(f"\nResults saved to: {output_path}")
return results_df
def run_inference(model, tokenizer, inference_df, batch_size=16,
output_path=None, probability_threshold=0.5):
"""
Legacy inference function for backward compatibility
"""
print("Warning: Using legacy inference function. Consider upgrading to run_inference_with_optimal_threshold()")
return run_inference_with_optimal_threshold(
model, tokenizer, inference_df, probability_threshold, batch_size, output_path
)
# =============================================================================
# IMPROVED CONVENIENCE FUNCTIONS
# =============================================================================
def quick_inference_with_optimal_threshold(model_path, data_path, output_path=None):
"""
Quick inference function that automatically uses the optimal threshold.
This is the recommended way to run inference with v3.0 models.
Args:
model_path: Path to trained model (must include metadata)
data_path: Path to input CSV (or DataFrame)
output_path: Optional output path
Returns:
DataFrame: Inference results using optimal threshold
"""
print("Quick OHCA Inference v3.0 with Optimal Threshold")
# Load model with metadata
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
# Load data
if isinstance(data_path, str):
df = pd.read_csv(data_path)
print(f"Loaded {len(df):,} cases from {data_path}")
else:
df = data_path.copy()
print(f"Processing {len(df):,} cases from DataFrame")
# Run inference with optimal threshold
results = run_inference_with_optimal_threshold(
model, tokenizer, df, optimal_threshold, output_path=output_path
)
# Enhanced summary
ohca_cases = results['ohca_prediction'].sum()
high_priority = (results['clinical_priority'] == 'Immediate Review').sum()
priority = (results['clinical_priority'] == 'Priority Review').sum()
print(f"\nEnhanced Summary:")
print(f" OHCA detected (optimal threshold): {ohca_cases:,}")
print(f" Immediate review needed: {high_priority:,}")
print(f" Priority review needed: {priority:,}")
print(f" Model version: {metadata.get('model_version', 'unknown')}")
print(f" Optimal threshold used: {optimal_threshold:.3f}")
return results
def quick_inference(model_path, data_path, output_path=None):
"""
Backward compatible quick inference function.
Automatically detects if model has metadata and uses optimal threshold if available.
Args:
model_path: Path to trained model
data_path: Path to input CSV (or DataFrame)
output_path: Optional output path
Returns:
DataFrame: Inference results
"""
print("Quick OHCA Inference")
# Try to load with metadata first
metadata_path = os.path.join(model_path, 'model_metadata.json')
if os.path.exists(metadata_path):
print("Detected v3.0 model with metadata - using optimal threshold")
return quick_inference_with_optimal_threshold(model_path, data_path, output_path)
else:
print("Detected legacy model - using default threshold 0.5")
# Load model without metadata
model, tokenizer = load_ohca_model(model_path)
# Load data
if isinstance(data_path, str):
df = pd.read_csv(data_path)
print(f"Loaded {len(df):,} cases from {data_path}")
else:
df = data_path.copy()
print(f"Processing {len(df):,} cases from DataFrame")
# Run inference with default threshold
results = run_inference_with_optimal_threshold(
model, tokenizer, df, optimal_threshold=0.5, output_path=output_path
)
# Quick summary
ohca_cases = results['ohca_prediction'].sum()
high_conf = (results['ohca_probability'] >= 0.8).sum()
print(f"\nQuick Summary:")
print(f" Predicted OHCA cases: {ohca_cases:,}")
print(f" High confidence: {high_conf:,}")
return results
def analyze_predictions_enhanced(results_df):
"""
Enhanced prediction analysis with optimal threshold insights
Args:
results_df: Results from inference with optimal threshold
Returns:
dict: Enhanced analysis summary
"""
print("Analyzing prediction patterns with optimal threshold insights...")
optimal_threshold = results_df['optimal_threshold_used'].iloc[0] if 'optimal_threshold_used' in results_df.columns else 0.5
# Basic statistics
stats = {
'total_cases': len(results_df),
'optimal_threshold_used': optimal_threshold,
'mean_probability': results_df['ohca_probability'].mean(),
'std_probability': results_df['ohca_probability'].std(),
'median_probability': results_df['ohca_probability'].median(),
'ohca_detected_optimal': results_df.get('ohca_prediction', []).sum(),
'high_confidence_cases': (results_df['ohca_probability'] >= 0.8).sum(),
'predicted_ohca_050': results_df.get('prediction_050', []).sum(),
'predicted_ohca_070': results_df.get('prediction_070', []).sum(),
'predicted_ohca_090': results_df.get('prediction_090', []).sum(),
}
# Clinical priority distribution
if 'clinical_priority' in results_df.columns:
priority_dist = results_df['clinical_priority'].value_counts().to_dict()
else:
priority_dist = {}
# Confidence distribution
if 'confidence_category' in results_df.columns:
conf_dist = results_df['confidence_category'].value_counts().to_dict()
else:
conf_dist = {}
# Print enhanced analysis
print(f"\nEnhanced Prediction Analysis:")
print(f" Total cases: {stats['total_cases']:,}")
print(f" Optimal threshold used: {stats['optimal_threshold_used']:.3f}")
print(f" Mean probability: {stats['mean_probability']:.4f}")
print(f" OHCA detected (optimal): {stats['ohca_detected_optimal']:,}")
if stats['ohca_detected_optimal'] > 0:
prevalence = stats['ohca_detected_optimal'] / stats['total_cases'] * 100
print(f" Estimated OHCA prevalence: {prevalence:.2f}%")
# Comparison with static thresholds
print(f"\nThreshold Comparison:")
print(f" Optimal threshold ({optimal_threshold:.3f}): {stats['ohca_detected_optimal']:,} cases")
print(f" Static threshold (0.5): {stats['predicted_ohca_050']:,} cases")
print(f" Static threshold (0.7): {stats['predicted_ohca_070']:,} cases")
# Clinical recommendations
print(f"\nClinical Recommendations:")
if priority_dist:
for priority, count in priority_dist.items():
if count > 0:
print(f" {priority}: {count:,} cases")
return {
'statistics': stats,
'clinical_priority_distribution': priority_dist,
'confidence_distribution': conf_dist,
'optimal_threshold': optimal_threshold,
'high_confidence_cases': results_df[results_df['ohca_probability'] >= 0.8] if len(results_df) > 0 else pd.DataFrame()
}
# =============================================================================
# ENHANCED BATCH PROCESSING
# =============================================================================
def process_large_dataset_with_optimal_threshold(model_path, data_path, output_path,
chunk_size=10000, batch_size=16):
"""
Process large datasets using optimal threshold from model metadata
Args:
model_path: Path to trained model with metadata
data_path: Path to input CSV file
output_path: Path for output results
chunk_size: Number of rows per chunk
batch_size: Batch size for inference
Returns:
str: Path to completed results file
"""
print(f"Processing large dataset in chunks of {chunk_size:,} with optimal threshold...")
# Load model with metadata once
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
# Read data in chunks
chunk_results = []
chunk_num = 0
for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
chunk_num += 1
print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
# Run inference on chunk with optimal threshold
chunk_result = run_inference_with_optimal_threshold(
model, tokenizer, chunk_df, optimal_threshold,
batch_size=batch_size, output_path=None
)
chunk_results.append(chunk_result)
# Save intermediate results
temp_path = f"{output_path}.chunk_{chunk_num}.csv"
chunk_result.to_csv(temp_path, index=False)
print(f"Chunk {chunk_num} saved to: {temp_path}")
# Combine all chunks
print(f"\nCombining {len(chunk_results)} chunks...")
final_results = pd.concat(chunk_results, ignore_index=True)
# Sort by probability and save
final_results = final_results.sort_values('ohca_probability', ascending=False)
# Add final metadata
final_results['model_version'] = metadata.get('model_version', 'unknown')
final_results['processing_date'] = pd.Timestamp.now().isoformat()
final_results.to_csv(output_path, index=False)
print(f"Complete results saved to: {output_path}")
print(f"Total cases processed: {len(final_results):,}")
print(f"OHCA detected with optimal threshold: {final_results['ohca_prediction'].sum():,}")
# Clean up intermediate files
for i in range(1, chunk_num + 1):
temp_path = f"{output_path}.chunk_{i}.csv"
if os.path.exists(temp_path):
os.remove(temp_path)
return output_path
# Legacy batch processing function
def process_large_dataset(model_path, data_path, output_path,
chunk_size=10000, batch_size=16):
"""Legacy function for backward compatibility"""
metadata_path = os.path.join(model_path, 'model_metadata.json')
if os.path.exists(metadata_path):
return process_large_dataset_with_optimal_threshold(
model_path, data_path, output_path, chunk_size, batch_size
)
else:
print("Warning: Legacy model detected. Using default threshold processing.")
# Fall back to original implementation
model, tokenizer = load_ohca_model(model_path)
chunk_results = []
chunk_num = 0
for chunk_df in pd.read_csv(data_path, chunksize=chunk_size):
chunk_num += 1
print(f"\nProcessing chunk {chunk_num} ({len(chunk_df):,} rows)...")
chunk_result = run_inference_with_optimal_threshold(
model, tokenizer, chunk_df, optimal_threshold=0.5,
batch_size=batch_size, output_path=None
)
chunk_results.append(chunk_result)
final_results = pd.concat(chunk_results, ignore_index=True)
final_results = final_results.sort_values('ohca_probability', ascending=False)
final_results.to_csv(output_path, index=False)
return output_path
# =============================================================================
# ENHANCED TESTING FUNCTIONS
# =============================================================================
def test_model_on_sample(model_path, sample_texts):
"""
Test model on sample texts using optimal threshold if available
Args:
model_path: Path to trained model
sample_texts: List of text strings or dict with hadm_id: text
Returns:
DataFrame: Test results with optimal threshold predictions
"""
print("Testing model on sample texts...")
# Prepare test data
if isinstance(sample_texts, dict):
test_df = pd.DataFrame([
{'hadm_id': hadm_id, 'clean_text': text}
for hadm_id, text in sample_texts.items()
])
else:
test_df = pd.DataFrame([
{'hadm_id': f'TEST_{i:03d}', 'clean_text': text}
for i, text in enumerate(sample_texts, 1)
])
# Try to load with metadata
metadata_path = os.path.join(model_path, 'model_metadata.json')
if os.path.exists(metadata_path):
model, tokenizer, optimal_threshold, metadata = load_ohca_model_with_metadata(model_path)
results = run_inference_with_optimal_threshold(
model, tokenizer, test_df, optimal_threshold, output_path=None
)
else:
model, tokenizer = load_ohca_model(model_path)
results = run_inference_with_optimal_threshold(
model, tokenizer, test_df, optimal_threshold=0.5, output_path=None
)
# Print enhanced results
print(f"\nTest Results:")
for _, row in results.iterrows():
prob = row['ohca_probability']
pred = "OHCA" if row['ohca_prediction'] == 1 else "Non-OHCA"
conf = row['confidence_category']
priority = row['clinical_priority']
print(f" {row['hadm_id']}: {pred} (p={prob:.3f}, {conf}, {priority})")
# Show text preview
text_preview = test_df[test_df['hadm_id']==row['hadm_id']]['clean_text'].iloc[0]
print(f" Text: {text_preview[:100]}...")
print()
return results
# =============================================================================
# LEGACY FUNCTIONS FOR BACKWARD COMPATIBILITY
# =============================================================================
def get_high_confidence_cases(results_df, threshold=0.8, max_cases=100):
"""Extract high-confidence OHCA predictions for manual review"""
high_conf = results_df[results_df['ohca_probability'] >= threshold].copy()
high_conf = high_conf.head(max_cases)
print(f"Found {len(high_conf)} high-confidence cases (≥{threshold})")
return high_conf
def analyze_predictions(results_df, original_df=None):
"""Legacy analysis function - redirects to enhanced version"""
return analyze_predictions_enhanced(results_df)
# =============================================================================
# EXAMPLE USAGE
# =============================================================================
if __name__ == "__main__":
print("OHCA Inference Module v3.0 - Enhanced with Optimal Threshold Support")
print("="*75)
print("Key improvements:")
print("✅ Automatic optimal threshold loading and usage")
print("✅ Enhanced confidence categories based on optimal threshold")
print("✅ Clinical priority recommendations")
print("✅ Backward compatibility with legacy models")
print("✅ Enhanced analysis and reporting")
print()
print("Main functions:")
print("• quick_inference_with_optimal_threshold() - Recommended for v3.0 models")
print("• load_ohca_model_with_metadata() - Load model with optimal threshold")
print("• run_inference_with_optimal_threshold() - Enhanced inference")
print("• process_large_dataset_with_optimal_threshold() - Batch processing")
print("• analyze_predictions_enhanced() - Enhanced prediction analysis")
print()
print("Legacy functions (maintained for compatibility):")
print("• quick_inference() - Auto-detects model version")
print("• load_ohca_model() - Basic model loading")
print("• run_inference() - Basic inference")
print()
print("See examples/ folder for detailed usage examples.")
print("="*75)