File size: 2,423 Bytes
2c0063e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from massspecgym.utils import MyopicMCES
import numpy as np
import tqdm
from multiprocessing import Pool
from scipy.stats import bootstrap
import os
import pandas as pd

class Compute_Myopic_MCES:
    mces_compute = MyopicMCES()
    

    def compute_mces(tar_cand):
        target, cand = tar_cand
        
        dist = Compute_Myopic_MCES.mces_compute(target, cand)
        return (tar_cand, dist)
    
    def compute_mces_parallel(target_cand_list, n_processes=25):


        with Pool(processes=n_processes) as pool:
            results = list(tqdm.tqdm(pool.imap(Compute_Myopic_MCES.compute_mces, target_cand_list), total=len(target_cand_list)))
        return results

class Compute_Myopic_MCES_timeout:
    mces_compute = MyopicMCES()

    @staticmethod
    def compute_mces(tar_cand):
        target, cand = tar_cand
        dist = Compute_Myopic_MCES.mces_compute(target, cand)
        return (tar_cand, dist)

    @staticmethod
    def compute_mces_parallel(target_cand_list, n_processes=35, timeout=60):  # timeout in seconds
        results = []

        with Pool(processes=n_processes) as pool:
            async_results = [
                pool.apply_async(Compute_Myopic_MCES.compute_mces, args=(tar_cand,))
                for tar_cand in target_cand_list
            ]
            for async_res in tqdm.tqdm(async_results, total=len(target_cand_list)):
                try:
                    result = async_res.get(timeout=timeout)
                except Exception as e:
                    # You can log the error or return a default value
                    result = (None, f"Timeout or error")
                results.append(result)

        return results


# get target
def get_target(candidates, labels):
    return np.array(candidates)[labels][0]

# get mol rank at 1
def get_top_cand(candidates, scores):
    return candidates[np.argmax(scores)]

# split into hit rates
def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
    top_k_hits = []
    rank = row[rank_col]
    for k in top_k:
        if rank <= k:
            top_k_hits.append(1)
        else:
            top_k_hits.append(0)
    return top_k_hits


def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000, seed=0):
    res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples, random_state=seed)
    ci = res.confidence_interval
    return f'{ci.low:.2f}-{ci.high:.2f}'