seyonec commited on
Commit
2e7fa7b
·
1 Parent(s): 4520ee8

clean benchmarking on new against max-sep

Browse files
clean_selection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ For using retrieval-conditioned conformal prediction on the CLEAN model, we assume you have followed the steps to install the CLEAN package (link github repo). From there, after loading the pre-trained models and datasets, copy the `data/` directory into the `clean_selection` folder and all methods should run off the fly.
clean_selection/analyze_clean_hierarchical_loss_protein_vec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b4d4e54f9a6d53d0fa29ae72291e719a394f18bae83a55408ad867934a19657b
3
- size 49522
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:462ddc0571c5bb0200078382f301c55a81ff68ba40a0ac71187344ea287f2fb2
3
+ size 66885
clean_selection/clean_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from CLEAN.utils import *
3
+ from CLEAN.distance_map import *
4
+ from CLEAN.evaluate import *
5
+ from CLEAN.model import LayerNormNet
6
+ from sklearn.metrics import precision_score, recall_score, \
7
+ roc_auc_score, accuracy_score, f1_score, average_precision_score
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import pandas as pd
12
+ import pickle
13
+
14
+ def get_true_labels_test(file_name, test_idx: None):
15
+ result = open(file_name+'.csv', 'r')
16
+ csvreader = csv.reader(result, delimiter='\t')
17
+ all_label = set()
18
+ true_label_dict = {}
19
+ header = True
20
+ count = 0
21
+ for row in csvreader:
22
+ # don't read the header
23
+ if header is False:
24
+ count += 1
25
+ true_ec_lst = row[1].split(';')
26
+ true_label_dict[row[0]] = true_ec_lst
27
+ for ec in true_ec_lst:
28
+ if test_idx is not None and count - 1 in test_idx:
29
+ all_label.add(ec)
30
+ if header:
31
+ header = False
32
+ true_label = [true_label_dict[i] for i in true_label_dict.keys()]
33
+ if test_idx is not None:
34
+ true_label = [true_label[i] for i in test_idx]
35
+ return true_label, all_label
36
+
37
+
38
+ def infer_conformal(train_data, test_data, thresh, report_metrics = False,
39
+ pretrained=True, model_name=None, test_idx=None, name_id="1"):
40
+ use_cuda = torch.cuda.is_available()
41
+ device = torch.device("cuda:0" if use_cuda else "cpu")
42
+ dtype = torch.float32
43
+ id_ec_train, ec_id_dict_train = get_ec_id_dict('./data/' + train_data + '.csv')
44
+ id_ec_test, _ = get_ec_id_dict('./data/' + test_data + '.csv')
45
+ # load checkpoints
46
+ # NOTE: change this to LayerNormNet(512, 256, device, dtype)
47
+ # and rebuild with [python build.py install]
48
+ # if inferencing on model trained with supconH loss
49
+ model = LayerNormNet(512, 128, device, dtype)
50
+
51
+ if pretrained:
52
+ try:
53
+ checkpoint = torch.load('./data/pretrained/'+ train_data +'.pth', map_location=device)
54
+ except FileNotFoundError as error:
55
+ raise Exception('No pretrained weights for this training data')
56
+ else:
57
+ try:
58
+ checkpoint = torch.load('./data/model/'+ model_name +'.pth', map_location=device)
59
+ except FileNotFoundError as error:
60
+ raise Exception('No model found!')
61
+
62
+ model.load_state_dict(checkpoint)
63
+ model.eval()
64
+ # load precomputed EC cluster center embeddings if possible
65
+ if train_data == "split70":
66
+ emb_train = torch.load('./data/pretrained/70.pt', map_location=device)
67
+ elif train_data == "split100":
68
+ emb_train = torch.load('./data/pretrained/100.pt', map_location=device)
69
+ else:
70
+ emb_train = model(esm_embedding(ec_id_dict_train, device, dtype))
71
+
72
+ emb_test = model_embedding_test(id_ec_test, model, device, dtype)
73
+ eval_dist = get_dist_map_test(emb_train, emb_test, ec_id_dict_train, id_ec_test, device, dtype)
74
+ seed_everything()
75
+ eval_df = pd.DataFrame.from_dict(eval_dist)
76
+ ensure_dirs("./results")
77
+ out_filename = "results/" + test_data + name_id
78
+ if test_idx is None:
79
+ idx = [i for i in range(len(id_ec_test))]
80
+ write_conformal_choices(eval_df, out_filename, threshold=thresh, test_idx=idx)
81
+ else:
82
+ write_conformal_choices(eval_df, out_filename, threshold=thresh, test_idx=test_idx)
83
+ if report_metrics:
84
+ pred_label = get_pred_labels(out_filename, pred_type='_conformal')
85
+ pred_probs = get_pred_probs(out_filename, pred_type='_conformal')
86
+ true_label, all_label = get_true_labels_test('./data/' + test_data, test_idx=test_idx if test_idx is not None else None)
87
+ pre, rec, f1, roc, acc = get_eval_metrics(
88
+ pred_label, pred_probs, true_label, all_label)
89
+ print("############ EC calling results using conformal calibration on randomly shuffled test set ############")
90
+ print('-' * 75)
91
+ print(f'>>> total samples: {len(true_label)} | total ec: {len(all_label)} \n'
92
+ f'>>> precision: {pre:.3} | recall: {rec:.3}'
93
+ f'| F1: {f1:.3} | AUC: {roc:.3} ')
94
+ print('-' * 75)
95
+
96
+
97
+ ## In theory, we should be able to use the lambda we find on the raw eval distance map,
98
+ ## slice the test set out of it, and pass it into a small method, infer_confromal
99
+ ## that will take in the eval_df and the lambda, and write the choices using this method,
100
+ ## then report all the metrics we want.
101
+ def write_conformal_choices(df, csv_name, threshold, test_idx: list):
102
+ """
103
+ df: dataframe containing the distances between the test set and the EC centroids
104
+ csv_name: name of the csv file to write the choices to
105
+ threshold: threshold to use for the choices (euclidean distance by default, so <=)
106
+ test_idx: list of indices of the test set within the dataframe. this is how we splice the columns
107
+ to get the ones we want to test on, not the ones calibrated on.
108
+ """
109
+ out_file = open(csv_name + '_conformal.csv', 'w', newline='')
110
+ csvwriter = csv.writer(out_file, delimiter=',')
111
+ dists = []
112
+ for col in df.iloc[:, test_idx].columns:
113
+ ec = []
114
+ dist_lst = []
115
+ ## grsb EC numbers bounded by threshold
116
+ smallest_dists_thresh = df[col][df[col] <= threshold]
117
+ for i in range(len(smallest_dists_thresh)):
118
+ EC_i = smallest_dists_thresh.index[i]
119
+ dist_i = smallest_dists_thresh[i]
120
+ dist_str = "{:.4f}".format(dist_i)
121
+ dist_lst.append(dist_i)
122
+ ec.append('EC:' + str(EC_i) + '/' + dist_str)
123
+ ec.insert(0, col)
124
+ dists.append(dist_lst)
125
+ csvwriter.writerow(ec)
126
+ return dists
127
+
128
+
129
+ ## Below code is taken from CLEAN/evaluate.py, but modified to take in the test_idx and only eval on that
130
+
131
+ def infer_maxsep(train_data, test_data, report_metrics = False,
132
+ pretrained=True, model_name=None, gmm = None, test_idx=None):
133
+ use_cuda = torch.cuda.is_available()
134
+ device = torch.device("cuda:0" if use_cuda else "cpu")
135
+ dtype = torch.float32
136
+ id_ec_train, ec_id_dict_train = get_ec_id_dict('./data/' + train_data + '.csv')
137
+ id_ec_test, _ = get_ec_id_dict('./data/' + test_data + '.csv')
138
+ # load checkpoints
139
+ # NOTE: change this to LayerNormNet(512, 256, device, dtype)
140
+ # and rebuild with [python build.py install]
141
+ # if inferencing on model trained with supconH loss
142
+ model = LayerNormNet(512, 128, device, dtype)
143
+
144
+ if pretrained:
145
+ try:
146
+ checkpoint = torch.load('./data/pretrained/'+ train_data +'.pth', map_location=device)
147
+ except FileNotFoundError as error:
148
+ raise Exception('No pretrained weights for this training data')
149
+ else:
150
+ try:
151
+ checkpoint = torch.load('./data/model/'+ model_name +'.pth', map_location=device)
152
+ except FileNotFoundError as error:
153
+ raise Exception('No model found!')
154
+
155
+ model.load_state_dict(checkpoint)
156
+ model.eval()
157
+ # load precomputed EC cluster center embeddings if possible
158
+ if train_data == "split70":
159
+ emb_train = torch.load('./data/pretrained/70.pt', map_location=device)
160
+ elif train_data == "split100":
161
+ emb_train = torch.load('./data/pretrained/100.pt', map_location=device)
162
+ else:
163
+ emb_train = model(esm_embedding(ec_id_dict_train, device, dtype))
164
+
165
+ emb_test = model_embedding_test(id_ec_test, model, device, dtype)
166
+ eval_dist = get_dist_map_test(emb_train, emb_test, ec_id_dict_train, id_ec_test, device, dtype)
167
+ seed_everything()
168
+ eval_df = pd.DataFrame.from_dict(eval_dist)
169
+ ensure_dirs("./results")
170
+ out_filename = "results/" + test_data + "test_idx"
171
+ if test_idx is None:
172
+ idx = [i for i in range(len(id_ec_test))]
173
+ write_max_sep_choices_test(eval_df, out_filename, gmm=gmm, test_idx=idx)
174
+ else:
175
+ write_max_sep_choices_test(eval_df, out_filename, gmm=gmm, test_idx=test_idx)
176
+ if report_metrics:
177
+ pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
178
+ pred_probs = get_pred_probs(out_filename, pred_type='_maxsep')
179
+ true_label, all_label = get_true_labels_test('./data/' + test_data, test_idx=test_idx if test_idx is not None else None)
180
+ pre, rec, f1, roc, acc = get_eval_metrics(
181
+ pred_label, pred_probs, true_label, all_label)
182
+ print("############ EC calling results using maximum separation on randomly shuffled test set ############")
183
+ print('-' * 75)
184
+ print(f'>>> total samples: {len(true_label)} | total ec: {len(all_label)} \n'
185
+ f'>>> precision: {pre:.3} | recall: {rec:.3}'
186
+ f'| F1: {f1:.3} | AUC: {roc:.3} ')
187
+ print('-' * 75)
188
+
189
+
190
+ def write_max_sep_choices_test(df, csv_name, test_idx, first_grad=True, use_max_grad=False, gmm = None):
191
+ out_file = open(csv_name + '_maxsep.csv', 'w', newline='')
192
+ csvwriter = csv.writer(out_file, delimiter=',')
193
+ all_test_EC = set()
194
+ for col in df.iloc[:, test_idx].columns:
195
+ ec = []
196
+ smallest_10_dist_df = df[col].nsmallest(10)
197
+ dist_lst = list(smallest_10_dist_df)
198
+ max_sep_i = maximum_separation(dist_lst, first_grad, use_max_grad)
199
+ for i in range(max_sep_i+1):
200
+ EC_i = smallest_10_dist_df.index[i]
201
+ dist_i = smallest_10_dist_df[i]
202
+ if gmm != None:
203
+ gmm_lst = pickle.load(open(gmm, 'rb'))
204
+ dist_i = infer_confidence_gmm(dist_i, gmm_lst)
205
+ dist_str = "{:.4f}".format(dist_i)
206
+ all_test_EC.add(EC_i)
207
+ ec.append('EC:' + str(EC_i) + '/' + dist_str)
208
+ ec.insert(0, col)
209
+ csvwriter.writerow(ec)
210
+ return
211
+
212
+
clean_selection/get_clean_dists.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff97c21ab385c419a420d0ecbad357e97f4914cb2f94cd19aba810cecc829e93
3
- size 5936
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aedd07adb2c5b2e7a599c94faceb57d292444f3b72f270bba83e26c42290492b
3
+ size 9023