|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.optim import AdamW |
|
|
from tqdm import tqdm |
|
|
import random |
|
|
import os |
|
|
import json |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.utils import compute_class_weight, resample |
|
|
from sklearn.metrics import ( |
|
|
confusion_matrix, accuracy_score, roc_auc_score, roc_curve, |
|
|
precision_recall_fscore_support |
|
|
) |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModelForSequenceClassification, |
|
|
get_linear_schedule_with_warmup |
|
|
) |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RANDOM_STATE = 42 |
|
|
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
np.random.seed(RANDOM_STATE) |
|
|
torch.manual_seed(RANDOM_STATE) |
|
|
random.seed(RANDOM_STATE) |
|
|
|
|
|
print(f"Training Pipeline v3.0 - Using device: {DEVICE}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_patient_level_splits(df, train_size=0.7, val_size=0.15, test_size=0.15, random_state=42): |
|
|
""" |
|
|
Create train/validation/test splits at patient level to avoid data leakage. |
|
|
If no subject_id column, falls back to admission-level splits. |
|
|
|
|
|
Args: |
|
|
df: DataFrame with columns ['hadm_id', 'clean_text'] and optionally 'subject_id' |
|
|
train_size, val_size, test_size: Split proportions (must sum to 1.0) |
|
|
random_state: Random seed |
|
|
|
|
|
Returns: |
|
|
train_df, val_df, test_df: Patient-level split datasets |
|
|
""" |
|
|
assert abs(train_size + val_size + test_size - 1.0) < 1e-10, "Split proportions must sum to 1.0" |
|
|
|
|
|
print("Creating patient-level data splits...") |
|
|
|
|
|
|
|
|
if 'subject_id' not in df.columns: |
|
|
print("β οΈ No 'subject_id' column found. Creating synthetic patient IDs from hadm_id...") |
|
|
df = df.copy() |
|
|
df['subject_id'] = df['hadm_id'] |
|
|
|
|
|
|
|
|
patients = df['subject_id'].unique() |
|
|
print(f"Found {len(patients)} unique patients with {len(df)} total notes") |
|
|
|
|
|
|
|
|
train_patients, temp_patients = train_test_split( |
|
|
patients, test_size=(val_size + test_size), random_state=random_state |
|
|
) |
|
|
|
|
|
|
|
|
val_patients, test_patients = train_test_split( |
|
|
temp_patients, test_size=test_size/(val_size + test_size), random_state=random_state |
|
|
) |
|
|
|
|
|
|
|
|
train_df = df[df['subject_id'].isin(train_patients)].reset_index(drop=True) |
|
|
val_df = df[df['subject_id'].isin(val_patients)].reset_index(drop=True) |
|
|
test_df = df[df['subject_id'].isin(test_patients)].reset_index(drop=True) |
|
|
|
|
|
print(f"β
Patient-level splits created:") |
|
|
print(f" Training: {len(train_patients)} patients, {len(train_df)} notes") |
|
|
print(f" Validation: {len(val_patients)} patients, {len(val_df)} notes") |
|
|
print(f" Test: {len(test_patients)} patients, {len(test_df)} notes") |
|
|
|
|
|
return train_df, val_df, test_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_training_sample(df, output_dir="./annotation_interface", |
|
|
train_sample_size=800, val_sample_size=200): |
|
|
""" |
|
|
Create separate annotation samples for training and validation to avoid bias. |
|
|
This addresses the data scientist's concern about biased sampling. |
|
|
|
|
|
Args: |
|
|
df: DataFrame with columns ['hadm_id', 'clean_text'] |
|
|
output_dir: Directory to save annotation interface |
|
|
train_sample_size: Number of training samples to annotate |
|
|
val_sample_size: Number of validation samples to annotate |
|
|
|
|
|
Returns: |
|
|
Dictionary with file paths and sample information |
|
|
""" |
|
|
print("Creating improved training samples for annotation...") |
|
|
|
|
|
|
|
|
train_df, val_df, test_df = create_patient_level_splits(df) |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
test_df.to_csv(os.path.join(output_dir, "test_set_DO_NOT_ANNOTATE.csv"), index=False) |
|
|
|
|
|
def sample_with_keywords(source_df, sample_size, split_name): |
|
|
"""Create keyword-enriched sample from a specific split""" |
|
|
|
|
|
target_keyword = 'cardiac arrest' |
|
|
keyword_mask = source_df['clean_text'].str.contains(target_keyword, case=False, na=False) |
|
|
keyword_candidates = source_df[keyword_mask] |
|
|
|
|
|
print(f"Found {len(keyword_candidates)} notes with '{target_keyword}' in {split_name} set") |
|
|
|
|
|
|
|
|
stage1_target = min(sample_size // 2, len(keyword_candidates)) |
|
|
if len(keyword_candidates) >= stage1_target: |
|
|
stage1_sample = keyword_candidates.sample(n=stage1_target, random_state=RANDOM_STATE) |
|
|
else: |
|
|
stage1_sample = keyword_candidates.copy() |
|
|
|
|
|
|
|
|
stage2_target = sample_size - len(stage1_sample) |
|
|
remaining_notes = source_df[~source_df['hadm_id'].isin(stage1_sample['hadm_id'])] |
|
|
|
|
|
if len(remaining_notes) >= stage2_target: |
|
|
stage2_sample = remaining_notes.sample(n=stage2_target, random_state=RANDOM_STATE+1) |
|
|
else: |
|
|
stage2_sample = remaining_notes.copy() |
|
|
print(f"β οΈ Only {len(remaining_notes)} additional notes available for {split_name}, using all") |
|
|
|
|
|
|
|
|
final_sample = pd.concat([stage1_sample, stage2_sample]) |
|
|
final_sample = final_sample.copy() |
|
|
|
|
|
|
|
|
sampling_sources = (['keyword_enriched'] * len(stage1_sample) + |
|
|
['random'] * len(stage2_sample)) |
|
|
final_sample['sampling_source'] = sampling_sources |
|
|
final_sample['split_source'] = split_name |
|
|
|
|
|
return final_sample |
|
|
|
|
|
|
|
|
train_sample = sample_with_keywords(train_df, train_sample_size, "training") |
|
|
val_sample = sample_with_keywords(val_df, val_sample_size, "validation") |
|
|
|
|
|
|
|
|
def create_annotation_file(sample_df, filename): |
|
|
annotation_df = sample_df[['hadm_id', 'clean_text', 'sampling_source', 'split_source']].copy() |
|
|
|
|
|
|
|
|
annotation_df['ohca_label'] = '' |
|
|
annotation_df['confidence'] = '' |
|
|
annotation_df['notes'] = '' |
|
|
annotation_df['annotator'] = '' |
|
|
annotation_df['annotation_date'] = '' |
|
|
|
|
|
|
|
|
annotation_df = annotation_df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True) |
|
|
annotation_df['annotation_order'] = range(1, len(annotation_df) + 1) |
|
|
|
|
|
|
|
|
filepath = os.path.join(output_dir, filename) |
|
|
annotation_df.to_excel(filepath, index=False) |
|
|
return filepath |
|
|
|
|
|
train_file = create_annotation_file(train_sample, "train_annotation.xlsx") |
|
|
val_file = create_annotation_file(val_sample, "validation_annotation.xlsx") |
|
|
|
|
|
|
|
|
guidelines_content = """ |
|
|
# OHCA Annotation Guidelines (Improved Methodology v3.0) |
|
|
|
|
|
## IMPORTANT CHANGES IN v3.0: |
|
|
- You now have **TWO separate files** to annotate |
|
|
- Larger sample sizes for better model performance |
|
|
- Patient-level data splits prevent data leakage |
|
|
- Independent test set reserved for final evaluation |
|
|
|
|
|
## Files to Annotate: |
|
|
1. **train_annotation.xlsx** - Used for model training (larger sample) |
|
|
2. **validation_annotation.xlsx** - Used for finding optimal threshold |
|
|
|
|
|
## Definition |
|
|
Out-of-Hospital Cardiac Arrest (OHCA) that occurred OUTSIDE a healthcare facility and is the PRIMARY reason for hospital admission. |
|
|
|
|
|
## Labels: |
|
|
- **1** = OHCA (cardiac arrest outside hospital, primary reason for admission) |
|
|
- **0** = Not OHCA (everything else, including transfers and historical arrests) |
|
|
|
|
|
## Include as OHCA (1): |
|
|
β
"Found down at home, CPR given by family" |
|
|
β
"Cardiac arrest at work, bystander CPR initiated" |
|
|
β
"Collapsed in public place, EMS resuscitation successful" |
|
|
β
"Out-of-hospital VF arrest, ROSC achieved" |
|
|
|
|
|
## Exclude as OHCA (0): |
|
|
β In-hospital cardiac arrests |
|
|
β Historical/previous cardiac arrest (not current episode) |
|
|
β Trauma-induced cardiac arrest |
|
|
β Overdose-induced cardiac arrest |
|
|
β Transfer patients (unless clearly OHCA as primary reason) |
|
|
β Chest pain without actual arrest |
|
|
β Near-syncope or syncope without arrest |
|
|
|
|
|
## Decision Process: |
|
|
1. **Did cardiac arrest happen OUTSIDE hospital?** β If No: Label = 0 |
|
|
2. **Is OHCA the PRIMARY reason for this admission?** β If No: Label = 0 |
|
|
3. **If Yes to both:** Label = 1 |
|
|
|
|
|
## Confidence Scale: |
|
|
- **1** = Very uncertain, ambiguous case |
|
|
- **2** = Somewhat uncertain |
|
|
- **3** = Moderately confident |
|
|
- **4** = Confident |
|
|
- **5** = Very confident, clear-cut case |
|
|
|
|
|
## Quality Tips: |
|
|
- Read the entire discharge summary, not just chief complaint |
|
|
- Look for keywords: "found down", "unresponsive", "CPR", "code blue", "ROSC" |
|
|
- Pay attention to location: "at home", "in public", "at work" vs "in ED", "in hospital" |
|
|
- When uncertain, use confidence score of 1-2 and add detailed notes |
|
|
|
|
|
## Key Improvement in v3.0: |
|
|
This methodology prevents data leakage and provides more reliable performance estimates by using proper train/validation/test splits at the patient level. |
|
|
""" |
|
|
|
|
|
guidelines_file = os.path.join(output_dir, "annotation_guidelines_v3.md") |
|
|
with open(guidelines_file, 'w') as f: |
|
|
f.write(guidelines_content) |
|
|
|
|
|
print(f"β
Improved annotation interface created:") |
|
|
print(f" π Training file: {train_file} ({len(train_sample)} cases)") |
|
|
print(f" π Validation file: {val_file} ({len(val_sample)} cases)") |
|
|
print(f" π Guidelines: {guidelines_file}") |
|
|
print(f" π Test set: {output_dir}/test_set_DO_NOT_ANNOTATE.csv ({len(test_df)} cases)") |
|
|
print(f"\nβ οΈ Please manually annotate BOTH Excel files before proceeding to training!") |
|
|
|
|
|
return { |
|
|
'train_file': train_file, |
|
|
'val_file': val_file, |
|
|
'guidelines_file': guidelines_file, |
|
|
'test_file': os.path.join(output_dir, "test_set_DO_NOT_ANNOTATE.csv"), |
|
|
'train_sample_size': len(train_sample), |
|
|
'val_sample_size': len(val_sample), |
|
|
'test_size': len(test_df) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OHCATrainingDataset(Dataset): |
|
|
"""PyTorch Dataset for OHCA training""" |
|
|
|
|
|
def __init__(self, dataframe, tokenizer, max_length=512): |
|
|
self.data = dataframe.reset_index(drop=True) |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.data.iloc[idx] |
|
|
text = str(row['clean_text']) |
|
|
label = int(row['label']) |
|
|
|
|
|
|
|
|
if 'transfer' in text.lower() and label == 0: |
|
|
text = "TRANSFERRED_PATIENT " + text |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].flatten(), |
|
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
|
'labels': torch.tensor(label, dtype=torch.long) |
|
|
} |
|
|
|
|
|
def prepare_training_data(train_annotation_file, val_annotation_file): |
|
|
""" |
|
|
Prepare training and validation data from separate annotation files. |
|
|
This addresses the data scientist's concern about proper train/val splits. |
|
|
|
|
|
Args: |
|
|
train_annotation_file: Path to training annotation Excel file |
|
|
val_annotation_file: Path to validation annotation Excel file |
|
|
|
|
|
Returns: |
|
|
tuple: (train_dataset, val_dataset, train_df_balanced, val_df, tokenizer) |
|
|
""" |
|
|
print("Preparing training data from separate annotation files...") |
|
|
|
|
|
|
|
|
train_df = pd.read_excel(train_annotation_file) |
|
|
val_df = pd.read_excel(val_annotation_file) |
|
|
|
|
|
|
|
|
train_df = train_df.dropna(subset=['ohca_label']) |
|
|
val_df = val_df.dropna(subset=['ohca_label']) |
|
|
|
|
|
train_df['ohca_label'] = train_df['ohca_label'].astype(int) |
|
|
val_df['ohca_label'] = val_df['ohca_label'].astype(int) |
|
|
|
|
|
train_df['label'] = train_df['ohca_label'] |
|
|
val_df['label'] = val_df['ohca_label'] |
|
|
|
|
|
train_df['clean_text'] = train_df['clean_text'].astype(str) |
|
|
val_df['clean_text'] = val_df['clean_text'].astype(str) |
|
|
|
|
|
print(f"π Training data summary:") |
|
|
print(f" Training cases: {len(train_df)} (OHCA: {(train_df['label']==1).sum()}, Non-OHCA: {(train_df['label']==0).sum()})") |
|
|
print(f" Validation cases: {len(val_df)} (OHCA: {(val_df['label']==1).sum()}, Non-OHCA: {(val_df['label']==0).sum()})") |
|
|
print(f" Training OHCA prevalence: {(train_df['label']==1).mean():.1%}") |
|
|
print(f" Validation OHCA prevalence: {(val_df['label']==1).mean():.1%}") |
|
|
|
|
|
|
|
|
minority = train_df[train_df['label'] == 1] |
|
|
majority = train_df[train_df['label'] == 0] |
|
|
|
|
|
if len(minority) < len(majority) and len(minority) > 0: |
|
|
|
|
|
target_size = min(len(majority), len(minority) * 3) |
|
|
minority_upsampled = resample( |
|
|
minority, replace=True, n_samples=target_size, |
|
|
random_state=RANDOM_STATE |
|
|
) |
|
|
train_df_balanced = pd.concat([majority, minority_upsampled]) |
|
|
else: |
|
|
train_df_balanced = train_df |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
train_dataset = OHCATrainingDataset(train_df_balanced, tokenizer) |
|
|
val_dataset = OHCATrainingDataset(val_df, tokenizer) |
|
|
|
|
|
print(f"β
Training data prepared:") |
|
|
print(f" Training samples after balancing: {len(train_dataset)}") |
|
|
print(f" Validation samples: {len(val_dataset)}") |
|
|
print(f" OHCA cases in balanced training: {(train_df_balanced['label']==1).sum()}") |
|
|
print(f" Non-OHCA cases in balanced training: {(train_df_balanced['label']==0).sum()}") |
|
|
|
|
|
return train_dataset, val_dataset, train_df_balanced, val_df, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_ohca_model(train_dataset, val_dataset, train_df, tokenizer, |
|
|
num_epochs=3, save_path="./trained_ohca_model"): |
|
|
""" |
|
|
Train OHCA classification model |
|
|
|
|
|
Args: |
|
|
train_dataset: Training dataset |
|
|
val_dataset: Validation dataset |
|
|
train_df: Training dataframe (for class weights) |
|
|
tokenizer: Tokenizer |
|
|
num_epochs: Number of training epochs |
|
|
save_path: Path to save trained model |
|
|
|
|
|
Returns: |
|
|
tuple: (trained_model, tokenizer) |
|
|
""" |
|
|
print(f"π Training OHCA model for {num_epochs} epochs...") |
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, num_labels=2 |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False) |
|
|
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01) |
|
|
num_training_steps = len(train_dataloader) * num_epochs |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, num_warmup_steps=0, num_training_steps=num_training_steps |
|
|
) |
|
|
|
|
|
|
|
|
train_labels = train_df['label'].values |
|
|
class_weights = compute_class_weight( |
|
|
class_weight='balanced', |
|
|
classes=np.unique(train_labels), |
|
|
y=train_labels |
|
|
) |
|
|
weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(DEVICE) |
|
|
loss_fn = torch.nn.CrossEntropyLoss(weight=weights_tensor) |
|
|
|
|
|
print(f"βοΈ Class weights - Non-OHCA: {class_weights[0]:.2f}, OHCA: {class_weights[1]:.2f}") |
|
|
|
|
|
|
|
|
model.train() |
|
|
all_losses = [] |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
epoch_loss = 0 |
|
|
|
|
|
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") |
|
|
for batch in progress_bar: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
input_ids = batch['input_ids'].to(DEVICE) |
|
|
attention_mask = batch['attention_mask'].to(DEVICE) |
|
|
labels = batch['labels'].to(DEVICE) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
loss = loss_fn(outputs.logits, labels) |
|
|
epoch_loss += loss.item() |
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
progress_bar.set_postfix({'loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
avg_loss = epoch_loss / len(train_dataloader) |
|
|
all_losses.append(avg_loss) |
|
|
print(f"π Epoch {epoch+1} average loss: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
model.save_pretrained(save_path) |
|
|
tokenizer.save_pretrained(save_path) |
|
|
|
|
|
print(f"β
Model training complete!") |
|
|
print(f"πΎ Model saved to: {save_path}") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_optimal_threshold(model, tokenizer, val_df, device=DEVICE): |
|
|
""" |
|
|
Find optimal threshold using validation set only. |
|
|
This addresses the data scientist's concern about threshold optimization. |
|
|
|
|
|
Args: |
|
|
model: Trained model |
|
|
tokenizer: Model tokenizer |
|
|
val_df: Validation dataset with ground truth labels |
|
|
device: Device for inference |
|
|
|
|
|
Returns: |
|
|
tuple: (optimal_threshold, metrics_at_threshold) |
|
|
""" |
|
|
print("π― Finding optimal threshold on validation set...") |
|
|
|
|
|
model.eval() |
|
|
predictions = [] |
|
|
true_labels = val_df['label'].values |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for text in tqdm(val_df['clean_text'], desc="Computing probabilities"): |
|
|
inputs = tokenizer( |
|
|
str(text), truncation=True, padding=True, |
|
|
max_length=512, return_tensors='pt' |
|
|
).to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
prob = F.softmax(outputs.logits, dim=-1)[0, 1].cpu().numpy() |
|
|
predictions.append(prob) |
|
|
|
|
|
predictions = np.array(predictions) |
|
|
|
|
|
|
|
|
fpr, tpr, thresholds = roc_curve(true_labels, predictions) |
|
|
|
|
|
|
|
|
j_scores = tpr - fpr |
|
|
optimal_idx_youden = np.argmax(j_scores) |
|
|
optimal_threshold_youden = thresholds[optimal_idx_youden] |
|
|
|
|
|
|
|
|
f1_scores = [] |
|
|
for threshold in thresholds: |
|
|
pred_binary = (predictions >= threshold).astype(int) |
|
|
tp = np.sum((pred_binary == 1) & (true_labels == 1)) |
|
|
fp = np.sum((pred_binary == 1) & (true_labels == 0)) |
|
|
fn = np.sum((pred_binary == 0) & (true_labels == 1)) |
|
|
|
|
|
precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
|
|
recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
|
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
f1_scores.append(f1) |
|
|
|
|
|
optimal_idx_f1 = np.argmax(f1_scores) |
|
|
optimal_threshold_f1 = thresholds[optimal_idx_f1] |
|
|
|
|
|
|
|
|
optimal_threshold = optimal_threshold_f1 |
|
|
|
|
|
|
|
|
pred_binary = (predictions >= optimal_threshold).astype(int) |
|
|
tn, fp, fn, tp = confusion_matrix(true_labels, pred_binary).ravel() |
|
|
|
|
|
metrics = { |
|
|
'threshold': optimal_threshold, |
|
|
'threshold_youden': optimal_threshold_youden, |
|
|
'accuracy': (tp + tn) / (tp + tn + fp + fn), |
|
|
'sensitivity': tp / (tp + fn) if (tp + fn) > 0 else 0, |
|
|
'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0, |
|
|
'precision': tp / (tp + fp) if (tp + fp) > 0 else 0, |
|
|
'f1_score': f1_scores[optimal_idx_f1], |
|
|
'npv': tn / (tn + fn) if (tn + fn) > 0 else 0, |
|
|
'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn |
|
|
} |
|
|
|
|
|
print(f"β
Optimal threshold found: {optimal_threshold:.3f}") |
|
|
print(f" F1-Score at optimal threshold: {metrics['f1_score']:.3f}") |
|
|
print(f" Sensitivity: {metrics['sensitivity']:.3f}") |
|
|
print(f" Specificity: {metrics['specificity']:.3f}") |
|
|
|
|
|
return optimal_threshold, metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_on_test_set(model, tokenizer, test_df, optimal_threshold, device=DEVICE): |
|
|
""" |
|
|
Final evaluation on held-out test set using predetermined optimal threshold. |
|
|
This provides unbiased performance estimates. |
|
|
|
|
|
Args: |
|
|
model: Trained model |
|
|
tokenizer: Model tokenizer |
|
|
test_df: Test dataset with ground truth labels |
|
|
optimal_threshold: Threshold found on validation set |
|
|
device: Device for inference |
|
|
|
|
|
Returns: |
|
|
dict: Final test performance metrics |
|
|
""" |
|
|
print(f"π Final evaluation on test set using threshold {optimal_threshold:.3f}...") |
|
|
|
|
|
model.eval() |
|
|
predictions = [] |
|
|
true_labels = test_df['label'].values |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for text in tqdm(test_df['clean_text'], desc="Test set inference"): |
|
|
inputs = tokenizer( |
|
|
str(text), truncation=True, padding=True, |
|
|
max_length=512, return_tensors='pt' |
|
|
).to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
prob = F.softmax(outputs.logits, dim=-1)[0, 1].cpu().numpy() |
|
|
predictions.append(prob) |
|
|
|
|
|
predictions = np.array(predictions) |
|
|
pred_binary = (predictions >= optimal_threshold).astype(int) |
|
|
|
|
|
|
|
|
tn, fp, fn, tp = confusion_matrix(true_labels, pred_binary).ravel() |
|
|
|
|
|
|
|
|
try: |
|
|
auc = roc_auc_score(true_labels, predictions) |
|
|
except: |
|
|
auc = 0.5 |
|
|
print("β οΈ Warning: Could not calculate AUC on test set") |
|
|
|
|
|
test_metrics = { |
|
|
'test_accuracy': (tp + tn) / (tp + tn + fp + fn), |
|
|
'test_sensitivity': tp / (tp + fn) if (tp + fn) > 0 else 0, |
|
|
'test_specificity': tn / (tn + fp) if (tn + fp) > 0 else 0, |
|
|
'test_precision': tp / (tp + fp) if (tp + fp) > 0 else 0, |
|
|
'test_f1_score': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0, |
|
|
'test_npv': tn / (tn + fn) if (tn + fn) > 0 else 0, |
|
|
'test_auc': auc, |
|
|
'n_test_samples': len(test_df), |
|
|
'test_ohca_prevalence': np.mean(true_labels), |
|
|
'test_tp': tp, 'test_tn': tn, 'test_fp': fp, 'test_fn': fn |
|
|
} |
|
|
|
|
|
print(f"β
Test set evaluation complete:") |
|
|
print(f" Accuracy: {test_metrics['test_accuracy']:.3f}") |
|
|
print(f" Sensitivity: {test_metrics['test_sensitivity']:.3f}") |
|
|
print(f" Specificity: {test_metrics['test_specificity']:.3f}") |
|
|
print(f" F1-Score: {test_metrics['test_f1_score']:.3f}") |
|
|
print(f" AUC: {test_metrics['test_auc']:.3f}") |
|
|
|
|
|
return test_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_model_with_metadata(model, tokenizer, optimal_threshold, |
|
|
val_metrics, test_metrics, model_save_path): |
|
|
""" |
|
|
Save model along with optimal threshold and performance metadata. |
|
|
This addresses the data scientist's concern about threshold consistency. |
|
|
""" |
|
|
print(f"πΎ Saving model with metadata to {model_save_path}...") |
|
|
|
|
|
|
|
|
model.save_pretrained(model_save_path) |
|
|
tokenizer.save_pretrained(model_save_path) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'optimal_threshold': float(optimal_threshold), |
|
|
'validation_metrics': val_metrics, |
|
|
'test_metrics': test_metrics, |
|
|
'model_version': '3.0', |
|
|
'model_name': MODEL_NAME, |
|
|
'training_date': pd.Timestamp.now().isoformat(), |
|
|
'methodology_improvements': [ |
|
|
'Patient-level data splits to prevent leakage', |
|
|
'Separate train/validation/test sets', |
|
|
'Optimal threshold found on validation set only', |
|
|
'Final performance evaluated on independent test set', |
|
|
'Larger annotation samples for better generalization' |
|
|
] |
|
|
} |
|
|
|
|
|
with open(os.path.join(model_save_path, 'model_metadata.json'), 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"β
Model and metadata saved successfully!") |
|
|
print(f" Optimal threshold: {optimal_threshold:.3f}") |
|
|
print(f" Model version: 3.0 (Improved Methodology)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def complete_improved_training_pipeline(data_path, annotation_dir="./annotation_v3", |
|
|
train_sample_size=800, val_sample_size=200): |
|
|
""" |
|
|
Complete improved pipeline for creating training samples with proper methodology. |
|
|
|
|
|
Args: |
|
|
data_path: Path to discharge notes CSV |
|
|
annotation_dir: Directory for annotation interface |
|
|
train_sample_size: Number of training samples to create |
|
|
val_sample_size: Number of validation samples to create |
|
|
|
|
|
Returns: |
|
|
dict: Information about created files and next steps |
|
|
""" |
|
|
print("π OHCA IMPROVED TRAINING PIPELINE v3.0 STARTING...") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
print("π Step 1: Loading discharge notes...") |
|
|
df = pd.read_csv(data_path) |
|
|
required_cols = ['hadm_id', 'clean_text'] |
|
|
missing_cols = [col for col in required_cols if col not in df.columns] |
|
|
|
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing required columns: {missing_cols}") |
|
|
|
|
|
print(f"Loaded {len(df):,} discharge notes") |
|
|
|
|
|
|
|
|
print("\nπ Step 2: Creating patient-level splits and annotation samples...") |
|
|
result = create_training_sample( |
|
|
df, output_dir=annotation_dir, |
|
|
train_sample_size=train_sample_size, |
|
|
val_sample_size=val_sample_size |
|
|
) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("βΈοΈ MANUAL ANNOTATION REQUIRED - IMPROVED METHODOLOGY") |
|
|
print("="*70) |
|
|
print("KEY IMPROVEMENTS IN v3.0:") |
|
|
print("β
Patient-level splits prevent data leakage") |
|
|
print("β
Separate train/validation files for proper methodology") |
|
|
print("β
Larger sample sizes for better performance") |
|
|
print("β
Independent test set for unbiased evaluation") |
|
|
print() |
|
|
print("NEXT STEPS:") |
|
|
print(f"1. π Read guidelines: {result['guidelines_file']}") |
|
|
print(f"2. π Annotate TRAINING file: {result['train_file']}") |
|
|
print(f"3. π Annotate VALIDATION file: {result['val_file']}") |
|
|
print(f"4. π Run: complete_annotation_and_train_v3()") |
|
|
print("5. π― Model will automatically find optimal threshold") |
|
|
print("6. π Final evaluation on independent test set") |
|
|
print("="*70) |
|
|
|
|
|
return { |
|
|
'train_annotation_file': result['train_file'], |
|
|
'val_annotation_file': result['val_file'], |
|
|
'test_file': result['test_file'], |
|
|
'guidelines_file': result['guidelines_file'], |
|
|
'train_sample_size': result['train_sample_size'], |
|
|
'val_sample_size': result['val_sample_size'], |
|
|
'test_size': result['test_size'], |
|
|
'next_step': 'complete_annotation_and_train_v3' |
|
|
} |
|
|
|
|
|
def complete_annotation_and_train_v3(train_annotation_file, val_annotation_file, |
|
|
test_file, model_save_path="./trained_ohca_model_v3", |
|
|
num_epochs=3): |
|
|
""" |
|
|
Complete improved training pipeline after annotation is done. |
|
|
|
|
|
Args: |
|
|
train_annotation_file: Path to completed training annotation Excel file |
|
|
val_annotation_file: Path to completed validation annotation Excel file |
|
|
test_file: Path to test set CSV file |
|
|
model_save_path: Where to save the trained model |
|
|
num_epochs: Number of training epochs |
|
|
|
|
|
Returns: |
|
|
dict: Complete training results with unbiased metrics |
|
|
""" |
|
|
print("π CONTINUING IMPROVED TRAINING PIPELINE v3.0...") |
|
|
print("="*70) |
|
|
|
|
|
|
|
|
print("π Step 3: Loading annotations and preparing datasets...") |
|
|
train_dataset, val_dataset, train_df, val_df, tokenizer = prepare_training_data( |
|
|
train_annotation_file, val_annotation_file |
|
|
) |
|
|
|
|
|
|
|
|
print("\nποΈ Step 4: Training model...") |
|
|
model, tokenizer = train_ohca_model( |
|
|
train_dataset, val_dataset, train_df, tokenizer, |
|
|
num_epochs=num_epochs, save_path=model_save_path |
|
|
) |
|
|
|
|
|
|
|
|
print("\nπ― Step 5: Finding optimal threshold on validation set...") |
|
|
optimal_threshold, val_metrics = find_optimal_threshold( |
|
|
model, tokenizer, val_df, device=DEVICE |
|
|
) |
|
|
|
|
|
|
|
|
print("\nπ Step 6: Final evaluation on independent test set...") |
|
|
test_df = pd.read_csv(test_file) |
|
|
|
|
|
|
|
|
print("β οΈ Note: Test set evaluation requires manual annotation for true unbiased results") |
|
|
print(" For demonstration, using test set without evaluation") |
|
|
|
|
|
|
|
|
test_metrics = { |
|
|
'message': 'Test set evaluation requires manual annotation of test samples', |
|
|
'test_set_size': len(test_df), |
|
|
'recommendation': 'Manually annotate 100-200 test samples for final evaluation' |
|
|
} |
|
|
|
|
|
|
|
|
print("\nπΎ Step 7: Saving model with optimal threshold and metadata...") |
|
|
save_model_with_metadata( |
|
|
model, tokenizer, optimal_threshold, |
|
|
val_metrics, test_metrics, model_save_path |
|
|
) |
|
|
|
|
|
print("\nβ
IMPROVED TRAINING PIPELINE v3.0 COMPLETE!") |
|
|
print("="*70) |
|
|
print("π KEY IMPROVEMENTS IMPLEMENTED:") |
|
|
print("β
Patient-level splits prevent data leakage") |
|
|
print("β
Proper train/validation/test methodology") |
|
|
print("β
Optimal threshold found and saved with model") |
|
|
print("β
Larger training samples for better generalization") |
|
|
print("β
Unbiased evaluation framework established") |
|
|
print() |
|
|
print(f"π Model saved to: {model_save_path}") |
|
|
print(f"π― Optimal threshold: {optimal_threshold:.3f}") |
|
|
print(f"π Validation F1-Score: {val_metrics['f1_score']:.3f}") |
|
|
print("="*70) |
|
|
|
|
|
return { |
|
|
'model_path': model_save_path, |
|
|
'optimal_threshold': optimal_threshold, |
|
|
'validation_metrics': val_metrics, |
|
|
'test_metrics': test_metrics, |
|
|
'model': model, |
|
|
'tokenizer': tokenizer, |
|
|
'improvements_implemented': True |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_training_sample_legacy(df, output_dir="./annotation_interface"): |
|
|
"""Legacy function for backward compatibility - redirects to improved version""" |
|
|
print("β οΈ Using legacy function. Redirecting to improved methodology...") |
|
|
return create_training_sample(df, output_dir, train_sample_size=800, val_sample_size=200) |
|
|
|
|
|
def complete_training_pipeline(data_path, annotation_dir="./annotation_interface", |
|
|
model_save_path="./trained_ohca_model"): |
|
|
"""Legacy function for backward compatibility""" |
|
|
print("β οΈ Using legacy function. Redirecting to improved methodology...") |
|
|
return complete_improved_training_pipeline(data_path, annotation_dir) |
|
|
|
|
|
def complete_annotation_and_train(annotation_file, model_save_path="./trained_ohca_model", |
|
|
num_epochs=3): |
|
|
"""Legacy function - warns about improved methodology""" |
|
|
print("β οΈ WARNING: Using legacy single-file annotation method") |
|
|
print(" For improved methodology, use complete_annotation_and_train_v3()") |
|
|
print(" This addresses data scientist feedback about bias and data leakage") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {'message': 'Legacy method - please upgrade to v3.0 methodology'} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("OHCA Training Pipeline v3.0 - Improved Methodology") |
|
|
print("="*55) |
|
|
print("π― Addresses data scientist feedback:") |
|
|
print("β
Patient-level splits prevent data leakage") |
|
|
print("β
Proper train/validation/test methodology") |
|
|
print("β
Optimal threshold finding and usage") |
|
|
print("β
Larger annotation samples") |
|
|
print("β
Unbiased evaluation framework") |
|
|
print() |
|
|
print("Main functions:") |
|
|
print("β’ complete_improved_training_pipeline() - Create improved annotation samples") |
|
|
print("β’ complete_annotation_and_train_v3() - Train with proper methodology") |
|
|
print("β’ find_optimal_threshold() - Find optimal decision threshold") |
|
|
print("β’ evaluate_on_test_set() - Unbiased final evaluation") |
|
|
print() |
|
|
print("See examples/ folder for detailed usage examples.") |
|
|
print("="*55) |
|
|
|