File size: 7,881 Bytes
6e347aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
prep_data.py β€” Run once to generate parquet data files for the Patient Portal app.
Saves 3 parquets to streamlit/ root (parent of this script's directory).

    cd streamlit/app
    python prep_data.py
"""
import json

import numpy as np
import pandas as pd
from datasets import load_dataset
from pathlib import Path

OUT_DIR = Path(__file__).parent.parent  # streamlit/ root

PATIENT_REPO = "ADS599-Capstone/interim_data"
DATA_REPO = "ADS599-Capstone/sbs_predictions"

print("Loading HuggingFace datasets β€” this may take a few minutes...")

patient_df = load_dataset(PATIENT_REPO, name='cohort_full', split='cohort_base').to_pandas()
weight_df = load_dataset(PATIENT_REPO, name='weight', split='weight_full').to_pandas()
height_df = load_dataset(PATIENT_REPO, name='height', split='height_full').to_pandas()
tv_df = load_dataset(PATIENT_REPO, name='triage_vitals', split='triage_vitals_full').to_pandas()
data_df = load_dataset(DATA_REPO, name='sbs_preds', split='sbs_preds_full').to_pandas()

print(f"  cohort: {len(patient_df):,} rows") # type: ignore
print(f"  data_df: {data_df.shape}") # type: ignore

# ── Sample 5k stays (1k ICU, 4k discharge) ───────────────────────────────
rng = np.random.default_rng(10)
stay_meta = data_df.drop_duplicates('ed_stay_id')[['ed_stay_id', 'terminal_event']] # type: ignore
icu_stays = stay_meta[stay_meta['terminal_event'] == 'transfer_icu']['ed_stay_id'].tolist()
dis_stays = stay_meta[stay_meta['terminal_event'] == 'discharge']['ed_stay_id'].tolist()
n_icu = min(1000, len(icu_stays))
n_dis = min(4000, len(dis_stays))
sampled_ids = set(
    rng.choice(icu_stays, n_icu, replace=False).tolist() +
    rng.choice(dis_stays, n_dis, replace=False).tolist()
)
data_df = data_df[data_df['ed_stay_id'].isin(sampled_ids)].copy() # type: ignore
print(f"  Sampled {len(sampled_ids):,} stays ({n_icu:,} ICU, {n_dis:,} discharge)")

with open(OUT_DIR / 'sampled_stay_ids.json', 'w') as f:
    json.dump(sorted(sampled_ids), f)

# ── Derive and save state_cols ────────────────────────────────────────────
dispensed_meds = list(data_df.loc[:, 'ACE Inhibitor':'Other'].columns) if 'ACE Inhibitor' in data_df.columns else [] # type: ignore
recon       = [c for c in data_df.columns if c.startswith('recon_')] # type: ignore
vitals      = [c for c in data_df.columns if c.startswith('current_')] # type: ignore
vital_change = [c for c in data_df.columns if c.endswith(('_rolling1h', '_delta', '_rate_per_min'))] # type: ignore
lab_ohe     = [c for c in data_df.columns if c.endswith(('_Normal', '_Pending', '_Abnormal')) and '-' in c] # type: ignore
micro_ohe   = [c for c in data_df.columns if c.endswith(('_Pending', '_Positive', '_Negative', '_Other')) and '-' not in c and not c.startswith(('ecg_status', 'rad_status'))] # type: ignore
ecg_ohe     = [c for c in data_df.columns if c.startswith('ecg_status')] # type: ignore
rad_ohe     = [c for c in data_df.columns if c.startswith('rad_status')] # type: ignore
arrival     = [c for c in data_df.columns if c.startswith('arrival_')] # type: ignore
missing     = [c for c in data_df.columns if c.endswith('_missing')] # type: ignore

state_cols = (
    ['gender', 'anchor_age', 'acuity', 'height', 'weight', 'time_since_last_min']
    + dispensed_meds + recon + vitals + vital_change
    + lab_ohe + micro_ohe + ecg_ohe + rad_ohe + arrival + missing
)
state_cols = [c for c in state_cols if c in data_df.columns] # type: ignore

with open(OUT_DIR / 'state_cols.json', 'w') as f:
    json.dump(state_cols, f)
print(f"  state_cols.json: {len(state_cols)} features")

eds = set(data_df['ed_stay_id'].unique()) # type: ignore

# ── Height / weight β€” single value per patient ────────────────────────────
avg_weight = weight_df.groupby('subject_id')['result_value'].mean().rename('weight').reset_index() # type: ignore
avg_height = height_df.groupby('subject_id')['result_value'].mean().rename('height').reset_index() # type: ignore

# ── patient_stats.parquet ─────────────────────────────────────────────────
# One row per ED stay (filtered to modeling data). Used by Tab 1.
cohort_cols = [
    'ed_stay_id', 'subject_id', 'ed_intime',
    'arrival_transport', 'gender', 'anchor_age', 'language',
]

triage_vital_cols = [
    c for c in tv_df.columns # type: ignore
    if c.startswith('current') and c != 'current_mean_arterial_pressure'
]
triage_cols = triage_vital_cols + ['acuity', 'chiefcomplaint', 'ed_stay_id']
triage_baseline = tv_df[tv_df['source'] == 'triage'][triage_cols] # type: ignore

patient_stats = (
    patient_df[cohort_cols] # type: ignore
    .merge(avg_weight, on='subject_id', how='left')
    .merge(avg_height, on='subject_id', how='left')
    .merge(triage_baseline, on='ed_stay_id', how='left')
)
patient_stats = patient_stats[patient_stats['ed_stay_id'].isin(eds)].copy()
patient_stats['ed_intime'] = pd.to_datetime(patient_stats['ed_intime'])

patient_stats.to_parquet(OUT_DIR / 'patient_stats.parquet', index=False)
print(f"  patient_stats.parquet: {len(patient_stats):,} rows, {len(patient_stats.columns)} cols")

# ── patient_probability.parquet ───────────────────────────────────────────
# One row per (ed_stay_id, step_idx). Used by Tab 2 trajectory chart.
prob_cols = [
    'ed_stay_id', 'in_ed', 'in_ward', 'terminal_event',
    'terminal_code', 'step_idx', 'time', 'p_icu',
]
patient_probability = data_df[prob_cols].copy() # type: ignore
patient_probability.to_parquet(OUT_DIR / 'patient_probability.parquet', index=False)
print(f"  patient_probability.parquet: {len(patient_probability):,} rows")

# ── step_features.parquet ─────────────────────────────────────────────────
# One row per (ed_stay_id, step_idx). Used by Tab 2 "what changed" and Tab 3 waterfall.
step_feature_cols = [
    'ed_stay_id', 'step_idx', 'time', 'in_ed', 'in_ward', 'terminal_code',
    # Vitals
    'current_temperature', 'current_heartrate', 'current_resprate',
    'current_o2sat', 'current_sbp', 'current_dbp', 'current_pain', 'current_map',
    # Labs β€” most clinically common
    'Chemistry-Blood_Pending',      'Chemistry-Blood_Normal',      'Chemistry-Blood_Abnormal',
    'Hematology-Blood_Pending',     'Hematology-Blood_Normal',     'Hematology-Blood_Abnormal',
    'Blood Gas-Blood_Pending',      'Blood Gas-Blood_Normal',      'Blood Gas-Blood_Abnormal',
    'BLOOD CULTURE_Pending',        'BLOOD CULTURE_Negative',      'BLOOD CULTURE_Positive',
    # ECG / Radiology
    'ecg_status_Normal', 'ecg_status_Moderate', 'ecg_status_Acute',
    'rad_status_Normal', 'rad_status_Moderate', 'rad_status_Acute',
    # Medications dispensed during stay
    'Antibiotic', 'IV Fluid', 'Analgesic - Opioid/NSAID', 'Analgesic - Acetaminophen',
    'Antiemetic', 'Anticoagulant', 'Corticosteroid',
    'Benzodiazepine - Sedative/Anxiolytic', 'Beta Blocker', 'Diuretic', 'Bronchodilator',
    # Model output
    'p_icu',
]
available_cols = [c for c in step_feature_cols if c in data_df.columns] # type: ignore
missing_cols   = [c for c in step_feature_cols if c not in data_df.columns] # type: ignore
if missing_cols:
    print(f"  WARNING: columns not found in data_df: {missing_cols}")

step_features = data_df[available_cols].copy() # type: ignore
step_features.to_parquet(OUT_DIR / 'step_features.parquet', index=False)
print(f"  step_features.parquet: {len(step_features):,} rows, {len(available_cols)} cols")

print("\nDone. Parquets saved to:", OUT_DIR)