ronboger commited on
Commit
c993f80
·
1 Parent(s): a22ca2c

refactoring + sva

Browse files
.gitignore CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
 
1
+ scratch/
2
+ # data/
3
+ data/*.npy
4
+ data/*.tsv
5
+ data/*.pkl
6
+
7
  # Byte-compiled / optimized / DLL files
8
  __pycache__/
9
  *.py[cod]
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Protein conformal retrieval
2
+
3
+ Code and notebooks from 2024 paper on conformal retrieval with proteins
4
+
5
+ ## Installation
6
+
7
+ `pip install -e .`
8
+
9
+ ## Structure
10
+
11
+ `./protein_conformal`: utility functions to creating confidence sets and assigning probabilities to any protein machine learning model for search
12
+ `./scope`: experiments pertraining to SCOPe
13
+ `./pfam`: notebooks demonstrating how to use our techniques to calibrate false discovery and false negative rates for different pfam classes
14
+ `./ec`: experiments pertraining to EC number classification on uniprot
15
+ `./data`: scripts and notebooks used to process data
16
+ `./clean_selection`: scripts and notebooks used to process data
clean_selection/analyze_clean_hierarchical_loss_protein_vec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:17792a81e3b66606738ca77aa001077daa7f5c0063b3f956902996871561817c
3
- size 86942
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:732b364272b85d567d6bb29dc0b1df8dcbd2afcc31a578a5e650e2a689ed9e55
3
+ size 86932
clean_selection/clean_utils.py CHANGED
@@ -39,8 +39,8 @@ def get_true_labels_test(file_name, test_idx: None):
39
  return true_label, all_label
40
 
41
 
42
- def infer_conformal(train_data, test_data, thresh, report_metrics = False,
43
- pretrained=True, model_name=None, test_idx=None, name_id="1"):
44
  use_cuda = torch.cuda.is_available()
45
  device = torch.device("cuda:0" if use_cuda else "cpu")
46
  dtype = torch.float32
 
39
  return true_label, all_label
40
 
41
 
42
+ def infer_conformal(train_data, test_data, thresh, report_metrics=False,
43
+ pretrained=True, model_name=None, test_idx=None, name_id="1"):
44
  use_cuda = torch.cuda.is_available()
45
  device = torch.device("cuda:0" if use_cuda else "cpu")
46
  dtype = torch.float32
clean_selection/process_clean_ec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6a7b8853fa3df8fe380cea2391bdace7b76605c3e80770b397c1ba5776aee6c
3
- size 5860
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbd2453a7cca1d1420476aa527873c065cff16e25af7296bfb18e1ba510ed7b8
3
+ size 5853
create_learn_then_test.ipynb → data/create_learn_then_test.ipynb RENAMED
File without changes
create_new_learn_then_test.ipynb → data/create_new_learn_then_test.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26cc7e1c33f21c4a1a66f0997d6f1a1b7dad7b5ada452e7742554d15385fa103
3
- size 10022
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c321219b09969d691f09fad1c8cab75e0f4735d811c32871e69ee7e074ca8f10
3
+ size 34530
Interactive-1.ipynb → data/create_smaller_ltt.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ad0aca44189d80ee7752b2fa0509d027d12e7d18d47bfe9737a50131054c1d96
3
- size 402714
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25f446c429ebf8cc67fa2bd395ac9fb3b363f24f01eef43e6213c38a42cc7e3f
3
+ size 36171
ec/analyze_ec_hierarchical_loss_protein_vec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:136cfeed57224aa55620ad3ad776c89de7a7b24fd2012cc20fc2b1c915ebe6ce
3
- size 87826
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6af5a8524e4b92b94ceea5a925a4621760194f95639ddb3d964c39322ab14f90
3
+ size 87812
ec/process_pfam_ec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8762050ecbdc532dd9c962d8dc28ad54a6aa8fafb5058b99c1e2aab19ea594f8
3
- size 114148
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a10ed21e5ed16e2de4871a50c53bf32cb0ea104c8f97b92a9b39970b7b2aece
3
+ size 114134
analyze_protein_vec_results.ipynb → pfam/analyze_protein_vec_results.ipynb RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:10e65ba7a87a4a3bfebbd30007f0b8ed4173ac59f6d47e14cf8c850f1910a021
3
- size 13863852
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51a908707ef800db0c8be4aae8adc45f441123b511da5c9d272f88a48de2f607
3
+ size 399639
pfam/multidomain_search.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fa68613561b4b7386628dd78f5f06b655cdc69bc493a517b79e92669d909a83
3
+ size 2222
protein_conformal/{scope_utils.py → util.py} RENAMED
@@ -1,23 +1,32 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
- import pdb
4
- import ipdb
5
  from sklearn.isotonic import IsotonicRegression
6
- import seaborn as sns
7
  from scipy.stats import binom, norm
8
 
 
 
9
  def get_sims_labels(data, partial=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  sims = []
11
  labels = []
12
  for query in data:
13
- similarity = query['S_i']
14
  sims += similarity.tolist()
15
  if partial:
16
- #labels_to_append = np.logical_or.reduce(query['partial'], axis=1).tolist()
17
- # NOTE: no need to do the above for scope - i already handle the pre-processing
18
- labels_to_append = query['partial']
19
  else:
20
- labels_to_append = query['exact']
21
  labels += labels_to_append
22
  return sims, labels
23
 
@@ -29,67 +38,81 @@ def get_arbitrary_attribute(data, attribute: str):
29
  attributes += attribute.tolist()
30
  return attributes
31
 
32
-
 
33
  def get_thresh(data, alpha):
34
  # conformal risk control
35
  all_sim_exact = []
36
  for query in data:
37
- idx = query['exact']
38
- similarity = query['S_i']
39
  sims_to_append = similarity[idx]
40
  all_sim_exact += list(sims_to_append)
41
  n = len(all_sim_exact)
42
  if n > 0:
43
- lhat = np.quantile(all_sim_exact, np.maximum(alpha-(1-alpha)/n, 0), interpolation='lower')
 
 
 
 
44
  else:
45
  lhat = 0
46
  return lhat
47
 
 
48
  # Bentkus p value
49
- def bentkus_p_value(r_hat,n,alpha):
50
- return binom.cdf(np.ceil(n*r_hat),n,alpha/np.e)
 
51
 
52
  # def clt_p_value(r_hat,n,alpha):
 
53
 
54
  def clt_p_value(r_hat, std_hat, n, alpha):
55
  z = (r_hat - alpha) / (std_hat / np.sqrt(n))
56
  p_value = norm.cdf(z)
57
  return p_value
58
-
 
59
  def percentage_of_discoveries(sims, labels, lam):
60
  # FDR: Number of false matches / number of matches
61
  total_discoveries = (sims >= lam).sum(axis=1)
62
- return total_discoveries.mean() / len(labels) # or sims.shape[1]
 
63
 
64
  def risk(sims, labels, lam):
65
  # FDR: Number of false matches / number of matches
66
  total_discoveries = (sims >= lam).sum(axis=1)
67
- false_discoveries = ((1-labels)*(sims >= lam)).sum(axis=1)
68
  total_discoveries = np.maximum(total_discoveries, 1)
69
- return (false_discoveries/total_discoveries).mean()
 
70
 
71
  def calculate_false_negatives(sims, labels, lam):
72
  # FNR: Number of false non-matches / number of non-matches
73
  total_non_matches = labels.sum(axis=1)
74
  false_non_matches = (labels & (sims < lam)).sum(axis=1)
75
  total_non_matches = np.maximum(total_non_matches, 1)
76
- return (false_non_matches/total_non_matches).mean()
 
77
 
78
  def risk_no_empties(sims, labels, lam):
79
  # FDR: Number of false matches / number of matches
80
  total_discoveries = (sims >= lam).sum(axis=1)
81
- false_discoveries = ((1-labels)*(sims >= lam)).sum(axis=1)
82
  idx = total_discoveries > 0
83
  total_discoveries = total_discoveries[idx]
84
  false_discoveries = false_discoveries[idx]
85
- return (false_discoveries/total_discoveries).mean()
 
86
 
87
  def std_loss(sims, labels, lam):
88
  # FDR: Number of false matches / number of matches
89
  total_discoveries = (sims >= lam).sum(axis=1)
90
- false_discoveries = ((1-labels)*(sims >= lam)).sum(axis=1)
91
  total_discoveries = np.maximum(total_discoveries, 1)
92
- return (false_discoveries/total_discoveries).std()
 
93
 
94
  def get_thresh_FDR(labels, sims, alpha, delta=0.5, N=5000):
95
  # FDR control with LTT
@@ -97,114 +120,66 @@ def get_thresh_FDR(labels, sims, alpha, delta=0.5, N=5000):
97
  # sims = np.stack([query['S_i'] for query in data], axis=0)
98
  print(f"sims.max: {sims.max()}")
99
  n = len(labels)
100
- lambdas = np.linspace(sims.min(),sims.max(),N)
101
- risks = np.array( [risk(sims, labels, lam) for lam in lambdas] )
102
- stds = np.array( [std_loss(sims, labels, lam) for lam in lambdas] )
103
- #pvals = np.array( [bentkus_p_value(r,n,alpha) for r in risks] )
104
- pvals = np.array( [clt_p_value(r,s,n,alpha) for r, s in zip(risks, stds)] )
105
  below = pvals <= delta
106
  # Pick the smallest lambda such that all lambda above it have p-value below delta
107
- pvals_satisfy_condition = np.array([ np.all(below[i:])for i in range(N) ])
108
  lhat = lambdas[np.argmax(pvals_satisfy_condition)]
109
  print(f"lhat: {lhat}")
110
  print(f"risk: {risk(sims, labels, lhat)}")
111
  return lhat
112
 
113
- def get_isotone_regression(data):
114
- sims, labels = get_sims_labels(data, partial=True)
115
- ir = IsotonicRegression(out_of_bounds='clip')
 
 
116
  ir.fit(sims, labels)
117
  return ir
118
 
119
- def scope_hierarchical_loss(y_sccs, y_hat_sccs):
120
- """
121
- Find the common ancestor of two sets of SCCs (0 if family, 1 if superfamily, 2 if fold, 3 if class)
122
- """
123
- y_sccs, y_hat_sccs = y_sccs.split('.'), y_hat_sccs.split('.')
124
- first_non_matching_index = next((i for i, (x, y) in enumerate(zip(y_sccs, y_hat_sccs)) if x != y), len(y_sccs))
125
-
126
- loss = len(y_sccs) - first_non_matching_index # ex if the first mismatch is at idx 2 (0-indexed), that means that the second last label (superfamily) is wrong, which is a loss of two (wrong family & superfamily)
127
- exact = True if len(y_sccs) == first_non_matching_index else False
128
-
129
- return loss, exact
130
-
131
- """
132
- def scope_hierarchical_loss(y_sccs, y_hat_sccs, slack = 0):
133
-
134
- #Find the common ancestor of two sets of SCCs (0 if family, 1 if superfamily, 2 if fold, 3 if class)
135
-
136
- # Find the common ancestor of the two sets of SCCs
137
- y_sccs, y_hat_sccs = y_sccs.split('.'), y_hat_sccs.split('.')
138
- assert len(y_sccs) == len(y_hat_sccs) == 4
139
-
140
- loss, count = None, 0
141
- while loss is None:
142
- if y_sccs[-1] != y_hat_sccs[-1]:
143
- count += 1
144
- y_sccs.pop()
145
- y_hat_sccs.pop()
146
- else:
147
- break
148
- loss = count - slack
149
-
150
- return loss == 0
151
- """
152
-
153
 
154
- def get_scope_dict(true_test_idcs, test_df, lookup_idcs, lookup_df, D, I):
 
155
  """
156
- true_test_idcs: indices of the test set within the scope dataframe
157
 
158
- test_df: dataframe containing the test set (indices are the same as true_test_idcs)
 
 
159
 
160
- lookup_idcs: indices of the lookup set within the larger scope dataframe
 
 
161
 
162
- lookup_df: dataframe containing the lookup set (indices are the same as lookup_idcs)
 
 
 
 
163
 
164
- D: distances matrix (400 x 14777 or test x lookup by default)
 
 
 
165
 
166
- I: indices matrix (400 x 14777 or test x lookup by default)
 
167
 
168
- NOTE: Indices computed by FAISS are not the same as the indices of the dataframe, so
169
- we use the lookup_idcs list to map FAISS indices in I to the indices of the dataframe
170
- """
171
 
172
- near_ids = []
173
- min_sim = np.min(D)
174
- max_sim = np.max(D)
175
 
176
- for i in range(len(true_test_idcs)):
177
- test_id = test_df.loc[true_test_idcs[i], 'sid']
178
- test_sccs = test_df.loc[true_test_idcs[i], 'sccs']
179
- query_ids = [lookup_df.loc[lookup_idcs[j], 'sid'] for j in I[i]]
180
- exact_loss = [scope_hierarchical_loss(test_sccs, lookup_df.loc[lookup_idcs[j], 'sccs']) for j in I[i]]
181
- # grab the 2nd element in the tuple belonging to each element of exact_loss as mask_exact
182
- mask_exact = [x[1] for x in exact_loss]
183
- loss = [x[0] for x in exact_loss]
184
-
185
- # define mask_partial as 1 for any element of loss that is <=1 (tolerate retrieving homolog with diff family but same superfamily)
186
- mask_partial = [l <= 1 for l in loss]
187
 
188
- # create a row of size len(lookup_df) where each element is the sum of all entries in S_i until that index
189
- sum = np.cumsum(D[i])
190
- norm_sim = (D[i] - min_sim) / (max_sim - min_sim) # convert similarities into a probability space (0, 1) based on (min_sim, max_sim)
191
- #mask_exact = [test_sccs == lookup_df.loc[lookup_idcs[j], 'sccs'] for j in I[i]]
192
 
193
- sum_norm_s_i = np.cumsum(norm_sim)
194
- near_ids.append({
195
- 'test_id': test_id,
196
- 'query_ids': query_ids,
197
- #'meta_query': meta_query,
198
- 'loss' : loss,
199
- 'exact': mask_exact,
200
- 'partial': mask_partial,
201
- 'S_i': D[i],
202
- 'Sum_i' : sum,
203
- 'Norm_S_i' : norm_sim,
204
- 'Sum_Norm_S_i': sum_norm_s_i,
205
- 'I_i': I[i]
206
- })
207
- return near_ids
208
 
209
  def validate_lhat(data, lhat):
210
  total_missed = 0
@@ -215,16 +190,16 @@ def validate_lhat(data, lhat):
215
  total_partial = 0
216
  total_partial_identified = 0
217
  for query in data:
218
- idx = query['exact']
219
  # if partial has multiple rows, we want to take the logical or of all of them. Otherwise just set it to the single row
220
  # check if there is one or more rows
221
  # query['partial'] = np.array(query['partial'])
222
- if len(np.array(query['partial']).shape) > 1:
223
- idx_partial = np.logical_or.reduce(query['partial'], axis=1)
224
  else:
225
- idx_partial = query['partial']
226
-
227
- sims = query['S_i']
228
  sims_exact = sims[idx]
229
  sims_partial = sims[idx_partial]
230
  total_missed += (sims_exact < lhat).sum()
@@ -237,9 +212,14 @@ def validate_lhat(data, lhat):
237
  total_exact += len(sims_exact)
238
  total_inexact_identified += (sims[~np.array(idx)] >= lhat).sum()
239
  total_identified += (sims >= lhat).sum()
240
- return total_missed/total_exact, total_inexact_identified/total_identified, total_missed_partial/total_partial, total_partial_identified/total_identified
241
-
242
-
 
 
 
 
 
243
  def load_database(lookup_database):
244
  """
245
  lookup_database: NxM matrix of embeddings
@@ -255,11 +235,39 @@ def load_database(lookup_database):
255
 
256
  return index
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  def get_thresh_hierarchical(data, lambdas, alpha):
259
  # get the worst case loss
260
- wc_loss = max([np.sum(x['loss']) for x in data])
261
  wc_loss = wc_loss / 14777
262
- loss_thresh = alpha - (wc_loss - alpha)/len(data) # normalize by size of calib set
 
 
263
  losses = []
264
  best_lam = None
265
  for lam in lambdas:
@@ -275,28 +283,37 @@ def get_thresh_hierarchical(data, lambdas, alpha):
275
 
276
  return best_lam, losses
277
 
 
278
  def get_hierarchical_loss(data_, lambda_):
279
  losses = []
280
  for query in data_:
281
- thresh_idx = query['Sum_Norm_S_i'] <= lambda_
282
  if np.sum(thresh_idx) == 0:
283
  loss = 0
284
  else:
285
- loss = np.sum(np.asarray(query['loss'])[thresh_idx]) / np.sum(thresh_idx) # NOTE: have fixed denominator, but alpha has to change
286
- losses.append(loss) # average over all queries
 
 
287
  return np.mean(losses)
288
 
289
 
290
  def get_thresh_max_hierarchical(data, lambdas, alpha, sim="cosine"):
291
  # get the worst case loss
292
- wc_loss = 4 # in the max case, the max_loss is simply retrieving a protein with different class
293
- loss_thresh = alpha - (wc_loss - alpha)/len(data) # normalize by size of calib set
 
 
294
  losses = []
295
  best_lam = None
296
  if sim == "cosine":
297
  ## reverse lambdas and return list
298
  lambdas = list(reversed(lambdas))
299
- for lam in lambdas: # start from the largest lambda since we are dealing with raw similarity scores
 
 
 
 
300
  per_lam_loss = get_hierarchical_max_loss(data, lam, sim=sim)
301
  if per_lam_loss > loss_thresh:
302
  break
@@ -309,42 +326,119 @@ def get_thresh_max_hierarchical(data, lambdas, alpha, sim="cosine"):
309
 
310
  return best_lam, losses
311
 
 
312
  def get_hierarchical_max_loss(data_, lambda_, sim="cosine"):
313
  losses = []
314
  for query in data_:
315
  if sim == "cosine":
316
- thresh_idx = query['S_i'] >= lambda_
317
  else:
318
- thresh_idx = query['S_i'] <= lambda_
319
  if np.sum(thresh_idx) == 0:
320
  loss = 0
321
  else:
322
- loss = np.max(np.asarray(query['loss'])[thresh_idx]) # monotonic loss
323
- #loss = np.sum(np.asarray(query['loss'])[thresh_idx]) / np.sum(thresh_idx) # NOTE: have fixed denominator, but alpha has to change
324
- losses.append(loss) # average over all queries
325
  return np.mean(losses)
326
 
 
 
 
327
 
328
- def query(index, queries, k=10):
329
- # On the fly FAISS import to prevent installation issues
330
- import faiss
331
 
332
- # Search indexed database
333
- faiss.normalize_L2(queries)
334
- D, I = index.search(queries, k)
335
 
336
- return (D, I)
337
 
338
- def build_scope_tree(list_sccs):
 
 
 
 
 
339
  """
340
- Build a scope tree from a list of SCCs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  """
342
- tree = {}
343
- for sccs in list_sccs:
344
- sccs = sccs.split('.')
345
- node = tree
346
- for s in sccs:
347
- if s not in node:
348
- node[s] = {}
349
- node = node[s]
350
- return tree
 
 
 
 
 
 
 
 
 
 
1
  from sklearn.isotonic import IsotonicRegression
2
+ import numpy as np
3
  from scipy.stats import binom, norm
4
 
5
+
6
+ # get similarity scores and labels for a set of proteins
7
  def get_sims_labels(data, partial=False):
8
+ """
9
+ Get the similarities and labels from the given data.
10
+
11
+ Parameters:
12
+ - data: A list of query data.
13
+ - partial: A boolean indicating whether to use partial hits or exact hits.
14
+ exact: Pfam1001 == Pfam1001
15
+ partial: Pfam1001 in [Pfam1001,Pfam1002], where [Pfam1001,Pfam1002] is the set of Pfam domains in a protein.
16
+
17
+ Returns:
18
+ - sims: A list of similarity scores.
19
+ - labels: A list of labels.
20
+ """
21
  sims = []
22
  labels = []
23
  for query in data:
24
+ similarity = query["S_i"]
25
  sims += similarity.tolist()
26
  if partial:
27
+ labels_to_append = np.logical_or.reduce(query["partial"], axis=1).tolist()
 
 
28
  else:
29
+ labels_to_append = query["exact"]
30
  labels += labels_to_append
31
  return sims, labels
32
 
 
38
  attributes += attribute.tolist()
39
  return attributes
40
 
41
+
42
+
43
  def get_thresh(data, alpha):
44
  # conformal risk control
45
  all_sim_exact = []
46
  for query in data:
47
+ idx = query["exact"]
48
+ similarity = query["S_i"]
49
  sims_to_append = similarity[idx]
50
  all_sim_exact += list(sims_to_append)
51
  n = len(all_sim_exact)
52
  if n > 0:
53
+ lhat = np.quantile(
54
+ all_sim_exact,
55
+ np.maximum(alpha - (1 - alpha) / n, 0),
56
+ interpolation="lower",
57
+ )
58
  else:
59
  lhat = 0
60
  return lhat
61
 
62
+
63
  # Bentkus p value
64
+ def bentkus_p_value(r_hat, n, alpha):
65
+ return binom.cdf(np.ceil(n * r_hat), n, alpha / np.e)
66
+
67
 
68
  # def clt_p_value(r_hat,n,alpha):
69
+ # TODO: we may have wanted to do a different implementation of this
70
 
71
  def clt_p_value(r_hat, std_hat, n, alpha):
72
  z = (r_hat - alpha) / (std_hat / np.sqrt(n))
73
  p_value = norm.cdf(z)
74
  return p_value
75
+
76
+
77
  def percentage_of_discoveries(sims, labels, lam):
78
  # FDR: Number of false matches / number of matches
79
  total_discoveries = (sims >= lam).sum(axis=1)
80
+ return total_discoveries.mean() / len(labels) # or sims.shape[1]
81
+
82
 
83
  def risk(sims, labels, lam):
84
  # FDR: Number of false matches / number of matches
85
  total_discoveries = (sims >= lam).sum(axis=1)
86
+ false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
87
  total_discoveries = np.maximum(total_discoveries, 1)
88
+ return (false_discoveries / total_discoveries).mean()
89
+
90
 
91
  def calculate_false_negatives(sims, labels, lam):
92
  # FNR: Number of false non-matches / number of non-matches
93
  total_non_matches = labels.sum(axis=1)
94
  false_non_matches = (labels & (sims < lam)).sum(axis=1)
95
  total_non_matches = np.maximum(total_non_matches, 1)
96
+ return (false_non_matches / total_non_matches).mean()
97
+
98
 
99
  def risk_no_empties(sims, labels, lam):
100
  # FDR: Number of false matches / number of matches
101
  total_discoveries = (sims >= lam).sum(axis=1)
102
+ false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
103
  idx = total_discoveries > 0
104
  total_discoveries = total_discoveries[idx]
105
  false_discoveries = false_discoveries[idx]
106
+ return (false_discoveries / total_discoveries).mean()
107
+
108
 
109
  def std_loss(sims, labels, lam):
110
  # FDR: Number of false matches / number of matches
111
  total_discoveries = (sims >= lam).sum(axis=1)
112
+ false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
113
  total_discoveries = np.maximum(total_discoveries, 1)
114
+ return (false_discoveries / total_discoveries).std()
115
+
116
 
117
  def get_thresh_FDR(labels, sims, alpha, delta=0.5, N=5000):
118
  # FDR control with LTT
 
120
  # sims = np.stack([query['S_i'] for query in data], axis=0)
121
  print(f"sims.max: {sims.max()}")
122
  n = len(labels)
123
+ lambdas = np.linspace(sims.min(), sims.max(), N)
124
+ risks = np.array([risk(sims, labels, lam) for lam in lambdas])
125
+ stds = np.array([std_loss(sims, labels, lam) for lam in lambdas])
126
+ # pvals = np.array( [bentkus_p_value(r,n,alpha) for r in risks] )
127
+ pvals = np.array([clt_p_value(r, s, n, alpha) for r, s in zip(risks, stds)])
128
  below = pvals <= delta
129
  # Pick the smallest lambda such that all lambda above it have p-value below delta
130
+ pvals_satisfy_condition = np.array([np.all(below[i:]) for i in range(N)])
131
  lhat = lambdas[np.argmax(pvals_satisfy_condition)]
132
  print(f"lhat: {lhat}")
133
  print(f"risk: {risk(sims, labels, lhat)}")
134
  return lhat
135
 
136
+
137
+ # get the isotonic regression
138
+ def get_isotone_regression(data, partial=False):
139
+ sims, labels = get_sims_labels(data, partial=partial)
140
+ ir = IsotonicRegression(out_of_bounds="clip")
141
  ir.fit(sims, labels)
142
  return ir
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Simplified version of Venn Abers prediction
146
+ def simplifed_venn_abers_prediction(calib_data, test_data_point):
147
  """
148
+ Perform simplified Venn Abers prediction.
149
 
150
+ Args:
151
+ calib_data (list): List of calibration data points.
152
+ test_data_point: The test data point to be predicted.
153
 
154
+ Returns:
155
+ Tuple: A tuple containing the predicted probabilities for two isotonic regressions.
156
+ """
157
 
158
+ sims, labels = get_sims_labels(calib_data, partial=False)
159
+ print(sims)
160
+ print(labels)
161
+ print(len(sims))
162
+ print(len(labels))
163
 
164
+ sims.append(test_data_point)
165
+ labels.append(True)
166
+ print(len(sims))
167
+ print(len(labels))
168
 
169
+ ir_0 = IsotonicRegression(out_of_bounds="clip")
170
+ ir_1 = IsotonicRegression(out_of_bounds="clip")
171
 
172
+ ir_0.fit(sims, labels)
 
 
173
 
174
+ # run the second isotonic regression with the last point as a false label
175
+ labels[-1] = False
176
+ ir_1.fit(sims, labels)
177
 
178
+ p_0 = ir_0.predict([test_data_point])[0]
179
+ p_1 = ir_1.predict([test_data_point])[0]
 
 
 
 
 
 
 
 
 
180
 
181
+ return p_0, p_1
 
 
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  def validate_lhat(data, lhat):
185
  total_missed = 0
 
190
  total_partial = 0
191
  total_partial_identified = 0
192
  for query in data:
193
+ idx = query["exact"]
194
  # if partial has multiple rows, we want to take the logical or of all of them. Otherwise just set it to the single row
195
  # check if there is one or more rows
196
  # query['partial'] = np.array(query['partial'])
197
+ if len(np.array(query["partial"]).shape) > 1:
198
+ idx_partial = np.logical_or.reduce(query["partial"], axis=1)
199
  else:
200
+ idx_partial = query["partial"]
201
+
202
+ sims = query["S_i"]
203
  sims_exact = sims[idx]
204
  sims_partial = sims[idx_partial]
205
  total_missed += (sims_exact < lhat).sum()
 
212
  total_exact += len(sims_exact)
213
  total_inexact_identified += (sims[~np.array(idx)] >= lhat).sum()
214
  total_identified += (sims >= lhat).sum()
215
+ return (
216
+ total_missed / total_exact,
217
+ total_inexact_identified / total_identified,
218
+ total_missed_partial / total_partial,
219
+ total_partial_identified / total_identified,
220
+ )
221
+
222
+ ### FAISS related functions
223
  def load_database(lookup_database):
224
  """
225
  lookup_database: NxM matrix of embeddings
 
235
 
236
  return index
237
 
238
+
239
+ def query(index, queries, k=10):
240
+ # On the fly FAISS import to prevent installation issues
241
+ import faiss
242
+
243
+ # Search indexed database
244
+ faiss.normalize_L2(queries)
245
+ D, I = index.search(queries, k)
246
+
247
+ return (D, I)
248
+
249
+ ### functions for hierarchical conformal
250
+ def build_scope_tree(list_sccs):
251
+ """
252
+ Build a scope tree from a list of SCCs
253
+ """
254
+ tree = {}
255
+ for sccs in list_sccs:
256
+ sccs = sccs.split(".")
257
+ node = tree
258
+ for s in sccs:
259
+ if s not in node:
260
+ node[s] = {}
261
+ node = node[s]
262
+ return tree
263
+
264
  def get_thresh_hierarchical(data, lambdas, alpha):
265
  # get the worst case loss
266
+ wc_loss = max([np.sum(x["loss"]) for x in data])
267
  wc_loss = wc_loss / 14777
268
+ loss_thresh = alpha - (wc_loss - alpha) / len(
269
+ data
270
+ ) # normalize by size of calib set
271
  losses = []
272
  best_lam = None
273
  for lam in lambdas:
 
283
 
284
  return best_lam, losses
285
 
286
+
287
  def get_hierarchical_loss(data_, lambda_):
288
  losses = []
289
  for query in data_:
290
+ thresh_idx = query["Sum_Norm_S_i"] <= lambda_
291
  if np.sum(thresh_idx) == 0:
292
  loss = 0
293
  else:
294
+ loss = np.sum(np.asarray(query["loss"])[thresh_idx]) / np.sum(
295
+ thresh_idx
296
+ ) # NOTE: have fixed denominator, but alpha has to change
297
+ losses.append(loss) # average over all queries
298
  return np.mean(losses)
299
 
300
 
301
  def get_thresh_max_hierarchical(data, lambdas, alpha, sim="cosine"):
302
  # get the worst case loss
303
+ wc_loss = 4 # in the max case, the max_loss is simply retrieving a protein with different class
304
+ loss_thresh = alpha - (wc_loss - alpha) / len(
305
+ data
306
+ ) # normalize by size of calib set
307
  losses = []
308
  best_lam = None
309
  if sim == "cosine":
310
  ## reverse lambdas and return list
311
  lambdas = list(reversed(lambdas))
312
+ for (
313
+ lam
314
+ ) in (
315
+ lambdas
316
+ ): # start from the largest lambda since we are dealing with raw similarity scores
317
  per_lam_loss = get_hierarchical_max_loss(data, lam, sim=sim)
318
  if per_lam_loss > loss_thresh:
319
  break
 
326
 
327
  return best_lam, losses
328
 
329
+
330
  def get_hierarchical_max_loss(data_, lambda_, sim="cosine"):
331
  losses = []
332
  for query in data_:
333
  if sim == "cosine":
334
+ thresh_idx = query["S_i"] >= lambda_
335
  else:
336
+ thresh_idx = query["S_i"] <= lambda_
337
  if np.sum(thresh_idx) == 0:
338
  loss = 0
339
  else:
340
+ loss = np.max(np.asarray(query["loss"])[thresh_idx]) # monotonic loss
341
+ # loss = np.sum(np.asarray(query['loss'])[thresh_idx]) / np.sum(thresh_idx) # NOTE: have fixed denominator, but alpha has to change
342
+ losses.append(loss) # average over all queries
343
  return np.mean(losses)
344
 
345
+ def get_scope_dict(true_test_idcs, test_df, lookup_idcs, lookup_df, D, I):
346
+ """
347
+ true_test_idcs: indices of the test set within the scope dataframe
348
 
349
+ test_df: dataframe containing the test set (indices are the same as true_test_idcs)
 
 
350
 
351
+ lookup_idcs: indices of the lookup set within the larger scope dataframe
 
 
352
 
353
+ lookup_df: dataframe containing the lookup set (indices are the same as lookup_idcs)
354
 
355
+ D: distances matrix (400 x 14777 or test x lookup by default)
356
+
357
+ I: indices matrix (400 x 14777 or test x lookup by default)
358
+
359
+ NOTE: Indices computed by FAISS are not the same as the indices of the dataframe, so
360
+ we use the lookup_idcs list to map FAISS indices in I to the indices of the dataframe
361
  """
362
+
363
+ near_ids = []
364
+ min_sim = np.min(D)
365
+ max_sim = np.max(D)
366
+
367
+ for i in range(len(true_test_idcs)):
368
+ test_id = test_df.loc[true_test_idcs[i], "sid"]
369
+ test_sccs = test_df.loc[true_test_idcs[i], "sccs"]
370
+ query_ids = [lookup_df.loc[lookup_idcs[j], "sid"] for j in I[i]]
371
+ exact_loss = [
372
+ scope_hierarchical_loss(test_sccs, lookup_df.loc[lookup_idcs[j], "sccs"])
373
+ for j in I[i]
374
+ ]
375
+ # grab the 2nd element in the tuple belonging to each element of exact_loss as mask_exact
376
+ mask_exact = [x[1] for x in exact_loss]
377
+ loss = [x[0] for x in exact_loss]
378
+
379
+ # define mask_partial as 1 for any element of loss that is <=1 (tolerate retrieving homolog with diff family but same superfamily)
380
+ mask_partial = [l <= 1 for l in loss]
381
+
382
+ # create a row of size len(lookup_df) where each element is the sum of all entries in S_i until that index
383
+ sum = np.cumsum(D[i])
384
+ norm_sim = (D[i] - min_sim) / (
385
+ max_sim - min_sim
386
+ ) # convert similarities into a probability space (0, 1) based on (min_sim, max_sim)
387
+ # mask_exact = [test_sccs == lookup_df.loc[lookup_idcs[j], 'sccs'] for j in I[i]]
388
+
389
+ sum_norm_s_i = np.cumsum(norm_sim)
390
+ near_ids.append(
391
+ {
392
+ "test_id": test_id,
393
+ "query_ids": query_ids,
394
+ #'meta_query': meta_query,
395
+ "loss": loss,
396
+ "exact": mask_exact,
397
+ "partial": mask_partial,
398
+ "S_i": D[i],
399
+ "Sum_i": sum,
400
+ "Norm_S_i": norm_sim,
401
+ "Sum_Norm_S_i": sum_norm_s_i,
402
+ "I_i": I[i],
403
+ }
404
+ )
405
+ return near_ids
406
+
407
+ """
408
+ def scope_hierarchical_loss(y_sccs, y_hat_sccs, slack = 0):
409
+
410
+ #Find the common ancestor of two sets of SCCs (0 if family, 1 if superfamily, 2 if fold, 3 if class)
411
+
412
+ # Find the common ancestor of the two sets of SCCs
413
+ y_sccs, y_hat_sccs = y_sccs.split('.'), y_hat_sccs.split('.')
414
+ assert len(y_sccs) == len(y_hat_sccs) == 4
415
+
416
+ loss, count = None, 0
417
+ while loss is None:
418
+ if y_sccs[-1] != y_hat_sccs[-1]:
419
+ count += 1
420
+ y_sccs.pop()
421
+ y_hat_sccs.pop()
422
+ else:
423
+ break
424
+ loss = count - slack
425
+
426
+ return loss == 0
427
+ """
428
+
429
+ def scope_hierarchical_loss(y_sccs, y_hat_sccs):
430
  """
431
+ Find the common ancestor of two sets of SCCs (0 if family, 1 if superfamily, 2 if fold, 3 if class)
432
+ """
433
+ y_sccs, y_hat_sccs = y_sccs.split("."), y_hat_sccs.split(".")
434
+ first_non_matching_index = next(
435
+ (i for i, (x, y) in enumerate(zip(y_sccs, y_hat_sccs)) if x != y), len(y_sccs)
436
+ )
437
+
438
+ loss = (
439
+ len(y_sccs) - first_non_matching_index
440
+ ) # ex if the first mismatch is at idx 2 (0-indexed), that means that the second last label (superfamily) is wrong, which is a loss of two (wrong family & superfamily)
441
+ exact = True if len(y_sccs) == first_non_matching_index else False
442
+
443
+ return loss, exact
444
+
scope/analyze_scope_hierarchical_loss_protein_vec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ec68b7560cd591d661a8f77f9c2df9354aa99b2bf350ad7b96803a9f9eacece
3
- size 91708
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05e1080b7854b1b02b59c8913509a599d5d65f249f15712aad9d6dc58e877092
3
+ size 91694
scope/analyze_scope_protein_vec.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5445b059ac49809729e487c50f60ed8ac2123f03126f56f77a7348313f26c11b
3
- size 449975
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15d00e9ddd6e3e23490a415f942065d9f485bac0d437f028eb400853aa75ffc2
3
+ size 449919
scope/test_scope_conformal_retrieval.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c8f4006f8108efcd9cb99a08cabde82ccfda250dc4807f7ff5fbe4a9aaccbe90
3
- size 3232293
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34d3c6c5df4cef9235c33fd0c73e80507f8ba533d495d5c1f1df39323d52cb21
3
+ size 3232279
util.py DELETED
@@ -1,183 +0,0 @@
1
- from sklearn.isotonic import IsotonicRegression
2
- import numpy as np
3
- from scipy.stats import binom, norm
4
-
5
-
6
- def get_sims_labels(data, partial=False):
7
- sims = []
8
- labels = []
9
- for query in data:
10
- similarity = query["S_i"]
11
- sims += similarity.tolist()
12
- if partial:
13
- labels_to_append = np.logical_or.reduce(query["partial"], axis=1).tolist()
14
- else:
15
- labels_to_append = query["exact"]
16
- labels += labels_to_append
17
- return sims, labels
18
-
19
-
20
- def get_thresh(data, alpha):
21
- # conformal risk control
22
- all_sim_exact = []
23
- for query in data:
24
- idx = query["exact"]
25
- similarity = query["S_i"]
26
- sims_to_append = similarity[idx]
27
- all_sim_exact += list(sims_to_append)
28
- n = len(all_sim_exact)
29
- if n > 0:
30
- lhat = np.quantile(
31
- all_sim_exact,
32
- np.maximum(alpha - (1 - alpha) / n, 0),
33
- interpolation="lower",
34
- )
35
- else:
36
- lhat = 0
37
- return lhat
38
-
39
-
40
- # Bentkus p value
41
- def bentkus_p_value(r_hat, n, alpha):
42
- return binom.cdf(np.ceil(n * r_hat), n, alpha / np.e)
43
-
44
-
45
- # def clt_p_value(r_hat,n,alpha):
46
-
47
-
48
- def clt_p_value(r_hat, std_hat, n, alpha):
49
- z = (r_hat - alpha) / (std_hat / np.sqrt(n))
50
- p_value = norm.cdf(z)
51
- return p_value
52
-
53
-
54
- def percentage_of_discoveries(sims, labels, lam):
55
- # FDR: Number of false matches / number of matches
56
- total_discoveries = (sims >= lam).sum(axis=1)
57
- return total_discoveries.mean() / len(labels) # or sims.shape[1]
58
-
59
-
60
- def risk(sims, labels, lam):
61
- # FDR: Number of false matches / number of matches
62
- total_discoveries = (sims >= lam).sum(axis=1)
63
- false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
64
- total_discoveries = np.maximum(total_discoveries, 1)
65
- return (false_discoveries / total_discoveries).mean()
66
-
67
-
68
- def calculate_false_negatives(sims, labels, lam):
69
- # FNR: Number of false non-matches / number of non-matches
70
- total_non_matches = labels.sum(axis=1)
71
- false_non_matches = (labels & (sims < lam)).sum(axis=1)
72
- total_non_matches = np.maximum(total_non_matches, 1)
73
- return (false_non_matches / total_non_matches).mean()
74
-
75
-
76
- def risk_no_empties(sims, labels, lam):
77
- # FDR: Number of false matches / number of matches
78
- total_discoveries = (sims >= lam).sum(axis=1)
79
- false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
80
- idx = total_discoveries > 0
81
- total_discoveries = total_discoveries[idx]
82
- false_discoveries = false_discoveries[idx]
83
- return (false_discoveries / total_discoveries).mean()
84
-
85
-
86
- def std_loss(sims, labels, lam):
87
- # FDR: Number of false matches / number of matches
88
- total_discoveries = (sims >= lam).sum(axis=1)
89
- false_discoveries = ((1 - labels) * (sims >= lam)).sum(axis=1)
90
- total_discoveries = np.maximum(total_discoveries, 1)
91
- return (false_discoveries / total_discoveries).std()
92
-
93
-
94
- def get_thresh_FDR(labels, sims, alpha, delta=0.5, N=5000):
95
- # FDR control with LTT
96
- # labels = np.stack([query['exact'] for query in data], axis=0)
97
- # sims = np.stack([query['S_i'] for query in data], axis=0)
98
- print(f"sims.max: {sims.max()}")
99
- n = len(labels)
100
- lambdas = np.linspace(sims.min(), sims.max(), N)
101
- risks = np.array([risk(sims, labels, lam) for lam in lambdas])
102
- stds = np.array([std_loss(sims, labels, lam) for lam in lambdas])
103
- # pvals = np.array( [bentkus_p_value(r,n,alpha) for r in risks] )
104
- pvals = np.array([clt_p_value(r, s, n, alpha) for r, s in zip(risks, stds)])
105
- below = pvals <= delta
106
- # Pick the smallest lambda such that all lambda above it have p-value below delta
107
- pvals_satisfy_condition = np.array([np.all(below[i:]) for i in range(N)])
108
- lhat = lambdas[np.argmax(pvals_satisfy_condition)]
109
- print(f"lhat: {lhat}")
110
- print(f"risk: {risk(sims, labels, lhat)}")
111
- return lhat
112
-
113
-
114
- def get_isotone_regression(data):
115
- sims, labels = get_sims_labels(data, partial=True)
116
- ir = IsotonicRegression(out_of_bounds="clip")
117
- ir.fit(sims, labels)
118
- return ir
119
-
120
-
121
- def simplifed_venn_abers_prediction(calib_data, test_data_point):
122
- sims, labels = get_sims_labels(calib_data, partial=False)
123
- print(sims)
124
- print(labels)
125
- print(len(sims))
126
- print(len(labels))
127
-
128
- sims.append(test_data_point)
129
- labels.append(True)
130
- print(len(sims))
131
- print(len(labels))
132
-
133
- ir_0 = IsotonicRegression(out_of_bounds="clip")
134
- ir_1 = IsotonicRegression(out_of_bounds="clip")
135
-
136
- ir_0.fit(sims, labels)
137
-
138
- labels[-1] = False
139
- ir_1.fit(sims, labels)
140
-
141
- p_0 = ir_0.predict([test_data_point])[0]
142
- p_1 = ir_1.predict([test_data_point])[0]
143
-
144
- return p_0, p_1
145
-
146
-
147
- def validate_lhat(data, lhat):
148
- total_missed = 0
149
- total_missed_partial = 0
150
- total_exact = 0
151
- total_inexact_identified = 0
152
- total_identified = 0
153
- total_partial = 0
154
- total_partial_identified = 0
155
- for query in data:
156
- idx = query["exact"]
157
- # if partial has multiple rows, we want to take the logical or of all of them. Otherwise just set it to the single row
158
- # check if there is one or more rows
159
- # query['partial'] = np.array(query['partial'])
160
- if len(np.array(query["partial"]).shape) > 1:
161
- idx_partial = np.logical_or.reduce(query["partial"], axis=1)
162
- else:
163
- idx_partial = query["partial"]
164
-
165
- sims = query["S_i"]
166
- sims_exact = sims[idx]
167
- sims_partial = sims[idx_partial]
168
- total_missed += (sims_exact < lhat).sum()
169
-
170
- # TODO: are there any divisions by zero here?
171
- total_missed_partial += (sims_partial < lhat).sum()
172
- total_partial_identified += (sims_partial >= lhat).sum()
173
- total_partial += len(sims_partial)
174
-
175
- total_exact += len(sims_exact)
176
- total_inexact_identified += (sims[~np.array(idx)] >= lhat).sum()
177
- total_identified += (sims >= lhat).sum()
178
- return (
179
- total_missed / total_exact,
180
- total_inexact_identified / total_identified,
181
- total_missed_partial / total_partial,
182
- total_partial_identified / total_identified,
183
- )