Spaces:
Running
Running
File size: 6,207 Bytes
3a8e9de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
'''
# default parameters
python scripts/search.py --query_embedding data/inputs/queries_embeddings.npy --query_fasta data/inputs/rcsb_pdb_4CS4.fasta --lookup_embedding data/lookup/scope_lookup_embeddings.npy --lookup_fasta data/lookup/scope_lookup.fasta --fdr --output results/search_results.csv --k 100
# lower lambda
python scripts/search.py --query_embedding data/inputs/queries_embeddings.npy --query_fasta data/inputs/rcsb_pdb_4CS4.fasta --lookup_embedding data/lookup/scope_lookup_embeddings.npy --lookup_fasta data/lookup/scope_lookup.fasta --fdr --fdr_lambda 0.5 --output results/search_results.csv --k 100 --save_inter
'''
import numpy as np
import pandas as pd
import argparse
from protein_conformal.util import *
def main(args):
query_embeddings = np.load(args.query_embedding, allow_pickle=True)
lookup_embeddings = np.load(args.lookup_embedding, allow_pickle=True)
query_fasta = read_fasta(args.query_fasta)
if args.lookup_fasta.endswith(".tsv"):
print("Loading lookup sequences and metadata from csv")
lookup_df = pd.read_csv(args.lookup_fasta, sep="\t")
# extract sequences in column "Sequence", and metadata in columns "Pfam" and "Protein names"
lookup_seqs = lookup_df["Sequence"].values
metadata_columns = ["Entry", "Pfam", "Protein names"]
# Construct `lookup_meta` as a list of tuples for each row
lookup_meta = lookup_df[metadata_columns].apply(tuple, axis=1).tolist()
else:
lookup_fasta = read_fasta(args.lookup_fasta)
lookup_seqs, lookup_meta = lookup_fasta
print("Loaded data")
# Extract sequences and metadata
query_seqs, query_meta = query_fasta
lookup_database = load_database(lookup_embeddings)
print("Loaded database")
k = args.k
D, I = query(lookup_database, query_embeddings, k)
# Create DataFrame to store results
results = []
for i, (indices, distances) in enumerate(zip(I, D)):
for idx, distance in zip(indices, distances):
# define result to have columns in metadata_columns
result = {
"query_seq": query_seqs[i],
"query_meta": query_meta[i],
"lookup_seq": lookup_seqs[idx],
"D_score": distance,
}
if args.lookup_fasta.endswith(".tsv"):
result["lookup_entry"] = lookup_meta[idx][0]
result["lookup_pfam"] = lookup_meta[idx][1]
result["lookup_protein_names"] = lookup_meta[idx][2]
else:
result["lookup_meta"] = lookup_meta[idx]
results.append(result)
results = pd.DataFrame(results)
if args.save_inter:
results.to_csv("inter_" + args.output, index=False)
# filter results based off of conformal guarantees
if args.fdr and args.fnr:
raise ValueError("Cannot control both FDR and FNR")
if args.fdr:
if args.fdr_lambda:
lhat = args.fdr_lambda
else:
# TODO: compute FDR as per pfam example
# lhat, fdr_cal = get_thresh_FDR(
# y_cal, X_cal, args.alpha, args.delta, N=100
# )
# get threshold from lambda.py, already exists but slow. a bit slow
# given new alpha, calculate lambda, and run example at diff values of alpha
# then get precomputed lambda, when code is run then dont need to calcualtion each time
# make a table of precomputed lambdas similar to calibrated probs, isnt there yet, we'll work on that
# find where lambda is calcuated against alpha in the conformal risk control
lhat = 0.1
results = results[results["D_score"] >= lhat] # cosine similarity
elif args.fnr:
if args.fnr_lambda:
lhat = args.fnr_lambda
else:
pass
results = results[results["D_score"] >= lhat]
results.to_csv(args.output, index=False)
def parse_args():
parser = argparse.ArgumentParser(
description="Process data with conformal guarantees"
)
parser.add_argument("--fnr", action='store_true', default=False, help="FNR risk control")
parser.add_argument("--fdr", action='store_true', default=False, help="FPR risk control")
parser.add_argument(
"--fdr_lambda",
type=float,
default=0.999980225003127,
help="FDR lambda hat value if precomputed",
)
parser.add_argument(
"--fnr_lambda",
type=float,
# default=0.999980225003127,
help="FNR lambda hat value if precomputed",
)
parser.add_argument(
"--k", type=int, default=1000, help="maximal number of neighbors with FAISS"
)
parser.add_argument(
"--save_inter", action='store_true', help="save intermediate results"
)
parser.add_argument(
"--alpha", type=float, default=0.1, help="Alpha value for the algorithm"
)
parser.add_argument(
"--num_trials", type=int, default=100, help="Number of trials to run"
)
parser.add_argument(
"--n_calib", type=int, default=1000, help="Number of calibration data points"
)
parser.add_argument(
"--delta", type=float, default=0.5, help="Delta value for the algorithm"
)
parser.add_argument(
"--output",
type=str,
default="results.csv",
help="Output file for the results",
)
parser.add_argument(
"--add_date", type=bool, default=True, help="Add date to output file name"
)
parser.add_argument(
"--query_embedding", type=str, default="", help="Query file with the embeddings"
)
parser.add_argument(
"--query_fasta", type=str, default="", help="Input file for the query sequences and metadata"
) # TODO: add an option to grab more metadata than just from the fasta file
parser.add_argument(
"--lookup_embedding", type=str, default="", help="Lookup embeddings file"
)
parser.add_argument(
"--lookup_fasta", type=str, default="", help="Input file for the lookup sequences and metadata"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)
|