haiquanchen's picture
Upload app.py
472f3e1 verified
#pip install kaleido
#pip install gradio
import gradio as gr
#import os
#import random
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm.auto import tqdm
#!pip install einops
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#pip install captum
import seaborn as sns
from captum.attr import LayerConductance
from captum.attr import IntegratedGradients
from captum.attr import configure_interpretable_embedding_layer
import matplotlib.pyplot as plt
from captum.attr import remove_interpretable_embedding_layer
import torch.nn.functional as F
# @title
import pandas as pd
import numpy as np
import tensorflow as tf
Raw_data = pd.read_excel('./STS Data with Up to dated AF 09-18-2022 (1).xlsx',usecols=lambda x: 'Unnamed' not in x)
pd.set_option('display.max_columns', None)
Raw_data['Aortic_Insufficiency']=Raw_data['Aortic_Insufficiency'].astype(np.int64)
Postop_columns = ['PostOpMedCoumadin',
'PostOpMedLipidLowering',
'PostOpMedAspirin',
'PostOpMedADPInhibitors',
'PostOpMedACE_ARBInhibitors',
'STS_PostOp.Renal_Failure',
'Oth_Cardiac_Arrest',
'Complications_Any',
'Neuro_Stroke_Permanent',
'Neuro_Stroke_Permanent',
'Neuro_Continuous_Coma',
'Neuro_Delirium',
'PostOpSepsis',
'Reop_Bleeding',
'Oth_OtherComplication',
'PostOpNeuroStrokeTransientTIA',
'Infect_Sternum_Deep',
'Infect_Thoracotomy',
'Pulm_Ventilator_Prolonged',
'Pulm_Pneumonia',
'Oth_Tamponade',
'Oth_Anticoagulant',
'Oth_MultiSystem_Failure',
'Oth_GI',
'Vasc_Ao_Dissection',
'Infect_Leg',
'PostOpInfectionArm',
'OthCard_Pacemaker',
'PostOpCreatinineLevel',
'Renal_Dialysis_Required',
'PostOpBloodRBCUnits',
'PostOpBloodFFPUnits',
'PostOpBloodCryoUnits',
'PostOpBloodPlateletUnits',
'ExtubatedI0R',
'InitHrsVentilated',
'ReIntubated',
'No Add Hrs Ventilator',
'PostOpVentHoursTotal',
'InitHrsICU',
'ReadmitICU',
'AddICUHours',
'TotHrsICU',
'DCMed_AntiPlate',
'Readmit_LessThan30Days',
'Blood_Bank_Products_Used',
'PostOpMedAntiarrhythmics',
'PostOpMedBetaBlockers']
#Dropping columns
preop_oper_data = Raw_data.drop(columns=Postop_columns)
preop_oper_data = preop_oper_data[preop_oper_data.Oth_Afib != -1]
preop_oper_data=preop_oper_data.drop(['Date_of_Birth','Surgery_Date','Discharge_Date', 'Death_Date','Death-Surery(y)','Mortality30d','Mortality1y','Mortality2y','Mortality3y','Mortality4y','Mortality5y'], axis=1)
preop_oper_data=preop_oper_data.drop(['EF','Height(cm)','Weight(kg)','CVA_When','Category','Race'], axis=1)
preop_oper_data=preop_oper_data.drop(['IABP_Indication','IABP_When'],axis=1)
#seperating continuous and Categorical Data
continous_df = preop_oper_data[['Oth_Afib','Age','BMI','LastCreatinineLevel','Cross_Clamp_Time','Perfusion_Time']]
categorical_col=list(set(preop_oper_data.columns) - set(continous_df.columns))
categorical_col.sort()
#Setting the target
df_AF=preop_oper_data['Oth_Afib']
continous_col=continous_df.drop('Oth_Afib',axis=1)
preop_oper_data=preop_oper_data.drop('Oth_Afib',axis=1)
#Creating categorical df
preop_oper_data=preop_oper_data[categorical_col]
# Label encoding
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
def encode_text_index(df, name):
le = preprocessing.LabelEncoder()
df[name] = le.fit_transform(df[name])
return le.classes_
encode_text_index(preop_oper_data,'Introp DEX or nDEX')
encode_text_index(preop_oper_data,'Status')
encode_text_index(preop_oper_data,'Gender')
#Calculating len of each categorical column
label_in_each = tuple(len(preop_oper_data[col].unique()) for col in preop_oper_data.columns)
categorical_col_with_ordinal = preop_oper_data.columns
#Making the final data frame
final_frame=pd.concat([continous_col,preop_oper_data],axis=1)
final_frame=pd.concat([final_frame,df_AF],axis=1)
# Encode a numeric column as zscores
def encode_numeric_zscore(df, name, mean=None, sd=None):
if mean is None:
mean = df[name].mean()
print(f'mean:{mean}')
if sd is None:
sd = df[name].std()
print(f'sd:{sd}')
df[name] = (df[name] - mean) / sd
for col in continous_col.columns:
encode_numeric_zscore(final_frame,col)
#Train test split
from sklearn.model_selection import train_test_split
x_train, x_temp, y_train, y_temp = train_test_split(final_frame.iloc[:,:-1], final_frame.iloc[:,-1], test_size=0.25, random_state=42, stratify=final_frame.iloc[:,-1])
print(x_train.shape)
print(y_train.shape)
print(x_temp.shape)
print(y_temp.shape)
# Duplicating class 1 records to balance dataset for training
training_frame = pd.concat([x_train, y_train],axis=1)
training_frame_ana = training_frame
class_1_rows = training_frame[training_frame['Oth_Afib'] == 1]
duplicated_class_1 = class_1_rows.copy()
training_frame= pd.concat([training_frame, duplicated_class_1,duplicated_class_1], ignore_index=True)
training_frame['Oth_Afib'].value_counts()
# Creating testing df
testing_frame = pd.concat([x_temp, y_temp],axis=1)
continous_df= continous_df.drop('Oth_Afib', axis=1)
continous_col=continous_df.columns
continous_col
training_frame_without_label=training_frame.iloc[:,:-1]
testing_frame_without_label=testing_frame.iloc[:,:-1]
training_frame=pd.concat([training_frame_without_label,pd.get_dummies(training_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
testing_frame=pd.concat([testing_frame_without_label,pd.get_dummies(testing_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
testing_frame
# @title
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# attention
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x, **kwargs):
return self.net(x)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 16,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
dropped_attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out), attn
# transformer
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
super().__init__()
# torch.manual_seed(1)
# self.embeds = nn.Embedding(num_tokens, dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
]))
def forward(self, x, return_attn = False):
# x = self.embeds(x)
post_softmax_attns = []
for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)
x = x + attn_out
x = ff(x) + x
if not return_attn:
return x
# return x, torch.stack(post_softmax_attns)
return x
# mlp
class MLP(nn.Module):
def __init__(self, dims, act = None):
super().__init__()
dims_pairs = list(zip(dims[:-1], dims[1:]))
layers = []
for ind, (dim_in, dim_out) in enumerate(dims_pairs):
is_last = ind >= (len(dims_pairs) - 1)
linear = nn.Linear(dim_in, dim_out)
layers.append(linear)
if is_last:
continue
act = default(act, nn.ReLU())
layers.append(act)
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
class NumericalEmbedder(nn.Module):
def __init__(self, dim, num_numerical_types):
super().__init__()
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))
def forward(self, x):
x = rearrange(x, 'b n -> b n 1')
return x * self.weights + self.biases
class CategoricalEmbedder(nn.Module):
def __init__(self, total_tokens,dim):
super().__init__()
self.embeds = nn.Embedding(total_tokens, dim)
def forward(self, x):
x_embed = self.embeds(x)
return x_embed
class CatConLayer(nn.Module):
def __init__(self, dim , heads ):
super().__init__()
self.cat_con_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
self.con_cat_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
def forward(self,attn_cat,attn_con,need_weights=False):
cat_Q,_ = self.cat_con_multihead_attn(attn_cat,attn_con,attn_con)
con_Q,_ = self.con_cat_multihead_attn(attn_con,attn_cat,attn_cat)
cat_Q=cat_Q.permute(1, 0, 2)
con_Q=con_Q.permute(1, 0, 2)
# output_concat = torch.cat([cat_Q, con_Q], dim=0)
return cat_Q,con_Q
# main class
class Co_Transformer(nn.Module):
def __init__(
self,
*,
categories,
num_continuous,
dim,
depth,
heads,
dim_head = 16,
dim_out = 1,
mlp_hidden_mults = (2,1),
mlp_act = None,
num_special_tokens = 0,
continuous_mean_std = None,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
assert len(categories) + num_continuous > 0, 'input shape must not be null'
self.num_categories = len(categories)
self.num_unique_categories = sum(categories)
self.num_special_tokens = num_special_tokens #0
total_tokens = self.num_unique_categories + num_special_tokens
# for automatically offsetting unique category ids to the correct position in the categories embedding table
if self.num_unique_categories > 0:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)
self.embeds = CategoricalEmbedder(total_tokens, dim)
# continuous
self.num_continuous = num_continuous
if self.num_continuous > 0:
self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)
# transformer
self.transformer_cat = Transformer(
dim = dim,
depth = depth,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
self.transformer_con = Transformer(
dim = dim,
depth = depth,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)
# fusion-part
self.catconlayer = CatConLayer(
dim=dim ,
heads=heads
)
# mlp to logits
input_size = dim * (self.num_categories + num_continuous)
print(f'input size{input_size}')
l = input_size // 5
hidden_dimensions = list(map(lambda t: int(l * t), mlp_hidden_mults))
all_dimensions = [input_size, *hidden_dimensions, dim_out]
self.mlp = MLP(all_dimensions, act = mlp_act)
# print(f" mlp {self.mlp}")
def forward(self, x_categ, x_cont, return_attn = True):
x_categ = self.embeds(x_categ)
# x_cat_, attns_cat = self.transformer(x_categ, return_attn = True)
x_cat_ = self.transformer_cat(x_categ, return_attn = True)
permuted_x_cat_= x_cat_.permute(1, 0, 2)
x_numer = self.numerical_embedder(x_cont)
# x_con_, attns_con = self.transformer(x_numer, return_attn = True)
x_con_ = self.transformer_con(x_numer, return_attn = True)
permuted_x_con_= x_con_.permute(1, 0, 2)
cat_Q,con_Q = self.catconlayer(permuted_x_cat_,permuted_x_con_ )
can_con_attn_output = torch.cat([cat_Q, con_Q], dim=1)
# permuted_can_con_attn_output= can_con_attn_output.permute(1, 0, 2)
can_con_attn_output_flattend= can_con_attn_output.flatten(1)
logits=self.mlp(can_con_attn_output_flattend)
return logits
def build_network(depth,heads,dim):
model = Co_Transformer(
categories = label_in_each , # tuple containing the number of unique values within each category
num_continuous = final_frame[continous_col].shape[1], # number of continuous values
dim = dim, # dimension, paper set at 32
dim_out = 2, # binary prediction, but could be anything
depth = depth, # depth, paper recommended 6
heads = heads, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1, # feed forward dropout
mlp_hidden_mults =((2,1,0.5,0.25)), # relative multiples of each hidden dimension of the last mlp to logits
mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
continuous_mean_std = torch.tensor(continous_df.agg(['mean','std']).transpose().values, dtype=torch.float32) # (optional) - normalize the continuous values before layer norm
)
return model
model = build_network(8,8,64)
model.load_state_dict(torch.load('./co_attention_transformer_model_trained.pth',map_location=torch.device('cpu')))
# print("Model Loaded!")
sample_Df=pd.read_csv('./sample_data.csv')
# @title
def run_inference(num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82):
mean1=62.63038219641993
sd1=12.280987675098727
mean2=29.238482057573915
sd2=6.511823065330923
mean3=1.2360909530720852
sd3=1.0307534004581604
mean4=137.98209966134493
sd4=58.71032609323997
mean5=193.39671020803095
sd5=79.63715724430536
num1 = (num1 - mean1)/sd1
num2 = (num2 - mean2)/sd2
num3 = (num3 - mean3)/sd3
num4 = (num4 - mean4)/sd4
num5 = (num5 - mean5)/sd5
list__inputs = [num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82]
print(list__inputs)
if (num0 == 'First_non_AFib' or num0 == 'Second_non_AFib'):
target_set = 0
else:
target_set = 1
# Remove specific elements from nested lists at specified indices
result_list =[item for sublist in list__inputs for item in (sublist if isinstance(sublist, list) else [sublist])]
print(result_list)
con = torch.tensor(result_list[0:5],dtype=torch.float32).reshape(1,-1)
print(con,con.shape,con.device.type)
cat = torch.tensor(result_list[5:82],dtype=torch.long).reshape(1,-1)
print(cat,cat.shape,cat.device.type)
model.eval()
output_tup=model(cat,con)
prob=F.softmax(output_tup,dim=-1)
print(prob[0][0].detach(),prob[0][1].detach())
# Categories for the bar plot
categories = ['Non-AFIB', 'A-FIB']
# Values for the bar plot
values = [(prob[0][0]).detach().numpy(), (prob[0][1]).detach().numpy()]
fig1 = plt.figure()
plt.barh(categories, values, color=['green', 'red'])
plt.xlabel('Values')
plt.ylabel('Labels')
# cal embedding attributes
ig = IntegratedGradients(model)
interpretable_embedding_cat = configure_interpretable_embedding_layer(model, 'embeds')
interpretable_embedding_con = configure_interpretable_embedding_layer(model, 'numerical_embedder')
emb_cat = interpretable_embedding_cat.indices_to_embeddings(cat)
emb_con = interpretable_embedding_con.indices_to_embeddings(con)
print(emb_cat.device.type)
baseline_cat = torch.zeros_like(emb_cat) # Set numerical baseline to zero
baseline_con = torch.zeros_like(emb_con)
emb_cat.requires_grad_
emb_cat.requires_grad_
attr, delta =ig.attribute((emb_cat, emb_con),baselines = (baseline_cat,baseline_con) ,target=target_set, return_convergence_delta=True, n_steps=50)
print("calculating attr")
print(attr[0].shape)
print(attr[1].shape)
categ_attr = (attr[0]).sum(dim=-1).squeeze(0)
cond_attr = (attr[1]).sum(dim=-1).squeeze(0)
concatenated_tensor = torch.cat([cond_attr, categ_attr],dim=0)
print(concatenated_tensor.device.type)
x_pos = (np.arange(len(testing_frame.iloc[:,0:-2].columns)))
fig2 = plt.figure(figsize=(30,6))
plt.bar(x_pos,concatenated_tensor.squeeze().cpu().detach().numpy(), align='center', color = 'red')
plt.xticks(x_pos,testing_frame.iloc[:,0:-2].columns, wrap=True)
plt.xticks(rotation=45)
plt.xlabel('Features')
plt.title('Embedded layer attributes')
# layer attributes
attn_con_cat = []
attn_con_cat.append(concatenated_tensor.detach().cpu())
for i in range(len(model.transformer_con.layers)):
con_module = [module for module in model.transformer_con.layers[i]]
layeroutput_con = []
for j in range(len(con_module)):
lc_con = LayerConductance(model, con_module[j])
layer_attributions_con= lc_con.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
if(type(layer_attributions_con) == "tuple"):
layeroutput_con.append(layer_attributions_con[0])
else:
layeroutput_con.append(layer_attributions_con)
attn_out_con = emb_con + layeroutput_con[0][0] + layeroutput_con[1]
attn_out_con=attn_out_con.sum(dim=-1).squeeze(0)
cat_module = [module for module in model.transformer_cat.layers[i]]
layeroutput_cat = []
for j in range(len(cat_module)):
lc_cat = LayerConductance(model, cat_module[j])
layer_attributions_cat= lc_cat.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
if(type(layer_attributions_cat) == "tuple"):
layeroutput_cat.append(layer_attributions_cat[0])
else:
layeroutput_cat.append(layer_attributions_cat)
attn_out_cat = emb_cat + layeroutput_cat[0][0] + layeroutput_cat[1]
attn_out_cat=attn_out_cat.sum(dim=-1).squeeze(0)
attn_con_cat.append((torch.cat([attn_out_con,attn_out_cat])).detach().cpu())
lc = LayerConductance(model, model.catconlayer)
layer_attributions_start = lc.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
value_coattn_cat=layer_attributions_start[0].sum(dim=-1).squeeze(0)
value_coattn_con=layer_attributions_start[1].sum(dim=-1).squeeze(0)
attn_con_cat.append(torch.cat([value_coattn_cat,value_coattn_con]).detach().cpu())
# fig 3
fig3, axes = plt.subplots(figsize=(15, 12),frameon=False)
for spine in plt.gca().spines.values():
spine.set_visible(False)
axes.xaxis.set_major_locator(plt.NullLocator())
axes.yaxis.set_major_locator(plt.NullLocator())
for i,k in enumerate(testing_frame.iloc[:,:-60].columns):
cmap = sns.color_palette("Reds")
# cmap = sns.cm.rocket_r
ax = fig3.add_subplot(5,5, i+1)
xticklabels=[k]
yticklabels=list(range(1,9))
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
plt.xlabel('features')
plt.ylabel('Layers')
plt.tight_layout()
# fig 4
fig4, axes = plt.subplots(figsize=(15, 12),frameon=False)
for spine in plt.gca().spines.values():
spine.set_visible(False)
axes.xaxis.set_major_locator(plt.NullLocator())
axes.yaxis.set_major_locator(plt.NullLocator())
for i,k in enumerate(testing_frame.iloc[:,24:-30].columns):
cmap = sns.color_palette("Reds")
# cmap = sns.cm.rocket_r
ax = fig4.add_subplot(6,5, i+1)
xticklabels=[k]
yticklabels=list(range(1,9))
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
plt.xlabel('features')
plt.ylabel('Layers')
plt.tight_layout()
# fig 5
fig5, axes = plt.subplots(figsize=(15, 12),frameon=False)
for spine in plt.gca().spines.values():
spine.set_visible(False)
axes.xaxis.set_major_locator(plt.NullLocator())
axes.yaxis.set_major_locator(plt.NullLocator())
for i,k in enumerate(testing_frame.iloc[:,54:-2].columns):
cmap = sns.color_palette("Reds")
# cmap = sns.cm.rocket_r
ax = fig5.add_subplot(6,5, i+1)
xticklabels=[k]
yticklabels=list(range(1,9))
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
plt.xlabel('features')
plt.ylabel('Layers')
plt.tight_layout()
#fig6
x_pos = (np.arange(len(testing_frame.iloc[:,:-2].columns)))
fig6 = plt.figure(figsize=(30,6))
plt.bar(x_pos, torch.stack(attn_con_cat)[9], align='center', color = 'red')
plt.xticks(x_pos,testing_frame.iloc[:,:-2].columns, wrap=True)
plt.xticks(rotation=45)
plt.xlabel('features')
plt.title('Attribution of co-attention layer')
remove_interpretable_embedding_layer(model, interpretable_embedding_con)
remove_interpretable_embedding_layer(model, interpretable_embedding_cat)
return fig1 , fig2 , fig3 , fig4, fig5, fig6
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Post-Operative Artrial Fibrillation Demo
Select values for the following and click submit to see the results:
""")
num0=gr.Textbox(visible = False)
num1=gr.Slider(0,100,label='Age',step=1)
num2=gr.Slider(0,100,label='BMI')
num3=gr.Slider(0,20,label='LastCreatinineLevel')
num4=gr.Slider(0,1000,label='Cross_Clamp_Time',step=1)
num5=gr.Slider(0,1000,label='Perfusion_Time',step=1)
num6=gr.Slider(0,8,label='# of coronary vessels corrected',step=1)
num7=gr.CheckboxGroup([0,1],label='Aortic stenosis')
num8=gr.CheckboxGroup([0,1,2,3,4],label='Aortic_Insufficiency')
num9=gr.CheckboxGroup([0,1],label='Aortic_Procedure')
num10=gr.CheckboxGroup([0,1],label='Arrhythmia')
num11=gr.CheckboxGroup([0,1],label='ArrhythmiaAfibAflutter')
num12=gr.CheckboxGroup([0,1],label='CABG')
num13=gr.CheckboxGroup([0,1],label='CHF')
num14=gr.CheckboxGroup([0,1],label='CVA')
num15=gr.CheckboxGroup([0,1],label='Cardiogenic_Shock')
num16=gr.CheckboxGroup([0,1],label='Cerebrovascular_Disease')
num17=gr.CheckboxGroup([0,1,2,3],label='ChronicLungDisease')
num18=gr.CheckboxGroup([0,1],label='Diabetes')
num19=gr.CheckboxGroup([0,1],label='Dialysis')
num20=gr.Slider(0,7,label='DistAnasVein',step=1)
num21=gr.Slider(0,6,label='DistAnastArt',step=1)
num22=gr.CheckboxGroup([0,1],label='Family_History_CAD')
num23=gr.CheckboxGroup([0,1],label='Gender')
num24=gr.CheckboxGroup([0,1],label='Hypercholesterolemia')
num25=gr.CheckboxGroup([0,1],label='Hypertension')
num26=gr.CheckboxGroup([0,1],label='IABP')
num27=gr.CheckboxGroup([0,1],label='Infectious_Endocarditis')
num28=gr.Slider(0,6,label='IntraopBloodCryo',step=1)
num29=gr.Slider(0,8,label='IntraopBloodFFP',step=1)
num30=gr.CheckboxGroup([0,1,2],label='IntraopBloodFactorVII')
num31=gr.Slider(0,8,label='IntraopBloodPlatelet',step=1)
num32=gr.CheckboxGroup([0,1],label='IntraopBloodProducts')
num33=gr.Slider(0,20,label='IntraopBloodRBC',step=1)
num34=gr.CheckboxGroup([0,1],label='IntraopMedEpsilonAmi0Caproic')
num35=gr.CheckboxGroup([0,1],label='IntraopMedTranexamicAcid')
num36=gr.CheckboxGroup([0,1],label='Introp DEX or nDEX')
num37=gr.CheckboxGroup([0,1],label='Left_Main_Disease')
num38=gr.CheckboxGroup([0,1],label='MACE')
num39=gr.CheckboxGroup([0,1],label='MedsG2b3aInhibitorMed')
num40=gr.CheckboxGroup([0,1,2,3,4],label='Mitral_Insufficiency')
num41=gr.CheckboxGroup([0,1],label='OthCard_AICD')
num42=gr.CheckboxGroup([0,1],label='Oth_Heart_Block')
num43=gr.CheckboxGroup([0,1],label='Other_Cardiac_Intervention')
num44=gr.CheckboxGroup([0,1],label='Peri_Op_MI')
num45=gr.CheckboxGroup([0,1],label='Peripheral_Vasc_Disease')
num46=gr.CheckboxGroup([0,1],label='PreOpMed Antiplatelets')
num47=gr.CheckboxGroup([0,1],label='PreOpMedACE_ARBInhibitors')
num48=gr.CheckboxGroup([0,1],label='PreOpMedADPInhibitors5Days')
num49=gr.CheckboxGroup([0,1],label='PreOpMedAntiarrhythmics')
num50=gr.CheckboxGroup([0,1],label='PreOpMedAnticoagulants')
num51=gr.CheckboxGroup([0,1],label='PreOpMedAspirin')
num52=gr.CheckboxGroup([0,1],label='PreOpMedCoumadin')
num53=gr.CheckboxGroup([0,1],label='PreOpMedGPIIbIIIaInhibitor')
num54=gr.CheckboxGroup([0,1],label='PreOpMedINotropes')
num55=gr.CheckboxGroup([0,1],label='PreOpMedLipidLowering')
num56=gr.CheckboxGroup([0,1],label='PreOpMedNitratesIV')
num57=gr.CheckboxGroup([0,1],label='PreOpMedSteroids')
num58=gr.CheckboxGroup([0,1],label='PreOp_BetaBlockers')
num59=gr.CheckboxGroup([0,1],label='PreOp_Ca_Antagonists')
num60=gr.CheckboxGroup([0,1],label='PreOp_Digitalis')
num61=gr.CheckboxGroup([0,1],label='PreOp_Diuretics')
num62=gr.CheckboxGroup([0,1],label='PrevArrhythmiaSurgery')
num63=gr.CheckboxGroup([0,1],label='PrevOthCardPCI')
num64=gr.CheckboxGroup([0,1],label='Previous_CABG')
num65=gr.CheckboxGroup([0,1],label='Previous_CV_Intervention')
num66=gr.CheckboxGroup([0,1],label='Previous_Valve')
num67=gr.CheckboxGroup([0,1],label='PriorHeartFailure')
num68=gr.CheckboxGroup([0,1],label='Pulmonic_Procedure')
num69=gr.CheckboxGroup([0,1],label='Pulmonic_Stenosis')
num70=gr.CheckboxGroup([0,1],label='STS_History.Renal_Failure')
num71=gr.CheckboxGroup([0,1],label='Smoking')
num72=gr.CheckboxGroup([0,1],label='Status')
num73=gr.CheckboxGroup([0,1,2,3,4],label='Tricuspid_Insufficiency')
num74=gr.CheckboxGroup([0,1],label='Tricuspid_Procedure')
num75=gr.CheckboxGroup([0,1],label='VSMitral')
num76=gr.CheckboxGroup([0,1],label='Valve')
num77=gr.CheckboxGroup([0,1],label='ValveDisAortic')
num78=gr.CheckboxGroup([0,1],label='ValveDisMitral')
num79=gr.CheckboxGroup([0,1],label='ValveDisPulmonic')
num80=gr.CheckboxGroup([0,1],label='ValveDisTricuspid')
num81=gr.CheckboxGroup([0,1],label='_MI')
num82=gr.CheckboxGroup([0,1],label='mitral stenosis')
num83=gr.CheckboxGroup([0,1],label='Oth_Afib_0',visible= False)
num84=gr.CheckboxGroup([0,1],label='Oth_Afib_1',visible= False)
b1 = gr.Button("Submit")
example =sample_Df.values.tolist()
gr.Examples(example,inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82])
output = [gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()]
b1.click(run_inference, inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82], outputs=output)
demo.launch(share=True)