haiquanchen commited on
Commit
9574a74
·
verified ·
1 Parent(s): aa6599d

Upload 3 files

Browse files
app.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pip install kaleido
2
+ #pip install gradio
3
+ import gradio as gr
4
+
5
+ #import os
6
+ #import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision
12
+ import torchvision.transforms as transforms
13
+ from tqdm.auto import tqdm
14
+ #!pip install einops
15
+
16
+ # Device configuration
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+
19
+ #pip install captum
20
+
21
+ import seaborn as sns
22
+ from captum.attr import LayerConductance
23
+
24
+ from captum.attr import IntegratedGradients
25
+ from captum.attr import configure_interpretable_embedding_layer
26
+
27
+ import matplotlib.pyplot as plt
28
+ from captum.attr import remove_interpretable_embedding_layer
29
+ import torch.nn.functional as F
30
+
31
+ # @title
32
+ import pandas as pd
33
+ import numpy as np
34
+ import tensorflow as tf
35
+
36
+ Raw_data = pd.read_excel('./STS Data with Up to dated AF 09-18-2022 (1).xlsx',usecols=lambda x: 'Unnamed' not in x)
37
+ pd.set_option('display.max_columns', None)
38
+
39
+ Raw_data['Aortic_Insufficiency']=Raw_data['Aortic_Insufficiency'].astype(np.int64)
40
+
41
+
42
+ Postop_columns = ['PostOpMedCoumadin',
43
+ 'PostOpMedLipidLowering',
44
+ 'PostOpMedAspirin',
45
+ 'PostOpMedADPInhibitors',
46
+ 'PostOpMedACE_ARBInhibitors',
47
+ 'STS_PostOp.Renal_Failure',
48
+ 'Oth_Cardiac_Arrest',
49
+ 'Complications_Any',
50
+ 'Neuro_Stroke_Permanent',
51
+ 'Neuro_Stroke_Permanent',
52
+ 'Neuro_Continuous_Coma',
53
+ 'Neuro_Delirium',
54
+ 'PostOpSepsis',
55
+ 'Reop_Bleeding',
56
+ 'Oth_OtherComplication',
57
+ 'PostOpNeuroStrokeTransientTIA',
58
+ 'Infect_Sternum_Deep',
59
+ 'Infect_Thoracotomy',
60
+ 'Pulm_Ventilator_Prolonged',
61
+ 'Pulm_Pneumonia',
62
+ 'Oth_Tamponade',
63
+ 'Oth_Anticoagulant',
64
+ 'Oth_MultiSystem_Failure',
65
+ 'Oth_GI',
66
+ 'Vasc_Ao_Dissection',
67
+ 'Infect_Leg',
68
+ 'PostOpInfectionArm',
69
+ 'OthCard_Pacemaker',
70
+ 'PostOpCreatinineLevel',
71
+ 'Renal_Dialysis_Required',
72
+ 'PostOpBloodRBCUnits',
73
+ 'PostOpBloodFFPUnits',
74
+ 'PostOpBloodCryoUnits',
75
+ 'PostOpBloodPlateletUnits',
76
+ 'ExtubatedI0R',
77
+ 'InitHrsVentilated',
78
+ 'ReIntubated',
79
+ 'No Add Hrs Ventilator',
80
+ 'PostOpVentHoursTotal',
81
+ 'InitHrsICU',
82
+ 'ReadmitICU',
83
+ 'AddICUHours',
84
+ 'TotHrsICU',
85
+ 'DCMed_AntiPlate',
86
+ 'Readmit_LessThan30Days',
87
+ 'Blood_Bank_Products_Used',
88
+ 'PostOpMedAntiarrhythmics',
89
+ 'PostOpMedBetaBlockers']
90
+
91
+
92
+ #Dropping columns
93
+ preop_oper_data = Raw_data.drop(columns=Postop_columns)
94
+ preop_oper_data = preop_oper_data[preop_oper_data.Oth_Afib != -1]
95
+ preop_oper_data=preop_oper_data.drop(['Date_of_Birth','Surgery_Date','Discharge_Date', 'Death_Date','Death-Surery(y)','Mortality30d','Mortality1y','Mortality2y','Mortality3y','Mortality4y','Mortality5y'], axis=1)
96
+ preop_oper_data=preop_oper_data.drop(['EF','Height(cm)','Weight(kg)','CVA_When','Category','Race'], axis=1)
97
+ preop_oper_data=preop_oper_data.drop(['IABP_Indication','IABP_When'],axis=1)
98
+
99
+
100
+ #seperating continuous and Categorical Data
101
+ continous_df = preop_oper_data[['Oth_Afib','Age','BMI','LastCreatinineLevel','Cross_Clamp_Time','Perfusion_Time']]
102
+ categorical_col=list(set(preop_oper_data.columns) - set(continous_df.columns))
103
+ categorical_col.sort()
104
+ #Setting the target
105
+ df_AF=preop_oper_data['Oth_Afib']
106
+ continous_col=continous_df.drop('Oth_Afib',axis=1)
107
+ preop_oper_data=preop_oper_data.drop('Oth_Afib',axis=1)
108
+
109
+
110
+ #Creating categorical df
111
+ preop_oper_data=preop_oper_data[categorical_col]
112
+
113
+ # Label encoding
114
+ from sklearn import preprocessing
115
+ from sklearn.preprocessing import LabelEncoder
116
+ def encode_text_index(df, name):
117
+ le = preprocessing.LabelEncoder()
118
+ df[name] = le.fit_transform(df[name])
119
+ return le.classes_
120
+
121
+ encode_text_index(preop_oper_data,'Introp DEX or nDEX')
122
+ encode_text_index(preop_oper_data,'Status')
123
+ encode_text_index(preop_oper_data,'Gender')
124
+
125
+ #Calculating len of each categorical column
126
+ label_in_each = tuple(len(preop_oper_data[col].unique()) for col in preop_oper_data.columns)
127
+ categorical_col_with_ordinal = preop_oper_data.columns
128
+
129
+ #Making the final data frame
130
+ final_frame=pd.concat([continous_col,preop_oper_data],axis=1)
131
+ final_frame=pd.concat([final_frame,df_AF],axis=1)
132
+
133
+ # Encode a numeric column as zscores
134
+ def encode_numeric_zscore(df, name, mean=None, sd=None):
135
+ if mean is None:
136
+ mean = df[name].mean()
137
+ print(f'mean:{mean}')
138
+
139
+ if sd is None:
140
+ sd = df[name].std()
141
+ print(f'sd:{sd}')
142
+
143
+ df[name] = (df[name] - mean) / sd
144
+
145
+
146
+ for col in continous_col.columns:
147
+ encode_numeric_zscore(final_frame,col)
148
+
149
+
150
+
151
+ #Train test split
152
+ from sklearn.model_selection import train_test_split
153
+ x_train, x_temp, y_train, y_temp = train_test_split(final_frame.iloc[:,:-1], final_frame.iloc[:,-1], test_size=0.25, random_state=42, stratify=final_frame.iloc[:,-1])
154
+
155
+
156
+ print(x_train.shape)
157
+ print(y_train.shape)
158
+ print(x_temp.shape)
159
+ print(y_temp.shape)
160
+
161
+ # Duplicating class 1 records to balance dataset for training
162
+ training_frame = pd.concat([x_train, y_train],axis=1)
163
+ training_frame_ana = training_frame
164
+ class_1_rows = training_frame[training_frame['Oth_Afib'] == 1]
165
+ duplicated_class_1 = class_1_rows.copy()
166
+ training_frame= pd.concat([training_frame, duplicated_class_1,duplicated_class_1], ignore_index=True)
167
+ training_frame['Oth_Afib'].value_counts()
168
+
169
+ # Creating testing df
170
+ testing_frame = pd.concat([x_temp, y_temp],axis=1)
171
+
172
+
173
+ continous_df= continous_df.drop('Oth_Afib', axis=1)
174
+ continous_col=continous_df.columns
175
+ continous_col
176
+
177
+ training_frame_without_label=training_frame.iloc[:,:-1]
178
+ testing_frame_without_label=testing_frame.iloc[:,:-1]
179
+ training_frame=pd.concat([training_frame_without_label,pd.get_dummies(training_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
180
+ testing_frame=pd.concat([testing_frame_without_label,pd.get_dummies(testing_frame.iloc[:,-1],prefix='Oth_Afib',dtype=np.int64)],axis=1)
181
+ testing_frame
182
+
183
+ # @title
184
+
185
+
186
+
187
+ import torch
188
+ import torch.nn.functional as F
189
+ from torch import nn, einsum
190
+
191
+ from einops import rearrange
192
+
193
+ # helpers
194
+
195
+ def exists(val):
196
+ return val is not None
197
+
198
+ def default(val, d):
199
+ return val if exists(val) else d
200
+
201
+ # classes
202
+
203
+ class Residual(nn.Module):
204
+ def __init__(self, fn):
205
+ super().__init__()
206
+ self.fn = fn
207
+
208
+ def forward(self, x, **kwargs):
209
+ return self.fn(x, **kwargs) + x
210
+
211
+ class PreNorm(nn.Module):
212
+ def __init__(self, dim, fn):
213
+ super().__init__()
214
+ self.norm = nn.LayerNorm(dim)
215
+ self.fn = fn
216
+
217
+ def forward(self, x, **kwargs):
218
+ return self.fn(self.norm(x), **kwargs)
219
+
220
+ # attention
221
+
222
+ class GEGLU(nn.Module):
223
+ def forward(self, x):
224
+ x, gates = x.chunk(2, dim = -1)
225
+ return x * F.gelu(gates)
226
+
227
+ class FeedForward(nn.Module):
228
+ def __init__(self, dim, mult = 4, dropout = 0.):
229
+ super().__init__()
230
+ self.net = nn.Sequential(
231
+ nn.Linear(dim, dim * mult * 2),
232
+ GEGLU(),
233
+ nn.Dropout(dropout),
234
+ nn.Linear(dim * mult, dim)
235
+ )
236
+
237
+ def forward(self, x, **kwargs):
238
+ return self.net(x)
239
+
240
+ class Attention(nn.Module):
241
+ def __init__(
242
+ self,
243
+ dim,
244
+ heads = 8,
245
+ dim_head = 16,
246
+ dropout = 0.
247
+ ):
248
+ super().__init__()
249
+ inner_dim = dim_head * heads
250
+ self.heads = heads
251
+ self.scale = dim_head ** -0.5
252
+
253
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
254
+ self.to_out = nn.Linear(inner_dim, dim)
255
+
256
+ self.dropout = nn.Dropout(dropout)
257
+
258
+ def forward(self, x):
259
+ h = self.heads
260
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
261
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
262
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
263
+
264
+ attn = sim.softmax(dim = -1)
265
+ dropped_attn = self.dropout(attn)
266
+
267
+ out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
268
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
269
+ return self.to_out(out), attn
270
+
271
+ # transformer
272
+
273
+ class Transformer(nn.Module):
274
+ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
275
+ super().__init__()
276
+ # torch.manual_seed(1)
277
+ # self.embeds = nn.Embedding(num_tokens, dim)
278
+
279
+
280
+ self.layers = nn.ModuleList([])
281
+
282
+ for _ in range(depth):
283
+ self.layers.append(nn.ModuleList([
284
+ PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
285
+ PreNorm(dim, FeedForward(dim, dropout = ff_dropout)),
286
+ ]))
287
+
288
+ def forward(self, x, return_attn = False):
289
+ # x = self.embeds(x)
290
+
291
+ post_softmax_attns = []
292
+
293
+ for attn, ff in self.layers:
294
+ attn_out, post_softmax_attn = attn(x)
295
+ post_softmax_attns.append(post_softmax_attn)
296
+
297
+ x = x + attn_out
298
+ x = ff(x) + x
299
+
300
+ if not return_attn:
301
+ return x
302
+
303
+ # return x, torch.stack(post_softmax_attns)
304
+ return x
305
+ # mlp
306
+
307
+ class MLP(nn.Module):
308
+ def __init__(self, dims, act = None):
309
+ super().__init__()
310
+ dims_pairs = list(zip(dims[:-1], dims[1:]))
311
+ layers = []
312
+ for ind, (dim_in, dim_out) in enumerate(dims_pairs):
313
+ is_last = ind >= (len(dims_pairs) - 1)
314
+ linear = nn.Linear(dim_in, dim_out)
315
+ layers.append(linear)
316
+
317
+ if is_last:
318
+ continue
319
+
320
+ act = default(act, nn.ReLU())
321
+ layers.append(act)
322
+
323
+ self.mlp = nn.Sequential(*layers)
324
+
325
+ def forward(self, x):
326
+ return self.mlp(x)
327
+
328
+
329
+ class NumericalEmbedder(nn.Module):
330
+ def __init__(self, dim, num_numerical_types):
331
+ super().__init__()
332
+ self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
333
+ self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))
334
+
335
+ def forward(self, x):
336
+ x = rearrange(x, 'b n -> b n 1')
337
+ return x * self.weights + self.biases
338
+
339
+ class CategoricalEmbedder(nn.Module):
340
+ def __init__(self, total_tokens,dim):
341
+ super().__init__()
342
+ self.embeds = nn.Embedding(total_tokens, dim)
343
+
344
+
345
+ def forward(self, x):
346
+ x_embed = self.embeds(x)
347
+ return x_embed
348
+
349
+
350
+ class CatConLayer(nn.Module):
351
+
352
+ def __init__(self, dim , heads ):
353
+ super().__init__()
354
+
355
+ self.cat_con_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
356
+ self.con_cat_multihead_attn = torch.nn.MultiheadAttention(dim , heads , dropout = 0.8)
357
+
358
+
359
+ def forward(self,attn_cat,attn_con,need_weights=False):
360
+
361
+
362
+ cat_Q,_ = self.cat_con_multihead_attn(attn_cat,attn_con,attn_con)
363
+
364
+ con_Q,_ = self.con_cat_multihead_attn(attn_con,attn_cat,attn_cat)
365
+
366
+ cat_Q=cat_Q.permute(1, 0, 2)
367
+ con_Q=con_Q.permute(1, 0, 2)
368
+ # output_concat = torch.cat([cat_Q, con_Q], dim=0)
369
+ return cat_Q,con_Q
370
+
371
+ # main class
372
+
373
+ class Co_Transformer(nn.Module):
374
+ def __init__(
375
+ self,
376
+ *,
377
+ categories,
378
+ num_continuous,
379
+ dim,
380
+ depth,
381
+ heads,
382
+ dim_head = 16,
383
+ dim_out = 1,
384
+ mlp_hidden_mults = (2,1),
385
+ mlp_act = None,
386
+
387
+ num_special_tokens = 0,
388
+ continuous_mean_std = None,
389
+ attn_dropout = 0.,
390
+ ff_dropout = 0.
391
+ ):
392
+ super().__init__()
393
+ assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
394
+ assert len(categories) + num_continuous > 0, 'input shape must not be null'
395
+
396
+
397
+
398
+ self.num_categories = len(categories)
399
+
400
+
401
+ self.num_unique_categories = sum(categories)
402
+
403
+
404
+
405
+ self.num_special_tokens = num_special_tokens #0
406
+ total_tokens = self.num_unique_categories + num_special_tokens
407
+
408
+ # for automatically offsetting unique category ids to the correct position in the categories embedding table
409
+
410
+ if self.num_unique_categories > 0:
411
+ categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
412
+
413
+ categories_offset = categories_offset.cumsum(dim = -1)[:-1]
414
+
415
+ self.register_buffer('categories_offset', categories_offset)
416
+ self.embeds = CategoricalEmbedder(total_tokens, dim)
417
+
418
+ # continuous
419
+ self.num_continuous = num_continuous
420
+
421
+
422
+
423
+ if self.num_continuous > 0:
424
+
425
+ self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)
426
+
427
+
428
+ # transformer
429
+
430
+ self.transformer_cat = Transformer(
431
+
432
+ dim = dim,
433
+ depth = depth,
434
+ heads = heads,
435
+ dim_head = dim_head,
436
+ attn_dropout = attn_dropout,
437
+ ff_dropout = ff_dropout
438
+ )
439
+
440
+ self.transformer_con = Transformer(
441
+
442
+ dim = dim,
443
+ depth = depth,
444
+ heads = heads,
445
+ dim_head = dim_head,
446
+ attn_dropout = attn_dropout,
447
+ ff_dropout = ff_dropout
448
+ )
449
+ # fusion-part
450
+
451
+ self.catconlayer = CatConLayer(
452
+ dim=dim ,
453
+ heads=heads
454
+ )
455
+
456
+ # mlp to logits
457
+
458
+ input_size = dim * (self.num_categories + num_continuous)
459
+ print(f'input size{input_size}')
460
+ l = input_size // 5
461
+ hidden_dimensions = list(map(lambda t: int(l * t), mlp_hidden_mults))
462
+ all_dimensions = [input_size, *hidden_dimensions, dim_out]
463
+ self.mlp = MLP(all_dimensions, act = mlp_act)
464
+ # print(f" mlp {self.mlp}")
465
+
466
+ def forward(self, x_categ, x_cont, return_attn = True):
467
+
468
+
469
+
470
+
471
+ x_categ = self.embeds(x_categ)
472
+ # x_cat_, attns_cat = self.transformer(x_categ, return_attn = True)
473
+ x_cat_ = self.transformer_cat(x_categ, return_attn = True)
474
+ permuted_x_cat_= x_cat_.permute(1, 0, 2)
475
+
476
+
477
+
478
+ x_numer = self.numerical_embedder(x_cont)
479
+ # x_con_, attns_con = self.transformer(x_numer, return_attn = True)
480
+ x_con_ = self.transformer_con(x_numer, return_attn = True)
481
+ permuted_x_con_= x_con_.permute(1, 0, 2)
482
+
483
+
484
+
485
+ cat_Q,con_Q = self.catconlayer(permuted_x_cat_,permuted_x_con_ )
486
+
487
+ can_con_attn_output = torch.cat([cat_Q, con_Q], dim=1)
488
+ # permuted_can_con_attn_output= can_con_attn_output.permute(1, 0, 2)
489
+
490
+
491
+ can_con_attn_output_flattend= can_con_attn_output.flatten(1)
492
+
493
+
494
+ logits=self.mlp(can_con_attn_output_flattend)
495
+
496
+ return logits
497
+
498
+
499
+ def build_network(depth,heads,dim):
500
+
501
+ model = Co_Transformer(
502
+ categories = label_in_each , # tuple containing the number of unique values within each category
503
+ num_continuous = final_frame[continous_col].shape[1], # number of continuous values
504
+ dim = dim, # dimension, paper set at 32
505
+ dim_out = 2, # binary prediction, but could be anything
506
+ depth = depth, # depth, paper recommended 6
507
+ heads = heads, # heads, paper recommends 8
508
+ attn_dropout = 0.1, # post-attention dropout
509
+ ff_dropout = 0.1, # feed forward dropout
510
+ mlp_hidden_mults =((2,1,0.5,0.25)), # relative multiples of each hidden dimension of the last mlp to logits
511
+ mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
512
+
513
+ continuous_mean_std = torch.tensor(continous_df.agg(['mean','std']).transpose().values, dtype=torch.float32) # (optional) - normalize the continuous values before layer norm
514
+
515
+
516
+ )
517
+
518
+ return model
519
+
520
+ model = build_network(8,8,64)
521
+
522
+ model.load_state_dict(torch.load('./co_attention_transformer_model_trained.pth',map_location=torch.device('cpu')))
523
+ # print("Model Loaded!")
524
+
525
+ sample_Df=pd.read_csv('./sample_data.csv')
526
+
527
+ # @title
528
+
529
+ def run_inference(num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82):
530
+
531
+ mean1=62.63038219641993
532
+ sd1=12.280987675098727
533
+ mean2=29.238482057573915
534
+ sd2=6.511823065330923
535
+ mean3=1.2360909530720852
536
+ sd3=1.0307534004581604
537
+ mean4=137.98209966134493
538
+ sd4=58.71032609323997
539
+ mean5=193.39671020803095
540
+ sd5=79.63715724430536
541
+
542
+ num1 = (num1 - mean1)/sd1
543
+ num2 = (num2 - mean2)/sd2
544
+ num3 = (num3 - mean3)/sd3
545
+ num4 = (num4 - mean4)/sd4
546
+ num5 = (num5 - mean5)/sd5
547
+
548
+ list__inputs = [num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82]
549
+ print(list__inputs)
550
+
551
+ if (num0 == 'First_non_AFib' or num0 == 'Second_non_AFib'):
552
+ target_set = 0
553
+ else:
554
+ target_set = 1
555
+
556
+ # Remove specific elements from nested lists at specified indices
557
+ result_list =[item for sublist in list__inputs for item in (sublist if isinstance(sublist, list) else [sublist])]
558
+ print(result_list)
559
+ con = torch.tensor(result_list[0:5],dtype=torch.float32).reshape(1,-1)
560
+ print(con,con.shape,con.device.type)
561
+ cat = torch.tensor(result_list[5:82],dtype=torch.long).reshape(1,-1)
562
+ print(cat,cat.shape,cat.device.type)
563
+ model.eval()
564
+ output_tup=model(cat,con)
565
+ prob=F.softmax(output_tup,dim=-1)
566
+ print(prob[0][0].detach(),prob[0][1].detach())
567
+
568
+ # Categories for the bar plot
569
+ categories = ['Non-AFIB', 'A-FIB']
570
+ # Values for the bar plot
571
+ values = [(prob[0][0]).detach().numpy(), (prob[0][1]).detach().numpy()]
572
+ fig1 = plt.figure()
573
+ plt.barh(categories, values, color=['green', 'red'])
574
+ plt.xlabel('Values')
575
+ plt.ylabel('Labels')
576
+ # cal embedding attributes
577
+ ig = IntegratedGradients(model)
578
+
579
+ interpretable_embedding_cat = configure_interpretable_embedding_layer(model, 'embeds')
580
+ interpretable_embedding_con = configure_interpretable_embedding_layer(model, 'numerical_embedder')
581
+
582
+
583
+
584
+ emb_cat = interpretable_embedding_cat.indices_to_embeddings(cat)
585
+ emb_con = interpretable_embedding_con.indices_to_embeddings(con)
586
+ print(emb_cat.device.type)
587
+ baseline_cat = torch.zeros_like(emb_cat) # Set numerical baseline to zero
588
+ baseline_con = torch.zeros_like(emb_con)
589
+ emb_cat.requires_grad_
590
+ emb_cat.requires_grad_
591
+ attr, delta =ig.attribute((emb_cat, emb_con),baselines = (baseline_cat,baseline_con) ,target=target_set, return_convergence_delta=True, n_steps=50)
592
+ print("calculating attr")
593
+ print(attr[0].shape)
594
+ print(attr[1].shape)
595
+
596
+ categ_attr = (attr[0]).sum(dim=-1).squeeze(0)
597
+ cond_attr = (attr[1]).sum(dim=-1).squeeze(0)
598
+
599
+ concatenated_tensor = torch.cat([cond_attr, categ_attr],dim=0)
600
+ print(concatenated_tensor.device.type)
601
+
602
+ x_pos = (np.arange(len(testing_frame.iloc[:,0:-2].columns)))
603
+
604
+ fig2 = plt.figure(figsize=(30,6))
605
+
606
+ plt.bar(x_pos,concatenated_tensor.squeeze().cpu().detach().numpy(), align='center', color = 'red')
607
+ plt.xticks(x_pos,testing_frame.iloc[:,0:-2].columns, wrap=True)
608
+ plt.xticks(rotation=45)
609
+ plt.xlabel('Features')
610
+ plt.title('Embedded layer attributes')
611
+
612
+ # layer attributes
613
+
614
+ attn_con_cat = []
615
+ attn_con_cat.append(concatenated_tensor.detach().cpu())
616
+
617
+ for i in range(len(model.transformer_con.layers)):
618
+ con_module = [module for module in model.transformer_con.layers[i]]
619
+ layeroutput_con = []
620
+ for j in range(len(con_module)):
621
+ lc_con = LayerConductance(model, con_module[j])
622
+ layer_attributions_con= lc_con.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
623
+ if(type(layer_attributions_con) == "tuple"):
624
+ layeroutput_con.append(layer_attributions_con[0])
625
+ else:
626
+ layeroutput_con.append(layer_attributions_con)
627
+ attn_out_con = emb_con + layeroutput_con[0][0] + layeroutput_con[1]
628
+ attn_out_con=attn_out_con.sum(dim=-1).squeeze(0)
629
+
630
+ cat_module = [module for module in model.transformer_cat.layers[i]]
631
+ layeroutput_cat = []
632
+ for j in range(len(cat_module)):
633
+ lc_cat = LayerConductance(model, cat_module[j])
634
+ layer_attributions_cat= lc_cat.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
635
+ if(type(layer_attributions_cat) == "tuple"):
636
+ layeroutput_cat.append(layer_attributions_cat[0])
637
+ else:
638
+ layeroutput_cat.append(layer_attributions_cat)
639
+ attn_out_cat = emb_cat + layeroutput_cat[0][0] + layeroutput_cat[1]
640
+ attn_out_cat=attn_out_cat.sum(dim=-1).squeeze(0)
641
+
642
+ attn_con_cat.append((torch.cat([attn_out_con,attn_out_cat])).detach().cpu())
643
+
644
+ lc = LayerConductance(model, model.catconlayer)
645
+ layer_attributions_start = lc.attribute((emb_cat,emb_con), baselines=(baseline_cat,baseline_con),target=target_set,n_steps=50)
646
+ value_coattn_cat=layer_attributions_start[0].sum(dim=-1).squeeze(0)
647
+ value_coattn_con=layer_attributions_start[1].sum(dim=-1).squeeze(0)
648
+ attn_con_cat.append(torch.cat([value_coattn_cat,value_coattn_con]).detach().cpu())
649
+ # fig 3
650
+ fig3, axes = plt.subplots(figsize=(15, 12),frameon=False)
651
+
652
+ for spine in plt.gca().spines.values():
653
+ spine.set_visible(False)
654
+
655
+ axes.xaxis.set_major_locator(plt.NullLocator())
656
+ axes.yaxis.set_major_locator(plt.NullLocator())
657
+
658
+ for i,k in enumerate(testing_frame.iloc[:,:-60].columns):
659
+
660
+ cmap = sns.color_palette("Reds")
661
+ # cmap = sns.cm.rocket_r
662
+ ax = fig3.add_subplot(5,5, i+1)
663
+
664
+ xticklabels=[k]
665
+ yticklabels=list(range(1,9))
666
+ ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
667
+ plt.xlabel('features')
668
+ plt.ylabel('Layers')
669
+ plt.tight_layout()
670
+
671
+ # fig 4
672
+ fig4, axes = plt.subplots(figsize=(15, 12),frameon=False)
673
+
674
+ for spine in plt.gca().spines.values():
675
+ spine.set_visible(False)
676
+
677
+ axes.xaxis.set_major_locator(plt.NullLocator())
678
+ axes.yaxis.set_major_locator(plt.NullLocator())
679
+
680
+ for i,k in enumerate(testing_frame.iloc[:,24:-30].columns):
681
+
682
+ cmap = sns.color_palette("Reds")
683
+ # cmap = sns.cm.rocket_r
684
+ ax = fig4.add_subplot(6,5, i+1)
685
+
686
+ xticklabels=[k]
687
+ yticklabels=list(range(1,9))
688
+ ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
689
+ plt.xlabel('features')
690
+ plt.ylabel('Layers')
691
+ plt.tight_layout()
692
+
693
+ # fig 5
694
+ fig5, axes = plt.subplots(figsize=(15, 12),frameon=False)
695
+
696
+ for spine in plt.gca().spines.values():
697
+ spine.set_visible(False)
698
+
699
+ axes.xaxis.set_major_locator(plt.NullLocator())
700
+ axes.yaxis.set_major_locator(plt.NullLocator())
701
+
702
+ for i,k in enumerate(testing_frame.iloc[:,54:-2].columns):
703
+
704
+ cmap = sns.color_palette("Reds")
705
+ # cmap = sns.cm.rocket_r
706
+ ax = fig5.add_subplot(6,5, i+1)
707
+
708
+ xticklabels=[k]
709
+ yticklabels=list(range(1,9))
710
+ ax = sns.heatmap(np.array(torch.stack(attn_con_cat)[1:9])[:,i].reshape(-1,1),ax=ax,xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap=cmap)
711
+ plt.xlabel('features')
712
+ plt.ylabel('Layers')
713
+ plt.tight_layout()
714
+
715
+ #fig6
716
+ x_pos = (np.arange(len(testing_frame.iloc[:,:-2].columns)))
717
+
718
+ fig6 = plt.figure(figsize=(30,6))
719
+
720
+ plt.bar(x_pos, torch.stack(attn_con_cat)[9], align='center', color = 'red')
721
+ plt.xticks(x_pos,testing_frame.iloc[:,:-2].columns, wrap=True)
722
+ plt.xticks(rotation=45)
723
+ plt.xlabel('features')
724
+ plt.title('Attribution of co-attention layer')
725
+
726
+
727
+
728
+
729
+ remove_interpretable_embedding_layer(model, interpretable_embedding_con)
730
+ remove_interpretable_embedding_layer(model, interpretable_embedding_cat)
731
+ return fig1 , fig2 , fig3 , fig4, fig5, fig6
732
+
733
+
734
+ demo = gr.Blocks()
735
+
736
+
737
+
738
+ with demo:
739
+
740
+ gr.Markdown(
741
+ """
742
+ # Post-Operative Artrial Fibrillation Demo
743
+
744
+ Select values for the following and click submit to see the results:
745
+ """)
746
+ num0=gr.Textbox(visible = False)
747
+ num1=gr.Slider(0,100,label='Age',step=1)
748
+ num2=gr.Slider(0,100,label='BMI')
749
+ num3=gr.Slider(0,20,label='LastCreatinineLevel')
750
+ num4=gr.Slider(0,1000,label='Cross_Clamp_Time',step=1)
751
+ num5=gr.Slider(0,1000,label='Perfusion_Time',step=1)
752
+ num6=gr.Slider(0,8,label='# of coronary vessels corrected',step=1)
753
+ num7=gr.CheckboxGroup([0,1],label='Aortic stenosis')
754
+ num8=gr.CheckboxGroup([0,1,2,3,4],label='Aortic_Insufficiency')
755
+ num9=gr.CheckboxGroup([0,1],label='Aortic_Procedure')
756
+ num10=gr.CheckboxGroup([0,1],label='Arrhythmia')
757
+ num11=gr.CheckboxGroup([0,1],label='ArrhythmiaAfibAflutter')
758
+ num12=gr.CheckboxGroup([0,1],label='CABG')
759
+ num13=gr.CheckboxGroup([0,1],label='CHF')
760
+ num14=gr.CheckboxGroup([0,1],label='CVA')
761
+ num15=gr.CheckboxGroup([0,1],label='Cardiogenic_Shock')
762
+ num16=gr.CheckboxGroup([0,1],label='Cerebrovascular_Disease')
763
+ num17=gr.CheckboxGroup([0,1,2,3],label='ChronicLungDisease')
764
+ num18=gr.CheckboxGroup([0,1],label='Diabetes')
765
+ num19=gr.CheckboxGroup([0,1],label='Dialysis')
766
+ num20=gr.Slider(0,7,label='DistAnasVein',step=1)
767
+ num21=gr.Slider(0,6,label='DistAnastArt',step=1)
768
+ num22=gr.CheckboxGroup([0,1],label='Family_History_CAD')
769
+ num23=gr.CheckboxGroup([0,1],label='Gender')
770
+ num24=gr.CheckboxGroup([0,1],label='Hypercholesterolemia')
771
+ num25=gr.CheckboxGroup([0,1],label='Hypertension')
772
+ num26=gr.CheckboxGroup([0,1],label='IABP')
773
+ num27=gr.CheckboxGroup([0,1],label='Infectious_Endocarditis')
774
+ num28=gr.Slider(0,6,label='IntraopBloodCryo',step=1)
775
+ num29=gr.Slider(0,8,label='IntraopBloodFFP',step=1)
776
+ num30=gr.CheckboxGroup([0,1,2],label='IntraopBloodFactorVII')
777
+ num31=gr.Slider(0,8,label='IntraopBloodPlatelet',step=1)
778
+ num32=gr.CheckboxGroup([0,1],label='IntraopBloodProducts')
779
+ num33=gr.Slider(0,20,label='IntraopBloodRBC',step=1)
780
+ num34=gr.CheckboxGroup([0,1],label='IntraopMedEpsilonAmi0Caproic')
781
+ num35=gr.CheckboxGroup([0,1],label='IntraopMedTranexamicAcid')
782
+ num36=gr.CheckboxGroup([0,1],label='Introp DEX or nDEX')
783
+ num37=gr.CheckboxGroup([0,1],label='Left_Main_Disease')
784
+ num38=gr.CheckboxGroup([0,1],label='MACE')
785
+ num39=gr.CheckboxGroup([0,1],label='MedsG2b3aInhibitorMed')
786
+ num40=gr.CheckboxGroup([0,1,2,3,4],label='Mitral_Insufficiency')
787
+ num41=gr.CheckboxGroup([0,1],label='OthCard_AICD')
788
+ num42=gr.CheckboxGroup([0,1],label='Oth_Heart_Block')
789
+ num43=gr.CheckboxGroup([0,1],label='Other_Cardiac_Intervention')
790
+ num44=gr.CheckboxGroup([0,1],label='Peri_Op_MI')
791
+ num45=gr.CheckboxGroup([0,1],label='Peripheral_Vasc_Disease')
792
+ num46=gr.CheckboxGroup([0,1],label='PreOpMed Antiplatelets')
793
+ num47=gr.CheckboxGroup([0,1],label='PreOpMedACE_ARBInhibitors')
794
+ num48=gr.CheckboxGroup([0,1],label='PreOpMedADPInhibitors5Days')
795
+ num49=gr.CheckboxGroup([0,1],label='PreOpMedAntiarrhythmics')
796
+ num50=gr.CheckboxGroup([0,1],label='PreOpMedAnticoagulants')
797
+ num51=gr.CheckboxGroup([0,1],label='PreOpMedAspirin')
798
+ num52=gr.CheckboxGroup([0,1],label='PreOpMedCoumadin')
799
+ num53=gr.CheckboxGroup([0,1],label='PreOpMedGPIIbIIIaInhibitor')
800
+ num54=gr.CheckboxGroup([0,1],label='PreOpMedINotropes')
801
+ num55=gr.CheckboxGroup([0,1],label='PreOpMedLipidLowering')
802
+ num56=gr.CheckboxGroup([0,1],label='PreOpMedNitratesIV')
803
+ num57=gr.CheckboxGroup([0,1],label='PreOpMedSteroids')
804
+ num58=gr.CheckboxGroup([0,1],label='PreOp_BetaBlockers')
805
+ num59=gr.CheckboxGroup([0,1],label='PreOp_Ca_Antagonists')
806
+ num60=gr.CheckboxGroup([0,1],label='PreOp_Digitalis')
807
+ num61=gr.CheckboxGroup([0,1],label='PreOp_Diuretics')
808
+ num62=gr.CheckboxGroup([0,1],label='PrevArrhythmiaSurgery')
809
+ num63=gr.CheckboxGroup([0,1],label='PrevOthCardPCI')
810
+ num64=gr.CheckboxGroup([0,1],label='Previous_CABG')
811
+ num65=gr.CheckboxGroup([0,1],label='Previous_CV_Intervention')
812
+ num66=gr.CheckboxGroup([0,1],label='Previous_Valve')
813
+ num67=gr.CheckboxGroup([0,1],label='PriorHeartFailure')
814
+ num68=gr.CheckboxGroup([0,1],label='Pulmonic_Procedure')
815
+ num69=gr.CheckboxGroup([0,1],label='Pulmonic_Stenosis')
816
+ num70=gr.CheckboxGroup([0,1],label='STS_History.Renal_Failure')
817
+ num71=gr.CheckboxGroup([0,1],label='Smoking')
818
+ num72=gr.CheckboxGroup([0,1],label='Status')
819
+ num73=gr.CheckboxGroup([0,1,2,3,4],label='Tricuspid_Insufficiency')
820
+ num74=gr.CheckboxGroup([0,1],label='Tricuspid_Procedure')
821
+ num75=gr.CheckboxGroup([0,1],label='VSMitral')
822
+ num76=gr.CheckboxGroup([0,1],label='Valve')
823
+ num77=gr.CheckboxGroup([0,1],label='ValveDisAortic')
824
+ num78=gr.CheckboxGroup([0,1],label='ValveDisMitral')
825
+ num79=gr.CheckboxGroup([0,1],label='ValveDisPulmonic')
826
+ num80=gr.CheckboxGroup([0,1],label='ValveDisTricuspid')
827
+ num81=gr.CheckboxGroup([0,1],label='_MI')
828
+ num82=gr.CheckboxGroup([0,1],label='mitral stenosis')
829
+ num83=gr.CheckboxGroup([0,1],label='Oth_Afib_0',visible= False)
830
+ num84=gr.CheckboxGroup([0,1],label='Oth_Afib_1',visible= False)
831
+
832
+
833
+
834
+ b1 = gr.Button("Submit")
835
+ example =sample_Df.values.tolist()
836
+ gr.Examples(example,inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82])
837
+
838
+ output = [gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()]
839
+
840
+ b1.click(run_inference, inputs=[num0,num1,num2,num3,num4,num5,num6,num7,num8,num9,num10,num11,num12,num13,num14,num15,num16,num17,num18,num19,num20,num21,num22,num23,num24,num25,num26,num27,num28,num29,num30,num31,num32,num33,num34,num35,num36,num37,num38,num39,num40,num41,num42,num43,num44,num45,num46,num47,num48,num49,num50,num51,num52,num53,num54,num55,num56,num57,num58,num59,num60,num61,num62,num63,num64,num65,num66,num67,num68,num69,num70,num71,num72,num73,num74,num75,num76,num77,num78,num79,num80,num81,num82], outputs=output)
841
+
842
+
843
+
844
+ demo.launch(share=True,auth=('poaf-users','dshrebs__324'))
co_attention_transformer_model_trained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99b2456a353c08c7fcab9a4f6f764c79c9995e0c703954f2a78174e81e7a7fdc
3
+ size 61186764
sample_data.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Records,Age_,BMI_,LastCreatineLevel_,CrossClampTime_,PerfusionTime_,# of coronary vessels corrected,Aortic stenosis,Aortic_Insufficiency,Aortic_Procedure,Arrhythmia,ArrhythmiaAfibAflutter,CABG,CHF,CVA,Cardiogenic_Shock,Cerebrovascular_Disease,ChronicLungDisease,Diabetes,Dialysis,DistAnasVein,DistAnastArt,Family_History_CAD,Gender,Hypercholesterolemia,Hypertension,IABP,Infectious_Endocarditis,IntraopBloodCryo,IntraopBloodFFP,IntraopBloodFactorVII,IntraopBloodPlatelet,IntraopBloodProducts,IntraopBloodRBC,IntraopMedEpsilonAmi0Caproic,IntraopMedTranexamicAcid,Introp DEX or nDEX,Left_Main_Disease,MACE,MedsG2b3aInhibitorMed,Mitral_Insufficiency,OthCard_AICD,Oth_Heart_Block,Other_Cardiac_Intervention,Peri_Op_MI,Peripheral_Vasc_Disease,PreOpMed Antiplatelets,PreOpMedACE_ARBInhibitors,PreOpMedADPInhibitors5Days,PreOpMedAntiarrhythmics,PreOpMedAnticoagulants,PreOpMedAspirin,PreOpMedCoumadin,PreOpMedGPIIbIIIaInhibitor,PreOpMedINotropes,PreOpMedLipidLowering,PreOpMedNitratesIV,PreOpMedSteroids,PreOp_BetaBlockers,PreOp_Ca_Antagonists,PreOp_Digitalis,PreOp_Diuretics,PrevArrhythmiaSurgery,PrevOthCardPCI,Previous_CABG,Previous_CV_Intervention,Previous_Valve,PriorHeartFailure,Pulmonic_Procedure,Pulmonic_Stenosis,STS_History.Renal_Failure,Smoking,Status,Tricuspid_Insufficiency,Tricuspid_Procedure,VSMitral,Valve,ValveDisAortic,ValveDisMitral,ValveDisPulmonic,ValveDisTricuspid,_MI,mitral stenosis
2
+ First_non_AFib,73,28.73,1.3,164,199,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
3
+ Second_non_AFib,84,24.63,1.5,220,278,5.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,1.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,2.0,1.0,7.0,1.0,0.0,1.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,2.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
4
+ First_Afib,63,41.03,0.9,119,152,0.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
5
+ Second_Afib,52,26.74,4.5,36,59,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0