Vlmbd commited on
split script
Browse files
match_with_train_list_commented.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Author: Elodie Laine
|
| 3 |
+
# Usage: python script.py
|
| 4 |
+
# Purpose: This script performs train-validation-evaluation splits based on sequence ID clustering
|
| 5 |
+
# for the petimot tool, ensuring non-redundant datasets with controlled sequence similarity.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def getQueries(fname):
|
| 12 |
+
"""
|
| 13 |
+
Parse a file containing protein pairs and create a dictionary mapping proteins to their queries.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
fname (str): Path to input file with format "protein_query" on each line
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
dict: Dictionary where keys are proteins and values are lists of queries
|
| 20 |
+
"""
|
| 21 |
+
# Open and read the file
|
| 22 |
+
with open(fname, "r") as fIN:
|
| 23 |
+
lines = fIN.readlines()
|
| 24 |
+
|
| 25 |
+
# Create the dictionary
|
| 26 |
+
d = {}
|
| 27 |
+
for line in lines:
|
| 28 |
+
# Split each line by underscore to get protein and query
|
| 29 |
+
prots = line.strip().split("_")
|
| 30 |
+
if prots[0] not in d:
|
| 31 |
+
d[prots[0]] = []
|
| 32 |
+
d[prots[0]].append(prots[1])
|
| 33 |
+
|
| 34 |
+
return d
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def getMatchHigherLevel(fnameDB, dEval, dTrainVal):
|
| 38 |
+
"""
|
| 39 |
+
Match higher-level clusters with their corresponding lower-level collections.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
fnameDB (str): Path to the cluster database file
|
| 43 |
+
dEval (dict): Dictionary of evaluation set proteins and their queries
|
| 44 |
+
dTrainVal (dict): Dictionary of training and validation sets proteins and their queries
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
dict: Dictionary where keys are higher-level cluster representatives and values are tuples of
|
| 48 |
+
([train-val proteins], [eval proteins]) that belong to that cluster
|
| 49 |
+
"""
|
| 50 |
+
# Open and read the database file
|
| 51 |
+
with open(fnameDB, "r") as fIN:
|
| 52 |
+
lines = fIN.readlines()
|
| 53 |
+
|
| 54 |
+
d = {}
|
| 55 |
+
# Each line represents a cluster of PDB chains
|
| 56 |
+
for line in lines:
|
| 57 |
+
# Get individual chains in the cluster
|
| 58 |
+
prots = line.strip().split()
|
| 59 |
+
|
| 60 |
+
# Use the first protein as the cluster representative
|
| 61 |
+
d[prots[0]] = ([], []) # Initialize tuple with empty lists for train-val and eval proteins
|
| 62 |
+
|
| 63 |
+
# For each protein in the cluster
|
| 64 |
+
for p in prots:
|
| 65 |
+
# If it belongs to the train or validation set
|
| 66 |
+
if p in dTrainVal:
|
| 67 |
+
d[prots[0]][0].append(p)
|
| 68 |
+
# If it belongs to the evaluation set
|
| 69 |
+
if p in dEval:
|
| 70 |
+
d[prots[0]][1].append(p)
|
| 71 |
+
|
| 72 |
+
return d
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def splitHigherLevel(dEns, random_seed=42):
|
| 76 |
+
"""
|
| 77 |
+
Split higher-level clusters into training, validation, and evaluation sets.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
dEns (dict): Dictionary mapping higher-level clusters to their proteins
|
| 81 |
+
random_seed (int): Random seed for reproducibility
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
tuple: (dtrainval, lval, ltrain, deval_strict) where:
|
| 85 |
+
- dtrainval: dictionary of clusters with proteins in train/val sets but not in eval
|
| 86 |
+
- lval: list of clusters selected for validation
|
| 87 |
+
- ltrain: list of clusters selected for training
|
| 88 |
+
- deval_strict: dictionary of clusters with proteins only in eval set
|
| 89 |
+
"""
|
| 90 |
+
random.seed(random_seed)
|
| 91 |
+
|
| 92 |
+
# Clusters that have at least one protein in any set
|
| 93 |
+
dall = {k: v[0] for k, v in dEns.items() if (len(v[1]) + len(v[0])) > 0}
|
| 94 |
+
|
| 95 |
+
# Clusters with proteins only in train/val sets (not in eval)
|
| 96 |
+
dtrainval = {k: v[0] for k, v in dEns.items() if len(v[1]) == 0 and len(v[0]) > 0}
|
| 97 |
+
|
| 98 |
+
# Clusters with proteins only in eval set (not in train/val)
|
| 99 |
+
deval_strict = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0 and len(v[0]) == 0}
|
| 100 |
+
|
| 101 |
+
# Clusters with at least one protein in eval set (may overlap with train/val)
|
| 102 |
+
deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
|
| 103 |
+
|
| 104 |
+
# Calculate number of validation samples (10% of total clusters)
|
| 105 |
+
n_val = int((len(dtrainval) + len(deval_strict)) / 10)
|
| 106 |
+
|
| 107 |
+
# Randomly sample clusters for validation
|
| 108 |
+
ival = random.sample(range(len(dtrainval)), n_val)
|
| 109 |
+
myKeys = list(dtrainval.keys())
|
| 110 |
+
|
| 111 |
+
# Create lists of cluster IDs for validation and training
|
| 112 |
+
lval = [myKeys[i] for i in ival]
|
| 113 |
+
ltrain = [myKeys[i] for i in range(len(myKeys)) if i not in ival]
|
| 114 |
+
|
| 115 |
+
# Print statistics about the split
|
| 116 |
+
print("Number of higher-order collections with members in train, val or eval:", len(dall))
|
| 117 |
+
print("Number of higher-order collections with members in train or val and nothing from eval:", len(dtrainval))
|
| 118 |
+
print("Number of higher-order collections with members in eval:", len(deval_relax))
|
| 119 |
+
print("Number of higher-order collections with members in eval and nothing from train-val:", len(deval_strict))
|
| 120 |
+
print("Sample validation clusters:", lval[1:5])
|
| 121 |
+
print("Number of higher-order collections in the new validation set:", len(lval))
|
| 122 |
+
print("Sample training clusters:", ltrain[1:5])
|
| 123 |
+
print("Number of higher-order collections in the new training set:", len(ltrain))
|
| 124 |
+
|
| 125 |
+
return dtrainval, lval, ltrain, deval_strict
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def reduceRedundancyInEval(dEns, dEval, random_seed=42):
|
| 129 |
+
"""
|
| 130 |
+
Reduce redundancy in the evaluation set by selecting one protein per cluster.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
dEns (dict): Dictionary mapping higher-level clusters to their proteins
|
| 134 |
+
dEval (dict): Dictionary of evaluation set proteins and their queries
|
| 135 |
+
random_seed (int): Random seed for reproducibility
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
dict: Dictionary of selected evaluation proteins and their queries
|
| 139 |
+
"""
|
| 140 |
+
random.seed(random_seed)
|
| 141 |
+
|
| 142 |
+
# Get clusters with proteins in eval set
|
| 143 |
+
deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
|
| 144 |
+
|
| 145 |
+
# Create new evaluation dictionary with reduced redundancy
|
| 146 |
+
deval = {}
|
| 147 |
+
for k in deval_relax:
|
| 148 |
+
# Randomly select one protein from each cluster
|
| 149 |
+
myColl = random.sample(deval_relax[k], 1)[0]
|
| 150 |
+
|
| 151 |
+
# Check if the selected protein has multiple queries (warning case)
|
| 152 |
+
if len(dEval[myColl]) > 1:
|
| 153 |
+
print("warning!!", myColl, dEval[myColl])
|
| 154 |
+
|
| 155 |
+
# Add only the first query for the selected protein
|
| 156 |
+
deval[myColl] = [dEval[myColl][0]]
|
| 157 |
+
|
| 158 |
+
return deval
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def sampleLowerLevel(dtrainval, myL, myD, random_seed=42):
|
| 162 |
+
"""
|
| 163 |
+
Sample queries for lower-level collections based on a stratified approach.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
dtrainval (dict): Dictionary of training and validation proteins
|
| 167 |
+
myL (list): List of higher-level clusters to process
|
| 168 |
+
myD (dict): Dictionary mapping proteins to their queries
|
| 169 |
+
random_seed (int): Random seed for reproducibility
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
dict: Dictionary of selected proteins and their queries
|
| 173 |
+
"""
|
| 174 |
+
random.seed(random_seed)
|
| 175 |
+
|
| 176 |
+
dres = {}
|
| 177 |
+
# Sampling strategy: number of samples to take based on available proteins
|
| 178 |
+
# For 1 protein, take 5 queries; for 2 proteins, take 3 and 2 queries; etc.
|
| 179 |
+
n_samples = ([5], [3, 2], [2, 2, 1], [2, 1, 1, 1], [1, 1, 1, 1, 1])
|
| 180 |
+
|
| 181 |
+
for higherColl in myL:
|
| 182 |
+
# Get lower-level collections for this cluster
|
| 183 |
+
lowerColl = dtrainval[higherColl]
|
| 184 |
+
n = len(lowerColl)
|
| 185 |
+
p = min(n, 5) # Cap at 5 proteins per cluster
|
| 186 |
+
|
| 187 |
+
# If we have 5 or more proteins, randomly sample 5
|
| 188 |
+
if p == 5:
|
| 189 |
+
selectedColl = random.sample(lowerColl, p)
|
| 190 |
+
else:
|
| 191 |
+
# Otherwise, use all available proteins
|
| 192 |
+
selectedColl = lowerColl
|
| 193 |
+
|
| 194 |
+
# Get the sampling distribution for this number of proteins
|
| 195 |
+
nbs = n_samples[p-1]
|
| 196 |
+
|
| 197 |
+
# Sample queries for each selected protein
|
| 198 |
+
for i in range(p):
|
| 199 |
+
j = nbs[i] # Number of queries to sample for this protein
|
| 200 |
+
dres[selectedColl[i]] = myD[selectedColl[i]][:j]
|
| 201 |
+
|
| 202 |
+
return dres
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def writeDico(dEns):
|
| 206 |
+
"""
|
| 207 |
+
Write cluster matching information to a CSV file.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
dEns (dict): Dictionary of clusters and their proteins
|
| 211 |
+
"""
|
| 212 |
+
with open("match_eval.csv", "w") as fOUT:
|
| 213 |
+
for k in dEns:
|
| 214 |
+
n = len(dEns[k])
|
| 215 |
+
if n > 0:
|
| 216 |
+
fOUT.write(k + "," + str(n) + "," + "-".join(dEns[k]) + "\n")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def write_queries(d, fname):
|
| 220 |
+
"""
|
| 221 |
+
Write protein-query pairs to a file.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
d (dict): Dictionary mapping proteins to their queries
|
| 225 |
+
fname (str): Output file name
|
| 226 |
+
"""
|
| 227 |
+
with open(fname, "w") as fOUT:
|
| 228 |
+
for k in d:
|
| 229 |
+
for q in d[k]:
|
| 230 |
+
fOUT.write(k + "_" + q + "\n")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def write_values(d, fname):
|
| 234 |
+
"""
|
| 235 |
+
Write only query values to a file.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
d (dict): Dictionary mapping proteins to their queries
|
| 239 |
+
fname (str): Output file name
|
| 240 |
+
"""
|
| 241 |
+
with open(fname, "w") as fOUT:
|
| 242 |
+
for k in d:
|
| 243 |
+
for q in d[k]:
|
| 244 |
+
fOUT.write(q + "\n")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
# Load evaluation set data
|
| 249 |
+
evalF = "eval_list.txt"
|
| 250 |
+
dEval = getQueries(evalF)
|
| 251 |
+
|
| 252 |
+
# Load training and validation sets data
|
| 253 |
+
trvalF = "train_val_list.txt"
|
| 254 |
+
dTrainVal = getQueries(trvalF)
|
| 255 |
+
|
| 256 |
+
# Load full training set data
|
| 257 |
+
trF = "full_train_list.txt"
|
| 258 |
+
dTrain = getQueries(trF)
|
| 259 |
+
|
| 260 |
+
# Match the collections using 30% sequence identity, 80% coverage threshold
|
| 261 |
+
dEns = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", dEval, dTrainVal)
|
| 262 |
+
|
| 263 |
+
# Split into training, validation and evaluation sets
|
| 264 |
+
dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns)
|
| 265 |
+
|
| 266 |
+
# Sample validation queries and write to file
|
| 267 |
+
dval = sampleLowerLevel(dtrainval, lval, dTrainVal)
|
| 268 |
+
write_queries(dval, "val_nr_list_12_05.txt")
|
| 269 |
+
|
| 270 |
+
# Sample training queries and write to file
|
| 271 |
+
dtrain = sampleLowerLevel(dtrainval, ltrain, dTrainVal)
|
| 272 |
+
write_queries(dtrain, "train_nr_list_12_05.txt")
|
| 273 |
+
|
| 274 |
+
# Reduce redundancy in evaluation set and write to file
|
| 275 |
+
deval = reduceRedundancyInEval(dEns, dEval)
|
| 276 |
+
write_queries(deval, "eval_nr_list_12_05.txt")
|
| 277 |
+
|
| 278 |
+
# Print number of strict evaluation clusters
|
| 279 |
+
print("Number of strict evaluation clusters:", len(deval_strict))
|
| 280 |
+
|
| 281 |
+
# Generate strict evaluation set based on full training
|
| 282 |
+
dEns2 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrain)
|
| 283 |
+
dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns2)
|
| 284 |
+
write_values(deval_strict, "eval_strict_list_12_05.txt")
|
| 285 |
+
|
| 286 |
+
# Generate even stricter evaluation set based on train+val
|
| 287 |
+
dEns3 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrainVal)
|
| 288 |
+
dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns3)
|
| 289 |
+
write_values(deval_strict, "eval_even_stricter_list_12_05.txt")
|