Spaces:
Sleeping
Sleeping
| #pip install kaleido | |
| #pip install gradio | |
| import gradio as gr | |
| #import os | |
| #import random | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from tqdm.auto import tqdm | |
| #!pip install einops | |
| # Device configuration | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| #pip install captum | |
| import seaborn as sns | |
| from captum.attr import LayerConductance | |
| from captum.attr import IntegratedGradients | |
| from captum.attr import configure_interpretable_embedding_layer | |
| import matplotlib.pyplot as plt | |
| from captum.attr import remove_interpretable_embedding_layer | |
| import torch.nn.functional as F | |
| # @title | |
| import pandas as pd | |
| import numpy as np | |
| import tensorflow as tf | |
| Raw_data = pd.read_excel('./STS Data with Up to dated AF 09-18-2022 (1).xlsx',usecols=lambda x: 'Unnamed' not in x) | |
| pd.set_option('display.max_columns', None) | |
| Raw_data['Aortic_Insufficiency']=Raw_data['Aortic_Insufficiency'].astype(np.int64) | |
| Postop_columns = ['PostOpMedCoumadin', | |
| 'PostOpMedLipidLowering', | |
| 'PostOpMedAspirin', | |
| 'PostOpMedADPInhibitors', | |
| 'PostOpMedACE_ARBInhibitors', | |
| 'STS_PostOp.Renal_Failure', | |
| 'Oth_Cardiac_Arrest', | |
| 'Complications_Any', | |
| 'Neuro_Stroke_Permanent', | |
| 'Neuro_Stroke_Permanent', | |
| 'Neuro_Continuous_Coma', | |
| 'Neuro_Delirium', | |
| 'PostOpSepsis', | |
| 'Reop_Bleeding', | |
| 'Oth_OtherComplication', | |
| 'PostOpNeuroStrokeTransientTIA', | |
| 'Infect_Sternum_Deep', | |
| 'Infect_Thoracotomy', | |
| 'Pulm_Ventilator_Prolonged', | |
| 'Pulm_Pneumonia', | |
| 'Oth_Tamponade', | |
| 'Oth_Anticoagulant', | |
| 'Oth_MultiSystem_Failure', | |
| 'Oth_GI', | |
| 'Vasc_Ao_Dissection', | |
| 'Infect_Leg', | |
| 'PostOpInfectionArm', | |
| 'OthCard_Pacemaker', | |
| 'PostOpCreatinineLevel', | |
| 'Renal_Dialysis_Required', | |
| 'PostOpBloodRBCUnits', | |
| 'PostOpBloodFFPUnits', | |
| 'PostOpBloodCryoUnits', | |
| 'PostOpBloodPlateletUnits', | |
| 'ExtubatedI0R', | |
| 'InitHrsVentilated', | |
| 'ReIntubated', | |
| 'No Add Hrs Ventilator', | |
| 'PostOpVentHoursTotal', | |
| 'InitHrsICU', | |
| 'ReadmitICU', | |
| 'AddICUHours', | |
| 'TotHrsICU', | |
| 'DCMed_AntiPlate', | |
| 'Readmit_LessThan30Days', | |
| 'Blood_Bank_Products_Used', | |
| 'PostOpMedAntiarrhythmics', | |
| 'PostOpMedBetaBlockers'] | |
| #Dropping columns | |
| preop_oper_data = Raw_data.drop(columns=Postop_columns) | |
| preop_oper_data = preop_oper_data[preop_oper_data.Oth_Afib != -1] | |
| preop_oper_data=preop_oper_data.drop(['Date_of_Birth','Surgery_Date','Discharge_Date', 'Death_Date','Death-Surery(y)','Mortality30d','Mortality1y','Mortality2y','Mortality3y','Mortality4y','Mortality5y'], axis=1) | |
| preop_oper_data=preop_oper_data.drop(['EF','Height(cm)','Weight(kg)','CVA_When','Category','Race'], axis=1) | |
| preop_oper_data=preop_oper_data.drop(['IABP_Indication','IABP_When'],axis=1) | |
| #seperating continuous and Categorical Data | |
| continous_df = preop_oper_data[['Oth_Afib','Age','BMI','LastCreatinineLevel','Cross_Clamp_Time','Perfusion_Time']] | |
| categorical_col=list(set(preop_oper_data.columns) - set(continous_df.columns)) | |
| categorical_col.sort() | |
| #Setting the target | |
| df_AF=preop_oper_data['Oth_Afib'] | |
| continous_col=continous_df.drop('Oth_Afib',axis=1) | |
| preop_oper_data=preop_oper_data.drop('Oth_Afib',axis=1) | |
| #Creating categorical df | |
| preop_oper_data=preop_oper_data[categorical_col] | |
| # Label encoding | |
| from sklearn import preprocessing | |
| from sklearn.preprocessing import LabelEncoder | |
| def encode_text_index(df, name): | |
| le = preprocessing.LabelEncoder() | |
| df[name] = le.fit_transform(df[name]) | |
| return le.classes_ | |
| encode_text_index(preop_oper_data,'Introp DEX or nDEX') | |
| encode_text_index(preop_oper_data,'Status') | |
| encode_text_index(preop_oper_data,'Gender') | |
| #Calculating len of each categorical column | |
| label_in_each = tuple(len(preop_oper_data[col].unique()) for col in preop_oper_data.columns) | |
| categorical_col_with_ordinal = preop_oper_data.columns | |
| #Making the final data frame | |
| final_frame=pd.concat([continous_col,preop_oper_data],axis=1) | |
| final_frame=pd.concat([final_frame,df_AF],axis=1) | |
| # Encode a numeric column as zscores | |
| def encode_numeric_zscore(df, name, mean=None, sd=None): | |
| if mean is None: | |
| mean = df[name].mean() | |
| print(f'mean:{mean}') | |
| if sd is None: | |
| sd = df[name].std() | |
| print(f'sd:{sd}') | |
| df[name] = (df[name] - mean) / sd | |
| for col in continous_col.columns: | |
| encode_numeric_zscore(final_frame,col) | |
| #Train test split | |
| from sklearn.model_selection import train_test_split | |
| x_train, x_temp, y_train, y_temp = train_test_split(final_frame.iloc[:,:-1], final_frame.iloc[:,-1], test_size=0.25, random_state=42, stratify=final_frame.iloc[:,-1]) | |
| print(x_train.shape) | |
| print(y_train.shape) | |
| print(x_temp.shape) | |
| print(y_temp.shape) | |
| # Duplicating class 1 records to balance dataset for training | |
| training_frame = pd.concat([x_train, y_train],axis=1) | |
| training_frame_ana = training_frame | |
| class_1_rows = training_frame[training_frame['Oth_Afib'] == 1] | |
| duplicated_class_1 = class_1_rows.copy() | |
| training_frame= pd.concat([training_frame, duplicated_class_1,duplicated_class_1], ignore_index=True) | |
| training_frame['Oth_Afib'].value_counts() | |
| # Creating testing df | |
| testing_frame = pd.concat([x_temp, y_temp],axis=1) | |
| continous_df= continous_df.drop('Oth_Afib', axis=1) | |
| continous_col=continous_df.columns | |
| continous_col | |
| training_frame_without_label=training_frame.iloc[:,:-1] | |
| testing_frame_without_label=testing_frame.iloc[:,:-1] | |
| training_frame=pd.concat([training_frame_without_label,pd.get_dummies(training_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1) | |
| testing_frame=pd.concat([testing_frame_without_label,pd.get_dummies(testing_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1) | |
| testing_frame | |
| # @title | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from einops import rearrange | |
| # helpers | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| # classes | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(x, **kwargs) + x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| # attention | |
| class GEGLU(nn.Module): | |
| def forward(self, x): | |
| x, gates = x.chunk(2, dim = -1) | |
| return x * F.gelu(gates) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, mult = 4, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim * mult * 2), | |
| GEGLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * mult, dim) | |
| ) | |
| def forward(self, x, **kwargs): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| heads = 8, | |
| dim_head = 16, | |
| dropout = 0. | |
| ): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | |
| self.to_out = nn.Linear(inner_dim, dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| h = self.heads | |
| q, k, v = self.to_qkv(x).chunk(3, dim = -1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
| sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
| attn = sim.softmax(dim = -1) | |
| dropped_attn = self.dropout(attn) | |
| out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)', h = h) | |
| return self.to_out(out), attn | |
| # transformer | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout): | |
| super().__init__() | |
| # torch.manual_seed(1) | |
| # self.embeds = nn.Embedding(num_tokens, dim) | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append(nn.ModuleList([ | |
| PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)), | |
| PreNorm(dim, FeedForward(dim, dropout = ff_dropout)), | |
| ])) | |
| def forward(self, x, return_attn = False): | |
| # x = self.embeds(x) | |
| post_softmax_attns = [] | |
| for attn, ff in self.layers: | |
| attn_out, post_softmax_attn = attn(x) | |
| post_softmax_attns.append(post_softmax_attn) | |
| x = x + attn_out | |
| x = ff(x) + x | |
| if not return_attn: | |
| return x | |
| # return x, torch.stack(post_softmax_attns) | |
| return x | |
| # mlp | |
| class MLP(nn.Module): | |
| def __init__(self, dims, act = None): | |
| super().__init__() | |
| dims_pairs = list(zip(dims[:-1], dims[1:])) | |
| layers = [] | |
| for ind, (dim_in, dim_out) in enumerate(dims_pairs): | |
| is_last = ind >= (len(dims_pairs) - 1) | |
| linear = nn.Linear(dim_in, dim_out) | |
| layers.append(linear) | |
| if is_last: | |
| continue | |
| act = default(act, nn.ReLU()) | |
| layers.append(act) | |
| self.mlp = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class NumericalEmbedder(nn.Module): | |
| def __init__(self, dim, num_numerical_types): | |
| super().__init__() | |
| self.weights = nn.Parameter(torch.randn(num_numerical_types, dim)) | |
| self.biases = nn.Parameter(torch.randn(num_numerical_types, dim)) | |
| def forward(self, x): | |
| x = rearrange(x, 'b n -> b n 1') | |
| return x * self.weights + self.biases | |
| class CategoricalEmbedder(nn.Module): | |
| def __init__(self, total_tokens,dim): | |
| super().__init__() | |
| self.embeds = nn.Embedding(total_tokens, dim) | |
| def forward(self, x): | |
| x_embed = self.embeds(x) | |
| return x_embed | |
| class CatConLayer(nn.Module): | |
| def __init__(self, dim , heads ): | |
| super().__init__() | |
| self.cat_con_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8) | |
| self.con_cat_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8) | |
| def forward(self,attn_cat,attn_con,need_weights=False): | |
| cat_Q,_ = self.cat_con_multihead_attn(attn_cat,attn_con,attn_con) | |
| con_Q,_ = self.con_cat_multihead_attn(attn_con,attn_cat,attn_cat) | |
| cat_Q=cat_Q.permute(1, 0, 2) | |
| con_Q=con_Q.permute(1, 0, 2) | |
| # output_concat = torch.cat([cat_Q, con_Q], dim=0) | |
| return cat_Q,con_Q | |
| # main class | |
| class Co_Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| categories, | |
| num_continuous, | |
| dim, | |
| depth, | |
| heads, | |
| dim_head = 16, | |
| dim_out = 1, | |
| mlp_hidden_mults = (2,1), | |
| mlp_act = None, | |
| num_special_tokens = 0, | |
| continuous_mean_std = None, | |
| attn_dropout = 0., | |
| ff_dropout = 0. | |
| ): | |
| super().__init__() | |
| assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive' | |
| assert len(categories) + num_continuous > 0, 'input shape must not be null' | |
| self.num_categories = len(categories) | |
| self.num_unique_categories = sum(categories) | |
| self.num_special_tokens = num_special_tokens #0 | |
| total_tokens = self.num_unique_categories + num_special_tokens | |
| # for automatically offsetting unique category ids to the correct position in the categories embedding table | |
| if self.num_unique_categories > 0: | |
| categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens) | |
| categories_offset = categories_offset.cumsum(dim = -1)[:-1] | |
| self.register_buffer('categories_offset', categories_offset) | |
| self.embeds = CategoricalEmbedder(total_tokens, dim) | |
| # continuous | |
| self.num_continuous = num_continuous | |
| if self.num_continuous > 0: | |
| self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous) | |
| # transformer | |
| self.transformer_cat = Transformer( | |
| dim = dim, | |
| depth = depth, | |
| heads = heads, | |
| dim_head = dim_head, | |
| attn_dropout = attn_dropout, | |
| ff_dropout = ff_dropout | |
| ) | |
| self.transformer_con = Transformer( | |
| dim = dim, | |
| depth = depth, | |
| heads = heads, | |
| dim_head = dim_head, | |
| attn_dropout = attn_dropout, | |
| ff_dropout = ff_dropout | |
| ) | |
| # fusion-part | |
| self.catconlayer = CatConLayer( | |
| dim=dim , | |
| heads=heads | |
| ) | |
| # mlp to logits | |
| input_size = dim * (self.num_categories + num_continuous) | |
| print(f'input size{input_size}') | |
| l = input_size // 5 | |
| hidden_dimensions = list(map(lambda t: int(l * t), mlp_hidden_mults)) | |
| all_dimensions = [input_size, *hidden_dimensions, dim_out] | |
| self.mlp = MLP(all_dimensions, act = mlp_act) | |
| # print(f" mlp {self.mlp}") | |
| def forward(self, x_categ, x_cont, return_attn = True): | |
| x_categ = self.embeds(x_categ) | |
| # x_cat_, attns_cat = self.transformer(x_categ, return_attn = True) | |
| x_cat_ = self.transformer_cat(x_categ, return_attn = True) | |
| permuted_x_cat_= x_cat_.permute(1, 0, 2) | |
| x_numer = self.numerical_embedder(x_cont) | |
| # x_con_, attns_con = self.transformer(x_numer, return_attn = True) | |
| x_con_ = self.transformer_con(x_numer, return_attn = True) | |
| permuted_x_con_= x_con_.permute(1, 0, 2) | |
| cat_Q,con_Q = self.catconlayer(permuted_x_cat_,permuted_x_con_ ) | |
| can_con_attn_output = torch.cat([cat_Q, con_Q], dim=1) | |
| # permuted_can_con_attn_output= can_con_attn_output.permute(1, 0, 2) | |
| can_con_attn_output_flattend= can_con_attn_output.flatten(1) | |
| logits=self.mlp(can_con_attn_output_flattend) | |
| return logits | |
| def build_network(depth,heads,dim): | |
| model = Co_Transformer( | |
| categories = label_in_each , # tuple containing the number of unique values within each category | |
| num_continuous = final_frame[continous_col].shape[1], # number of continuous values | |
| dim = dim, # dimension, paper set at 32 | |
| dim_out = 2, # binary prediction, but could be anything | |
| depth = depth, # depth, paper recommended 6 | |
| heads = heads, # heads, paper recommends 8 | |
| attn_dropout = 0.1, # post-attention dropout | |
| ff_dropout = 0.1, # feed forward dropout | |
| mlp_hidden_mults =((2,1,0.5,0.25)), # relative multiples of each hidden dimension of the last mlp to logits | |
| mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc) | |
| continuous_mean_std = torch.tensor(continous_df.agg(['mean','std']).transpose().values, dtype=torch.float32) # (optional) - normalize the continuous values before layer norm | |
| ) | |
| return model | |
| model = build_network(8,8,64) | |
| model.load_state_dict(torch.load('./co_attention_transformer_model_trained.pth',map_location=torch.device('cpu'))) | |
| # print("Model Loaded!") | |
| sample_Df=pd.read_csv('./sample_data.csv') | |
| # @title | |
| def run_inference(num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82): | |
| mean1=62.63038219641993 | |
| sd1=12.280987675098727 | |
| mean2=29.238482057573915 | |
| sd2=6.511823065330923 | |
| mean3=1.2360909530720852 | |
| sd3=1.0307534004581604 | |
| mean4=137.98209966134493 | |
| sd4=58.71032609323997 | |
| mean5=193.39671020803095 | |
| sd5=79.63715724430536 | |
| num1 = (num1 - mean1)/sd1 | |
| num2 = (num2 - mean2)/sd2 | |
| num3 = (num3 - mean3)/sd3 | |
| num4 = (num4 - mean4)/sd4 | |
| num5 = (num5 - mean5)/sd5 | |
| list__inputs = [num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82] | |
| print(list__inputs) | |
| if (num0 == 'First_non_AFib' or num0 == 'Second_non_AFib'): | |
| target_set = 0 | |
| else: | |
| target_set = 1 | |
| # Remove specific elements from nested lists at specified indices | |
| result_list =[item for sublist in list__inputs for item in (sublist if isinstance(sublist, list) else [sublist])] | |
| print(result_list) | |
| con = torch.tensor(result_list[0:5],dtype=torch.float32).reshape(1,-1) | |
| print(con,con.shape,con.device.type) | |
| cat = torch.tensor(result_list[5:82],dtype=torch.long).reshape(1,-1) | |
| print(cat,cat.shape,cat.device.type) | |
| model.eval() | |
| output_tup=model(cat,con) | |
| prob=F.softmax(output_tup,dim=-1) | |
| print(prob[0][0].detach(),prob[0][1].detach()) | |
| # Categories for the bar plot | |
| categories = ['Non-AFIB', 'A-FIB'] | |
| # Values for the bar plot | |
| values = [(prob[0][0]).detach().numpy(), (prob[0][1]).detach().numpy()] | |
| fig1 = plt.figure() | |
| plt.barh(categories, values, color=['green', 'red']) | |
| plt.xlabel('Values') | |
| plt.ylabel('Labels') | |
| # cal embedding attributes | |
| ig = IntegratedGradients(model) | |
| interpretable_embedding_cat = configure_interpretable_embedding_layer(model, 'embeds') | |
| interpretable_embedding_con = configure_interpretable_embedding_layer(model, 'numerical_embedder') | |
| emb_cat = interpretable_embedding_cat.indices_to_embeddings(cat) | |
| emb_con = interpretable_embedding_con.indices_to_embeddings(con) | |
| print(emb_cat.device.type) | |
| baseline_cat = torch.zeros_like(emb_cat) # Set numerical baseline to zero | |
| baseline_con = torch.zeros_like(emb_con) | |
| emb_cat.requires_grad_ | |
| emb_cat.requires_grad_ | |
| attr, delta =ig.attribute((emb_cat, emb_con),baselines = (baseline_cat,baseline_con) ,target=target_set, return_convergence_delta=True, n_steps=50) | |
| print("calculating attr") | |
| print(attr[0].shape) | |
| print(attr[1].shape) | |
| categ_attr = (attr[0]).sum(dim=-1).squeeze(0) | |
| cond_attr = (attr[1]).sum(dim=-1).squeeze(0) | |
| concatenated_tensor = torch.cat([cond_attr, categ_attr],dim=0) | |
| print(concatenated_tensor.device.type) | |
| x_pos = (np.arange(len(testing_frame.iloc[:,0:-2].columns))) | |
| fig2 = plt.figure(figsize=(30,6)) | |
| plt.bar(x_pos,concatenated_tensor.squeeze().cpu().detach().numpy(), align='center', color = 'red') | |
| plt.xticks(x_pos,testing_frame.iloc[:,0:-2].columns, wrap=True) | |
| plt.xticks(rotation=45) | |
| plt.xlabel('Features') | |
| plt.title('Embedded layer attributes') | |
| # layer attributes | |
| attn_con_cat = [] | |
| attn_con_cat.append(concatenated_tensor.detach().cpu()) | |
| for i in range(len(model.transformer_con.layers)): | |
| con_module = [module for module in model.transformer_con.layers[i]] | |
| layeroutput_con = [] | |
| for j in range(len(con_module)): | |
| lc_con = LayerConductance(model, con_module[j]) | |
| layer_attributions_con= lc_con.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50) | |
| if(type(layer_attributions_con) == "tuple"): | |
| layeroutput_con.append(layer_attributions_con[0]) | |
| else: | |
| layeroutput_con.append(layer_attributions_con) | |
| attn_out_con = emb_con + layeroutput_con[0][0] + layeroutput_con[1] | |
| attn_out_con=attn_out_con.sum(dim=-1).squeeze(0) | |
| cat_module = [module for module in model.transformer_cat.layers[i]] | |
| layeroutput_cat = [] | |
| for j in range(len(cat_module)): | |
| lc_cat = LayerConductance(model, cat_module[j]) | |
| layer_attributions_cat= lc_cat.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50) | |
| if(type(layer_attributions_cat) == "tuple"): | |
| layeroutput_cat.append(layer_attributions_cat[0]) | |
| else: | |
| layeroutput_cat.append(layer_attributions_cat) | |
| attn_out_cat = emb_cat + layeroutput_cat[0][0] + layeroutput_cat[1] | |
| attn_out_cat=attn_out_cat.sum(dim=-1).squeeze(0) | |
| attn_con_cat.append((torch.cat([attn_out_con,attn_out_cat])).detach().cpu()) | |
| lc = LayerConductance(model, model.catconlayer) | |
| layer_attributions_start = lc.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50) | |
| value_coattn_cat=layer_attributions_start[0].sum(dim=-1).squeeze(0) | |
| value_coattn_con=layer_attributions_start[1].sum(dim=-1).squeeze(0) | |
| attn_con_cat.append(torch.cat([value_coattn_cat,value_coattn_con]).detach().cpu()) | |
| # fig 3 | |
| fig3, axes = plt.subplots(figsize=(15, 12),frameon=False) | |
| for spine in plt.gca().spines.values(): | |
| spine.set_visible(False) | |
| axes.xaxis.set_major_locator(plt.NullLocator()) | |
| axes.yaxis.set_major_locator(plt.NullLocator()) | |
| for i,k in enumerate(testing_frame.iloc[:,:-60].columns): | |
| cmap = sns.color_palette("Reds") | |
| # cmap = sns.cm.rocket_r | |
| ax = fig3.add_subplot(5,5, i+1) | |
| xticklabels=[k] | |
| yticklabels=list(range(1,9)) | |
| ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap) | |
| plt.xlabel('features') | |
| plt.ylabel('Layers') | |
| plt.tight_layout() | |
| # fig 4 | |
| fig4, axes = plt.subplots(figsize=(15, 12),frameon=False) | |
| for spine in plt.gca().spines.values(): | |
| spine.set_visible(False) | |
| axes.xaxis.set_major_locator(plt.NullLocator()) | |
| axes.yaxis.set_major_locator(plt.NullLocator()) | |
| for i,k in enumerate(testing_frame.iloc[:,24:-30].columns): | |
| cmap = sns.color_palette("Reds") | |
| # cmap = sns.cm.rocket_r | |
| ax = fig4.add_subplot(6,5, i+1) | |
| xticklabels=[k] | |
| yticklabels=list(range(1,9)) | |
| ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap) | |
| plt.xlabel('features') | |
| plt.ylabel('Layers') | |
| plt.tight_layout() | |
| # fig 5 | |
| fig5, axes = plt.subplots(figsize=(15, 12),frameon=False) | |
| for spine in plt.gca().spines.values(): | |
| spine.set_visible(False) | |
| axes.xaxis.set_major_locator(plt.NullLocator()) | |
| axes.yaxis.set_major_locator(plt.NullLocator()) | |
| for i,k in enumerate(testing_frame.iloc[:,54:-2].columns): | |
| cmap = sns.color_palette("Reds") | |
| # cmap = sns.cm.rocket_r | |
| ax = fig5.add_subplot(6,5, i+1) | |
| xticklabels=[k] | |
| yticklabels=list(range(1,9)) | |
| ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap) | |
| plt.xlabel('features') | |
| plt.ylabel('Layers') | |
| plt.tight_layout() | |
| #fig6 | |
| x_pos = (np.arange(len(testing_frame.iloc[:,:-2].columns))) | |
| fig6 = plt.figure(figsize=(30,6)) | |
| plt.bar(x_pos, torch.stack(attn_con_cat)[9], align='center', color = 'red') | |
| plt.xticks(x_pos,testing_frame.iloc[:,:-2].columns, wrap=True) | |
| plt.xticks(rotation=45) | |
| plt.xlabel('features') | |
| plt.title('Attribution of co-attention layer') | |
| remove_interpretable_embedding_layer(model, interpretable_embedding_con) | |
| remove_interpretable_embedding_layer(model, interpretable_embedding_cat) | |
| return fig1 , fig2 , fig3 , fig4, fig5, fig6 | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # Post-Operative Artrial Fibrillation Demo | |
| Select values for the following and click submit to see the results: | |
| """) | |
| num0=gr.Textbox(visible = False) | |
| num1=gr.Slider(0,100,label='Age',step=1) | |
| num2=gr.Slider(0,100,label='BMI') | |
| num3=gr.Slider(0,20,label='LastCreatinineLevel') | |
| num4=gr.Slider(0,1000,label='Cross_Clamp_Time',step=1) | |
| num5=gr.Slider(0,1000,label='Perfusion_Time',step=1) | |
| num6=gr.Slider(0,8,label='# of coronary vessels corrected',step=1) | |
| num7=gr.CheckboxGroup([0,1],label='Aortic stenosis') | |
| num8=gr.CheckboxGroup([0,1,2,3,4],label='Aortic_Insufficiency') | |
| num9=gr.CheckboxGroup([0,1],label='Aortic_Procedure') | |
| num10=gr.CheckboxGroup([0,1],label='Arrhythmia') | |
| num11=gr.CheckboxGroup([0,1],label='ArrhythmiaAfibAflutter') | |
| num12=gr.CheckboxGroup([0,1],label='CABG') | |
| num13=gr.CheckboxGroup([0,1],label='CHF') | |
| num14=gr.CheckboxGroup([0,1],label='CVA') | |
| num15=gr.CheckboxGroup([0,1],label='Cardiogenic_Shock') | |
| num16=gr.CheckboxGroup([0,1],label='Cerebrovascular_Disease') | |
| num17=gr.CheckboxGroup([0,1,2,3],label='ChronicLungDisease') | |
| num18=gr.CheckboxGroup([0,1],label='Diabetes') | |
| num19=gr.CheckboxGroup([0,1],label='Dialysis') | |
| num20=gr.Slider(0,7,label='DistAnasVein',step=1) | |
| num21=gr.Slider(0,6,label='DistAnastArt',step=1) | |
| num22=gr.CheckboxGroup([0,1],label='Family_History_CAD') | |
| num23=gr.CheckboxGroup([0,1],label='Gender') | |
| num24=gr.CheckboxGroup([0,1],label='Hypercholesterolemia') | |
| num25=gr.CheckboxGroup([0,1],label='Hypertension') | |
| num26=gr.CheckboxGroup([0,1],label='IABP') | |
| num27=gr.CheckboxGroup([0,1],label='Infectious_Endocarditis') | |
| num28=gr.Slider(0,6,label='IntraopBloodCryo',step=1) | |
| num29=gr.Slider(0,8,label='IntraopBloodFFP',step=1) | |
| num30=gr.CheckboxGroup([0,1,2],label='IntraopBloodFactorVII') | |
| num31=gr.Slider(0,8,label='IntraopBloodPlatelet',step=1) | |
| num32=gr.CheckboxGroup([0,1],label='IntraopBloodProducts') | |
| num33=gr.Slider(0,20,label='IntraopBloodRBC',step=1) | |
| num34=gr.CheckboxGroup([0,1],label='IntraopMedEpsilonAmi0Caproic') | |
| num35=gr.CheckboxGroup([0,1],label='IntraopMedTranexamicAcid') | |
| num36=gr.CheckboxGroup([0,1],label='Introp DEX or nDEX') | |
| num37=gr.CheckboxGroup([0,1],label='Left_Main_Disease') | |
| num38=gr.CheckboxGroup([0,1],label='MACE') | |
| num39=gr.CheckboxGroup([0,1],label='MedsG2b3aInhibitorMed') | |
| num40=gr.CheckboxGroup([0,1,2,3,4],label='Mitral_Insufficiency') | |
| num41=gr.CheckboxGroup([0,1],label='OthCard_AICD') | |
| num42=gr.CheckboxGroup([0,1],label='Oth_Heart_Block') | |
| num43=gr.CheckboxGroup([0,1],label='Other_Cardiac_Intervention') | |
| num44=gr.CheckboxGroup([0,1],label='Peri_Op_MI') | |
| num45=gr.CheckboxGroup([0,1],label='Peripheral_Vasc_Disease') | |
| num46=gr.CheckboxGroup([0,1],label='PreOpMed Antiplatelets') | |
| num47=gr.CheckboxGroup([0,1],label='PreOpMedACE_ARBInhibitors') | |
| num48=gr.CheckboxGroup([0,1],label='PreOpMedADPInhibitors5Days') | |
| num49=gr.CheckboxGroup([0,1],label='PreOpMedAntiarrhythmics') | |
| num50=gr.CheckboxGroup([0,1],label='PreOpMedAnticoagulants') | |
| num51=gr.CheckboxGroup([0,1],label='PreOpMedAspirin') | |
| num52=gr.CheckboxGroup([0,1],label='PreOpMedCoumadin') | |
| num53=gr.CheckboxGroup([0,1],label='PreOpMedGPIIbIIIaInhibitor') | |
| num54=gr.CheckboxGroup([0,1],label='PreOpMedINotropes') | |
| num55=gr.CheckboxGroup([0,1],label='PreOpMedLipidLowering') | |
| num56=gr.CheckboxGroup([0,1],label='PreOpMedNitratesIV') | |
| num57=gr.CheckboxGroup([0,1],label='PreOpMedSteroids') | |
| num58=gr.CheckboxGroup([0,1],label='PreOp_BetaBlockers') | |
| num59=gr.CheckboxGroup([0,1],label='PreOp_Ca_Antagonists') | |
| num60=gr.CheckboxGroup([0,1],label='PreOp_Digitalis') | |
| num61=gr.CheckboxGroup([0,1],label='PreOp_Diuretics') | |
| num62=gr.CheckboxGroup([0,1],label='PrevArrhythmiaSurgery') | |
| num63=gr.CheckboxGroup([0,1],label='PrevOthCardPCI') | |
| num64=gr.CheckboxGroup([0,1],label='Previous_CABG') | |
| num65=gr.CheckboxGroup([0,1],label='Previous_CV_Intervention') | |
| num66=gr.CheckboxGroup([0,1],label='Previous_Valve') | |
| num67=gr.CheckboxGroup([0,1],label='PriorHeartFailure') | |
| num68=gr.CheckboxGroup([0,1],label='Pulmonic_Procedure') | |
| num69=gr.CheckboxGroup([0,1],label='Pulmonic_Stenosis') | |
| num70=gr.CheckboxGroup([0,1],label='STS_History.Renal_Failure') | |
| num71=gr.CheckboxGroup([0,1],label='Smoking') | |
| num72=gr.CheckboxGroup([0,1],label='Status') | |
| num73=gr.CheckboxGroup([0,1,2,3,4],label='Tricuspid_Insufficiency') | |
| num74=gr.CheckboxGroup([0,1],label='Tricuspid_Procedure') | |
| num75=gr.CheckboxGroup([0,1],label='VSMitral') | |
| num76=gr.CheckboxGroup([0,1],label='Valve') | |
| num77=gr.CheckboxGroup([0,1],label='ValveDisAortic') | |
| num78=gr.CheckboxGroup([0,1],label='ValveDisMitral') | |
| num79=gr.CheckboxGroup([0,1],label='ValveDisPulmonic') | |
| num80=gr.CheckboxGroup([0,1],label='ValveDisTricuspid') | |
| num81=gr.CheckboxGroup([0,1],label='_MI') | |
| num82=gr.CheckboxGroup([0,1],label='mitral stenosis') | |
| num83=gr.CheckboxGroup([0,1],label='Oth_Afib_0',visible= False) | |
| num84=gr.CheckboxGroup([0,1],label='Oth_Afib_1',visible= False) | |
| b1 = gr.Button("Submit") | |
| example =sample_Df.values.tolist() | |
| gr.Examples(example,inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82]) | |
| output = [gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()] | |
| b1.click(run_inference, inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82], outputs=output) | |
| demo.launch(share=True) |