Spaces:
Build error
Build error
Upload 3 files
Browse files- app.py +844 -0
- co_attention_transformer_model_trained.pth +3 -0
- sample_data.csv +5 -0
app.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pip install kaleido
|
| 2 |
+
#pip install gradio
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
#import os
|
| 6 |
+
#import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision
|
| 12 |
+
import torchvision.transforms as transforms
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
#!pip install einops
|
| 15 |
+
|
| 16 |
+
# Device configuration
|
| 17 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
#pip install captum
|
| 20 |
+
|
| 21 |
+
import seaborn as sns
|
| 22 |
+
from captum.attr import LayerConductance
|
| 23 |
+
|
| 24 |
+
from captum.attr import IntegratedGradients
|
| 25 |
+
from captum.attr import configure_interpretable_embedding_layer
|
| 26 |
+
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
from captum.attr import remove_interpretable_embedding_layer
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
|
| 31 |
+
# @title
|
| 32 |
+
import pandas as pd
|
| 33 |
+
import numpy as np
|
| 34 |
+
import tensorflow as tf
|
| 35 |
+
|
| 36 |
+
Raw_data = pd.read_excel('./STS Data with Up to dated AF 09-18-2022 (1).xlsx',usecols=lambda x: 'Unnamed' not in x)
|
| 37 |
+
pd.set_option('display.max_columns', None)
|
| 38 |
+
|
| 39 |
+
Raw_data['Aortic_Insufficiency']=Raw_data['Aortic_Insufficiency'].astype(np.int64)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Postop_columns = ['PostOpMedCoumadin',
|
| 43 |
+
'PostOpMedLipidLowering',
|
| 44 |
+
'PostOpMedAspirin',
|
| 45 |
+
'PostOpMedADPInhibitors',
|
| 46 |
+
'PostOpMedACE_ARBInhibitors',
|
| 47 |
+
'STS_PostOp.Renal_Failure',
|
| 48 |
+
'Oth_Cardiac_Arrest',
|
| 49 |
+
'Complications_Any',
|
| 50 |
+
'Neuro_Stroke_Permanent',
|
| 51 |
+
'Neuro_Stroke_Permanent',
|
| 52 |
+
'Neuro_Continuous_Coma',
|
| 53 |
+
'Neuro_Delirium',
|
| 54 |
+
'PostOpSepsis',
|
| 55 |
+
'Reop_Bleeding',
|
| 56 |
+
'Oth_OtherComplication',
|
| 57 |
+
'PostOpNeuroStrokeTransientTIA',
|
| 58 |
+
'Infect_Sternum_Deep',
|
| 59 |
+
'Infect_Thoracotomy',
|
| 60 |
+
'Pulm_Ventilator_Prolonged',
|
| 61 |
+
'Pulm_Pneumonia',
|
| 62 |
+
'Oth_Tamponade',
|
| 63 |
+
'Oth_Anticoagulant',
|
| 64 |
+
'Oth_MultiSystem_Failure',
|
| 65 |
+
'Oth_GI',
|
| 66 |
+
'Vasc_Ao_Dissection',
|
| 67 |
+
'Infect_Leg',
|
| 68 |
+
'PostOpInfectionArm',
|
| 69 |
+
'OthCard_Pacemaker',
|
| 70 |
+
'PostOpCreatinineLevel',
|
| 71 |
+
'Renal_Dialysis_Required',
|
| 72 |
+
'PostOpBloodRBCUnits',
|
| 73 |
+
'PostOpBloodFFPUnits',
|
| 74 |
+
'PostOpBloodCryoUnits',
|
| 75 |
+
'PostOpBloodPlateletUnits',
|
| 76 |
+
'ExtubatedI0R',
|
| 77 |
+
'InitHrsVentilated',
|
| 78 |
+
'ReIntubated',
|
| 79 |
+
'No Add Hrs Ventilator',
|
| 80 |
+
'PostOpVentHoursTotal',
|
| 81 |
+
'InitHrsICU',
|
| 82 |
+
'ReadmitICU',
|
| 83 |
+
'AddICUHours',
|
| 84 |
+
'TotHrsICU',
|
| 85 |
+
'DCMed_AntiPlate',
|
| 86 |
+
'Readmit_LessThan30Days',
|
| 87 |
+
'Blood_Bank_Products_Used',
|
| 88 |
+
'PostOpMedAntiarrhythmics',
|
| 89 |
+
'PostOpMedBetaBlockers']
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
#Dropping columns
|
| 93 |
+
preop_oper_data = Raw_data.drop(columns=Postop_columns)
|
| 94 |
+
preop_oper_data = preop_oper_data[preop_oper_data.Oth_Afib != -1]
|
| 95 |
+
preop_oper_data=preop_oper_data.drop(['Date_of_Birth','Surgery_Date','Discharge_Date', 'Death_Date','Death-Surery(y)','Mortality30d','Mortality1y','Mortality2y','Mortality3y','Mortality4y','Mortality5y'], axis=1)
|
| 96 |
+
preop_oper_data=preop_oper_data.drop(['EF','Height(cm)','Weight(kg)','CVA_When','Category','Race'], axis=1)
|
| 97 |
+
preop_oper_data=preop_oper_data.drop(['IABP_Indication','IABP_When'],axis=1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
#seperating continuous and Categorical Data
|
| 101 |
+
continous_df = preop_oper_data[['Oth_Afib','Age','BMI','LastCreatinineLevel','Cross_Clamp_Time','Perfusion_Time']]
|
| 102 |
+
categorical_col=list(set(preop_oper_data.columns) - set(continous_df.columns))
|
| 103 |
+
categorical_col.sort()
|
| 104 |
+
#Setting the target
|
| 105 |
+
df_AF=preop_oper_data['Oth_Afib']
|
| 106 |
+
continous_col=continous_df.drop('Oth_Afib',axis=1)
|
| 107 |
+
preop_oper_data=preop_oper_data.drop('Oth_Afib',axis=1)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#Creating categorical df
|
| 111 |
+
preop_oper_data=preop_oper_data[categorical_col]
|
| 112 |
+
|
| 113 |
+
# Label encoding
|
| 114 |
+
from sklearn import preprocessing
|
| 115 |
+
from sklearn.preprocessing import LabelEncoder
|
| 116 |
+
def encode_text_index(df, name):
|
| 117 |
+
le = preprocessing.LabelEncoder()
|
| 118 |
+
df[name] = le.fit_transform(df[name])
|
| 119 |
+
return le.classes_
|
| 120 |
+
|
| 121 |
+
encode_text_index(preop_oper_data,'Introp DEX or nDEX')
|
| 122 |
+
encode_text_index(preop_oper_data,'Status')
|
| 123 |
+
encode_text_index(preop_oper_data,'Gender')
|
| 124 |
+
|
| 125 |
+
#Calculating len of each categorical column
|
| 126 |
+
label_in_each = tuple(len(preop_oper_data[col].unique()) for col in preop_oper_data.columns)
|
| 127 |
+
categorical_col_with_ordinal = preop_oper_data.columns
|
| 128 |
+
|
| 129 |
+
#Making the final data frame
|
| 130 |
+
final_frame=pd.concat([continous_col,preop_oper_data],axis=1)
|
| 131 |
+
final_frame=pd.concat([final_frame,df_AF],axis=1)
|
| 132 |
+
|
| 133 |
+
# Encode a numeric column as zscores
|
| 134 |
+
def encode_numeric_zscore(df, name, mean=None, sd=None):
|
| 135 |
+
if mean is None:
|
| 136 |
+
mean = df[name].mean()
|
| 137 |
+
print(f'mean:{mean}')
|
| 138 |
+
|
| 139 |
+
if sd is None:
|
| 140 |
+
sd = df[name].std()
|
| 141 |
+
print(f'sd:{sd}')
|
| 142 |
+
|
| 143 |
+
df[name] = (df[name] - mean) / sd
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
for col in continous_col.columns:
|
| 147 |
+
encode_numeric_zscore(final_frame,col)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
#Train test split
|
| 152 |
+
from sklearn.model_selection import train_test_split
|
| 153 |
+
x_train, x_temp, y_train, y_temp = train_test_split(final_frame.iloc[:,:-1], final_frame.iloc[:,-1], test_size=0.25, random_state=42, stratify=final_frame.iloc[:,-1])
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
print(x_train.shape)
|
| 157 |
+
print(y_train.shape)
|
| 158 |
+
print(x_temp.shape)
|
| 159 |
+
print(y_temp.shape)
|
| 160 |
+
|
| 161 |
+
# Duplicating class 1 records to balance dataset for training
|
| 162 |
+
training_frame = pd.concat([x_train, y_train],axis=1)
|
| 163 |
+
training_frame_ana = training_frame
|
| 164 |
+
class_1_rows = training_frame[training_frame['Oth_Afib'] == 1]
|
| 165 |
+
duplicated_class_1 = class_1_rows.copy()
|
| 166 |
+
training_frame= pd.concat([training_frame, duplicated_class_1,duplicated_class_1], ignore_index=True)
|
| 167 |
+
training_frame['Oth_Afib'].value_counts()
|
| 168 |
+
|
| 169 |
+
# Creating testing df
|
| 170 |
+
testing_frame = pd.concat([x_temp, y_temp],axis=1)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
continous_df= continous_df.drop('Oth_Afib', axis=1)
|
| 174 |
+
continous_col=continous_df.columns
|
| 175 |
+
continous_col
|
| 176 |
+
|
| 177 |
+
training_frame_without_label=training_frame.iloc[:,:-1]
|
| 178 |
+
testing_frame_without_label=testing_frame.iloc[:,:-1]
|
| 179 |
+
training_frame=pd.concat([training_frame_without_label,pd.get_dummies(training_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
|
| 180 |
+
testing_frame=pd.concat([testing_frame_without_label,pd.get_dummies(testing_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
|
| 181 |
+
testing_frame
|
| 182 |
+
|
| 183 |
+
# @title
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
import torch
|
| 188 |
+
import torch.nn.functional as F
|
| 189 |
+
from torch import nn, einsum
|
| 190 |
+
|
| 191 |
+
from einops import rearrange
|
| 192 |
+
|
| 193 |
+
# helpers
|
| 194 |
+
|
| 195 |
+
def exists(val):
|
| 196 |
+
return val is not None
|
| 197 |
+
|
| 198 |
+
def default(val, d):
|
| 199 |
+
return val if exists(val) else d
|
| 200 |
+
|
| 201 |
+
# classes
|
| 202 |
+
|
| 203 |
+
class Residual(nn.Module):
|
| 204 |
+
def __init__(self, fn):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.fn = fn
|
| 207 |
+
|
| 208 |
+
def forward(self, x, **kwargs):
|
| 209 |
+
return self.fn(x, **kwargs) + x
|
| 210 |
+
|
| 211 |
+
class PreNorm(nn.Module):
|
| 212 |
+
def __init__(self, dim, fn):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.norm = nn.LayerNorm(dim)
|
| 215 |
+
self.fn = fn
|
| 216 |
+
|
| 217 |
+
def forward(self, x, **kwargs):
|
| 218 |
+
return self.fn(self.norm(x), **kwargs)
|
| 219 |
+
|
| 220 |
+
# attention
|
| 221 |
+
|
| 222 |
+
class GEGLU(nn.Module):
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x, gates = x.chunk(2, dim = -1)
|
| 225 |
+
return x * F.gelu(gates)
|
| 226 |
+
|
| 227 |
+
class FeedForward(nn.Module):
|
| 228 |
+
def __init__(self, dim, mult = 4, dropout = 0.):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.net = nn.Sequential(
|
| 231 |
+
nn.Linear(dim, dim * mult * 2),
|
| 232 |
+
GEGLU(),
|
| 233 |
+
nn.Dropout(dropout),
|
| 234 |
+
nn.Linear(dim * mult, dim)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def forward(self, x, **kwargs):
|
| 238 |
+
return self.net(x)
|
| 239 |
+
|
| 240 |
+
class Attention(nn.Module):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
dim,
|
| 244 |
+
heads = 8,
|
| 245 |
+
dim_head = 16,
|
| 246 |
+
dropout = 0.
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
inner_dim = dim_head * heads
|
| 250 |
+
self.heads = heads
|
| 251 |
+
self.scale = dim_head ** -0.5
|
| 252 |
+
|
| 253 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
| 254 |
+
self.to_out = nn.Linear(inner_dim, dim)
|
| 255 |
+
|
| 256 |
+
self.dropout = nn.Dropout(dropout)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
h = self.heads
|
| 260 |
+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
| 261 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
| 262 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
| 263 |
+
|
| 264 |
+
attn = sim.softmax(dim = -1)
|
| 265 |
+
dropped_attn = self.dropout(attn)
|
| 266 |
+
|
| 267 |
+
out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
|
| 268 |
+
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
| 269 |
+
return self.to_out(out), attn
|
| 270 |
+
|
| 271 |
+
# transformer
|
| 272 |
+
|
| 273 |
+
class Transformer(nn.Module):
|
| 274 |
+
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
|
| 275 |
+
super().__init__()
|
| 276 |
+
# torch.manual_seed(1)
|
| 277 |
+
# self.embeds = nn.Embedding(num_tokens, dim)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
self.layers = nn.ModuleList([])
|
| 281 |
+
|
| 282 |
+
for _ in range(depth):
|
| 283 |
+
self.layers.append(nn.ModuleList([
|
| 284 |
+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
|
| 285 |
+
PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
|
| 286 |
+
]))
|
| 287 |
+
|
| 288 |
+
def forward(self, x, return_attn = False):
|
| 289 |
+
# x = self.embeds(x)
|
| 290 |
+
|
| 291 |
+
post_softmax_attns = []
|
| 292 |
+
|
| 293 |
+
for attn, ff in self.layers:
|
| 294 |
+
attn_out, post_softmax_attn = attn(x)
|
| 295 |
+
post_softmax_attns.append(post_softmax_attn)
|
| 296 |
+
|
| 297 |
+
x = x + attn_out
|
| 298 |
+
x = ff(x) + x
|
| 299 |
+
|
| 300 |
+
if not return_attn:
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
# return x, torch.stack(post_softmax_attns)
|
| 304 |
+
return x
|
| 305 |
+
# mlp
|
| 306 |
+
|
| 307 |
+
class MLP(nn.Module):
|
| 308 |
+
def __init__(self, dims, act = None):
|
| 309 |
+
super().__init__()
|
| 310 |
+
dims_pairs = list(zip(dims[:-1], dims[1:]))
|
| 311 |
+
layers = []
|
| 312 |
+
for ind, (dim_in, dim_out) in enumerate(dims_pairs):
|
| 313 |
+
is_last = ind >= (len(dims_pairs) - 1)
|
| 314 |
+
linear = nn.Linear(dim_in, dim_out)
|
| 315 |
+
layers.append(linear)
|
| 316 |
+
|
| 317 |
+
if is_last:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
act = default(act, nn.ReLU())
|
| 321 |
+
layers.append(act)
|
| 322 |
+
|
| 323 |
+
self.mlp = nn.Sequential(*layers)
|
| 324 |
+
|
| 325 |
+
def forward(self, x):
|
| 326 |
+
return self.mlp(x)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class NumericalEmbedder(nn.Module):
|
| 330 |
+
def __init__(self, dim, num_numerical_types):
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
|
| 333 |
+
self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))
|
| 334 |
+
|
| 335 |
+
def forward(self, x):
|
| 336 |
+
x = rearrange(x, 'b n -> b n 1')
|
| 337 |
+
return x * self.weights + self.biases
|
| 338 |
+
|
| 339 |
+
class CategoricalEmbedder(nn.Module):
|
| 340 |
+
def __init__(self, total_tokens,dim):
|
| 341 |
+
super().__init__()
|
| 342 |
+
self.embeds = nn.Embedding(total_tokens, dim)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def forward(self, x):
|
| 346 |
+
x_embed = self.embeds(x)
|
| 347 |
+
return x_embed
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class CatConLayer(nn.Module):
|
| 351 |
+
|
| 352 |
+
def __init__(self, dim , heads ):
|
| 353 |
+
super().__init__()
|
| 354 |
+
|
| 355 |
+
self.cat_con_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
|
| 356 |
+
self.con_cat_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def forward(self,attn_cat,attn_con,need_weights=False):
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
cat_Q,_ = self.cat_con_multihead_attn(attn_cat,attn_con,attn_con)
|
| 363 |
+
|
| 364 |
+
con_Q,_ = self.con_cat_multihead_attn(attn_con,attn_cat,attn_cat)
|
| 365 |
+
|
| 366 |
+
cat_Q=cat_Q.permute(1, 0, 2)
|
| 367 |
+
con_Q=con_Q.permute(1, 0, 2)
|
| 368 |
+
# output_concat = torch.cat([cat_Q, con_Q], dim=0)
|
| 369 |
+
return cat_Q,con_Q
|
| 370 |
+
|
| 371 |
+
# main class
|
| 372 |
+
|
| 373 |
+
class Co_Transformer(nn.Module):
|
| 374 |
+
def __init__(
|
| 375 |
+
self,
|
| 376 |
+
*,
|
| 377 |
+
categories,
|
| 378 |
+
num_continuous,
|
| 379 |
+
dim,
|
| 380 |
+
depth,
|
| 381 |
+
heads,
|
| 382 |
+
dim_head = 16,
|
| 383 |
+
dim_out = 1,
|
| 384 |
+
mlp_hidden_mults = (2,1),
|
| 385 |
+
mlp_act = None,
|
| 386 |
+
|
| 387 |
+
num_special_tokens = 0,
|
| 388 |
+
continuous_mean_std = None,
|
| 389 |
+
attn_dropout = 0.,
|
| 390 |
+
ff_dropout = 0.
|
| 391 |
+
):
|
| 392 |
+
super().__init__()
|
| 393 |
+
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
|
| 394 |
+
assert len(categories) + num_continuous > 0, 'input shape must not be null'
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
self.num_categories = len(categories)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
self.num_unique_categories = sum(categories)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
self.num_special_tokens = num_special_tokens #0
|
| 406 |
+
total_tokens = self.num_unique_categories + num_special_tokens
|
| 407 |
+
|
| 408 |
+
# for automatically offsetting unique category ids to the correct position in the categories embedding table
|
| 409 |
+
|
| 410 |
+
if self.num_unique_categories > 0:
|
| 411 |
+
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
|
| 412 |
+
|
| 413 |
+
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
|
| 414 |
+
|
| 415 |
+
self.register_buffer('categories_offset', categories_offset)
|
| 416 |
+
self.embeds = CategoricalEmbedder(total_tokens, dim)
|
| 417 |
+
|
| 418 |
+
# continuous
|
| 419 |
+
self.num_continuous = num_continuous
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if self.num_continuous > 0:
|
| 424 |
+
|
| 425 |
+
self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# transformer
|
| 429 |
+
|
| 430 |
+
self.transformer_cat = Transformer(
|
| 431 |
+
|
| 432 |
+
dim = dim,
|
| 433 |
+
depth = depth,
|
| 434 |
+
heads = heads,
|
| 435 |
+
dim_head = dim_head,
|
| 436 |
+
attn_dropout = attn_dropout,
|
| 437 |
+
ff_dropout = ff_dropout
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
self.transformer_con = Transformer(
|
| 441 |
+
|
| 442 |
+
dim = dim,
|
| 443 |
+
depth = depth,
|
| 444 |
+
heads = heads,
|
| 445 |
+
dim_head = dim_head,
|
| 446 |
+
attn_dropout = attn_dropout,
|
| 447 |
+
ff_dropout = ff_dropout
|
| 448 |
+
)
|
| 449 |
+
# fusion-part
|
| 450 |
+
|
| 451 |
+
self.catconlayer = CatConLayer(
|
| 452 |
+
dim=dim ,
|
| 453 |
+
heads=heads
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# mlp to logits
|
| 457 |
+
|
| 458 |
+
input_size = dim * (self.num_categories + num_continuous)
|
| 459 |
+
print(f'input size{input_size}')
|
| 460 |
+
l = input_size // 5
|
| 461 |
+
hidden_dimensions = list(map(lambda t: int(l * t), mlp_hidden_mults))
|
| 462 |
+
all_dimensions = [input_size, *hidden_dimensions, dim_out]
|
| 463 |
+
self.mlp = MLP(all_dimensions, act = mlp_act)
|
| 464 |
+
# print(f" mlp {self.mlp}")
|
| 465 |
+
|
| 466 |
+
def forward(self, x_categ, x_cont, return_attn = True):
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
x_categ = self.embeds(x_categ)
|
| 472 |
+
# x_cat_, attns_cat = self.transformer(x_categ, return_attn = True)
|
| 473 |
+
x_cat_ = self.transformer_cat(x_categ, return_attn = True)
|
| 474 |
+
permuted_x_cat_= x_cat_.permute(1, 0, 2)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
x_numer = self.numerical_embedder(x_cont)
|
| 479 |
+
# x_con_, attns_con = self.transformer(x_numer, return_attn = True)
|
| 480 |
+
x_con_ = self.transformer_con(x_numer, return_attn = True)
|
| 481 |
+
permuted_x_con_= x_con_.permute(1, 0, 2)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
cat_Q,con_Q = self.catconlayer(permuted_x_cat_,permuted_x_con_ )
|
| 486 |
+
|
| 487 |
+
can_con_attn_output = torch.cat([cat_Q, con_Q], dim=1)
|
| 488 |
+
# permuted_can_con_attn_output= can_con_attn_output.permute(1, 0, 2)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
can_con_attn_output_flattend= can_con_attn_output.flatten(1)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
logits=self.mlp(can_con_attn_output_flattend)
|
| 495 |
+
|
| 496 |
+
return logits
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def build_network(depth,heads,dim):
|
| 500 |
+
|
| 501 |
+
model = Co_Transformer(
|
| 502 |
+
categories = label_in_each , # tuple containing the number of unique values within each category
|
| 503 |
+
num_continuous = final_frame[continous_col].shape[1], # number of continuous values
|
| 504 |
+
dim = dim, # dimension, paper set at 32
|
| 505 |
+
dim_out = 2, # binary prediction, but could be anything
|
| 506 |
+
depth = depth, # depth, paper recommended 6
|
| 507 |
+
heads = heads, # heads, paper recommends 8
|
| 508 |
+
attn_dropout = 0.1, # post-attention dropout
|
| 509 |
+
ff_dropout = 0.1, # feed forward dropout
|
| 510 |
+
mlp_hidden_mults =((2,1,0.5,0.25)), # relative multiples of each hidden dimension of the last mlp to logits
|
| 511 |
+
mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
|
| 512 |
+
|
| 513 |
+
continuous_mean_std = torch.tensor(continous_df.agg(['mean','std']).transpose().values, dtype=torch.float32) # (optional) - normalize the continuous values before layer norm
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
return model
|
| 519 |
+
|
| 520 |
+
model = build_network(8,8,64)
|
| 521 |
+
|
| 522 |
+
model.load_state_dict(torch.load('./co_attention_transformer_model_trained.pth',map_location=torch.device('cpu')))
|
| 523 |
+
# print("Model Loaded!")
|
| 524 |
+
|
| 525 |
+
sample_Df=pd.read_csv('./sample_data.csv')
|
| 526 |
+
|
| 527 |
+
# @title
|
| 528 |
+
|
| 529 |
+
def run_inference(num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82):
|
| 530 |
+
|
| 531 |
+
mean1=62.63038219641993
|
| 532 |
+
sd1=12.280987675098727
|
| 533 |
+
mean2=29.238482057573915
|
| 534 |
+
sd2=6.511823065330923
|
| 535 |
+
mean3=1.2360909530720852
|
| 536 |
+
sd3=1.0307534004581604
|
| 537 |
+
mean4=137.98209966134493
|
| 538 |
+
sd4=58.71032609323997
|
| 539 |
+
mean5=193.39671020803095
|
| 540 |
+
sd5=79.63715724430536
|
| 541 |
+
|
| 542 |
+
num1 = (num1 - mean1)/sd1
|
| 543 |
+
num2 = (num2 - mean2)/sd2
|
| 544 |
+
num3 = (num3 - mean3)/sd3
|
| 545 |
+
num4 = (num4 - mean4)/sd4
|
| 546 |
+
num5 = (num5 - mean5)/sd5
|
| 547 |
+
|
| 548 |
+
list__inputs = [num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82]
|
| 549 |
+
print(list__inputs)
|
| 550 |
+
|
| 551 |
+
if (num0 == 'First_non_AFib' or num0 == 'Second_non_AFib'):
|
| 552 |
+
target_set = 0
|
| 553 |
+
else:
|
| 554 |
+
target_set = 1
|
| 555 |
+
|
| 556 |
+
# Remove specific elements from nested lists at specified indices
|
| 557 |
+
result_list =[item for sublist in list__inputs for item in (sublist if isinstance(sublist, list) else [sublist])]
|
| 558 |
+
print(result_list)
|
| 559 |
+
con = torch.tensor(result_list[0:5],dtype=torch.float32).reshape(1,-1)
|
| 560 |
+
print(con,con.shape,con.device.type)
|
| 561 |
+
cat = torch.tensor(result_list[5:82],dtype=torch.long).reshape(1,-1)
|
| 562 |
+
print(cat,cat.shape,cat.device.type)
|
| 563 |
+
model.eval()
|
| 564 |
+
output_tup=model(cat,con)
|
| 565 |
+
prob=F.softmax(output_tup,dim=-1)
|
| 566 |
+
print(prob[0][0].detach(),prob[0][1].detach())
|
| 567 |
+
|
| 568 |
+
# Categories for the bar plot
|
| 569 |
+
categories = ['Non-AFIB', 'A-FIB']
|
| 570 |
+
# Values for the bar plot
|
| 571 |
+
values = [(prob[0][0]).detach().numpy(), (prob[0][1]).detach().numpy()]
|
| 572 |
+
fig1 = plt.figure()
|
| 573 |
+
plt.barh(categories, values, color=['green', 'red'])
|
| 574 |
+
plt.xlabel('Values')
|
| 575 |
+
plt.ylabel('Labels')
|
| 576 |
+
# cal embedding attributes
|
| 577 |
+
ig = IntegratedGradients(model)
|
| 578 |
+
|
| 579 |
+
interpretable_embedding_cat = configure_interpretable_embedding_layer(model, 'embeds')
|
| 580 |
+
interpretable_embedding_con = configure_interpretable_embedding_layer(model, 'numerical_embedder')
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
emb_cat = interpretable_embedding_cat.indices_to_embeddings(cat)
|
| 585 |
+
emb_con = interpretable_embedding_con.indices_to_embeddings(con)
|
| 586 |
+
print(emb_cat.device.type)
|
| 587 |
+
baseline_cat = torch.zeros_like(emb_cat) # Set numerical baseline to zero
|
| 588 |
+
baseline_con = torch.zeros_like(emb_con)
|
| 589 |
+
emb_cat.requires_grad_
|
| 590 |
+
emb_cat.requires_grad_
|
| 591 |
+
attr, delta =ig.attribute((emb_cat, emb_con),baselines = (baseline_cat,baseline_con) ,target=target_set, return_convergence_delta=True, n_steps=50)
|
| 592 |
+
print("calculating attr")
|
| 593 |
+
print(attr[0].shape)
|
| 594 |
+
print(attr[1].shape)
|
| 595 |
+
|
| 596 |
+
categ_attr = (attr[0]).sum(dim=-1).squeeze(0)
|
| 597 |
+
cond_attr = (attr[1]).sum(dim=-1).squeeze(0)
|
| 598 |
+
|
| 599 |
+
concatenated_tensor = torch.cat([cond_attr, categ_attr],dim=0)
|
| 600 |
+
print(concatenated_tensor.device.type)
|
| 601 |
+
|
| 602 |
+
x_pos = (np.arange(len(testing_frame.iloc[:,0:-2].columns)))
|
| 603 |
+
|
| 604 |
+
fig2 = plt.figure(figsize=(30,6))
|
| 605 |
+
|
| 606 |
+
plt.bar(x_pos,concatenated_tensor.squeeze().cpu().detach().numpy(), align='center', color = 'red')
|
| 607 |
+
plt.xticks(x_pos,testing_frame.iloc[:,0:-2].columns, wrap=True)
|
| 608 |
+
plt.xticks(rotation=45)
|
| 609 |
+
plt.xlabel('Features')
|
| 610 |
+
plt.title('Embedded layer attributes')
|
| 611 |
+
|
| 612 |
+
# layer attributes
|
| 613 |
+
|
| 614 |
+
attn_con_cat = []
|
| 615 |
+
attn_con_cat.append(concatenated_tensor.detach().cpu())
|
| 616 |
+
|
| 617 |
+
for i in range(len(model.transformer_con.layers)):
|
| 618 |
+
con_module = [module for module in model.transformer_con.layers[i]]
|
| 619 |
+
layeroutput_con = []
|
| 620 |
+
for j in range(len(con_module)):
|
| 621 |
+
lc_con = LayerConductance(model, con_module[j])
|
| 622 |
+
layer_attributions_con= lc_con.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
|
| 623 |
+
if(type(layer_attributions_con) == "tuple"):
|
| 624 |
+
layeroutput_con.append(layer_attributions_con[0])
|
| 625 |
+
else:
|
| 626 |
+
layeroutput_con.append(layer_attributions_con)
|
| 627 |
+
attn_out_con = emb_con + layeroutput_con[0][0] + layeroutput_con[1]
|
| 628 |
+
attn_out_con=attn_out_con.sum(dim=-1).squeeze(0)
|
| 629 |
+
|
| 630 |
+
cat_module = [module for module in model.transformer_cat.layers[i]]
|
| 631 |
+
layeroutput_cat = []
|
| 632 |
+
for j in range(len(cat_module)):
|
| 633 |
+
lc_cat = LayerConductance(model, cat_module[j])
|
| 634 |
+
layer_attributions_cat= lc_cat.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
|
| 635 |
+
if(type(layer_attributions_cat) == "tuple"):
|
| 636 |
+
layeroutput_cat.append(layer_attributions_cat[0])
|
| 637 |
+
else:
|
| 638 |
+
layeroutput_cat.append(layer_attributions_cat)
|
| 639 |
+
attn_out_cat = emb_cat + layeroutput_cat[0][0] + layeroutput_cat[1]
|
| 640 |
+
attn_out_cat=attn_out_cat.sum(dim=-1).squeeze(0)
|
| 641 |
+
|
| 642 |
+
attn_con_cat.append((torch.cat([attn_out_con,attn_out_cat])).detach().cpu())
|
| 643 |
+
|
| 644 |
+
lc = LayerConductance(model, model.catconlayer)
|
| 645 |
+
layer_attributions_start = lc.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
|
| 646 |
+
value_coattn_cat=layer_attributions_start[0].sum(dim=-1).squeeze(0)
|
| 647 |
+
value_coattn_con=layer_attributions_start[1].sum(dim=-1).squeeze(0)
|
| 648 |
+
attn_con_cat.append(torch.cat([value_coattn_cat,value_coattn_con]).detach().cpu())
|
| 649 |
+
# fig 3
|
| 650 |
+
fig3, axes = plt.subplots(figsize=(15, 12),frameon=False)
|
| 651 |
+
|
| 652 |
+
for spine in plt.gca().spines.values():
|
| 653 |
+
spine.set_visible(False)
|
| 654 |
+
|
| 655 |
+
axes.xaxis.set_major_locator(plt.NullLocator())
|
| 656 |
+
axes.yaxis.set_major_locator(plt.NullLocator())
|
| 657 |
+
|
| 658 |
+
for i,k in enumerate(testing_frame.iloc[:,:-60].columns):
|
| 659 |
+
|
| 660 |
+
cmap = sns.color_palette("Reds")
|
| 661 |
+
# cmap = sns.cm.rocket_r
|
| 662 |
+
ax = fig3.add_subplot(5,5, i+1)
|
| 663 |
+
|
| 664 |
+
xticklabels=[k]
|
| 665 |
+
yticklabels=list(range(1,9))
|
| 666 |
+
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
|
| 667 |
+
plt.xlabel('features')
|
| 668 |
+
plt.ylabel('Layers')
|
| 669 |
+
plt.tight_layout()
|
| 670 |
+
|
| 671 |
+
# fig 4
|
| 672 |
+
fig4, axes = plt.subplots(figsize=(15, 12),frameon=False)
|
| 673 |
+
|
| 674 |
+
for spine in plt.gca().spines.values():
|
| 675 |
+
spine.set_visible(False)
|
| 676 |
+
|
| 677 |
+
axes.xaxis.set_major_locator(plt.NullLocator())
|
| 678 |
+
axes.yaxis.set_major_locator(plt.NullLocator())
|
| 679 |
+
|
| 680 |
+
for i,k in enumerate(testing_frame.iloc[:,24:-30].columns):
|
| 681 |
+
|
| 682 |
+
cmap = sns.color_palette("Reds")
|
| 683 |
+
# cmap = sns.cm.rocket_r
|
| 684 |
+
ax = fig4.add_subplot(6,5, i+1)
|
| 685 |
+
|
| 686 |
+
xticklabels=[k]
|
| 687 |
+
yticklabels=list(range(1,9))
|
| 688 |
+
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
|
| 689 |
+
plt.xlabel('features')
|
| 690 |
+
plt.ylabel('Layers')
|
| 691 |
+
plt.tight_layout()
|
| 692 |
+
|
| 693 |
+
# fig 5
|
| 694 |
+
fig5, axes = plt.subplots(figsize=(15, 12),frameon=False)
|
| 695 |
+
|
| 696 |
+
for spine in plt.gca().spines.values():
|
| 697 |
+
spine.set_visible(False)
|
| 698 |
+
|
| 699 |
+
axes.xaxis.set_major_locator(plt.NullLocator())
|
| 700 |
+
axes.yaxis.set_major_locator(plt.NullLocator())
|
| 701 |
+
|
| 702 |
+
for i,k in enumerate(testing_frame.iloc[:,54:-2].columns):
|
| 703 |
+
|
| 704 |
+
cmap = sns.color_palette("Reds")
|
| 705 |
+
# cmap = sns.cm.rocket_r
|
| 706 |
+
ax = fig5.add_subplot(6,5, i+1)
|
| 707 |
+
|
| 708 |
+
xticklabels=[k]
|
| 709 |
+
yticklabels=list(range(1,9))
|
| 710 |
+
ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
|
| 711 |
+
plt.xlabel('features')
|
| 712 |
+
plt.ylabel('Layers')
|
| 713 |
+
plt.tight_layout()
|
| 714 |
+
|
| 715 |
+
#fig6
|
| 716 |
+
x_pos = (np.arange(len(testing_frame.iloc[:,:-2].columns)))
|
| 717 |
+
|
| 718 |
+
fig6 = plt.figure(figsize=(30,6))
|
| 719 |
+
|
| 720 |
+
plt.bar(x_pos, torch.stack(attn_con_cat)[9], align='center', color = 'red')
|
| 721 |
+
plt.xticks(x_pos,testing_frame.iloc[:,:-2].columns, wrap=True)
|
| 722 |
+
plt.xticks(rotation=45)
|
| 723 |
+
plt.xlabel('features')
|
| 724 |
+
plt.title('Attribution of co-attention layer')
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
remove_interpretable_embedding_layer(model, interpretable_embedding_con)
|
| 730 |
+
remove_interpretable_embedding_layer(model, interpretable_embedding_cat)
|
| 731 |
+
return fig1 , fig2 , fig3 , fig4, fig5, fig6
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
demo = gr.Blocks()
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
with demo:
|
| 739 |
+
|
| 740 |
+
gr.Markdown(
|
| 741 |
+
"""
|
| 742 |
+
# Post-Operative Artrial Fibrillation Demo
|
| 743 |
+
|
| 744 |
+
Select values for the following and click submit to see the results:
|
| 745 |
+
""")
|
| 746 |
+
num0=gr.Textbox(visible = False)
|
| 747 |
+
num1=gr.Slider(0,100,label='Age',step=1)
|
| 748 |
+
num2=gr.Slider(0,100,label='BMI')
|
| 749 |
+
num3=gr.Slider(0,20,label='LastCreatinineLevel')
|
| 750 |
+
num4=gr.Slider(0,1000,label='Cross_Clamp_Time',step=1)
|
| 751 |
+
num5=gr.Slider(0,1000,label='Perfusion_Time',step=1)
|
| 752 |
+
num6=gr.Slider(0,8,label='# of coronary vessels corrected',step=1)
|
| 753 |
+
num7=gr.CheckboxGroup([0,1],label='Aortic stenosis')
|
| 754 |
+
num8=gr.CheckboxGroup([0,1,2,3,4],label='Aortic_Insufficiency')
|
| 755 |
+
num9=gr.CheckboxGroup([0,1],label='Aortic_Procedure')
|
| 756 |
+
num10=gr.CheckboxGroup([0,1],label='Arrhythmia')
|
| 757 |
+
num11=gr.CheckboxGroup([0,1],label='ArrhythmiaAfibAflutter')
|
| 758 |
+
num12=gr.CheckboxGroup([0,1],label='CABG')
|
| 759 |
+
num13=gr.CheckboxGroup([0,1],label='CHF')
|
| 760 |
+
num14=gr.CheckboxGroup([0,1],label='CVA')
|
| 761 |
+
num15=gr.CheckboxGroup([0,1],label='Cardiogenic_Shock')
|
| 762 |
+
num16=gr.CheckboxGroup([0,1],label='Cerebrovascular_Disease')
|
| 763 |
+
num17=gr.CheckboxGroup([0,1,2,3],label='ChronicLungDisease')
|
| 764 |
+
num18=gr.CheckboxGroup([0,1],label='Diabetes')
|
| 765 |
+
num19=gr.CheckboxGroup([0,1],label='Dialysis')
|
| 766 |
+
num20=gr.Slider(0,7,label='DistAnasVein',step=1)
|
| 767 |
+
num21=gr.Slider(0,6,label='DistAnastArt',step=1)
|
| 768 |
+
num22=gr.CheckboxGroup([0,1],label='Family_History_CAD')
|
| 769 |
+
num23=gr.CheckboxGroup([0,1],label='Gender')
|
| 770 |
+
num24=gr.CheckboxGroup([0,1],label='Hypercholesterolemia')
|
| 771 |
+
num25=gr.CheckboxGroup([0,1],label='Hypertension')
|
| 772 |
+
num26=gr.CheckboxGroup([0,1],label='IABP')
|
| 773 |
+
num27=gr.CheckboxGroup([0,1],label='Infectious_Endocarditis')
|
| 774 |
+
num28=gr.Slider(0,6,label='IntraopBloodCryo',step=1)
|
| 775 |
+
num29=gr.Slider(0,8,label='IntraopBloodFFP',step=1)
|
| 776 |
+
num30=gr.CheckboxGroup([0,1,2],label='IntraopBloodFactorVII')
|
| 777 |
+
num31=gr.Slider(0,8,label='IntraopBloodPlatelet',step=1)
|
| 778 |
+
num32=gr.CheckboxGroup([0,1],label='IntraopBloodProducts')
|
| 779 |
+
num33=gr.Slider(0,20,label='IntraopBloodRBC',step=1)
|
| 780 |
+
num34=gr.CheckboxGroup([0,1],label='IntraopMedEpsilonAmi0Caproic')
|
| 781 |
+
num35=gr.CheckboxGroup([0,1],label='IntraopMedTranexamicAcid')
|
| 782 |
+
num36=gr.CheckboxGroup([0,1],label='Introp DEX or nDEX')
|
| 783 |
+
num37=gr.CheckboxGroup([0,1],label='Left_Main_Disease')
|
| 784 |
+
num38=gr.CheckboxGroup([0,1],label='MACE')
|
| 785 |
+
num39=gr.CheckboxGroup([0,1],label='MedsG2b3aInhibitorMed')
|
| 786 |
+
num40=gr.CheckboxGroup([0,1,2,3,4],label='Mitral_Insufficiency')
|
| 787 |
+
num41=gr.CheckboxGroup([0,1],label='OthCard_AICD')
|
| 788 |
+
num42=gr.CheckboxGroup([0,1],label='Oth_Heart_Block')
|
| 789 |
+
num43=gr.CheckboxGroup([0,1],label='Other_Cardiac_Intervention')
|
| 790 |
+
num44=gr.CheckboxGroup([0,1],label='Peri_Op_MI')
|
| 791 |
+
num45=gr.CheckboxGroup([0,1],label='Peripheral_Vasc_Disease')
|
| 792 |
+
num46=gr.CheckboxGroup([0,1],label='PreOpMed Antiplatelets')
|
| 793 |
+
num47=gr.CheckboxGroup([0,1],label='PreOpMedACE_ARBInhibitors')
|
| 794 |
+
num48=gr.CheckboxGroup([0,1],label='PreOpMedADPInhibitors5Days')
|
| 795 |
+
num49=gr.CheckboxGroup([0,1],label='PreOpMedAntiarrhythmics')
|
| 796 |
+
num50=gr.CheckboxGroup([0,1],label='PreOpMedAnticoagulants')
|
| 797 |
+
num51=gr.CheckboxGroup([0,1],label='PreOpMedAspirin')
|
| 798 |
+
num52=gr.CheckboxGroup([0,1],label='PreOpMedCoumadin')
|
| 799 |
+
num53=gr.CheckboxGroup([0,1],label='PreOpMedGPIIbIIIaInhibitor')
|
| 800 |
+
num54=gr.CheckboxGroup([0,1],label='PreOpMedINotropes')
|
| 801 |
+
num55=gr.CheckboxGroup([0,1],label='PreOpMedLipidLowering')
|
| 802 |
+
num56=gr.CheckboxGroup([0,1],label='PreOpMedNitratesIV')
|
| 803 |
+
num57=gr.CheckboxGroup([0,1],label='PreOpMedSteroids')
|
| 804 |
+
num58=gr.CheckboxGroup([0,1],label='PreOp_BetaBlockers')
|
| 805 |
+
num59=gr.CheckboxGroup([0,1],label='PreOp_Ca_Antagonists')
|
| 806 |
+
num60=gr.CheckboxGroup([0,1],label='PreOp_Digitalis')
|
| 807 |
+
num61=gr.CheckboxGroup([0,1],label='PreOp_Diuretics')
|
| 808 |
+
num62=gr.CheckboxGroup([0,1],label='PrevArrhythmiaSurgery')
|
| 809 |
+
num63=gr.CheckboxGroup([0,1],label='PrevOthCardPCI')
|
| 810 |
+
num64=gr.CheckboxGroup([0,1],label='Previous_CABG')
|
| 811 |
+
num65=gr.CheckboxGroup([0,1],label='Previous_CV_Intervention')
|
| 812 |
+
num66=gr.CheckboxGroup([0,1],label='Previous_Valve')
|
| 813 |
+
num67=gr.CheckboxGroup([0,1],label='PriorHeartFailure')
|
| 814 |
+
num68=gr.CheckboxGroup([0,1],label='Pulmonic_Procedure')
|
| 815 |
+
num69=gr.CheckboxGroup([0,1],label='Pulmonic_Stenosis')
|
| 816 |
+
num70=gr.CheckboxGroup([0,1],label='STS_History.Renal_Failure')
|
| 817 |
+
num71=gr.CheckboxGroup([0,1],label='Smoking')
|
| 818 |
+
num72=gr.CheckboxGroup([0,1],label='Status')
|
| 819 |
+
num73=gr.CheckboxGroup([0,1,2,3,4],label='Tricuspid_Insufficiency')
|
| 820 |
+
num74=gr.CheckboxGroup([0,1],label='Tricuspid_Procedure')
|
| 821 |
+
num75=gr.CheckboxGroup([0,1],label='VSMitral')
|
| 822 |
+
num76=gr.CheckboxGroup([0,1],label='Valve')
|
| 823 |
+
num77=gr.CheckboxGroup([0,1],label='ValveDisAortic')
|
| 824 |
+
num78=gr.CheckboxGroup([0,1],label='ValveDisMitral')
|
| 825 |
+
num79=gr.CheckboxGroup([0,1],label='ValveDisPulmonic')
|
| 826 |
+
num80=gr.CheckboxGroup([0,1],label='ValveDisTricuspid')
|
| 827 |
+
num81=gr.CheckboxGroup([0,1],label='_MI')
|
| 828 |
+
num82=gr.CheckboxGroup([0,1],label='mitral stenosis')
|
| 829 |
+
num83=gr.CheckboxGroup([0,1],label='Oth_Afib_0',visible= False)
|
| 830 |
+
num84=gr.CheckboxGroup([0,1],label='Oth_Afib_1',visible= False)
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
b1 = gr.Button("Submit")
|
| 835 |
+
example =sample_Df.values.tolist()
|
| 836 |
+
gr.Examples(example,inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82])
|
| 837 |
+
|
| 838 |
+
output = [gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()]
|
| 839 |
+
|
| 840 |
+
b1.click(run_inference, inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82], outputs=output)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
demo.launch(share=True,auth=('poaf-users','dshrebs__324'))
|
co_attention_transformer_model_trained.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99b2456a353c08c7fcab9a4f6f764c79c9995e0c703954f2a78174e81e7a7fdc
|
| 3 |
+
size 61186764
|
sample_data.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Records,Age_,BMI_,LastCreatineLevel_,CrossClampTime_,PerfusionTime_,# of coronary vessels corrected,Aortic stenosis,Aortic_Insufficiency,Aortic_Procedure,Arrhythmia,ArrhythmiaAfibAflutter,CABG,CHF,CVA,Cardiogenic_Shock,Cerebrovascular_Disease,ChronicLungDisease,Diabetes,Dialysis,DistAnasVein,DistAnastArt,Family_History_CAD,Gender,Hypercholesterolemia,Hypertension,IABP,Infectious_Endocarditis,IntraopBloodCryo,IntraopBloodFFP,IntraopBloodFactorVII,IntraopBloodPlatelet,IntraopBloodProducts,IntraopBloodRBC,IntraopMedEpsilonAmi0Caproic,IntraopMedTranexamicAcid,Introp DEX or nDEX,Left_Main_Disease,MACE,MedsG2b3aInhibitorMed,Mitral_Insufficiency,OthCard_AICD,Oth_Heart_Block,Other_Cardiac_Intervention,Peri_Op_MI,Peripheral_Vasc_Disease,PreOpMed Antiplatelets,PreOpMedACE_ARBInhibitors,PreOpMedADPInhibitors5Days,PreOpMedAntiarrhythmics,PreOpMedAnticoagulants,PreOpMedAspirin,PreOpMedCoumadin,PreOpMedGPIIbIIIaInhibitor,PreOpMedINotropes,PreOpMedLipidLowering,PreOpMedNitratesIV,PreOpMedSteroids,PreOp_BetaBlockers,PreOp_Ca_Antagonists,PreOp_Digitalis,PreOp_Diuretics,PrevArrhythmiaSurgery,PrevOthCardPCI,Previous_CABG,Previous_CV_Intervention,Previous_Valve,PriorHeartFailure,Pulmonic_Procedure,Pulmonic_Stenosis,STS_History.Renal_Failure,Smoking,Status,Tricuspid_Insufficiency,Tricuspid_Procedure,VSMitral,Valve,ValveDisAortic,ValveDisMitral,ValveDisPulmonic,ValveDisTricuspid,_MI,mitral stenosis
|
| 2 |
+
First_non_AFib,73,28.73,1.3,164,199,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
|
| 3 |
+
Second_non_AFib,84,24.63,1.5,220,278,5.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,2.0,1.0,7.0,1.0,0.0,1.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,2.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
|
| 4 |
+
First_Afib,63,41.03,0.9,119,152,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
|
| 5 |
+
Second_Afib,52,26.74,4.5,36,59,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
|