| |
| |
| |
| |
| |
|
|
| import sys |
| import random |
|
|
|
|
| def getQueries(fname): |
| """ |
| Parse a file containing protein pairs and create a dictionary mapping proteins to their queries. |
| |
| Args: |
| fname (str): Path to input file with format "protein_query" on each line |
| |
| Returns: |
| dict: Dictionary where keys are proteins and values are lists of queries |
| """ |
| |
| with open(fname, "r") as fIN: |
| lines = fIN.readlines() |
|
|
| |
| d = {} |
| for line in lines: |
| |
| prots = line.strip().split("_") |
| if prots[0] not in d: |
| d[prots[0]] = [] |
| d[prots[0]].append(prots[1]) |
|
|
| return d |
|
|
|
|
| def getMatchHigherLevel(fnameDB, dEval, dTrainVal): |
| """ |
| Match higher-level clusters with their corresponding lower-level collections. |
| |
| Args: |
| fnameDB (str): Path to the cluster database file |
| dEval (dict): Dictionary of evaluation set proteins and their queries |
| dTrainVal (dict): Dictionary of training and validation sets proteins and their queries |
| |
| Returns: |
| dict: Dictionary where keys are higher-level cluster representatives and values are tuples of |
| ([train-val proteins], [eval proteins]) that belong to that cluster |
| """ |
| |
| with open(fnameDB, "r") as fIN: |
| lines = fIN.readlines() |
|
|
| d = {} |
| |
| for line in lines: |
| |
| prots = line.strip().split() |
| |
| |
| d[prots[0]] = ([], []) |
| |
| |
| for p in prots: |
| |
| if p in dTrainVal: |
| d[prots[0]][0].append(p) |
| |
| if p in dEval: |
| d[prots[0]][1].append(p) |
| |
| return d |
|
|
|
|
| def splitHigherLevel(dEns, random_seed=42): |
| """ |
| Split higher-level clusters into training, validation, and evaluation sets. |
| |
| Args: |
| dEns (dict): Dictionary mapping higher-level clusters to their proteins |
| random_seed (int): Random seed for reproducibility |
| |
| Returns: |
| tuple: (dtrainval, lval, ltrain, deval_strict) where: |
| - dtrainval: dictionary of clusters with proteins in train/val sets but not in eval |
| - lval: list of clusters selected for validation |
| - ltrain: list of clusters selected for training |
| - deval_strict: dictionary of clusters with proteins only in eval set |
| """ |
| random.seed(random_seed) |
| |
| |
| dall = {k: v[0] for k, v in dEns.items() if (len(v[1]) + len(v[0])) > 0} |
| |
| |
| dtrainval = {k: v[0] for k, v in dEns.items() if len(v[1]) == 0 and len(v[0]) > 0} |
| |
| |
| deval_strict = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0 and len(v[0]) == 0} |
| |
| |
| deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0} |
| |
| |
| n_val = int((len(dtrainval) + len(deval_strict)) / 10) |
| |
| |
| ival = random.sample(range(len(dtrainval)), n_val) |
| myKeys = list(dtrainval.keys()) |
| |
| |
| lval = [myKeys[i] for i in ival] |
| ltrain = [myKeys[i] for i in range(len(myKeys)) if i not in ival] |
| |
| |
| print("Number of higher-order collections with members in train, val or eval:", len(dall)) |
| print("Number of higher-order collections with members in train or val and nothing from eval:", len(dtrainval)) |
| print("Number of higher-order collections with members in eval:", len(deval_relax)) |
| print("Number of higher-order collections with members in eval and nothing from train-val:", len(deval_strict)) |
| print("Sample validation clusters:", lval[1:5]) |
| print("Number of higher-order collections in the new validation set:", len(lval)) |
| print("Sample training clusters:", ltrain[1:5]) |
| print("Number of higher-order collections in the new training set:", len(ltrain)) |
| |
| return dtrainval, lval, ltrain, deval_strict |
|
|
|
|
| def reduceRedundancyInEval(dEns, dEval, random_seed=42): |
| """ |
| Reduce redundancy in the evaluation set by selecting one protein per cluster. |
| |
| Args: |
| dEns (dict): Dictionary mapping higher-level clusters to their proteins |
| dEval (dict): Dictionary of evaluation set proteins and their queries |
| random_seed (int): Random seed for reproducibility |
| |
| Returns: |
| dict: Dictionary of selected evaluation proteins and their queries |
| """ |
| random.seed(random_seed) |
| |
| |
| deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0} |
| |
| |
| deval = {} |
| for k in deval_relax: |
| |
| myColl = random.sample(deval_relax[k], 1)[0] |
| |
| |
| if len(dEval[myColl]) > 1: |
| print("warning!!", myColl, dEval[myColl]) |
| |
| |
| deval[myColl] = [dEval[myColl][0]] |
|
|
| return deval |
|
|
|
|
| def sampleLowerLevel(dtrainval, myL, myD, random_seed=42): |
| """ |
| Sample queries for lower-level collections based on a stratified approach. |
| |
| Args: |
| dtrainval (dict): Dictionary of training and validation proteins |
| myL (list): List of higher-level clusters to process |
| myD (dict): Dictionary mapping proteins to their queries |
| random_seed (int): Random seed for reproducibility |
| |
| Returns: |
| dict: Dictionary of selected proteins and their queries |
| """ |
| random.seed(random_seed) |
| |
| dres = {} |
| |
| |
| n_samples = ([5], [3, 2], [2, 2, 1], [2, 1, 1, 1], [1, 1, 1, 1, 1]) |
| |
| for higherColl in myL: |
| |
| lowerColl = dtrainval[higherColl] |
| n = len(lowerColl) |
| p = min(n, 5) |
| |
| |
| if p == 5: |
| selectedColl = random.sample(lowerColl, p) |
| else: |
| |
| selectedColl = lowerColl |
| |
| |
| nbs = n_samples[p-1] |
| |
| |
| for i in range(p): |
| j = nbs[i] |
| dres[selectedColl[i]] = myD[selectedColl[i]][:j] |
| |
| return dres |
|
|
|
|
| def writeDico(dEns): |
| """ |
| Write cluster matching information to a CSV file. |
| |
| Args: |
| dEns (dict): Dictionary of clusters and their proteins |
| """ |
| with open("match_eval.csv", "w") as fOUT: |
| for k in dEns: |
| n = len(dEns[k]) |
| if n > 0: |
| fOUT.write(k + "," + str(n) + "," + "-".join(dEns[k]) + "\n") |
|
|
|
|
| def write_queries(d, fname): |
| """ |
| Write protein-query pairs to a file. |
| |
| Args: |
| d (dict): Dictionary mapping proteins to their queries |
| fname (str): Output file name |
| """ |
| with open(fname, "w") as fOUT: |
| for k in d: |
| for q in d[k]: |
| fOUT.write(k + "_" + q + "\n") |
|
|
|
|
| def write_values(d, fname): |
| """ |
| Write only query values to a file. |
| |
| Args: |
| d (dict): Dictionary mapping proteins to their queries |
| fname (str): Output file name |
| """ |
| with open(fname, "w") as fOUT: |
| for k in d: |
| for q in d[k]: |
| fOUT.write(q + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| |
| evalF = "eval_list.txt" |
| dEval = getQueries(evalF) |
| |
| |
| trvalF = "train_val_list.txt" |
| dTrainVal = getQueries(trvalF) |
| |
| |
| trF = "full_train_list.txt" |
| dTrain = getQueries(trF) |
| |
| |
| dEns = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", dEval, dTrainVal) |
| |
| |
| dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns) |
| |
| |
| dval = sampleLowerLevel(dtrainval, lval, dTrainVal) |
| write_queries(dval, "val_nr_list_12_05.txt") |
| |
| |
| dtrain = sampleLowerLevel(dtrainval, ltrain, dTrainVal) |
| write_queries(dtrain, "train_nr_list_12_05.txt") |
| |
| |
| deval = reduceRedundancyInEval(dEns, dEval) |
| write_queries(deval, "eval_nr_list_12_05.txt") |
| |
| |
| print("Number of strict evaluation clusters:", len(deval_strict)) |
| |
| |
| dEns2 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrain) |
| dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns2) |
| write_values(deval_strict, "eval_strict_list_12_05.txt") |
| |
| |
| dEns3 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrainVal) |
| dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns3) |
| write_values(deval_strict, "eval_even_stricter_list_12_05.txt") |
|
|