IamGrooooot's picture
Model E: Unsupervised PCA + clustering risk stratification
53a6def
"""
Validation process
"""
import sys
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mlflow
from matplotlib import rcParams
from tableone import TableOne
# Set-up figures
rcParams['figure.figsize'] = 20, 5
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
def plot_cluster_size(df, data_type):
"""
Produce a bar plot of cluster size
--------
:param df: dataframe to plot
:param data_type: type of data - train, test, val, rec, sup
"""
# Number of patients
fig, ax = plt.subplots()
df.groupby('cluster').size().plot(ax=ax, kind='barh')
title = "Patient Cohorts"
ax.set_title(title)
ax.set_xlabel("Number of Patients", size=20)
ax.set_ylabel("Cluster")
plt.tight_layout()
mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png')
def plot_feature_hist(df, col, data_type):
"""
Produce a histogram plot for a chosen feature
--------
:param df: dataframe to plot
:param col: feature column to plot
:param data_type: type of data - train, test, val, rec, sup
"""
fig, ax = plt.subplots()
df.groupby('cluster')[col].plot(ax=ax, kind='hist', alpha=0.5)
ax.set_xlabel(col)
title = col + ' Histogram'
ax.set_title(title, size=20)
ax.legend()
plt.tight_layout()
mlflow.log_figure(fig, 'fig/' + title.replace(' ', '_') + '_' + data_type + '.png')
def plot_feature_bar(data, col, typ, data_type):
"""
Produce a bar plot for a chosen feature
--------
:param df: dataframe to plot
:param col: feature column to plot
:param typ: 'count' or 'percentage'
:param data_type: type of data - train, test, val, rec, sup
"""
if typ == 'count':
to_plot = data.groupby(['cluster']).apply(
lambda x: x.groupby(col).size())
x_label = "Number"
else:
to_plot = data.groupby(['cluster']).apply(
lambda x: 100 * x.groupby(col).size() / len(x))
x_label = "Percentage"
fig, ax = plt.subplots()
to_plot.plot(ax=ax, kind='barh')
title = "Patient Cohorts"
ax.set_title(title, size=20)
ax.set_xlabel(x_label + " of patients")
ax.set_ylabel("Cluster")
plt.tight_layout()
mlflow.log_figure(fig, 'fig/' + '_'.join((title.replace(' ', '_'), col, data_type + '.png')))
def plot_cluster_bar(data, typ, data_type):
"""
Produce a bar plot for a chosen feature
--------
:param data: data to plot
:param typ: 'count' or 'percentage'
:param data_type: type of data - train, test, val, rec, sup
"""
fig, ax = plt.subplots()
data.plot(ax=ax, kind='bar')
ax.set_title(typ, size=20)
ax.set_xlabel("Cluster")
ax.set_ylabel("Percentage")
ax.set_ylim(0, 100)
plt.tight_layout()
mlflow.log_figure(fig, 'fig/' + typ + '_' + data_type + '.png')
def plot_events(df, data_type):
"""
Plot events in the next 12 months based on metric table
--------
:param df: metric table
:param data_type: type of data - train, test, val, rec, sup
"""
df = df.drop('SafeHavenID', axis=1).set_index('cluster')
events = df.groupby('cluster').apply(lambda x: 100 * x.apply(
lambda x: len(x[x == 1]) / len(x)))
plot_cluster_bar(events, 'events', data_type)
def process_deceased_metrics(col):
"""
Process deceased column for plotting
-------
:param col: column to process
"""
n_deceased = 100 * ((col[col < '12+']).count()) / len(col)
res = pd.DataFrame({'alive': [100 - n_deceased], 'deceased': [n_deceased]})
return res
def plot_deceased(df, data_type):
"""
Plot events in the next 12 months based on metric table
--------
:param df: metric table
:param data_type: type of data - train, test, val, rec, sup
"""
survival = df.groupby('cluster')['time_to_death'].apply(
process_deceased_metrics).reset_index().drop(
'level_1', axis=1).set_index('cluster')
plot_cluster_bar(survival, 'survival', data_type)
def plot_therapies(df_year, results, data_type):
"""
Plot patient therapies per cluster
--------
:param df_year: unscaled data for current year
:param results: cluster results and safehaven id
:param data_type: type of data - train, test, val, rec, sup
"""
# Inhaler data for training group
therapies = df_year[['SafeHavenID', 'single_inhaler', 'double_inhaler', 'triple_inhaler']]
res_therapies = pd.merge(therapies, results, on='SafeHavenID', how='inner')
# Find counts/percentage per cluster
inhaler_cols = ['single_inhaler', 'double_inhaler', 'triple_inhaler']
inhals = res_therapies[['cluster'] + inhaler_cols].set_index('cluster')
in_res = inhals.groupby('cluster').apply(
lambda x: x.apply(lambda x: 100 * (x[x > 0].count()) / len(x)))
# Number of people without an inhaler presc
no_in = res_therapies.groupby('cluster').apply(
lambda x: 100 * len(x[(x[inhaler_cols] == 0).all(axis=1)]) / len(x)).values
# Rename columns for plotting
in_res.columns = [c[0] for c in in_res.columns.str.split('_')]
# Add those with no inhaler
in_res['no_inhaler'] = no_in
plot_cluster_bar(in_res, 'therapies', data_type)
def main():
# Load in config items
with open('../../../config.json') as json_config_file:
config = json.load(json_config_file)
data_path = config['model_data_path']
# Get datatype from cmd line
data_type = sys.argv[1]
run_name = sys.argv[2]
run_id = sys.argv[3]
# Set MLFlow parameters
model_type = 'hierarchical'
experiment_name = 'Model E - Date Specific: ' + model_type
mlflow.set_tracking_uri('http://127.0.0.1:5000/')
mlflow.set_experiment(experiment_name)
mlflow.start_run(run_id=run_id)
# Read in unscaled data, results and column names used to train model
columns = np.load(data_path + run_name + '_cols.npy', allow_pickle=True)
df_clusters = pd.read_pickle(data_path + "_".join((run_name, data_type, 'clusters.pkl')))
df_reduced = df_clusters[list(columns) + ['cluster']]
# Number of patients
plot_cluster_size(df_reduced, data_type)
# Generate mean/std table
t1_year = TableOne(df_reduced, categorical=[], groupby='cluster', pval=True)
t1yr_file = data_path + 't1_year_' + run_name + '_' + data_type + '.html'
t1_year.to_html(t1yr_file)
mlflow.log_artifact(t1yr_file)
# Histogram feature plots
plot_feature_hist(df_clusters, 'age', data_type)
plot_feature_hist(df_clusters, 'albumin_med_2yr', data_type)
# Bar plots
df_clusters['sex'] = df_clusters['sex_bin'].map({0: 'Male', 1: 'Female'})
plot_feature_bar(df_clusters, 'sex', 'percent', data_type)
plot_feature_bar(df_clusters, 'simd_decile', 'precent', data_type)
# Metrics for following 12 months
df_events = pd.read_pickle(data_path + 'metric_table_events.pkl')
df_counts = pd.read_pickle(data_path + 'metric_table_counts.pkl')
df_next = pd.read_pickle(data_path + 'metric_table_next.pkl')
# Merge cluster number with SafeHavenID and metrics
clusters = df_clusters[['SafeHavenID', 'cluster']]
df_events = clusters.merge(df_events, on='SafeHavenID', how='left').fillna(0)
df_counts = clusters.merge(df_counts, on='SafeHavenID', how='left').fillna(0)
df_next = clusters.merge(df_next, on='SafeHavenID', how='left').fillna('12+')
# Generate TableOne for events
cat_cols = df_events.columns[2:]
df_events[cat_cols] = df_events[cat_cols].astype('int')
event_limit = dict(zip(cat_cols, 5 * [1]))
event_order = dict(zip(cat_cols, 5 * [[1, 0]]))
t1_events = TableOne(df_events[df_events.columns[1:]], groupby='cluster',
limit=event_limit, order=event_order)
t1_events_file = data_path + '_'.join(('t1', data_type, 'events', run_name + '.html'))
t1_events.to_html(t1_events_file)
mlflow.log_artifact(t1_events_file)
# Generate TableOne for event counts
count_cols = df_counts.columns[2:]
df_counts[count_cols] = df_counts[count_cols].astype('int')
t1_counts = TableOne(df_counts[df_counts.columns[1:]], categorical=[], groupby='cluster')
t1_counts_file = data_path + '_'.join(('t1', data_type, 'counts', run_name + '.html'))
t1_counts.to_html(t1_counts_file)
mlflow.log_artifact(t1_counts_file)
# Generate TableOne for time to next events
next_cols = df_next.columns[2:]
next_event_order = dict(zip(next_cols, 5 * [['1', '3', '6', '12', '12+']]))
t1_next = TableOne(df_next[df_next.columns[1:]], groupby='cluster',
order=next_event_order)
t1_next_file = data_path + '_'.join(('t1', data_type, 'next', run_name + '.html'))
t1_next.to_html(t1_next_file)
mlflow.log_artifact(t1_next_file)
# Plot metrics
plot_events(df_events, data_type)
plot_deceased(df_next, data_type)
plot_therapies(df_clusters, clusters, data_type)
# Stop ML Flow
mlflow.end_run()
main()