FLARE / flare /utils /eval.py
yzhouchen001's picture
cleaned up
2c0063e
from massspecgym.utils import MyopicMCES
import numpy as np
import tqdm
from multiprocessing import Pool
from scipy.stats import bootstrap
import os
import pandas as pd
class Compute_Myopic_MCES:
mces_compute = MyopicMCES()
def compute_mces(tar_cand):
target, cand = tar_cand
dist = Compute_Myopic_MCES.mces_compute(target, cand)
return (tar_cand, dist)
def compute_mces_parallel(target_cand_list, n_processes=25):
with Pool(processes=n_processes) as pool:
results = list(tqdm.tqdm(pool.imap(Compute_Myopic_MCES.compute_mces, target_cand_list), total=len(target_cand_list)))
return results
class Compute_Myopic_MCES_timeout:
mces_compute = MyopicMCES()
@staticmethod
def compute_mces(tar_cand):
target, cand = tar_cand
dist = Compute_Myopic_MCES.mces_compute(target, cand)
return (tar_cand, dist)
@staticmethod
def compute_mces_parallel(target_cand_list, n_processes=35, timeout=60): # timeout in seconds
results = []
with Pool(processes=n_processes) as pool:
async_results = [
pool.apply_async(Compute_Myopic_MCES.compute_mces, args=(tar_cand,))
for tar_cand in target_cand_list
]
for async_res in tqdm.tqdm(async_results, total=len(target_cand_list)):
try:
result = async_res.get(timeout=timeout)
except Exception as e:
# You can log the error or return a default value
result = (None, f"Timeout or error")
results.append(result)
return results
# get target
def get_target(candidates, labels):
return np.array(candidates)[labels][0]
# get mol rank at 1
def get_top_cand(candidates, scores):
return candidates[np.argmax(scores)]
# split into hit rates
def convert_rank_to_hit_rates(row, rank_col ,top_k=[1,5,20]):
top_k_hits = []
rank = row[rank_col]
for k in top_k:
if rank <= k:
top_k_hits.append(1)
else:
top_k_hits.append(0)
return top_k_hits
def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000, seed=0):
res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples, random_state=seed)
ci = res.confidence_interval
return f'{ci.low:.2f}-{ci.high:.2f}'