File size: 10,590 Bytes
1c02b99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#!/usr/bin/env python3
# Author: Elodie Laine
# Usage: python script.py
# Purpose: This script performs train-validation-evaluation splits based on sequence ID clustering 
#          for the petimot tool, ensuring non-redundant datasets with controlled sequence similarity.

import sys
import random


def getQueries(fname):
    """
    Parse a file containing protein pairs and create a dictionary mapping proteins to their queries.
    
    Args:
        fname (str): Path to input file with format "protein_query" on each line
        
    Returns:
        dict: Dictionary where keys are proteins and values are lists of queries
    """
    # Open and read the file
    with open(fname, "r") as fIN:
        lines = fIN.readlines()

    # Create the dictionary
    d = {}
    for line in lines:
        # Split each line by underscore to get protein and query
        prots = line.strip().split("_")
        if prots[0] not in d:
            d[prots[0]] = []
        d[prots[0]].append(prots[1])

    return d


def getMatchHigherLevel(fnameDB, dEval, dTrainVal):
    """
    Match higher-level clusters with their corresponding lower-level collections.
    
    Args:
        fnameDB (str): Path to the cluster database file
        dEval (dict): Dictionary of evaluation set proteins and their queries
        dTrainVal (dict): Dictionary of training and validation sets proteins and their queries
        
    Returns:
        dict: Dictionary where keys are higher-level cluster representatives and values are tuples of 
              ([train-val proteins], [eval proteins]) that belong to that cluster
    """
    # Open and read the database file
    with open(fnameDB, "r") as fIN:
        lines = fIN.readlines()

    d = {}
    # Each line represents a cluster of PDB chains
    for line in lines:
        # Get individual chains in the cluster
        prots = line.strip().split()
        
        # Use the first protein as the cluster representative
        d[prots[0]] = ([], [])  # Initialize tuple with empty lists for train-val and eval proteins
        
        # For each protein in the cluster
        for p in prots:
            # If it belongs to the train or validation set
            if p in dTrainVal:
                d[prots[0]][0].append(p)
            # If it belongs to the evaluation set
            if p in dEval:
                d[prots[0]][1].append(p)
                
    return d


def splitHigherLevel(dEns, random_seed=42):
    """
    Split higher-level clusters into training, validation, and evaluation sets.
    
    Args:
        dEns (dict): Dictionary mapping higher-level clusters to their proteins
        random_seed (int): Random seed for reproducibility
        
    Returns:
        tuple: (dtrainval, lval, ltrain, deval_strict) where:
               - dtrainval: dictionary of clusters with proteins in train/val sets but not in eval
               - lval: list of clusters selected for validation
               - ltrain: list of clusters selected for training
               - deval_strict: dictionary of clusters with proteins only in eval set
    """
    random.seed(random_seed)
    
    # Clusters that have at least one protein in any set
    dall = {k: v[0] for k, v in dEns.items() if (len(v[1]) + len(v[0])) > 0}
    
    # Clusters with proteins only in train/val sets (not in eval)
    dtrainval = {k: v[0] for k, v in dEns.items() if len(v[1]) == 0 and len(v[0]) > 0}
    
    # Clusters with proteins only in eval set (not in train/val)
    deval_strict = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0 and len(v[0]) == 0}
    
    # Clusters with at least one protein in eval set (may overlap with train/val)
    deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
    
    # Calculate number of validation samples (10% of total clusters)
    n_val = int((len(dtrainval) + len(deval_strict)) / 10)
    
    # Randomly sample clusters for validation
    ival = random.sample(range(len(dtrainval)), n_val)
    myKeys = list(dtrainval.keys())
    
    # Create lists of cluster IDs for validation and training
    lval = [myKeys[i] for i in ival]
    ltrain = [myKeys[i] for i in range(len(myKeys)) if i not in ival]
    
    # Print statistics about the split
    print("Number of higher-order collections with members in train, val or eval:", len(dall))
    print("Number of higher-order collections with members in train or val and nothing from eval:", len(dtrainval))
    print("Number of higher-order collections with members in eval:", len(deval_relax))
    print("Number of higher-order collections with members in eval and nothing from train-val:", len(deval_strict))
    print("Sample validation clusters:", lval[1:5])
    print("Number of higher-order collections in the new validation set:", len(lval))
    print("Sample training clusters:", ltrain[1:5])
    print("Number of higher-order collections in the new training set:", len(ltrain))
    
    return dtrainval, lval, ltrain, deval_strict


def reduceRedundancyInEval(dEns, dEval, random_seed=42):
    """
    Reduce redundancy in the evaluation set by selecting one protein per cluster.
    
    Args:
        dEns (dict): Dictionary mapping higher-level clusters to their proteins
        dEval (dict): Dictionary of evaluation set proteins and their queries
        random_seed (int): Random seed for reproducibility
        
    Returns:
        dict: Dictionary of selected evaluation proteins and their queries
    """
    random.seed(random_seed)
    
    # Get clusters with proteins in eval set
    deval_relax = {k: v[1] for k, v in dEns.items() if len(v[1]) > 0}
    
    # Create new evaluation dictionary with reduced redundancy
    deval = {}
    for k in deval_relax:
        # Randomly select one protein from each cluster
        myColl = random.sample(deval_relax[k], 1)[0]
        
        # Check if the selected protein has multiple queries (warning case)
        if len(dEval[myColl]) > 1:
            print("warning!!", myColl, dEval[myColl])
            
        # Add only the first query for the selected protein
        deval[myColl] = [dEval[myColl][0]]

    return deval


def sampleLowerLevel(dtrainval, myL, myD, random_seed=42):
    """
    Sample queries for lower-level collections based on a stratified approach.
    
    Args:
        dtrainval (dict): Dictionary of training and validation proteins
        myL (list): List of higher-level clusters to process
        myD (dict): Dictionary mapping proteins to their queries
        random_seed (int): Random seed for reproducibility
        
    Returns:
        dict: Dictionary of selected proteins and their queries
    """
    random.seed(random_seed)
    
    dres = {}
    # Sampling strategy: number of samples to take based on available proteins
    # For 1 protein, take 5 queries; for 2 proteins, take 3 and 2 queries; etc.
    n_samples = ([5], [3, 2], [2, 2, 1], [2, 1, 1, 1], [1, 1, 1, 1, 1])
    
    for higherColl in myL:
        # Get lower-level collections for this cluster
        lowerColl = dtrainval[higherColl]
        n = len(lowerColl)
        p = min(n, 5)  # Cap at 5 proteins per cluster
        
        # If we have 5 or more proteins, randomly sample 5
        if p == 5:
            selectedColl = random.sample(lowerColl, p)
        else:
            # Otherwise, use all available proteins
            selectedColl = lowerColl
            
        # Get the sampling distribution for this number of proteins
        nbs = n_samples[p-1]
        
        # Sample queries for each selected protein
        for i in range(p):
            j = nbs[i]  # Number of queries to sample for this protein
            dres[selectedColl[i]] = myD[selectedColl[i]][:j]
            
    return dres


def writeDico(dEns):
    """
    Write cluster matching information to a CSV file.
    
    Args:
        dEns (dict): Dictionary of clusters and their proteins
    """
    with open("match_eval.csv", "w") as fOUT:
        for k in dEns:
            n = len(dEns[k])
            if n > 0:
                fOUT.write(k + "," + str(n) + "," + "-".join(dEns[k]) + "\n")


def write_queries(d, fname):
    """
    Write protein-query pairs to a file.
    
    Args:
        d (dict): Dictionary mapping proteins to their queries
        fname (str): Output file name
    """
    with open(fname, "w") as fOUT:
        for k in d:
            for q in d[k]:
                fOUT.write(k + "_" + q + "\n")


def write_values(d, fname):
    """
    Write only query values to a file.
    
    Args:
        d (dict): Dictionary mapping proteins to their queries
        fname (str): Output file name
    """
    with open(fname, "w") as fOUT:
        for k in d:
            for q in d[k]:
                fOUT.write(q + "\n")


if __name__ == "__main__":
    # Load evaluation set data
    evalF = "eval_list.txt"
    dEval = getQueries(evalF)
    
    # Load training and validation sets data
    trvalF = "train_val_list.txt"
    dTrainVal = getQueries(trvalF)
    
    # Load full training set data
    trF = "full_train_list.txt"
    dTrain = getQueries(trF)
    
    # Match the collections using 30% sequence identity, 80% coverage threshold
    dEns = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", dEval, dTrainVal)
    
    # Split into training, validation and evaluation sets
    dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns)
    
    # Sample validation queries and write to file
    dval = sampleLowerLevel(dtrainval, lval, dTrainVal)
    write_queries(dval, "val_nr_list_12_05.txt")
    
    # Sample training queries and write to file
    dtrain = sampleLowerLevel(dtrainval, ltrain, dTrainVal)
    write_queries(dtrain, "train_nr_list_12_05.txt")
    
    # Reduce redundancy in evaluation set and write to file
    deval = reduceRedundancyInEval(dEns, dEval)
    write_queries(deval, "eval_nr_list_12_05.txt")
    
    # Print number of strict evaluation clusters
    print("Number of strict evaluation clusters:", len(deval_strict))
    
    # Generate strict evaluation set based on full training
    dEns2 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrain)
    dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns2)
    write_values(deval_strict, "eval_strict_list_12_05.txt")
    
    # Generate even stricter evaluation set based on train+val
    dEns3 = getMatchHigherLevel("rewrited_clusterDB_30_80.tsv", deval, dTrainVal)
    dtrainval, lval, ltrain, deval_strict = splitHigherLevel(dEns3)
    write_values(deval_strict, "eval_even_stricter_list_12_05.txt")