Vlmbd commited on
Commit
f1804f1
·
unverified ·
1 Parent(s): 0a329d8

split script

Browse files
Files changed (1) hide show
  1. match_with_train_list_commented.py +289 -0
match_with_train_list_commented.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Author: Elodie Laine
3
+ # Usage: python script.py
4
+ # Purpose: This script performs train-validation-evaluation splits based on sequence ID clustering
5
+ # for the petimot tool, ensuring non-redundant datasets with controlled sequence similarity.
6
+
7
+ import sys
8
+ import random
9
+
10
+
11
+ def getQueries(fname):
12
+ """
13
+ Parse a file containing protein pairs and create a dictionary mapping proteins to their queries.
14
+
15
+ Args:
16
+ fname (str): Path to input file with format "protein_query" on each line
17
+
18
+ Returns:
19
+ dict: Dictionary where keys are proteins and values are lists of queries
20
+ """
21
+ # Open and read the file
22
+ with open(fname, "r") as fIN:
23
+ lines = fIN.readlines()
24
+
25
+ # Create the dictionary
26
+ d = {}
27
+ for line in lines:
28
+ # Split each line by underscore to get protein and query
29
+ prots = line.strip().split("_")
30
+ if prots[0] not in d:
31
+ d[prots[0]] = []
32
+ d[prots[0]].append(prots[1])
33
+
34
+ return d
35
+
36
+
37
+ def getMatchHigherLevel(fnameDB, dEval, dTrainVal):
38
+ """
39
+ Match higher-level clusters with their corresponding lower-level collections.
40
+
41
+ Args:
42
+ fnameDB (str): Path to the cluster database file
43
+ dEval (dict): Dictionary of evaluation set proteins and their queries
44
+ dTrainVal (dict): Dictionary of training and validation sets proteins and their queries
45
+
46
+ Returns:
47
+ dict: Dictionary where keys are higher-level cluster representatives and values are tuples of
48
+ ([train-val proteins], [eval proteins]) that belong to that cluster
49
+ """
50
+ # Open and read the database file
51
+ with open(fnameDB, "r") as fIN:
52
+ lines = fIN.readlines()
53
+
54
+ d = {}
55
+ # Each line represents a cluster of PDB chains
56
+ for line in lines:
57
+ # Get individual chains in the cluster
58
+ prots = line.strip().split()
59
+
60
+ # Use the first protein as the cluster representative
61
+ d[prots[0]] = ([], []) # Initialize tuple with empty lists for train-val and eval proteins
62
+
63
+ # For each protein in the cluster
64
+ for p in prots:
65
+ # If it belongs to the train or validation set
66
+ if p in dTrainVal:
67
+ d[prots[0]][0].append(p)
68
+ # If it belongs to the evaluation set
69
+ if p in dEval:
70
+ d[prots[0]][1].append(p)
71
+
72
+ return d
73
+
74
+
75
+ def splitHigherLevel(dEns, random_seed=42):
76
+ """
77
+ Split higher-level clusters into training, validation, and evaluation sets.
78
+
79
+ Args:
80
+ dEns (dict): Dictionary mapping higher-level clusters to their proteins
81
+ random_seed (int): Random seed for reproducibility
82
+
83
+ Returns:
84
+ tuple: (dtrainval, lval, ltrain, deval_strict) where:
85
+ - dtrainval: dictionary of clusters with proteins in train/val sets but not in eval
86
+ - lval: list of clusters selected for validation
87
+ - ltrain: list of clusters selected for training
88
+ - deval_strict: dictionary of clusters with proteins only in eval set
89
+ """
90
+ random.seed(random_seed)
91
+
92
+ # Clusters that have at least one protein in any set
93
+ dall = {k: v[0] for k, v in dEns.items() if (len(v[1]) + len(v[0])) > 0}
94
+
95
+ # Clusters with proteins only in train/val sets (not in eval)
96
+ dtrainval = {k: v[0] for k, v in dEns.items() if len(v[1]) == 0 and len(v[0]) > 0}
97
+
98
+ # Clusters with proteins only in eval set (not in train/val)
99
+ deval_strict = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0 and len(v[0]) == 0}
100
+
101
+ # Clusters with at least one protein in eval set (may overlap with train/val)
102
+ deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
103
+
104
+ # Calculate number of validation samples (10% of total clusters)
105
+ n_val = int((len(dtrainval) + len(deval_strict)) / 10)
106
+
107
+ # Randomly sample clusters for validation
108
+ ival = random.sample(range(len(dtrainval)), n_val)
109
+ myKeys = list(dtrainval.keys())
110
+
111
+ # Create lists of cluster IDs for validation and training
112
+ lval = [myKeys[i] for i in ival]
113
+ ltrain = [myKeys[i] for i in range(len(myKeys)) if i not in ival]
114
+
115
+ # Print statistics about the split
116
+ print("Number of higher-order collections with members in train, val or eval:", len(dall))
117
+ print("Number of higher-order collections with members in train or val and nothing from eval:", len(dtrainval))
118
+ print("Number of higher-order collections with members in eval:", len(deval_relax))
119
+ print("Number of higher-order collections with members in eval and nothing from train-val:", len(deval_strict))
120
+ print("Sample validation clusters:", lval[1:5])
121
+ print("Number of higher-order collections in the new validation set:", len(lval))
122
+ print("Sample training clusters:", ltrain[1:5])
123
+ print("Number of higher-order collections in the new training set:", len(ltrain))
124
+
125
+ return dtrainval, lval, ltrain, deval_strict
126
+
127
+
128
+ def reduceRedundancyInEval(dEns, dEval, random_seed=42):
129
+ """
130
+ Reduce redundancy in the evaluation set by selecting one protein per cluster.
131
+
132
+ Args:
133
+ dEns (dict): Dictionary mapping higher-level clusters to their proteins
134
+ dEval (dict): Dictionary of evaluation set proteins and their queries
135
+ random_seed (int): Random seed for reproducibility
136
+
137
+ Returns:
138
+ dict: Dictionary of selected evaluation proteins and their queries
139
+ """
140
+ random.seed(random_seed)
141
+
142
+ # Get clusters with proteins in eval set
143
+ deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
144
+
145
+ # Create new evaluation dictionary with reduced redundancy
146
+ deval = {}
147
+ for k in deval_relax:
148
+ # Randomly select one protein from each cluster
149
+ myColl = random.sample(deval_relax[k], 1)[0]
150
+
151
+ # Check if the selected protein has multiple queries (warning case)
152
+ if len(dEval[myColl]) > 1:
153
+ print("warning!!", myColl, dEval[myColl])
154
+
155
+ # Add only the first query for the selected protein
156
+ deval[myColl] = [dEval[myColl][0]]
157
+
158
+ return deval
159
+
160
+
161
+ def sampleLowerLevel(dtrainval, myL, myD, random_seed=42):
162
+ """
163
+ Sample queries for lower-level collections based on a stratified approach.
164
+
165
+ Args:
166
+ dtrainval (dict): Dictionary of training and validation proteins
167
+ myL (list): List of higher-level clusters to process
168
+ myD (dict): Dictionary mapping proteins to their queries
169
+ random_seed (int): Random seed for reproducibility
170
+
171
+ Returns:
172
+ dict: Dictionary of selected proteins and their queries
173
+ """
174
+ random.seed(random_seed)
175
+
176
+ dres = {}
177
+ # Sampling strategy: number of samples to take based on available proteins
178
+ # For 1 protein, take 5 queries; for 2 proteins, take 3 and 2 queries; etc.
179
+ n_samples = ([5], [3, 2], [2, 2, 1], [2, 1, 1, 1], [1, 1, 1, 1, 1])
180
+
181
+ for higherColl in myL:
182
+ # Get lower-level collections for this cluster
183
+ lowerColl = dtrainval[higherColl]
184
+ n = len(lowerColl)
185
+ p = min(n, 5) # Cap at 5 proteins per cluster
186
+
187
+ # If we have 5 or more proteins, randomly sample 5
188
+ if p == 5:
189
+ selectedColl = random.sample(lowerColl, p)
190
+ else:
191
+ # Otherwise, use all available proteins
192
+ selectedColl = lowerColl
193
+
194
+ # Get the sampling distribution for this number of proteins
195
+ nbs = n_samples[p-1]
196
+
197
+ # Sample queries for each selected protein
198
+ for i in range(p):
199
+ j = nbs[i] # Number of queries to sample for this protein
200
+ dres[selectedColl[i]] = myD[selectedColl[i]][:j]
201
+
202
+ return dres
203
+
204
+
205
+ def writeDico(dEns):
206
+ """
207
+ Write cluster matching information to a CSV file.
208
+
209
+ Args:
210
+ dEns (dict): Dictionary of clusters and their proteins
211
+ """
212
+ with open("match_eval.csv", "w") as fOUT:
213
+ for k in dEns:
214
+ n = len(dEns[k])
215
+ if n > 0:
216
+ fOUT.write(k + "," + str(n) + "," + "-".join(dEns[k]) + "\n")
217
+
218
+
219
+ def write_queries(d, fname):
220
+ """
221
+ Write protein-query pairs to a file.
222
+
223
+ Args:
224
+ d (dict): Dictionary mapping proteins to their queries
225
+ fname (str): Output file name
226
+ """
227
+ with open(fname, "w") as fOUT:
228
+ for k in d:
229
+ for q in d[k]:
230
+ fOUT.write(k + "_" + q + "\n")
231
+
232
+
233
+ def write_values(d, fname):
234
+ """
235
+ Write only query values to a file.
236
+
237
+ Args:
238
+ d (dict): Dictionary mapping proteins to their queries
239
+ fname (str): Output file name
240
+ """
241
+ with open(fname, "w") as fOUT:
242
+ for k in d:
243
+ for q in d[k]:
244
+ fOUT.write(q + "\n")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ # Load evaluation set data
249
+ evalF = "eval_list.txt"
250
+ dEval = getQueries(evalF)
251
+
252
+ # Load training and validation sets data
253
+ trvalF = "train_val_list.txt"
254
+ dTrainVal = getQueries(trvalF)
255
+
256
+ # Load full training set data
257
+ trF = "full_train_list.txt"
258
+ dTrain = getQueries(trF)
259
+
260
+ # Match the collections using 30% sequence identity, 80% coverage threshold
261
+ dEns = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", dEval, dTrainVal)
262
+
263
+ # Split into training, validation and evaluation sets
264
+ dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns)
265
+
266
+ # Sample validation queries and write to file
267
+ dval = sampleLowerLevel(dtrainval, lval, dTrainVal)
268
+ write_queries(dval, "val_nr_list_12_05.txt")
269
+
270
+ # Sample training queries and write to file
271
+ dtrain = sampleLowerLevel(dtrainval, ltrain, dTrainVal)
272
+ write_queries(dtrain, "train_nr_list_12_05.txt")
273
+
274
+ # Reduce redundancy in evaluation set and write to file
275
+ deval = reduceRedundancyInEval(dEns, dEval)
276
+ write_queries(deval, "eval_nr_list_12_05.txt")
277
+
278
+ # Print number of strict evaluation clusters
279
+ print("Number of strict evaluation clusters:", len(deval_strict))
280
+
281
+ # Generate strict evaluation set based on full training
282
+ dEns2 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrain)
283
+ dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns2)
284
+ write_values(deval_strict, "eval_strict_list_12_05.txt")
285
+
286
+ # Generate even stricter evaluation set based on train+val
287
+ dEns3 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrainVal)
288
+ dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns3)
289
+ write_values(deval_strict, "eval_even_stricter_list_12_05.txt")