File size: 6,207 Bytes
3a8e9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
'''
# default parameters
python scripts/search.py --query_embedding data/inputs/queries_embeddings.npy --query_fasta data/inputs/rcsb_pdb_4CS4.fasta --lookup_embedding data/lookup/scope_lookup_embeddings.npy --lookup_fasta data/lookup/scope_lookup.fasta --fdr --output results/search_results.csv --k 100
# lower lambda
python scripts/search.py --query_embedding data/inputs/queries_embeddings.npy --query_fasta data/inputs/rcsb_pdb_4CS4.fasta --lookup_embedding data/lookup/scope_lookup_embeddings.npy --lookup_fasta data/lookup/scope_lookup.fasta --fdr --fdr_lambda 0.5 --output results/search_results.csv --k 100 --save_inter


'''
import numpy as np
import pandas as pd
import argparse
from protein_conformal.util import *


def main(args):
    query_embeddings = np.load(args.query_embedding, allow_pickle=True)
    lookup_embeddings = np.load(args.lookup_embedding, allow_pickle=True)
    query_fasta = read_fasta(args.query_fasta)
    if args.lookup_fasta.endswith(".tsv"):
        print("Loading lookup sequences and metadata from csv")
        lookup_df = pd.read_csv(args.lookup_fasta, sep="\t")
        # extract sequences in column "Sequence", and metadata in columns "Pfam" and "Protein names"
        lookup_seqs = lookup_df["Sequence"].values
        metadata_columns = ["Entry", "Pfam", "Protein names"]
        # Construct `lookup_meta` as a list of tuples for each row
        lookup_meta = lookup_df[metadata_columns].apply(tuple, axis=1).tolist()
    else:
        lookup_fasta = read_fasta(args.lookup_fasta)
        lookup_seqs, lookup_meta = lookup_fasta
    print("Loaded data")
    # Extract sequences and metadata
    query_seqs, query_meta = query_fasta

    lookup_database = load_database(lookup_embeddings)
    print("Loaded database")
    k = args.k
    D, I = query(lookup_database, query_embeddings, k)

    # Create DataFrame to store results
    results = []
    for i, (indices, distances) in enumerate(zip(I, D)):
        for idx, distance in zip(indices, distances):
            # define result to have columns in metadata_columns
            result = {
                "query_seq": query_seqs[i],
                "query_meta": query_meta[i],
                "lookup_seq": lookup_seqs[idx],
                "D_score": distance,
            }
            if args.lookup_fasta.endswith(".tsv"):
                result["lookup_entry"] = lookup_meta[idx][0]
                result["lookup_pfam"] = lookup_meta[idx][1]
                result["lookup_protein_names"] = lookup_meta[idx][2]
            else:
                result["lookup_meta"] = lookup_meta[idx]
            results.append(result)
    results = pd.DataFrame(results)
    if args.save_inter:
        results.to_csv("inter_" + args.output, index=False)

    # filter results based off of conformal guarantees
    if args.fdr and args.fnr:
        raise ValueError("Cannot control both FDR and FNR")
    if args.fdr:
        if args.fdr_lambda:
            lhat = args.fdr_lambda
        else:
            # TODO: compute FDR as per pfam example
            # lhat, fdr_cal = get_thresh_FDR(
            #     y_cal, X_cal, args.alpha, args.delta, N=100
            # )
            # get threshold from lambda.py, already exists but slow. a bit slow
            # given new alpha, calculate lambda, and run example at diff values of alpha
            # then get precomputed lambda, when code is run then dont need to calcualtion each time
            # make a table of precomputed lambdas similar to calibrated probs, isnt there yet, we'll work on that
            # find where lambda is calcuated against alpha in the conformal risk control 
            lhat = 0.1
        results = results[results["D_score"] >= lhat] # cosine similarity
    elif args.fnr:
        if args.fnr_lambda:
            lhat = args.fnr_lambda
        else:
            pass
        results = results[results["D_score"] >= lhat]

    results.to_csv(args.output, index=False)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Process data with conformal guarantees"
    )
    parser.add_argument("--fnr", action='store_true', default=False, help="FNR risk control")
    parser.add_argument("--fdr", action='store_true', default=False, help="FPR risk control")
    parser.add_argument(
        "--fdr_lambda",
        type=float,
        default=0.999980225003127,
        help="FDR lambda hat value if precomputed",
    )
    parser.add_argument(
        "--fnr_lambda",
        type=float,
        # default=0.999980225003127,
        help="FNR lambda hat value if precomputed",
    )
    parser.add_argument(
        "--k", type=int, default=1000, help="maximal number of neighbors with FAISS"
    )
    parser.add_argument(
        "--save_inter", action='store_true', help="save intermediate results"
    )
    parser.add_argument(
        "--alpha", type=float, default=0.1, help="Alpha value for the algorithm"
    )

    parser.add_argument(
        "--num_trials", type=int, default=100, help="Number of trials to run"
    )
    parser.add_argument(
        "--n_calib", type=int, default=1000, help="Number of calibration data points"
    )
    parser.add_argument(
        "--delta", type=float, default=0.5, help="Delta value for the algorithm"
    )

    parser.add_argument(
        "--output",
        type=str,
        default="results.csv",
        help="Output file for the results",
    )
    parser.add_argument(
        "--add_date", type=bool, default=True, help="Add date to output file name"
    )
    parser.add_argument(
        "--query_embedding", type=str, default="", help="Query file with the embeddings"
    )
    parser.add_argument(
        "--query_fasta", type=str, default="", help="Input file for the query sequences and metadata"
    )  # TODO: add an option to grab more metadata than just from the fasta file
    parser.add_argument(
	"--lookup_embedding", type=str, default="", help="Lookup embeddings file"
    )
    parser.add_argument(
        "--lookup_fasta", type=str, default="", help="Input file for the lookup sequences and metadata"
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)