ligeti commited on
Commit
0572702
·
verified ·
1 Parent(s): c677647

Adding LCA tokenizer source code

Browse files
Files changed (5) hide show
  1. config_utils.py +769 -0
  2. general_utils.py +309 -0
  3. sequtils.py +980 -0
  4. tokenizer.py +363 -0
  5. tokenizer_config.json +6 -0
config_utils.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config utils
2
+ import yaml
3
+ import pathlib
4
+ from os.path import join
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import argparse
9
+ from multiprocessing import cpu_count
10
+ from transformers import TrainingArguments
11
+ from copy import deepcopy
12
+ import re
13
+ import sys
14
+
15
+ def add_hf_args_to_parser(parser):
16
+ # Create a temporary TrainingArguments to access default values and descriptions
17
+ hf_args = TrainingArguments(output_dir="/tmp") # Dummy output_dir
18
+ # Iterate over all public attributes
19
+ for attr in dir(hf_args):
20
+ if not attr.startswith("_"):
21
+ default = getattr(hf_args, attr)
22
+ # You can add more sophisticated handling based on attribute types here
23
+ if isinstance(default, (int, float, str, bool)):
24
+ help_str = f"Auto-generated help for {attr}"
25
+ parser.add_argument(f"--{attr}", type=type(default), default=default, help=help_str)
26
+
27
+ return parser
28
+
29
+ class BaseConfig:
30
+ """Base class for managing and validating configurations."""
31
+
32
+ numpy_dtype_mapping = {1: np.int8,
33
+ 2: np.int16,
34
+ 8: np.int64,
35
+ 4: np.int32}
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+
40
+ def cast_to_expected_type(self, parameter_class: str, parameter_name: str, value: any) -> any:
41
+ """
42
+ Cast the given value to the expected type.
43
+
44
+ :param parameter_class: The class/category of the parameter.
45
+ :type parameter_class: str
46
+ :param parameter_name: The name of the parameter.
47
+ :type parameter_name: str
48
+ :param value: The value to be casted.
49
+ :type value: any
50
+ :return: Value casted to the expected type.
51
+ :rtype: any
52
+ :raises ValueError: If casting fails.
53
+ """
54
+ expected_type = self.parameters[parameter_class][parameter_name]['type']
55
+
56
+ if expected_type in ["integer", "int"]:
57
+ try:
58
+ return int(value)
59
+ except ValueError:
60
+ raise ValueError(f"Failed to cast value '{value}' to integer for parameter '{parameter_name}' in class '{parameter_class}'.")
61
+ elif expected_type == "float":
62
+ try:
63
+ return float(value)
64
+ except ValueError:
65
+ raise ValueError(f"Failed to cast value '{value}' to float for parameter '{parameter_name}' in class '{parameter_class}'.")
66
+ elif expected_type in ["string", "str"]:
67
+ return str(value)
68
+ elif expected_type in ["boolean", "bool"]:
69
+ if isinstance(value, bool):
70
+ return value
71
+ elif str(value).lower() == "true":
72
+ return True
73
+ elif str(value).lower() == "false":
74
+ return False
75
+ else:
76
+ raise ValueError(f"Failed to cast value '{value}' to boolean for parameter '{parameter_name}' in class '{parameter_class}'.")
77
+ elif expected_type == "type":
78
+ # For this type, we will simply return the value without casting.
79
+ # It assumes the configuration provides valid Python types.
80
+ return value
81
+ elif expected_type == "list":
82
+ if isinstance(value, list):
83
+ return value
84
+ else:
85
+ raise ValueError(f"Failed to validate value '{value}' as a list for parameter '{parameter_name}' in class '{parameter_class}'.")
86
+ elif expected_type == "tuple":
87
+ if isinstance(value, tuple):
88
+ return value
89
+ else:
90
+ raise ValueError(f"Failed to validate value '{value}' as a tuple for parameter '{parameter_name}' in class '{parameter_class}'.")
91
+ elif expected_type == "set":
92
+ if isinstance(value, set):
93
+ return value
94
+ else:
95
+ raise ValueError(f"Failed to validate value '{value}' as a set for parameter '{parameter_name}' in class '{parameter_class}'.")
96
+ elif expected_type == "dict":
97
+ if isinstance(value, dict):
98
+ return value
99
+ else:
100
+ raise ValueError(f"Failed to validate value '{value}' as a dict for parameter '{parameter_name}' in class '{parameter_class}'.")
101
+ else:
102
+ raise ValueError(f"Unknown expected type '{expected_type}' for parameter '{parameter_name}' in class '{parameter_class}'.")
103
+
104
+
105
+
106
+ def get_parameter(self, parameter_class: str, parameter_name: str) -> any:
107
+ """
108
+ Retrieve the default value of a specified parameter.
109
+
110
+ :param parameter_class: The class/category of the parameter (e.g., 'segmentation').
111
+ :type parameter_class: str
112
+ :param parameter_name: The name of the parameter.
113
+ :type parameter_name: str
114
+ :return: Default value of the parameter, casted to the expected type.
115
+ :rtype: any
116
+ """
117
+ default_value = self.parameters[parameter_class][parameter_name]['default']
118
+ return self.cast_to_expected_type(parameter_class, parameter_name, default_value)
119
+
120
+
121
+
122
+ def validate_type(self, parameter_class: str, parameter_name: str, value: any) -> bool:
123
+ """
124
+ Validate the type of a given value against the expected type.
125
+
126
+ :param parameter_class: The class/category of the parameter.
127
+ :type parameter_class: str
128
+ :param parameter_name: The name of the parameter.
129
+ :type parameter_name: str
130
+ :param value: The value to be validated.
131
+ :type value: any
132
+ :return: True if the value is of the expected type, otherwise False.
133
+ :rtype: bool
134
+ """
135
+ expected_type = self.parameters[parameter_class][parameter_name]['type']
136
+
137
+ if expected_type == "integer" and not isinstance(value, int):
138
+ return False
139
+ elif expected_type == "float" and not isinstance(value, float):
140
+ return False
141
+ elif expected_type == "string" and not isinstance(value, str):
142
+ return False
143
+ else:
144
+ return True
145
+
146
+ def validate_value(self, parameter_class: str, parameter_name: str, value: any) -> bool:
147
+ """
148
+ Validate the value of a parameter against its constraints.
149
+
150
+ :param parameter_class: The class/category of the parameter.
151
+ :type parameter_class: str
152
+ :param parameter_name: The name of the parameter.
153
+ :type parameter_name: str
154
+ :param value: The value to be validated.
155
+ :type value: any
156
+ :return: True if the value meets the constraints, otherwise False.
157
+ :rtype: bool
158
+ """
159
+ constraints = self.parameters[parameter_class][parameter_name].get('constraints', {})
160
+
161
+ if 'options' in constraints and value not in constraints['options']:
162
+ return False
163
+ if 'min' in constraints and value < constraints['min']:
164
+ return False
165
+ if 'max' in constraints and value > constraints['max']:
166
+ return False
167
+ return True
168
+
169
+
170
+ def validate(self, parameter_class: str, parameter_name: str, value: any):
171
+ """
172
+ Validate both the type and value of a parameter.
173
+
174
+ :param parameter_class: The class/category of the parameter.
175
+ :type parameter_class: str
176
+ :param parameter_name: The name of the parameter.
177
+ :type parameter_name: str
178
+ :param value: The value to be validated.
179
+ :type value: any
180
+ :raises TypeError: If the value is not of the expected type.
181
+ :raises ValueError: If the value does not meet the parameter's constraints.
182
+ """
183
+ if not self.validate_type(parameter_class, parameter_name, value):
184
+ raise TypeError(f"Invalid type for {parameter_name} for parameter class '{parameter_class}'. Expected {self.parameters[parameter_class][parameter_name]['type']}.")
185
+
186
+ if not self.validate_value(parameter_class, parameter_name, value):
187
+ raise ValueError(f"Invalid value for {parameter_name} for parameter class '{parameter_class}'. Constraints: {self.parameters[parameter_class][parameter_name].get('constraints', {})}.")
188
+
189
+ def describe(self, parameter_class: str, parameter_name: str) -> str:
190
+ """
191
+ Retrieve the description of a parameter.
192
+
193
+ :param parameter_class: The class/category of the parameter.
194
+ :type parameter_class: str
195
+ :param parameter_name: The name of the parameter.
196
+ :type parameter_name: str
197
+ :return: Description of the parameter.
198
+ :rtype: str
199
+ """
200
+ return self.parameters[parameter_class][parameter_name]['description']
201
+
202
+ @staticmethod
203
+ def rename_non_unique_parameters(config: dict) -> tuple[dict, dict, dict]:
204
+ """
205
+ Rename parameters in the configuration to ensure uniqueness across different groups.
206
+
207
+ This method identifies parameters with the same name across different groups and renames them
208
+ by prefixing the group name. This is to prevent conflicts when parameters are used in a context
209
+ where the group name is not specified.
210
+
211
+ :param config: A dictionary where each key is a group name and each value is a dict
212
+ of parameters for that group.
213
+ :type config: dict
214
+
215
+ :return: A tuple containing:
216
+ - renamed_config: A dictionary with the same structure as the input, but with non-unique parameter
217
+ names renamed. The structure is {group_name: {param_name: param_info}}.
218
+ - cmd_argument2group_param: A dictionary mapping the new parameter names to their original group
219
+ and parameter name. The structure is {new_param_name: [group_name, original_param_name]}.
220
+ - group2param2cmdarg: A dictionary mapping each group to a dict that maps the original parameter
221
+ names to the new parameter names. The structure is {group_name: {original_param_name: new_param_name}}.
222
+ :rtype: tuple[dict, dict, dict]
223
+ """
224
+
225
+ # Identify non-unique parameter names
226
+ param_counts = {}
227
+ for group_name, parameters in config.items():
228
+ for param_name in parameters.keys():
229
+ param_counts[param_name] = param_counts.get(param_name, 0) + 1
230
+
231
+ non_unique_params = {param for param, count in param_counts.items() if count > 1}
232
+
233
+ cmd_argument2group_param = {}
234
+ group2param2cmdarg = {}
235
+ for group_name, parameters in config.items():
236
+ group2param2cmdarg[group_name]={}
237
+ for param_name in parameters.keys():
238
+ group2param2cmdarg[group_name][param_name] = param_name
239
+
240
+
241
+ # Rename only the non-unique parameters
242
+ renamed_config = {}
243
+ for group_name, parameters in config.items():
244
+ renamed_group = {}
245
+ for param_name, param_info in parameters.items():
246
+
247
+ new_param_name = f"{group_name}_{param_name}" if param_name in non_unique_params else param_name
248
+ cmd_argument2group_param[new_param_name] = [group_name, param_name]
249
+ group2param2cmdarg[group_name][param_name]=new_param_name
250
+
251
+ renamed_group[new_param_name] = param_info
252
+ renamed_config[group_name] = renamed_group
253
+ return renamed_config, cmd_argument2group_param, group2param2cmdarg
254
+
255
+ @staticmethod
256
+ def create_parser(config: dict) -> argparse.ArgumentParser:
257
+ """
258
+ Create and configure an argparse parser based on the given configuration.
259
+
260
+ This method sets up a command-line argument parser with arguments defined in the configuration.
261
+ Each top-level key in the configuration represents a group of related arguments.
262
+
263
+ :param config: A dictionary where each key is a group name and each value is a dict
264
+ of parameters for that group. Each parameter's information should include
265
+ its type, default value, and help description.
266
+ :type config: dict
267
+
268
+ :return: Configured argparse.ArgumentParser instance with arguments added as specified
269
+ in the configuration.
270
+ :rtype: argparse.ArgumentParser
271
+
272
+ :raises ValueError: If an unknown or unsupported type is specified for a parameter.
273
+ """
274
+ parser = argparse.ArgumentParser(description="Command-line parser for project settings")
275
+ # Mapping of type strings to Python types
276
+ type_mapping = {
277
+ 'integer': int,
278
+ 'int': int,
279
+ 'float': float,
280
+ 'string': str,
281
+ 'str': str,
282
+ 'bool': bool,
283
+ 'boolean': bool,
284
+ 'list': list
285
+ # Complex types like 'dict' and 'type' are intentionally excluded
286
+ }
287
+
288
+ # List of types to handle as strings
289
+ handle_as_string = ['dict', 'type', 'list']
290
+ excluded_parameters = ['vocabmap', 'np_tokentype', 'pretraining_dataset_data', 'optim']
291
+
292
+
293
+ for group_name, parameters in config.items():
294
+ group = parser.add_argument_group(group_name)
295
+ for param_name, param_info in parameters.items():
296
+ param_type_str = param_info['type']
297
+ description = param_info['description']
298
+ escaped_description = re.sub(r"([^%])%", r"\1%%", description)
299
+ if param_name in excluded_parameters:
300
+ continue
301
+ if param_type_str in handle_as_string:
302
+ # Handle these types as strings in argparse, conversion will be done later in the program
303
+ param_type = str
304
+ elif param_type_str not in type_mapping:
305
+ raise ValueError(f"Unknown or unsupported type '{param_type_str}' for parameter '{param_name}'")
306
+ else:
307
+ param_type = type_mapping[param_type_str]
308
+
309
+ #print(f'The current type is: {param_type}')
310
+ default_param = param_info['default']
311
+ description = param_info['description']
312
+ kwargs = {
313
+ 'type': param_type,
314
+ 'default': param_info['default'],
315
+ 'help': escaped_description
316
+ } # Add constraints if they exist
317
+ """
318
+ if 'constraints' in param_info:
319
+ constraints = param_info['constraints']
320
+ if 'min' in constraints:
321
+ kwargs['type'] = lambda x: eval(param_type_str)(x) if eval(param_type_str)(x) >= constraints['min'] else sys.exit(f"Value for {param_name} must be at least {constraints['min']}")
322
+ if 'max' in constraints:
323
+ kwargs['type'] = lambda x: eval(param_type_str)(x) if eval(param_type_str)(x) <= constraints['max'] else sys.exit(f"Value for {param_name} must be at most {constraints['max']}")
324
+ if 'options' in constraints:
325
+ kwargs['choices'] = constraints['options']
326
+ """
327
+ # Add argument to the group
328
+ group.add_argument(f'--{param_name}', **kwargs)
329
+ #parser = add_hf_args_to_parser(parser)
330
+
331
+ return parser
332
+
333
+
334
+
335
+ class SeqConfig(BaseConfig):
336
+ """Class to manage and validate sequence processing configurations."""
337
+
338
+ def __init__(self):
339
+ super().__init__()
340
+ self.default_seq_config_file = self._get_default_sequence_processing_config_file()
341
+ with open(self.default_seq_config_file, 'r') as file:
342
+ self.parameters = yaml.safe_load(file)
343
+
344
+ # Some postprocessing steps
345
+ self.parameters['tokenization']['shift']['constraints']['max'] = self.parameters['tokenization']['kmer']['default']-1
346
+ # Ha valaki update-li a k-mer paramter-t, akkor triggerelni kellene, hogy mi legyen.
347
+
348
+ self.get_and_set_segmentation_parameters()
349
+ self.get_and_set_tokenization_parameters()
350
+ self.get_and_set_computational_parameters()
351
+
352
+ def _get_default_sequence_processing_config_file(self) -> str:
353
+ """
354
+ Retrieve the default sequence processing configuration file.
355
+
356
+ :return: Path to the configuration file.
357
+ :rtype: str
358
+ """
359
+ current_path = pathlib.Path(__file__).parent
360
+ prokbert_seq_config_file = join(current_path, 'configs', 'sequence_processing.yaml')
361
+ self.current_path = current_path
362
+
363
+ try:
364
+ # Attempt to read the environment variable
365
+ prokbert_seq_config_file = os.environ['SEQ_CONFIG_FILE']
366
+ except KeyError:
367
+ # Handle the case when the environment variable is not found
368
+ pass
369
+ # print("SEQ_CONFIG_FILE environment variable has not been set. Using default value: {0}".format(prokbert_seq_config_file))
370
+ return prokbert_seq_config_file
371
+
372
+
373
+ def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict:
374
+ """
375
+ Retrieve and validate the provided parameters for segmentation.
376
+
377
+ :param parameters: A dictionary of parameters to be validated.
378
+ :type parameters: dict
379
+ :return: A dictionary of validated segmentation parameters.
380
+ :rtype: dict
381
+ :raises ValueError: If an invalid segmentation parameter is provided.
382
+ """
383
+ segmentation_params = {k: self.get_parameter('segmentation', k) for k in self.parameters['segmentation']}
384
+
385
+ for param, param_value in parameters.items():
386
+ if param not in segmentation_params:
387
+ raise ValueError(f"The provided {param} is an INVALID segmentation parameter! The valid parameters are: {list(segmentation_params.keys())}")
388
+ self.validate('segmentation', param, param_value)
389
+ segmentation_params[param] = param_value
390
+ self.segmentation_params = segmentation_params
391
+
392
+
393
+ return segmentation_params
394
+
395
+
396
+ def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict:
397
+ # Updating the other parameters if necesseary, i.e. if k-mer has-been changed, then the shift is updated and we run a parameter check at the end
398
+
399
+ tokenization_params = {k: self.get_parameter('tokenization', k) for k in self.parameters['tokenization']}
400
+ for param, param_value in parameters.items():
401
+ if param not in tokenization_params:
402
+ raise ValueError(f"The provided {param} is an INVALID tokenization parameter! The valid parameters are: {list(tokenization_params.keys())}")
403
+ self.validate('tokenization', param, param_value)
404
+ tokenization_params[param] = param_value
405
+
406
+ # Loading and check the vocab file. It is assumed that its ordered dictionary
407
+ vocabfile=tokenization_params['vocabfile']
408
+ act_kmer = tokenization_params['kmer']
409
+ if vocabfile=='auto':
410
+ vocabfile_path = join(self.current_path, 'data/prokbert_vocabs/', f'prokbert-base-dna{act_kmer}', 'vocab.txt')
411
+ tokenization_params['vocabfile'] = vocabfile_path
412
+ else:
413
+ vocabfile_path = vocabfile
414
+ with open(vocabfile_path) as vocabfile_in:
415
+ vocabmap = {line.strip(): i for i, line in enumerate(vocabfile_in)}
416
+ tokenization_params['vocabmap'] = vocabmap
417
+
418
+ # Loading the vocab
419
+ self.tokenization_params = tokenization_params
420
+ return tokenization_params
421
+
422
+ def get_and_set_computational_parameters(self, parameters: dict = {}) -> dict:
423
+ """ Reading and validating the computational paramters
424
+ """
425
+
426
+ computational_params = {k: self.get_parameter('computation', k) for k in self.parameters['computation']}
427
+ core_count = cpu_count()
428
+
429
+ if computational_params['cpu_cores_for_segmentation'] == -1:
430
+ computational_params['cpu_cores_for_segmentation'] = core_count
431
+
432
+ if computational_params['cpu_cores_for_tokenization'] == -1:
433
+ computational_params['cpu_cores_for_tokenization'] = core_count
434
+
435
+
436
+
437
+ for param, param_value in parameters.items():
438
+ if param not in computational_params:
439
+ raise ValueError(f"The provided {param} is an INVALID computation parameter! The valid parameters are: {list(computational_params.keys())}")
440
+ self.validate('computation', param, param_value)
441
+ computational_params[param] = param_value
442
+
443
+ np_tokentype= SeqConfig.numpy_dtype_mapping[computational_params['numpy_token_integer_prec_byte']]
444
+ computational_params['np_tokentype'] = np_tokentype
445
+ self.computational_params = computational_params
446
+ return computational_params
447
+
448
+
449
+ def get_maximum_segment_length_from_token_count_from_params(self):
450
+ """Calculating the maximum length of the segment from the token count """
451
+ max_token_counts = self.tokenization_params['token_limit']
452
+ shift = self.tokenization_params['shift']
453
+ kmer = self.tokenization_params['kmer']
454
+ return self.get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer)
455
+
456
+ def get_maximum_token_count_from_max_length_from_params(self):
457
+ """Calculating the maximum length of the segment from the token count """
458
+
459
+
460
+ max_segment_length = self.tokenization_params['max_segment_length']
461
+ shift = self.tokenization_params['shift']
462
+ kmer = self.tokenization_params['kmer']
463
+ max_token_count = self.get_maximum_token_count_from_max_length(max_segment_length, shift, kmer)
464
+
465
+ return max_token_count
466
+
467
+ def get_cmd_arg_parser(self) -> tuple[argparse.ArgumentParser, dict, dict]:
468
+ """
469
+ Create and return a command-line argument parser for ProkBERT configurations, along with mappings
470
+ between command-line arguments and configuration parameters.
471
+
472
+ This method combines sequence configuration parameters with training configuration parameters
473
+ and sets up a command-line argument parser using these combined settings. It ensures that parameter
474
+ names are unique across different groups by renaming any non-unique parameters.
475
+
476
+ :return: A tuple containing:
477
+ - Configured argparse.ArgumentParser instance for handling ProkBERT configurations.
478
+ - A dictionary mapping new command-line arguments to their original group and parameter name.
479
+ - A dictionary mapping each group to a dict that maps the original parameter names
480
+ to the new command-line argument names.
481
+ :rtype: tuple[argparse.ArgumentParser, dict, dict]
482
+
483
+ Note: The method assumes that the configuration parameters for training and sequence configuration
484
+ are available within the class.
485
+ """
486
+ combined_params = deepcopy(self.parameters)
487
+ combined_params['Sequence'] = {}
488
+ combined_params['Sequence']['fasta_file_dir'] = {'default': 'None',
489
+ 'description' : 'Directory where the input fasta file are located for the pretraining',
490
+ 'type': 'string'}
491
+ combined_params['Sequence']['out'] = {'default': 'pretrain.h5',
492
+ 'description' : 'Output path',
493
+ 'type': 'string'}
494
+
495
+
496
+ combined_params, cmd_argument2group_param, group2param2cmdarg = BaseConfig.rename_non_unique_parameters(combined_params)
497
+
498
+ parser = BaseConfig.create_parser(combined_params)
499
+ return parser,cmd_argument2group_param, group2param2cmdarg
500
+
501
+
502
+ @staticmethod
503
+ def get_maximum_segment_length_from_token_count(max_token_counts, shift, kmer):
504
+ """Calcuates how long sequence can be covered
505
+ """
506
+
507
+ max_segment_length = (max_token_counts-3)*shift + kmer
508
+ return max_segment_length
509
+
510
+ @staticmethod
511
+ def get_maximum_token_count_from_max_length(max_segment_length, shift, kmer):
512
+ """Calcuates how long sequence can be covered
513
+ """
514
+ max_token_count = int(np.ceil((max_segment_length - kmer)/shift+3))
515
+ return max_token_count
516
+
517
+ class ProkBERTConfig(BaseConfig):
518
+ """Class to manage and validate pretraining configurations."""
519
+
520
+ torch_dtype_mapping = {1: torch.uint8,
521
+ 2: torch.int16,
522
+ 8: torch.int64,
523
+ 4: torch.int32}
524
+
525
+ def __init__(self):
526
+ super().__init__()
527
+
528
+ self.default_pretrain_config_file = self._get_default_pretrain_config_file()
529
+ with open(self.default_pretrain_config_file, 'r') as file:
530
+ self.parameters = yaml.safe_load(file)
531
+
532
+ # Load and validate each parameter set
533
+ self.data_collator_params = self.get_set_parameters('data_collator')
534
+ self.model_params = self.get_set_parameters('model')
535
+ self.dataset_params = self.get_set_parameters('dataset')
536
+ self.pretraining_params = self.get_set_parameters('pretraining')
537
+ self.finetuning_params = self.get_set_parameters('finetuning')
538
+ # Getting the sequtils params as well
539
+
540
+ self.def_seq_config = SeqConfig()
541
+ self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(self.parameters['segmentation'])
542
+ self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(self.parameters['tokenization'])
543
+ self.computation_params = self.def_seq_config.get_and_set_computational_parameters(self.parameters['computation'])
544
+
545
+ self.default_torchtype = ProkBERTConfig.torch_dtype_mapping[self.computation_params['numpy_token_integer_prec_byte']]
546
+
547
+ hf_training_args = TrainingArguments("working_dir")
548
+ self.hf_training_args_dict = hf_training_args.to_dict()
549
+
550
+
551
+ def _get_default_pretrain_config_file(self) -> str:
552
+ """
553
+ Retrieve the default pretraining configuration file.
554
+
555
+ :return: Path to the configuration file.
556
+ :rtype: str
557
+ """
558
+ current_path = pathlib.Path(__file__).parent
559
+ pretrain_config_file = join(current_path, 'configs', 'pretraining.yaml')
560
+
561
+ try:
562
+ # Attempt to read the environment variable
563
+ pretrain_config_file = os.environ['PRETRAIN_CONFIG_FILE']
564
+ except KeyError:
565
+ # Handle the case when the environment variable is not found
566
+ pass
567
+ # print(f"PRETRAIN_CONFIG_FILE environment variable has not been set. Using default value: {pretrain_config_file}")
568
+ return pretrain_config_file
569
+
570
+ def get_set_parameters(self, parameter_class: str, parameters: dict = {}) -> dict:
571
+ """
572
+ Retrieve and validate the provided parameters for a given parameter class.
573
+
574
+ :param parameter_class: The class/category of the parameter (e.g., 'data_collator').
575
+ :type parameter_class: str
576
+ :param parameters: A dictionary of parameters to be validated.
577
+ :type parameters: dict
578
+ :return: A dictionary of validated parameters.
579
+ :rtype: dict
580
+ :raises ValueError: If an invalid parameter is provided.
581
+ """
582
+ class_params = {k: self.get_parameter(parameter_class, k) for k in self.parameters[parameter_class]}
583
+
584
+
585
+ # First validatiading the class parameters as well
586
+ for param, param_value in class_params.items():
587
+
588
+ self.validate(parameter_class, param, param_value)
589
+
590
+
591
+ for param, param_value in parameters.items():
592
+ if param not in class_params and (parameter_class!='pretraining'):
593
+ raise ValueError(f"The provided {param} is an INVALID {parameter_class} parameter! The valid parameters are: {list(class_params.keys())}")
594
+ else:
595
+ if parameter_class == 'pretraining' or parameter_class == 'finetuning' :
596
+ if param in self.hf_training_args_dict or param in class_params:
597
+ if param in class_params:
598
+ self.validate(parameter_class, param, param_value)
599
+ class_params[param] = param_value
600
+ else:
601
+ raise ValueError(f"The provided {param} is an INVALID {parameter_class} parameter! In addition is not a valid training argument.")
602
+ else:
603
+ self.validate(parameter_class, param, param_value)
604
+ class_params[param] = param_value
605
+
606
+ return class_params
607
+
608
+ def get_and_set_model_parameters(self, parameters: dict = {}) -> dict:
609
+ """ Setting the model parameters """
610
+
611
+ # Here we include the additional training arguments available for the trainer
612
+
613
+ self.model_params = self.get_set_parameters('model', parameters)
614
+
615
+ return self.model_params
616
+
617
+ def get_and_set_dataset_parameters(self, parameters: dict = {}) -> dict:
618
+ """ Setting the dataset parameters """
619
+
620
+ self.dataset_params = self.get_set_parameters('dataset', parameters)
621
+
622
+ return self.dataset_params
623
+
624
+ def get_and_set_pretraining_parameters(self, parameters: dict = {}) -> dict:
625
+ """ Setting the model parameters """
626
+ self.pretraining_params = self.get_set_parameters('pretraining', parameters)
627
+
628
+ return self.pretraining_params
629
+
630
+
631
+ def get_and_set_datacollator_parameters(self, parameters: dict = {}) -> dict:
632
+ """ Setting the model parameters """
633
+ self.data_collator_params = self.get_set_parameters('data_collator', parameters)
634
+ return self.data_collator_params
635
+
636
+ def get_and_set_segmentation_parameters(self, parameters: dict = {}) -> dict:
637
+ self.segmentation_params = self.def_seq_config.get_and_set_segmentation_parameters(parameters)
638
+
639
+ return self.segmentation_params
640
+ def get_and_set_tokenization_parameters(self, parameters: dict = {}) -> dict:
641
+ self.tokenization_params = self.def_seq_config.get_and_set_tokenization_parameters(parameters)
642
+
643
+ return self.tokenization_params
644
+ def get_and_set_computation_params(self, parameters: dict = {}) -> dict:
645
+ self.computation_params = self.def_seq_config.get_and_set_computational_parameters(parameters)
646
+ return self.computation_params
647
+
648
+ def get_and_set_finetuning_parameters(self, parameters: dict = {}) -> dict:
649
+ """ Setting the finetuning parameters """
650
+
651
+ # Here we include the additional training arguments available for the trainer
652
+
653
+ self.finetuning_params = self.get_set_parameters('finetuning', parameters)
654
+
655
+ return self.finetuning_params
656
+
657
+
658
+ def get_inference_parameters(self):
659
+ # Instantiate TrainingArguments to access default values
660
+ hf_defaults = TrainingArguments(output_dir="/tmp") # Dummy output_dir for initialization
661
+
662
+ return {
663
+ 'inference': {
664
+ 'fastain': {
665
+ 'default': None,
666
+ 'type': 'str',
667
+ 'description': 'Path to the input data for inference.'
668
+ },
669
+ 'out': {
670
+ 'default': None,
671
+ 'type': 'str',
672
+ 'description': 'Output path for the inference results.'
673
+ },
674
+ 'per_device_eval_batch_size': {
675
+ 'default': hf_defaults.per_device_eval_batch_size,
676
+ 'type': 'int',
677
+ 'description': 'Batch size per device during evaluation.'
678
+ },
679
+ 'ddp_backend': {
680
+ 'default': hf_defaults.ddp_backend,
681
+ 'type': 'str',
682
+ 'description': 'The backend to use for distributed training.'
683
+ },
684
+ 'dataloader_drop_last': {
685
+ 'default': hf_defaults.dataloader_drop_last,
686
+ 'type': 'bool',
687
+ 'description': 'Drop the last incomplete batch if it is not divisible by the batch size.'
688
+ },
689
+ 'torch_compile': {
690
+ 'default': getattr(hf_defaults, 'torch_compile', False), # Fallback for compatibility
691
+ 'type': 'bool',
692
+ 'description': 'Whether to use TorchScript’s JIT compilation to accelerate training.'
693
+ },
694
+ 'torch_compile_mode': {
695
+ 'default': getattr(hf_defaults, 'torch_compile_mode', 'eager'), # Fallback for compatibility
696
+ 'type': 'str',
697
+ 'description': 'The JIT mode to use for compiling PyTorch operations.'
698
+ }
699
+ }
700
+ }
701
+
702
+
703
+ def get_cmd_arg_parser(self, keyset=[]) -> tuple[argparse.ArgumentParser, dict, dict]:
704
+ """
705
+ Create and return a command-line argument parser for ProkBERT configurations, along with mappings
706
+ between command-line arguments and configuration parameters.
707
+
708
+ This method combines sequence configuration parameters with training configuration parameters
709
+ and sets up a command-line argument parser using these combined settings. It ensures that parameter
710
+ names are unique across different groups by renaming any non-unique parameters.
711
+
712
+ :return: A tuple containing:
713
+ - Configured argparse.ArgumentParser instance for handling ProkBERT configurations.
714
+ - A dictionary mapping new command-line arguments to their original group and parameter name.
715
+ - A dictionary mapping each group to a dict that maps the original parameter names
716
+ to the new command-line argument names.
717
+ :rtype: tuple[argparse.ArgumentParser, dict, dict]
718
+
719
+ Note: The method assumes that the configuration parameters for training and sequence configuration
720
+ are available within the class.
721
+ """
722
+ if len(keyset) ==0:
723
+ trainin_conf_keysets = ['data_collator', 'model', 'dataset', 'pretraining', 'finetuning']
724
+ else:
725
+ trainin_conf_keysets = keyset
726
+
727
+ inference_params = self.get_inference_parameters()
728
+ seq_config = deepcopy(self.def_seq_config.parameters)
729
+ default_other_config = deepcopy(self.parameters)
730
+ combined_params = {}
731
+ for k,v in seq_config.items():
732
+ combined_params[k] = v
733
+ for k in trainin_conf_keysets:
734
+ combined_params[k] = default_other_config[k]
735
+ combined_params.update(inference_params)
736
+ combined_params, cmd_argument2group_param, group2param2cmdarg = BaseConfig.rename_non_unique_parameters(combined_params)
737
+ parser = BaseConfig.create_parser(combined_params)
738
+
739
+ return parser,cmd_argument2group_param, group2param2cmdarg
740
+
741
+
742
+ def get_user_provided_args(args, parser):
743
+ """
744
+ Extract arguments provided by the user from the parsed arguments.
745
+
746
+ Args:
747
+ args (argparse.Namespace): Parsed command-line arguments.
748
+ parser (argparse.ArgumentParser): The argument parser instance.
749
+
750
+ Returns:
751
+ dict: A dictionary of user-provided arguments and their values.
752
+ """
753
+
754
+ user_provided_args = {}
755
+ for action in parser._actions:
756
+ arg_name = action.dest
757
+ default_value = action.default
758
+ user_value = getattr(args, arg_name, None)
759
+ if user_value != default_value:
760
+ user_provided_args[arg_name] = user_value
761
+
762
+ return user_provided_args
763
+
764
+
765
+
766
+
767
+
768
+
769
+
general_utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ import pandas as pd
4
+ import os
5
+ import numpy as np
6
+ import subprocess
7
+ import shutil
8
+ """ Library for general utils, such as dataframe properties checking,
9
+ creating directories, checking files, etc.
10
+ """
11
+
12
+
13
+ def check_expected_columns(df: pd.DataFrame, expected_columns: list) -> bool:
14
+ """Checks if a DataFrame contains the expected columns.
15
+
16
+ :param df: The input DataFrame to be checked.
17
+ :type df: pd.DataFrame
18
+ :param expected_columns: A list of columns that are expected to be present in the DataFrame.
19
+ :type expected_columns: list
20
+ :param df: pd.DataFrame:
21
+ :param expected_columns: list:
22
+ :returns: True if all expected columns are present in the DataFrame, False otherwise.
23
+ :rtype: bool
24
+ :raises ValueError: If any of the expected columns are not present in the DataFrame.
25
+
26
+ Examples
27
+ --------
28
+ >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]})
29
+ >>> check_expected_columns(df, ['A', 'B'])
30
+ True
31
+
32
+ >>> check_expected_columns(df, ['A', 'C'])
33
+ ValueError: The following columns are missing: ['C']
34
+ """
35
+
36
+ missing_columns = [col for col in expected_columns if col not in df.columns]
37
+
38
+ if missing_columns:
39
+ raise ValueError(f"The following columns are missing: {missing_columns}")
40
+
41
+ return True
42
+
43
+
44
+ def is_valid_primary_key(df: pd.DataFrame, column_name: str) -> bool:
45
+ """Checks if a specified column in a DataFrame can serve as a valid primary key.
46
+
47
+ :param df: The input DataFrame to be checked.
48
+ :type df: pd.DataFrame
49
+ :param column_name: The name of the column to check.
50
+ :type column_name: str
51
+ :returns: True if the column can serve as a valid primary key, False otherwise.
52
+ :rtype: bool
53
+ :raises ValueError: If the specified column does not exist in the DataFrame.
54
+
55
+ Examples
56
+ --------
57
+ >>> df = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
58
+ >>> is_valid_primary_key(df, 'A')
59
+ True
60
+
61
+ >>> df = pd.DataFrame({'A': [1, 2, 2], 'B': [4, 5, 6]})
62
+ >>> is_valid_primary_key(df, 'A')
63
+ False
64
+ """
65
+
66
+ if column_name not in df.columns:
67
+ raise ValueError(f"Column '{column_name}' does not exist in the DataFrame.")
68
+
69
+ # Check for NaN values
70
+ if df[column_name].isnull().any():
71
+ return False
72
+
73
+ # Check for unique values
74
+ if not df[column_name].is_unique:
75
+ return False
76
+
77
+ return True
78
+
79
+ def get_non_empty_files(start_path: str, extensions: tuple = ('.fasta', '.fna')) -> str:
80
+ """Generator that yields non-empty files from a specified directory and its subdirectories based on the given extensions.
81
+
82
+ :param start_path: The path to the directory from which to start the search.
83
+ :type start_path: str
84
+ :param extensions: A tuple of file extensions to look for (default is ('.fasta', '.fna')).
85
+ The function also automatically checks for compressed versions with '.gz'.
86
+ :type extensions: tuple
87
+ :returns: Yields filenames that match the specified extensions and are non-empty.
88
+ :rtype: str
89
+
90
+ """
91
+
92
+ for dirpath, _, filenames in os.walk(start_path):
93
+ for filename in filenames:
94
+ filepath = os.path.join(dirpath, filename)
95
+ if any(filename.endswith(ext) or filename.endswith(ext + '.gz') for ext in extensions) and os.path.getsize(filepath) > 0:
96
+ yield filename
97
+
98
+
99
+
100
+ def truncate_zero_columns(arr: np.ndarray) -> np.ndarray:
101
+ """Truncate all trailing columns composed entirely of zeros in a given 2D numpy array.
102
+
103
+ :param arr: Input 2D numpy array.
104
+ :type arr: np.ndarray
105
+ :returns: A new array with trailing zero columns removed.
106
+ :rtype: np.ndarray
107
+
108
+ """
109
+
110
+ # Iterate over columns from the end
111
+ for idx in range(arr.shape[1]-1, -1, -1):
112
+ if np.any(arr[:, idx]):
113
+ return arr[:, :(idx+1)]
114
+ return np.empty((arr.shape[0], 0))
115
+
116
+
117
+ import os
118
+
119
+ def create_directory_for_filepath(filepath: str) -> None:
120
+ """Given a file path, creates the underlying directory structure if it doesn't already exist.
121
+
122
+ :param filepath: The path to the file for which the directory structure should be created.
123
+ :type filepath: str
124
+ :raises ValueError: If the provided path is empty or None.
125
+ :raises OSError: If there's an error creating the directory structure.
126
+
127
+ """
128
+
129
+ if not filepath:
130
+ raise ValueError("The provided filepath is empty or None.")
131
+
132
+ directory = os.path.dirname(filepath)
133
+
134
+ if directory and not os.path.exists(directory):
135
+ try:
136
+ os.makedirs(directory)
137
+ print(f"Directory structure {directory} created successfully.")
138
+ except OSError as e:
139
+ raise OSError(f"Error creating directory structure {directory}. Error: {e}")
140
+
141
+ # Example usage:
142
+ # create_directory_for_filepath("/path/to/directory/that/might/not/exist/filename.txt")
143
+
144
+ def check_file_exists(file_path: str) -> bool:
145
+ """Checks if the provided file path exists.
146
+
147
+ :param file_path: Path to the file.
148
+ :type file_path: str
149
+ :returns: True if the file exists, raises ValueError otherwise.
150
+ :rtype: bool
151
+
152
+ """
153
+ if os.path.exists(file_path):
154
+ return True
155
+ else:
156
+ raise ValueError(f"The provided file path '{file_path}' does not exist.")
157
+
158
+ def count_gpus(method="clinfo"):
159
+ """
160
+ Count the number of available GPUs using the specified method.
161
+
162
+ This function counts the number of NVIDIA and AMD GPUs using the chosen method. By default, it uses the 'clinfo'
163
+ method for AMD GPUs.
164
+
165
+ :param method: The method to use for GPU counting. Choose between 'clinfo' (default) and 'rocm'.
166
+ :type method: str, optional
167
+
168
+ :return: The total number of GPUs detected.
169
+ :rtype: int
170
+
171
+ :raises ValueError: If an unknown method is provided.
172
+
173
+ :raises Exception: If an error occurs while querying AMD GPUs using the specified method.
174
+
175
+ .. note::
176
+ - The 'clinfo' method queries AMD GPUs by running the 'clinfo' command.
177
+ - The 'rocm' method queries AMD GPUs by running 'rocm-smi --list' command.
178
+
179
+ """
180
+ import torch
181
+ import subprocess
182
+
183
+ # Count NVIDIA GPUs
184
+ nvidia_gpu_count = torch.cuda.device_count()
185
+
186
+ # Count AMD GPUs
187
+ amd_gpu_count = 0
188
+ try:
189
+ if method == "clinfo":
190
+ clinfo_output = subprocess.check_output('clinfo').decode('utf-8')
191
+ amd_gpu_count = clinfo_output.lower().count('device type: gpu')
192
+ elif method == "rocm":
193
+ rocm_output = subprocess.check_output('rocm-smi --list', shell=True).decode('utf-8')
194
+ amd_gpu_count = len(rocm_output.strip().split('\n'))
195
+ else:
196
+ raise ValueError("Unknown method provided. Choose between 'clinfo' and 'rocm'.")
197
+ except Exception as e:
198
+ print(f"Error querying AMD GPUs using method '{method}': {e}")
199
+
200
+ total_gpus = nvidia_gpu_count + amd_gpu_count
201
+
202
+ return total_gpus
203
+
204
+
205
+
206
+ def create_hard_links(source_directory: str, target_directory: str, blacklist: list = []) -> None:
207
+ """Creates hard links for all files from the source directory to the target directory.
208
+
209
+ :param source_directory: The directory containing the original files.
210
+ :type source_directory: str
211
+ :param target_directory: The directory where hard links will be created.
212
+ :type target_directory: str
213
+ :param blacklist: List of filenames to exclude from creating hard links.
214
+ :type blacklist: list
215
+ :returns: None
216
+
217
+ """
218
+
219
+ # Ensure the provided directories exist
220
+ if not os.path.exists(source_directory):
221
+ raise ValueError(f"The source directory '{source_directory}' does not exist.")
222
+ if not os.path.exists(target_directory):
223
+ os.makedirs(target_directory)
224
+
225
+ # Iterate through the files in the source directory
226
+ for filename in os.listdir(source_directory):
227
+ source_file_path = os.path.join(source_directory, filename)
228
+ target_file_path = os.path.join(target_directory, filename)
229
+
230
+ # Check for files to skip
231
+ if (filename.startswith('.') or
232
+ filename.startswith('_') or
233
+ os.path.isdir(source_file_path) or
234
+ filename in blacklist):
235
+ continue
236
+
237
+ # Create a hard link
238
+ os.link(source_file_path, target_file_path)
239
+
240
+ return f"Hard links created in {target_directory} from {source_directory}."
241
+
242
+ # Example usage
243
+ # create_hard_links("/path/to/source_directory", "/path/to/target_directory", blacklist=["file_to_skip.txt"])
244
+
245
+ def create_selected_hard_links(source_directory: str, target_directory: str, filenames: list) -> None:
246
+ """Creates hard links for the specified files from the source directory to the target directory.
247
+
248
+ :param source_directory: The directory containing the original files.
249
+ :type source_directory: str
250
+ :param target_directory: The directory where hard links will be created.
251
+ :type target_directory: str
252
+ :param filenames: List of filenames for which hard links should be created.
253
+ :type filenames: list
254
+ :returns: None
255
+
256
+ """
257
+
258
+ # Ensure the provided directories exist
259
+ if not os.path.exists(source_directory):
260
+ raise ValueError(f"The source directory '{source_directory}' does not exist.")
261
+ if not os.path.exists(target_directory):
262
+ os.makedirs(target_directory)
263
+
264
+ # Iterate through the specified filenames
265
+ for filename in filenames:
266
+ source_file_path = os.path.join(source_directory, filename)
267
+ target_file_path = os.path.join(target_directory, filename)
268
+
269
+ # Ensure the file exists in the source directory
270
+ if not os.path.isfile(source_file_path):
271
+ print(f"Warning: {filename} does not exist in the source directory. Skipping.")
272
+ continue
273
+
274
+ # Create a hard link
275
+ try:
276
+ os.link(source_file_path, target_file_path)
277
+ except FileExistsError:
278
+ print(f'The target hard link {target_file_path} exist. Skipping...')
279
+
280
+ return f"Hard links for specified files created in {target_directory} from {source_directory}."
281
+
282
+ def remove_hidden_files(directory: str) -> None:
283
+ """Removes all files recursively in a folder that start with '.' or '_'.
284
+
285
+ :param directory: The directory from which hidden files should be removed.
286
+ :type directory: str
287
+ :returns: None
288
+
289
+ """
290
+
291
+ # Ensure the directory exists
292
+ if not os.path.exists(directory):
293
+ raise ValueError(f"The directory '{directory}' does not exist.")
294
+
295
+ # Use os.walk to iterate through all subdirectories and files
296
+ for dirpath, dirnames, filenames in os.walk(directory, topdown=False):
297
+
298
+ # Filter out directories starting with '.' or '_'
299
+ dirnames[:] = [d for d in dirnames if not d.startswith('.') and not d.startswith('_')]
300
+
301
+ # Remove files starting with '.' or '_'
302
+ for filename in filenames:
303
+ if filename.startswith('.') or filename.startswith('_'):
304
+ file_path = os.path.join(dirpath, filename)
305
+ os.remove(file_path)
306
+ print(f"Removed: {file_path}")
307
+
308
+ print(f"All hidden files removed from {directory}.")
309
+
sequtils.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+
4
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
5
+ # coding=utf-8
6
+
7
+ """ Library for sequence processing """
8
+
9
+
10
+ import os
11
+ import sys
12
+ import pandas as pd
13
+ from multiprocessing import Pool
14
+ import multiprocessing
15
+ from os.path import join, isfile, splitext
16
+ from os import listdir
17
+ import random
18
+ from Bio import SeqIO
19
+ import numpy as np
20
+ import math
21
+ import gzip
22
+ from mimetypes import guess_type
23
+ from functools import partial
24
+ import operator
25
+ import pathlib
26
+ #from typing import Dict, List, Type, Tuple
27
+ from itertools import product
28
+ from typing import List, Union, Dict, Any, Optional, Tuple, Type, Set
29
+ from .general_utils import *
30
+ from Bio.Seq import Seq
31
+ from Bio.SeqRecord import SeqRecord
32
+ from scipy.ndimage import convolve1d
33
+ import h5py
34
+
35
+ def load_contigs(
36
+ fasta_files_list: Union[List[str], str],
37
+ adding_reverse_complement: bool = True,
38
+ IsAddHeader: bool = False,
39
+ AsDataFrame: bool = False,
40
+ to_uppercase: bool = False,
41
+ is_add_sequence_id: bool = False
42
+ ) -> Union[List[Union[str, List[str]]], pd.DataFrame]:
43
+ """
44
+ Loads contigs from a list of FASTA files.
45
+
46
+ :param fasta_files_list: List of paths to FASTA files or a single file path. Compressed (gz) FASTA files are accepted.
47
+ :type fasta_files_list: Union[List[str], str]
48
+ :param adding_reverse_complement: If True, adds the reverse complement of each sequence. Defaults to True.
49
+ :type adding_reverse_complement: bool
50
+ :param IsAddHeader: If True, includes the FASTA ID and description in the output. Defaults to False.
51
+ :type IsAddHeader: bool
52
+ :param AsDataFrame: If True, returns the sequences as a pandas DataFrame. Defaults to False.
53
+ :type AsDataFrame: bool
54
+ :param to_uppercase: If True, converts sequences to uppercase. Defaults to False.
55
+ :type to_uppercase: bool
56
+ :param is_add_sequence_id: If True, adds a unique integer sequence ID to each sequence. Defaults to False.
57
+ :type is_add_sequence_id: bool
58
+ :return: The loaded sequences. Each sequence is represented as a string if IsAddHeader is False, or as a list
59
+ [sequence_id, fasta_id, description, source_file, sequence, orientation] if IsAddHeader is True and is_add_sequence_id is True.
60
+ If AsDataFrame is True, the sequences are returned as a DataFrame.
61
+ :rtype: Union[List[Union[str, List[str]]], pd.DataFrame]
62
+
63
+ Example:
64
+ >>> fasta_files = ['path/to/file1.fasta', 'path/to/file2.fasta.gz']
65
+ >>> load_contigs(fasta_files, adding_reverse_complement=False, IsAddHeader=True, AsDataFrame=True, to_uppercase=True, is_add_sequence_id=True)
66
+ # Returns a DataFrame with the sequences from the specified FASTA files, all in uppercase, with unique sequence IDs.
67
+ """
68
+
69
+ logging.info('Loading sequence data into memory!')
70
+ if isinstance(fasta_files_list, str):
71
+ logging.info('Since the fasta_files_list is a string, not a list, we convert it to a list.')
72
+ fasta_files_list = [fasta_files_list]
73
+
74
+ sequences = []
75
+ sequence_id = 0
76
+ df_cols = ['sequence_id', 'fasta_id', 'description', 'source_file', 'sequence', 'orientation'] if (IsAddHeader and is_add_sequence_id) else ['fasta_id', 'description', 'source_file', 'sequence', 'orientation'] if IsAddHeader else ['sequence']
77
+ for act_assembly in fasta_files_list:
78
+ # Determine the file encoding based on the file extension
79
+ encoding = guess_type(act_assembly)[1]
80
+ _open = partial(gzip.open, mode='rt') if encoding == 'gzip' else open
81
+ with _open(act_assembly) as f_assembly:
82
+ # Parse the fasta file
83
+ contigs = list(SeqIO.parse(f_assembly, "fasta"))
84
+ for contig in contigs:
85
+ act_seq = str(contig.seq)[:] if not to_uppercase else str(contig.seq).upper()[:]
86
+ act_header = str(contig.id)
87
+ act_description = str(contig.description)
88
+ if adding_reverse_complement:
89
+ # Compute the reverse complement of the sequence
90
+ act_reverse_complement = str(contig.seq.reverse_complement()) if not to_uppercase else str(contig.seq.reverse_complement()).upper()
91
+
92
+ if IsAddHeader:
93
+ # Include sequence ID (if applicable), fasta ID, description, source file, sequence, and orientation in the output
94
+ entry = [sequence_id] if is_add_sequence_id else []
95
+ entry.extend([act_header, act_description, act_assembly, act_seq, 'forward'])
96
+ sequences.append(entry)
97
+ if adding_reverse_complement:
98
+ entry = [sequence_id + 1] if is_add_sequence_id else []
99
+ entry.extend([act_header, act_description, act_assembly, act_reverse_complement, 'reverse'])
100
+ sequences.append(entry)
101
+ if is_add_sequence_id:
102
+ sequence_id += 2
103
+ else:
104
+ sequence_id+=1
105
+ else:
106
+ # Only include the sequence in the output
107
+ sequences.append(act_seq)
108
+ if adding_reverse_complement:
109
+ sequences.append(act_reverse_complement)
110
+
111
+ if AsDataFrame:
112
+ # Convert the sequences to a DataFrame
113
+ sequences = pd.DataFrame(sequences, columns=df_cols)
114
+ return sequences
115
+
116
+
117
+ def segment_sequence_contiguous(
118
+ sequence: str,
119
+ params: Dict[str, Any],
120
+ sequence_id: Optional[Any] = np.nan
121
+ ) -> List[Dict[str, Any]]:
122
+ """
123
+ Creates end-to-end, disjoint segments of a sequence without overlaps.
124
+
125
+ Segments smaller than the predefined minimum length will be discarded.
126
+ This function returns a list of segments along with their positions in the original sequence.
127
+
128
+ :param sequence: The input nucleotide sequence to be segmented.
129
+ :type sequence: str
130
+ :param params: Dictionary containing the segmentation parameters. Must include 'min_length' and 'max_length' keys
131
+ specifying the minimum and maximum lengths of the segments, respectively. Can contain other parameters.
132
+ :type params: Dict[str, Any]
133
+ :param sequence_id: An identifier for the sequence, optional. Defaults to NaN.
134
+ :type sequence_id: Optional[Any]
135
+ :return: A list of dictionaries, each representing a segment. Each dictionary contains the segment's sequence,
136
+ start position, end position, and sequence ID.
137
+ :rtype: List[Dict[str, Any]]
138
+
139
+ Example:
140
+ >>> params = {'min_length': 0, 'max_length': 100}
141
+ >>> segment_sequence_contiguous('ATCGATCGA', params)
142
+ [{'segment': 'ATCGATCGA', 'segment_start': 0, 'segment_end': 9, 'sequence_id': np.nan}]
143
+ """
144
+
145
+ # Extract segmentation parameters
146
+ min_segment_len = params['min_length']
147
+ max_segment_len = params['max_length']
148
+
149
+ # Ensure the sequence is treated as a string
150
+ if isinstance(sequence, str):
151
+ act_seq = sequence
152
+ L = len(sequence)
153
+
154
+ segments = []
155
+ for i in range(0, L, max_segment_len):
156
+ act_start_pos = i
157
+ act_end_pos = min(i + max_segment_len, L)
158
+ act_segment = sequence[act_start_pos:act_end_pos]
159
+
160
+
161
+
162
+ # Add segment to the list if it's longer than the minimum length
163
+ if len(act_segment) >= min_segment_len:
164
+ new_record = {
165
+ 'segment': act_segment,
166
+ 'segment_start': act_start_pos,
167
+ 'segment_end': act_end_pos,
168
+ 'sequence_id': sequence_id
169
+ }
170
+ segments.append(new_record)
171
+
172
+ return segments
173
+
174
+
175
+
176
+ def segment_sequences_random(
177
+ sequences: Union[pd.DataFrame, List[str]],
178
+ params: Dict[str, Union[int, float, str, Dict, List, Tuple]]
179
+ ) -> List[Dict[str, Union[int, str]]]:
180
+ """
181
+ Randomly segments the input sequences.
182
+
183
+ This function accepts either a list of sequences or a DataFrame containing sequences.
184
+ If a DataFrame is provided, it's assumed to have preprocessed sequences with "sequence" and "sequence_id" columns,
185
+ where "sequence_id" is a valid primary key. The function returns a list of dictionaries,
186
+ each containing details of a segment including its sequence, start position, end position,
187
+ associated sequence ID, and a segment ID (not generated in this function).
188
+
189
+ :param sequences: A DataFrame containing sequences with "sequence" and "sequence_id" columns or a list of sequences.
190
+ :type sequences: Union[pd.DataFrame, List[str]]
191
+ :param params: Dictionary containing segmentation parameters such as 'coverage', 'min_length', and 'max_length'.
192
+ :type params: Dict[str, Union[int, float, str, Dict, List, Tuple]]
193
+ :return: A list of dictionaries with each containing details of a segment.
194
+ :rtype: List[Dict[str, Union[int, str]]]
195
+
196
+ Notes:
197
+ - The actual number of segments may differ from the expected number due to random sampling and sequences
198
+ being shorter than the specified segment size.
199
+ - Segment IDs are not generated by this function.
200
+ """
201
+
202
+ # Calculate sequence lengths and cumulative sum of lengths
203
+ sequences['seq_lengths'] = sequences.apply(lambda x: len(x['sequence']), axis=1)
204
+ sequences['lenght_cum_sum'] = sequences['seq_lengths'].cumsum()
205
+ Lseqs = sum(sequences['seq_lengths'])
206
+
207
+ # Calculate the number of segments to sample based on expected coverage.
208
+ # Note: The actual number might be biased if many sequences are "short" compared to the segment sizes.
209
+ N_segments = int(np.ceil(params['coverage'] * Lseqs / params['max_length']))
210
+ logging.info(f'Sampling {N_segments} segments from {len(sequences)} sequences.')
211
+
212
+ # Generate random starting coordinates for segments
213
+ start_coords = list(np.sort(np.int64(np.random.uniform(0, sequences['lenght_cum_sum'].max(), N_segments))))
214
+ segmentdb = []
215
+
216
+ for sid, act_sampling_coord in enumerate(start_coords):
217
+
218
+ diff = act_sampling_coord - sequences['lenght_cum_sum']
219
+
220
+ # Find the sequence in which the current segment starts
221
+ for i in range(len(sequences['lenght_cum_sum'])):
222
+ if diff[i] < 0:
223
+ break
224
+
225
+ act_sequence_id = sequences['sequence_id'].iloc[i]
226
+ rel_coord = act_sampling_coord - sequences['lenght_cum_sum'].iloc[i] + sequences['seq_lengths'].iloc[i]
227
+
228
+ segment_end = min(rel_coord + params['max_length'], sequences['seq_lengths'].iloc[i])
229
+
230
+ # Skip the segment if it's shorter than the minimum segment length
231
+ if segment_end - rel_coord < params['min_length']:
232
+ pred_seqgment = sequences['sequence'].iloc[i][rel_coord:segment_end]
233
+ minimum_len = params['min_length']
234
+ logging.info(f'Too short segment, skip! Sampled segment: {pred_seqgment}, Segment end coordinate: {segment_end}, relative coordinate: {rel_coord}, minimum length is: {minimum_len}')
235
+ continue
236
+
237
+ new_segment = sequences['sequence'].iloc[i][rel_coord:segment_end]
238
+ new_record = {
239
+ 'sequence_id': act_sequence_id,
240
+ 'segment_start': rel_coord,
241
+ 'segment_end': segment_end,
242
+ 'segment': new_segment,
243
+ 'segment_id': str(sid)
244
+ }
245
+
246
+ segmentdb.append(new_record)
247
+
248
+ return segmentdb
249
+
250
+ def segment_sequences(
251
+ sequences: Union[List[str], pd.DataFrame],
252
+ params: Dict[str, Union[int, float, str, ]],
253
+ AsDataFrame: bool = False
254
+ ) -> Union[List[str], pd.DataFrame]:
255
+ """
256
+ Segments sequences based on the provided parameters.
257
+
258
+ This function assumes that the sequence is quality controlled and preprocessed, i.e., it is a valid nucleotide sequence.
259
+ If sequences are provided as a DataFrame, then it is assumed that there is a "sequence_id" and
260
+ a "sequence" attribute. The "sequence_id" should be a valid primary key.
261
+ If the output is requested as a DataFrame, then the IDs are added as well.
262
+
263
+ :param sequences: A list of sequences or a DataFrame containing sequences.
264
+ If a DataFrame, it must have "sequence_id" and "sequence" attributes.
265
+ :type sequences: Union[List[str], pd.DataFrame]
266
+ :param params: Dictionary containing the segmentation parameters.
267
+ - 'type' (str): The type of segmentation ('contiguous' or 'random').
268
+ - 'min_length' (int): Minimum length of a segment.
269
+ - 'max_length' (int): Maximum length of a segment.
270
+ - 'coverage' (float): Coverage percentage for random segmentation.
271
+ :type params: Dict[str, Union[int, float, str, Dict[str, int], List[int], Tuple[int, int]]]
272
+ :param AsDataFrame: If True, the output will be a DataFrame. If False, it will be a list. Defaults to False.
273
+ :type AsDataFrame: bool
274
+ :return: List of segmented sequences or a DataFrame with segmented sequences and their corresponding information based on the `AsDataFrame` parameter.
275
+ :rtype: Union[List[str], pd.DataFrame]
276
+ :raises ValueError: If the provided sequences DataFrame does not have the required attributes.
277
+ :raises ValueError: If the "sequence_id" column is not a valid primary key.
278
+
279
+ Examples:
280
+ >>> segment_sequences(['AATCAATTTTATTT', 'AGCCGATTCAATTGCATTATTT'], {'type': 'contiguous', 'min_length': 1, 'max_length': 1000, 'coverage': 1.0})
281
+ """
282
+
283
+ segmentation_type = params['type']
284
+
285
+ # Checking for primary key and sequence attribute???
286
+ expected_attributes = ['sequence_id', 'sequence']
287
+ return_cols = ['segment_id', 'sequence_id', 'segment_start', 'segment_end', 'segment']
288
+
289
+ if isinstance(sequences, list):
290
+ logging.info('Sequences is a list, therefore ignoring ids and tracking information. ')
291
+ IsSequenceId = None
292
+ IsSeqList = True
293
+ elif isinstance(sequences, pd.DataFrame):
294
+ #logging.info('Sequences is a list, therefore adding tracking information.')
295
+ logging.info('Checking input DataFrame!')
296
+ check_expected_columns(sequences, expected_attributes)
297
+ logging.info('Checking input sequence_id is valid primary key in the DataFrame')
298
+ is_valid_primary_key(sequences, 'sequence_id')
299
+ IsSequenceId = True
300
+ IsSeqList=False
301
+
302
+ segments = []
303
+ if segmentation_type == 'contiguous':
304
+ if IsSeqList:
305
+ if IsSequenceId:
306
+ for act_seq_id, seq in enumerate(sequences):
307
+ act_segments = segment_sequence_contiguous(seq, params, act_seq_id)
308
+ segments.extend(act_segments)
309
+ else:
310
+ for seq in sequences:
311
+ act_segments = segment_sequence_contiguous(seq, params)
312
+ segments.extend(act_segments)
313
+ else:
314
+ for _, rec in sequences.iterrows():
315
+ act_seq = rec['sequence']
316
+ act_seq_id = rec['sequence_id']
317
+ act_segments = segment_sequence_contiguous(act_seq, params, act_seq_id)
318
+ segments.extend(act_segments)
319
+
320
+ elif segmentation_type == 'random':
321
+ if IsSeqList:
322
+ seqeunce_df = pd.DataFrame(sequences,
323
+ columns = ['sequence'])
324
+ seqeunce_df['sequence_id'] = list(range(len(sequences)))
325
+ segments = segment_sequences_random(seqeunce_df, params)
326
+
327
+ else:
328
+ segments = segment_sequences_random(sequences, params)
329
+ if AsDataFrame:
330
+ #logging.info('Creating a DataFrame from the segments. ')
331
+ segment_db = pd.DataFrame(segments)
332
+ segment_ids = list(range(len(segment_db)))
333
+ segment_db['segment_id'] = segment_ids
334
+ segment_db = segment_db[return_cols]
335
+
336
+ else:
337
+ segment_db = [seg['segment'] for seg in segments]
338
+ return segment_db
339
+
340
+ def lca_kmer_tokenize_segment(segment: str, offset: int, params: Dict[str, Dict[str, int] | int | float]):
341
+ # calculate the tokenization for one offset value
342
+ shift = params['shift']
343
+ max_segment_length = params['max_segment_length']
344
+ max_unknown_token_proportion = params['max_unknown_token_proportion']
345
+ kmer = params['kmer']
346
+ token_limit = params['token_limit']
347
+ vocabmap = params['vocabmap']
348
+ add_special_token = params['add_special_token']
349
+ if len(segment) > max_segment_length:
350
+ raise(ValueError(f'The segment is longer {len(segment)} then the maximum allowed segment length ({max_segment_length}). '))
351
+
352
+ kmers = [segment[i:i + kmer] for i in range(offset, len(segment) - kmer + 1, shift)]
353
+
354
+ return kmers
355
+
356
+
357
+
358
+
359
+
360
+ def lca_tokenize_segment(segment: str, params: Dict[str, Dict[str, int] | int | float]) -> Tuple[List[List[int]], List[List[str]]]:
361
+ """
362
+ Tokenizes a single segment using Local Context Aware (LCA) tokenization.
363
+ The segment is first split into k-mers with specified shifts and then tokenized into token vectors.
364
+
365
+ :param segment: The input nucleotide sequence segment to be tokenized.
366
+ :type segment: str
367
+ :param params: Dictionary containing the tokenization parameters.
368
+ - 'shift' (int): The k-mer shift parameter.
369
+ - 'max_segment_length' (int): Maximum allowable segment length.
370
+ - 'max_unknown_token_proportion' (float): Maximum allowable proportion of unknown tokens in a segment.
371
+ - 'kmer' (int): Size of the k-mer.
372
+ - 'token_limit' (int): Maximum number of tokens allowed in the tokenized output.
373
+ - 'vocabmap' (dict[str, int]): Dictionary mapping k-mers to their respective token values.
374
+ :type params: dict
375
+ :returns: A tuple containing:
376
+ - list[list[int]]: List of tokenized segments (each segment as a list of integers).
377
+ - list[list[str]]: List of k-merized segments with different shifts (each segment as a list of strings).
378
+ :rtype: Tuple[List[List[int]], List[List[str]]]
379
+ :raises ValueError: If the segment length exceeds the `max_segment_length`.
380
+
381
+ Examples:
382
+ >>> vocabmap_example = {"[CLS]": 2, "[SEP]": 3, "[UNK]": 0, "TCTTT": 4, "CTTTG": 5, "TTTGC": 6, "TTGCT": 7}
383
+ >>> segment_example = 'TCTTTGCTAAG'
384
+ >>> params_example = {'shift': 1, 'max_segment_length': 512, 'max_unknown_token_proportion': 0.2, 'kmer': 5, 'token_limit': 10, 'vocabmap': vocabmap_example}
385
+ >>> lca_tokenize_segment(segment_example, params_example)
386
+ ([[2, 4, 5, 6, 7, 3]], [['TCTTT', 'CTTTG', 'TTTGC', 'TTGCT']])
387
+ """
388
+
389
+
390
+ #logging.info('Tokenizing a segment')
391
+ shift = params['shift']
392
+ max_segment_length = params['max_segment_length']
393
+ max_unknown_token_proportion = params['max_unknown_token_proportion']
394
+ kmer = params['kmer']
395
+ token_limit = params['token_limit']
396
+ vocabmap = params['vocabmap']
397
+ add_special_token = params['add_special_token']
398
+ if len(segment) > max_segment_length:
399
+ raise(ValueError(f'The segment is longer {len(segment)} then the maximum allowed segment length ({max_segment_length}). '))
400
+
401
+ kmers_offset = []
402
+ # For every pssoble offset and window we should get a k-mer vector.
403
+ # If the segmen is too short or non-existent, then we might have a problem. So, please ensure the segment
404
+ for offset in range(shift):
405
+ kmers = [segment[i:i + kmer] for i in range(offset, len(segment) - kmer + 1, shift)]
406
+ kmers_offset.append(kmers)
407
+ # Mapping the k-mers into numbers
408
+ tokenized_segments = tokenize_kmerized_segment_list(kmers_offset, vocabmap, token_limit, max_unknown_token_proportion, add_special_token)
409
+ return tokenized_segments, kmers_offset
410
+
411
+
412
+
413
+ def tokenize_kmerized_segment_list(kmerized_segments: List[List[str]],
414
+ vocabmap: Dict[str, int],
415
+ token_limit: int,
416
+ max_unknown_token_proportion: float,
417
+ add_special_tokens: bool = True) -> List[List[int]]:
418
+ """Tokenizes or vectorizes a list of k-merized segments into a list of token vectors. If the expected number of
419
+ tokens in a segment exceeds the maximum allowed tokens (`token_limit`), the function raises an error. For segments
420
+ where unknown k-mers exceed the proportion set by `max_unknown_token_proportion`, the output is a special token
421
+ sequence indicating an empty sentence.
422
+
423
+ :param kmerized_segments: List containing k-merized segments.
424
+ :type kmerized_segments: List[List[str]]
425
+ :param vocabmap: Dictionary that maps k-mers to their respective token values.
426
+ :type vocabmap: Dict[str, int]
427
+ :param token_limit: Maximum number of tokens allowed in the tokenized output.
428
+ :type token_limit: int
429
+ :param max_unknown_token_proportion: Maximum allowable proportion of unknown tokens in a segment.
430
+ :type max_unknown_token_proportion: float
431
+ :param add_special_tokens: Whether to add special tokens (`[CLS]` and `[SEP]`) to the tokenized segments.
432
+ :type add_special_tokens: bool, optional (default=True)
433
+ :returns: List containing tokenized segments.
434
+ :rtype: List[List[int]]
435
+ :raises ValueError: If the expected number of tokens in a segment exceeds `token_limit`.
436
+
437
+ Examples
438
+ --------
439
+
440
+ >>> vocabmap_example = {"[CLS]": 2, "[SEP]": 3, "[UNK]": 0, "TCTTTG": 4, "CTTTGC": 5, "TTTGCT": 6, "TTGCTA": 7}
441
+ >>> kmerized_segment_example = [['TCTTTG', 'CTTTGC', 'TTTGCT', 'TTGCTA']]
442
+ >>> tokenize_kmerized_segment_list(kmerized_segment_example, vocabmap_example, 10, 0.2)
443
+ [[2, 4, 5, 6, 7, 3]]
444
+ """
445
+
446
+ tokenized_segments = []
447
+ if add_special_tokens:
448
+ empty_sentence = [2, 3]
449
+ else:
450
+ empty_sentence = []
451
+
452
+ for act_kmer_list in kmerized_segments:
453
+ if add_special_tokens:
454
+ tokenized_kmerized_segment = [vocabmap['[CLS]']]
455
+ else:
456
+ tokenized_kmerized_segment = []
457
+ unkcount=0
458
+ L_kmerized_segment = len(act_kmer_list)
459
+ unkw_tsh_count = int(L_kmerized_segment*max_unknown_token_proportion)
460
+ if len(act_kmer_list)+2 > token_limit:
461
+ raise(ValueError(f'The expected number of tokens in the segment ({L_kmerized_segment+2}) is larger, then the maximum allowed number of tokens = ({token_limit}). '))
462
+ if L_kmerized_segment == 0:
463
+ logging.info('Its and empty sentence')
464
+ tokenized_kmerized_segment = empty_sentence
465
+ tokenized_segments.append(empty_sentence)
466
+ continue
467
+ for kmer in act_kmer_list:
468
+ try:
469
+ tokenized_kmerized_segment.append(vocabmap[kmer.upper()])
470
+ except KeyError:
471
+ tokenized_kmerized_segment.append(vocabmap['[UNK]'])
472
+ unkcount+=1
473
+ if unkcount > unkw_tsh_count:
474
+ tokenized_segments.append(empty_sentence)
475
+ continue
476
+ if add_special_tokens:
477
+ tokenized_kmerized_segment.append(vocabmap['[SEP]'])
478
+ tokenized_segments.append(tokenized_kmerized_segment)
479
+
480
+ return tokenized_segments
481
+
482
+ def process_batch_tokenize_segments_with_ids(
483
+ segments: List[str],
484
+ segment_ids: List[Any],
485
+ tokenization_params: Dict[str, Any],
486
+ np_token_type: type = np.uint16
487
+ ) -> Dict[Any, List[np.ndarray]]:
488
+ """
489
+ Tokenizes a batch of segments and associates them with their provided IDs.
490
+
491
+ This function generates vector representations for a collection of segments, assuming the segments
492
+ have undergone quality control. The result is a dictionary where the keys are segment IDs, and the values
493
+ are lists of potential vector representations for the segment, with each list element corresponding to
494
+ a specific shift.
495
+
496
+ The vector representations are converted to numpy arrays. The output is not a 2D rectangular array but
497
+ a dictionary mapping each segment ID to its tokenized representations.
498
+
499
+ :param segments: A list of preprocessed and validated segments.
500
+ :type segments: List[str]
501
+ :param segment_ids: A list of segment IDs corresponding to each segment in `segments`.
502
+ :type segment_ids: List[Any]
503
+ :param tokenization_params: A dictionary containing tokenization parameters.
504
+ :type tokenization_params: Dict[str, Any]
505
+ :param np_token_type: Numpy data type for the tokenized segments. Defaults to np.uint16.
506
+ :type np_token_type: type, optional
507
+ :return: A dictionary with segment IDs as keys and lists of numpy arrays representing tokenized segments as values.
508
+ :rtype: Dict[Any, List[np.ndarray]]
509
+
510
+ Example:
511
+ >>> segments = ['ACTG', 'TGCA']
512
+ >>> segment_ids = [1, 2]
513
+ >>> tokenization_params = {'max_segment_length': 50, ...}
514
+ >>> tokenized_segments = process_batch_tokenize_segments_with_ids(
515
+ segments, segment_ids, tokenization_params
516
+ )
517
+ """
518
+ tokenized_segments_with_ids = {}
519
+ for i, segment in enumerate(segments):
520
+ act_id = segment_ids[i]
521
+ tokenized_segments_with_ids[act_id] = []
522
+ max_segment_length = tokenization_params['max_segment_length']
523
+ if len(segment) > max_segment_length:
524
+ raise ValueError(f'The segment is longer ({len(segment)}) than the maximum allowed segment length ({max_segment_length}).')
525
+
526
+ tokenized_segment, _ = lca_tokenize_segment(segment, tokenization_params)
527
+ tokenized_segment = [np.array(act_segment, dtype=np_token_type) for act_segment in tokenized_segment]
528
+ tokenized_segments_with_ids[act_id] = tokenized_segment
529
+ return tokenized_segments_with_ids
530
+
531
+ def batch_tokenize_segments_with_ids(
532
+ segment_data: Union[Tuple[List[str], List[Any]], pd.DataFrame],
533
+ tokenization_params: Dict[str, Any],
534
+ num_cores: int = 1,
535
+ batch_size: int = 10000,
536
+ np_token_type: type = np.uint16
537
+ ) -> Dict[Any, List[np.ndarray]]:
538
+ """
539
+ Parallel tokenization of segments with associated IDs.
540
+
541
+ This function splits the input data into batches and uses multiprocessing to tokenize
542
+ the segments in parallel. It supports both list/tuple inputs and pandas DataFrames.
543
+
544
+ :param segment_data: Either a tuple/list containing two elements (segments, segment_ids),
545
+ or a pandas DataFrame with 'segment' and 'segment_id' columns.
546
+ :type segment_data: Union[Tuple[List[str], List[Any]], pd.DataFrame]
547
+ :param tokenization_params: Dictionary containing tokenization parameters.
548
+ :type tokenization_params: Dict[str, Any]
549
+ :param num_cores: Number of CPU cores to use for parallel processing. Defaults to 1.
550
+ :type num_cores: int, optional
551
+ :param batch_size: Number of segments to process in each batch. Defaults to 10,000.
552
+ :type batch_size: int, optional
553
+ :param np_token_type: Numpy data type for the tokenized segments. Defaults to np.uint16.
554
+ :type np_token_type: type, optional
555
+ :return: A dictionary where keys are segment IDs and values are lists of numpy arrays representing tokenized segments.
556
+ :rtype: Dict[Any, List[np.ndarray]]
557
+ :raises ValueError: If the input data is neither a tuple/list nor a pandas DataFrame.
558
+
559
+ Example:
560
+ >>> segments = ['ACTG', 'TGCA']
561
+ >>> segment_ids = [1, 2]
562
+ >>> tokenization_params = {'max_segment_length': 50, ...}
563
+ >>> tokenized_data = batch_tokenize_segments_with_ids(
564
+ (segments, segment_ids),
565
+ tokenization_params,
566
+ num_cores=4,
567
+ batch_size=1000
568
+ )
569
+ """
570
+ if isinstance(segment_data, tuple) or isinstance(segment_data, list):
571
+ segments = segment_data[0]
572
+ segment_ids = segment_data[1]
573
+ elif isinstance(segment_data, pd.DataFrame):
574
+ segments = list(segment_data['segment'])
575
+ segment_ids = list(segment_data['segment_id'])
576
+ else:
577
+ raise ValueError(f'The input should be either pandas DataFrame or a tuple instead of {type(segment_data)}')
578
+
579
+ Ndata = len(segments)
580
+ batch_intervals = [(i, min(i + batch_size, Ndata)) for i in range(0, Ndata, batch_size)]
581
+ params = [
582
+ (segments[interval[0]:interval[1]],
583
+ segment_ids[interval[0]:interval[1]],
584
+ tokenization_params,
585
+ np_token_type)
586
+ for interval in batch_intervals
587
+ ]
588
+ with Pool(processes=num_cores) as pool:
589
+ result_list = pool.starmap(process_batch_tokenize_segments_with_ids, params)
590
+
591
+ tokenized_sets = {}
592
+ for d in result_list:
593
+ tokenized_sets.update(d)
594
+
595
+ return tokenized_sets
596
+
597
+
598
+ def get_rectangular_array_from_tokenized_dataset(tokenized_segments_data: Dict[int, List[np.ndarray]], shift: int, max_token_count: int, truncate_zeros: bool = True, randomize: bool = True, numpy_dtype: Type = np.uint16) -> Tuple[np.ndarray, pd.DataFrame]:
599
+ """Create a rectangular numpy array that can be used as input to a Language Model (LM) from tokenized segment data.
600
+
601
+ :param tokenized_segments_data: A dictionary where keys are segment ids and values are lists of possible LCA tokenized vectors.
602
+ :type tokenized_segments_data: Dict[int, List[np.ndarray]]
603
+
604
+ :param shift: Number of LCA offsets.
605
+ :type shift: int
606
+
607
+ :param max_token_count: Maximum allowed token count in the output numpy array.
608
+ :type max_token_count: int
609
+
610
+ :param truncate_zeros: If True, truncate columns from the end of the numpy array that only contain zeros. (default=True)
611
+ :type truncate_zeros: bool, optional
612
+
613
+ :param randomize: If True, randomize the order of the rows in the output numpy array. (default=True)
614
+ :type randomize: bool, optional
615
+
616
+ :param numpy_dtype: Data type of the values in the output numpy array. (default=np.uint16)
617
+ :type numpy_dtype: Type, optional
618
+
619
+ :returns: A rectangular numpy array suitable for input to an LM.
620
+ :rtype: np.ndarray
621
+
622
+ :returns: A dataframe that describes which row in the numpy array corresponds to which segment and its LCA offset.
623
+ Columns are: ['torch_id', 'segment_id', 'offset']
624
+ :rtype: pd.DataFrame
625
+
626
+ """
627
+
628
+
629
+ expected_length = len(tokenized_segments_data)*shift
630
+ X=np.full((expected_length,max_token_count),0, dtype=numpy_dtype)
631
+ torch_db = []
632
+ torch_id = 0
633
+ for segment_id, tokenized_vectors in tokenized_segments_data.items():
634
+ for offset in range(shift):
635
+ segment_vector = tokenized_vectors[offset]
636
+ X[torch_id,0:segment_vector.shape[0]] = segment_vector
637
+ torch_db.append([torch_id, segment_id, offset])
638
+ torch_id+=1
639
+ torch_tokenized_segment_db = pd.DataFrame(torch_db,
640
+ columns = ['torch_id', 'segment_id', 'offset'])
641
+
642
+ if randomize:
643
+ logging.info('Doing randomization!')
644
+ perm = np.random.permutation(expected_length)
645
+ X = X[perm,:]
646
+ torch_tokenized_segment_db.rename({'torch_id': 'original_torch_id'}, axis=1, inplace=True)
647
+ torch_tokenized_segment_db = torch_tokenized_segment_db.iloc[perm,:].reset_index().drop('index', axis=1).reset_index().rename({'index' : 'torch_id'}, axis=1)
648
+
649
+ if truncate_zeros:
650
+ logging.info('Tuncating all zeros column')
651
+ X = truncate_zero_columns(X)
652
+ return X, torch_tokenized_segment_db
653
+
654
+
655
+ def pretty_print_overlapping_sequence(segment, segment_kmers, tokenizer_params):
656
+ """
657
+ Format the sequence for pretty printing with overlapping k-mers.
658
+
659
+ :param segment: DNA sequence.
660
+ :type segment: str
661
+
662
+ :param segment_kmers: List of k-mers in the segment.
663
+ :type segment_kmers: list
664
+
665
+ :param tokenizer_params: Dictionary containing tokenization parameters.
666
+ :type tokenizer_params: dict
667
+
668
+ :return: List of formatted strings representing the sequence with overlapping k-mers.
669
+ :rtype: list
670
+ """
671
+
672
+ shift = tokenizer_params['shift']
673
+ k = tokenizer_params['kmer']
674
+ sep_c = 2
675
+ lines = []
676
+ base_offset = len(str( int((k+3)/shift))) + 3
677
+ first_line = ' '*base_offset + segment
678
+ lines.append(first_line)
679
+ nr_lines = int(np.ceil((k+sep_c)/shift))
680
+ logging.info('Nr. line to cover the seq: {0}'.format(nr_lines))
681
+
682
+ for line_id in range(nr_lines):
683
+
684
+ line_mers = [k_mer for j, k_mer in enumerate(segment_kmers) if j%nr_lines== line_id]
685
+ act_line = str(line_id) + '. ' + ' '*(line_id*shift) + (' '*(sep_c)).join(line_mers)
686
+ lines.append(act_line)
687
+ lines = '\n'.join(lines)
688
+ return lines
689
+
690
+
691
+ def generate_kmers(abc: Set[str], k: int) -> List[str]:
692
+ """
693
+ Generates all possible k-mers from a given alphabet.
694
+
695
+ :param abc: The alphabet.
696
+ :type abc: Set[str]
697
+ :param k: Length of the k-mers.
698
+ :type k: int
699
+ :return: List of all possible k-mers.
700
+ :rtype: List[str]
701
+ """
702
+ return [''.join(p) for p in product(abc, repeat=k)]
703
+
704
+ def save_to_hdf(X: np.ndarray, hdf_file_path: str, database: pd.DataFrame = None, compression: bool = False, pd_chunksize: int = 10_000_000) -> None:
705
+ """Save a numpy array and an optional pandas DataFrame to an HDF5 file.
706
+
707
+ :param X: 2D numpy array to be saved.
708
+ :type X: np.ndarray
709
+ :param hdf_file_path: Path to the HDF5 file.
710
+ :type hdf_file_path: str
711
+ :param database: Pandas DataFrame to be saved. Defaults to None.
712
+ :type database: pd.DataFrame
713
+ :param compression: Whether to apply compression. Defaults to False.
714
+ :type compression: bool
715
+ :param pd_chunksize: Number of rows per chunk for saving the DataFrame. Defaults to 10,000,000.
716
+ :type pd_chunksize: int
717
+ :raises ValueError: If the provided numpy array is not 2D.
718
+ :raises OSError: If there's an error creating the directory structure or removing an existing HDF5 file.
719
+ Example:
720
+
721
+ >>> import numpy as np
722
+ >>> import pandas as pd
723
+ >>> array = np.random.random((100, 100))
724
+ >>> df = pd.DataFrame({'A': range(1, 101), 'B': range(101, 201)})
725
+ >>> save_to_hdf(array, "sample.hdf5", database=df, compression=True)
726
+ """
727
+
728
+ # Check if X is a 2D numpy array
729
+ if len(X.shape) != 2:
730
+ raise ValueError("The provided numpy array is not 2D.")
731
+
732
+ # If HDF5 file exists, attempt to delete it
733
+ if os.path.exists(hdf_file_path):
734
+ try:
735
+ os.remove(hdf_file_path)
736
+ logging.info(f"Existing HDF5 file {hdf_file_path} removed successfully.")
737
+ except Exception as e:
738
+ raise OSError(f"Error removing existing HDF5 file {hdf_file_path}. Error: {e}")
739
+
740
+ # Create directory structure for HDF5 file
741
+ create_directory_for_filepath(hdf_file_path)
742
+
743
+ # Save the numpy array to HDF5
744
+ with h5py.File(hdf_file_path, 'w') as hdf:
745
+ try:
746
+ grp = hdf.create_group("training_data")
747
+ except ValueError:
748
+ del hdf['training_data']
749
+
750
+ if compression:
751
+ grp.create_dataset("X", data=X, compression="lzf", chunks=True)
752
+ else:
753
+ grp.create_dataset("X", data=X, chunks=True)
754
+
755
+ logging.info(f"Numpy array saved to {hdf_file_path} successfully.")
756
+
757
+ # Save the pandas DataFrame to HDF5, if provided
758
+ if database is not None:
759
+ logging.info("Adding database into the HDF5 file!")
760
+ num_chunks = int(np.ceil(len(database) / pd_chunksize))
761
+ logging.info(f'Number of chunks: {num_chunks}')
762
+ chunk_grouping = np.arange(len(database)) // pd_chunksize
763
+ chunkseqs = database.groupby(chunk_grouping)
764
+ for i, (_, chunk) in enumerate(chunkseqs):
765
+ logging.info(f'Writing database chunk {i} into {hdf_file_path}')
766
+ if compression:
767
+ chunk.to_hdf(hdf_file_path, f'database_{i}', format='table', data_columns=True, mode='a', complib='lzo')
768
+ else:
769
+ chunk.to_hdf(hdf_file_path, f'database_{i}', format='table', data_columns=True, mode='a')
770
+
771
+ logging.info('Database addition finished!')
772
+
773
+
774
+
775
+ def dataframe_to_seqrecords(
776
+ df: pd.DataFrame,
777
+ fastaidcol: str = 'test_fastaid',
778
+ sequencecol: str = 'sequence'
779
+ ) -> List[SeqRecord]:
780
+ """
781
+ Convert a DataFrame with sequence information into a list of SeqRecord objects.
782
+
783
+ :param df: DataFrame containing at least two columns: one for sequence IDs and one for sequences.
784
+ :type df: pd.DataFrame
785
+ :param fastaidcol: Name of the column in `df` that contains sequence IDs. Defaults to 'test_fastaid'.
786
+ :type fastaidcol: str, optional
787
+ :param sequencecol: Name of the column in `df` that contains nucleotide sequences. Defaults to 'sequence'.
788
+ :type sequencecol: str, optional
789
+ :return: A list of SeqRecord objects constructed from the DataFrame.
790
+ :rtype: List[SeqRecord]
791
+
792
+ Example:
793
+ >>> import pandas as pd
794
+ >>> data = {'test_fastaid': ['seq1', 'seq2'], 'sequence': ['ATCG', 'GGTA']}
795
+ >>> df = pd.DataFrame(data)
796
+ >>> seq_records = dataframe_to_seqrecords(df)
797
+ >>> seq_records[0].id
798
+ 'seq1'
799
+ """
800
+ seq_records = []
801
+ for _, row in df.iterrows():
802
+ seq = Seq(row[sequencecol])
803
+ record = SeqRecord(seq, id=str(row[fastaidcol]), description="")
804
+ seq_records.append(record)
805
+ return seq_records
806
+
807
+
808
+ def write_seqrecords_to_fasta(
809
+ seq_records: List[SeqRecord],
810
+ file_name: str
811
+ ) -> None:
812
+ """
813
+ Write a list of SeqRecord objects to a FASTA file.
814
+
815
+ :param seq_records: List of SeqRecord objects to be written to file.
816
+ :type seq_records: List[SeqRecord]
817
+ :param file_name: Name or path of the file to write the FASTA records.
818
+ :type file_name: str
819
+ :return: None
820
+ :rtype: None
821
+
822
+ Example:
823
+ >>> from Bio.Seq import Seq
824
+ >>> from Bio.SeqRecord import SeqRecord
825
+ >>> seq_records = [SeqRecord(Seq('ATCG'), id='seq1'), SeqRecord(Seq('GGTA'), id='seq2')]
826
+ >>> write_seqrecords_to_fasta(seq_records, 'output.fasta')
827
+ """
828
+ SeqIO.write(seq_records, file_name, "fasta")
829
+
830
+
831
+ def dump_records_to_files(
832
+ seq_records: List[SeqRecord],
833
+ folder_path: str
834
+ ) -> None:
835
+ """
836
+ Write each SeqRecord to a separate FASTA file in the specified folder.
837
+
838
+ :param seq_records: List of SeqRecord objects to be written individually.
839
+ :type seq_records: List[SeqRecord]
840
+ :param folder_path: Path to the folder where the files should be saved.
841
+ The folder will be created if it does not exist.
842
+ :type folder_path: str
843
+ :return: None
844
+ :rtype: None
845
+
846
+ Example:
847
+ >>> from Bio.Seq import Seq
848
+ >>> from Bio.SeqRecord import SeqRecord
849
+ >>> seq_records = [SeqRecord(Seq('ATCG'), id='seq1'), SeqRecord(Seq('GGTA'), id='seq2')]
850
+ >>> dump_records_to_files(seq_records, 'sequences_folder')
851
+ """
852
+ # Ensure the folder exists
853
+ os.makedirs(folder_path, exist_ok=True)
854
+
855
+ for record in seq_records:
856
+ file_path = os.path.join(folder_path, f"{record.id}.fasta")
857
+ SeqIO.write(record, file_path, "fasta")
858
+
859
+
860
+ def split_seqrecords_to_fasta_chunks(
861
+ seq_records: List[SeqRecord],
862
+ output_folder: str,
863
+ chunk_size_mb: int = 10
864
+ ) -> None:
865
+ """
866
+ Splits a list of SeqRecord objects into multiple FASTA files, each less than a specified size in MB.
867
+
868
+ :param seq_records: List of SeqRecord objects to be split into chunks.
869
+ :type seq_records: List[SeqRecord]
870
+ :param output_folder: The output folder where the FASTA files will be saved.
871
+ :type output_folder: str
872
+ :param chunk_size_mb: Maximum size of each FASTA file in megabytes. Defaults to 10 MB.
873
+ :type chunk_size_mb: int, optional
874
+ :return: None
875
+ :rtype: None
876
+
877
+ Example:
878
+ >>> seq_records = [...] # A list of SeqRecord objects
879
+ >>> split_seqrecords_to_fasta_chunks(seq_records, 'output_chunks', chunk_size_mb=5)
880
+
881
+ Notes:
882
+ - The last chunk may be smaller than the specified `chunk_size_mb`.
883
+ - The function approximates the size of each record for chunking.
884
+ """
885
+ # Ensure output folder exists
886
+ os.makedirs(output_folder, exist_ok=True)
887
+
888
+ current_chunk = []
889
+ current_chunk_size = 0 # in bytes
890
+ chunk_id = 1 # Identifier for chunks/files
891
+ for record in seq_records:
892
+ # Approximate size of the record in bytes
893
+ record_size = len(str(record.seq)) + len(record.id) + 2 # Adding buffer for '>' and '\n'
894
+
895
+ # Check if adding this record exceeds the chunk size
896
+ if current_chunk_size + record_size > chunk_size_mb * 1024 * 1024:
897
+ file_path = os.path.join(output_folder, f"chunk_{chunk_id}.fasta")
898
+ SeqIO.write(current_chunk, file_path, "fasta")
899
+ current_chunk = []
900
+ current_chunk_size = 0
901
+ chunk_id += 1
902
+
903
+ current_chunk.append(record)
904
+ current_chunk_size += record_size
905
+
906
+ # Write any remaining records to the last chunk
907
+ if current_chunk:
908
+ file_path = os.path.join(output_folder, f"chunk_{chunk_id}.fasta")
909
+ SeqIO.write(current_chunk, file_path, "fasta")
910
+
911
+
912
+ def filter_short_sequences(
913
+ seq_records: List[SeqRecord],
914
+ length_threshold: int
915
+ ) -> List[SeqRecord]:
916
+ """
917
+ Filters out SeqRecord objects with sequences shorter than a specified threshold.
918
+
919
+ :param seq_records: List of SeqRecord objects.
920
+ :type seq_records: List[SeqRecord]
921
+ :param length_threshold: The minimum length of sequences to be retained.
922
+ :type length_threshold: int
923
+ :return: A list of SeqRecord objects that meet or exceed the length threshold.
924
+ :rtype: List[SeqRecord]
925
+
926
+ Example:
927
+ >>> from Bio.Seq import Seq
928
+ >>> from Bio.SeqRecord import SeqRecord
929
+ >>> records = [
930
+ ... SeqRecord(Seq('ATCG'), id='seq1'),
931
+ ... SeqRecord(Seq('AT'), id='seq2')
932
+ ... ]
933
+ >>> filtered_records = filter_short_sequences(records, 3)
934
+ >>> len(filtered_records)
935
+ 1
936
+ >>> filtered_records[0].id
937
+ 'seq1'
938
+ """
939
+ filtered_records = [record for record in seq_records if len(record.seq) >= length_threshold]
940
+ return filtered_records
941
+
942
+
943
+
944
+ def get_token_counts_for_segment(Lseg, kmer, shift, offset):
945
+ nr_tokens = int((Lseg -kmer)/shift + 1)
946
+ return nr_tokens
947
+
948
+ def get_seq_coordinates(token_pos, kmer, shift, offset):
949
+ seq_start = int(token_pos*shift + offset)
950
+ seq_end = int(token_pos*shift+kmer + offset)
951
+ return seq_start, seq_end
952
+
953
+ def get_token_coordinates(seq_pos, kmer, shift, offset, Lseg):
954
+
955
+ nrtokens = get_token_counts_for_segment(Lseg, kmer, shift, offset)
956
+
957
+ token_pos_end = int((seq_pos+offset - kmer) / shift)
958
+ token_pos_start = int((seq_pos + offset) / shift)
959
+
960
+ if token_pos_end<0:
961
+ token_pos_end=0
962
+ if token_pos_start >= nrtokens:
963
+ token_pos_start = nrtokens-1
964
+
965
+ return token_pos_start, token_pos_end
966
+
967
+ def sliding_window_average(arr, window_size=6):
968
+ # Create a window for averaging
969
+ window = np.ones(window_size) / window_size
970
+ # Use 'valid' mode to slide the window over the array without padding
971
+ result = np.convolve(arr, window, mode='valid')
972
+ return result
973
+
974
+ def convolve_expression_array(expression_array, window_size=6, step=2):
975
+ # Define the averaging window
976
+ window = np.ones(window_size) / window_size
977
+ # Apply convolution along each column (axis=0)
978
+ convolved_array = convolve1d(expression_array, window, axis=1, mode='reflect')
979
+ # Downsample by step size
980
+ return convolved_array[:, ::step]
tokenizer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import json
4
+ from copy import deepcopy
5
+ from typing import List, Optional, Tuple, Dict
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils.hub import cached_file, hf_hub_url
8
+
9
+ from .config_utils import SeqConfig
10
+ from .sequtils import generate_kmers, lca_kmer_tokenize_segment
11
+
12
+ # Define the names of the vocabulary files
13
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
14
+
15
+ # Define the mapping for pretrained vocabulary files
16
+ PRETRAINED_VOCAB_FILES_MAP = {
17
+ "vocab_file": {
18
+ "lca-mini-k6s1": "lca-base-dna6/vocab.txt",
19
+ "lca-mini-k6s2": "lca-base-dna6/vocab.txt",
20
+ "lca-mini-k1s1": "lca-base-dna1/vocab.txt",
21
+ }
22
+ }
23
+
24
+ # Define positional embedding sizes for pretrained models
25
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
26
+ "lca-mini-k6s1": 1024,
27
+ "lca-mini-k1s1": 1024,
28
+ "lca-mini-k6s2": 2048,
29
+ }
30
+
31
+ # Define initial configuration for pretrained models
32
+ PRETRAINED_INIT_CONFIGURATION = {
33
+ "lca-mini-k6s1": {"do_upper_case": True},
34
+ "lca-mini-k1s1": {"do_upper_case": True},
35
+ "lca-mini-k6s2": {"do_upper_case": True},
36
+ }
37
+
38
+ # Utility function to load vocabulary from a file
39
+ def load_vocab(vocab_file):
40
+ """Loads a vocabulary file into a dictionary."""
41
+ vocab = collections.OrderedDict()
42
+ with open(vocab_file, "r", encoding="utf-8") as reader:
43
+ tokens = reader.readlines()
44
+ for index, token in enumerate(tokens):
45
+ vocab[token.rstrip("\n")] = index
46
+ return vocab
47
+
48
+ class LCATokenizer(PreTrainedTokenizer):
49
+ """
50
+ Custom tokenizer for LCA (Local Context Aware) tasks.
51
+ Handles specific tokenization processes, including k-mer tokenization with configurable shifts.
52
+
53
+ Attributes:
54
+ vocab_files_names (dict): Mapping of vocabulary file names.
55
+ pretrained_vocab_files_map (dict): Mapping of pretrained vocabulary files.
56
+ pretrained_init_configuration (dict): Initial configuration for pretrained models.
57
+ max_model_input_sizes (dict): Maximum input sizes for pretrained models.
58
+ """
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
63
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
64
+
65
+ nucleotide_abc = {"A", "T", "C", "G"}
66
+ extended_nucleotide_abc = {"A", "T", "C", "G", "*"}
67
+ sequence_unk_token = 'N'
68
+
69
+ default_unk_token = "[UNK]"
70
+ default_sep_token = "[SEP]"
71
+ default_pad_token = "[PAD]"
72
+ default_cls_token = "[CLS]"
73
+ default_mask_token = "[MASK]"
74
+
75
+ def __init__(
76
+ self,
77
+ config: Dict = {},
78
+ operation_space: str = "kmer",
79
+ **kwargs,
80
+ ):
81
+ """
82
+ Initializes the LCATokenizer with configuration and operation space.
83
+
84
+ Args:
85
+ config (dict): Tokenization parameters like k-mer size and shift.
86
+ operation_space (str): Defines operation mode ('kmer' or 'sequence').
87
+ kwargs: Additional arguments for PreTrainedTokenizer.
88
+ """
89
+ self.defconfig = SeqConfig()
90
+ config = self.defconfig.get_and_set_tokenization_parameters(config)
91
+ self.config = config
92
+ self.operation_space = operation_space
93
+
94
+ # Set default tokens
95
+ kwargs.setdefault("cls_token", self.default_cls_token)
96
+ kwargs.setdefault("unk_token", self.default_unk_token)
97
+ kwargs.setdefault("sep_token", self.default_sep_token)
98
+ kwargs.setdefault("pad_token", self.default_pad_token)
99
+ kwargs.setdefault("mask_token", self.default_mask_token)
100
+
101
+ # Load vocabulary
102
+ vocab_file = self.config["vocabfile"]
103
+ self.vocab = self.config["vocabmap"]
104
+ self.id2token = {v: k for k, v in self.vocab.items()}
105
+ self.max_len = self.config["max_segment_length"]
106
+
107
+ super().__init__(**kwargs)
108
+
109
+ # Handle extended vocabulary for sequence mode
110
+ if self.operation_space == 'sequence':
111
+ token_extension = sorted(list(set(generate_kmers(LCATokenizer.extended_nucleotide_abc, self.config['kmer'])) - \
112
+ set(generate_kmers(LCATokenizer.nucleotide_abc, self.config['kmer'])) ))
113
+ self.extended_vocab = deepcopy(self.vocab)
114
+ for token in token_extension:
115
+ self.extended_vocab[token] = 4
116
+
117
+ self.unk_token = LCATokenizer.sequence_unk_token * self.config['shift']
118
+ self.mask_token = '*'
119
+ self.extended_vocab[self.mask_token] = self.vocab['[MASK]']
120
+
121
+ full_unk = 'N' * self.config['kmer']
122
+ self.vocab[full_unk] = 1
123
+ self.id2token[1] = full_unk
124
+ self.full_unk_token = full_unk
125
+
126
+ else:
127
+ self.extended_vocab = self.vocab
128
+ self.unk_token = '[UNK]'
129
+
130
+ self.unkown_tokenid = self.vocab['[UNK]']
131
+ self.sep_token = '[SEP]'
132
+ self.cls_token = '[CLS]'
133
+ self.pad_token = '[PAD]'
134
+ self.mask_token = '[MASK]'
135
+ self.special_tokens = list(self.special_tokens_map.values())
136
+
137
+
138
+
139
+ def _tokenize(self, text, **kwargs):
140
+ """
141
+ Tokenizes the input text using LCA tokenization with an optional offset.
142
+
143
+ Args:
144
+ text (str): The input DNA sequence to tokenize.
145
+ kwargs: Additional arguments, including:
146
+ - offset (int): The starting position for tokenization. Default is 0.
147
+
148
+ Returns:
149
+ List[str]: A list of tokens generated from the input text.
150
+ """
151
+ offset = kwargs.get("offset", 0)
152
+ #if offset < 0 or offset >= self.config.get("shift", 1):
153
+ # raise ValueError(f"Invalid offset: {offset}. Must be between 0 and {self.config['shift'] - 1}.")
154
+
155
+ return lca_kmer_tokenize_segment(text, offset, self.config)
156
+
157
+ def _convert_token_to_id(self, token: str) -> int:
158
+ """
159
+ Converts a token to its corresponding ID using the vocabulary.
160
+
161
+ Args:
162
+ token (str): The token to convert.
163
+
164
+ Returns:
165
+ int: Token ID, or the unknown token ID if the token is not in the vocabulary.
166
+ """
167
+ return self.extended_vocab.get(token, self.unkown_tokenid)
168
+
169
+ def _convert_id_to_token(self, index: int) -> str:
170
+ """
171
+ Converts an ID to its corresponding token using the vocabulary.
172
+
173
+ Args:
174
+ index (int): The ID to convert.
175
+
176
+ Returns:
177
+ str: Corresponding token, or the unknown token if the ID is not in the vocabulary.
178
+ """
179
+
180
+
181
+ return self.id2token.get(index, self.unk_token)
182
+
183
+ def __len__(self) -> int:
184
+ """
185
+ Returns the length of the tokenizer's vocabulary.
186
+
187
+ The length returned is one less than the actual number of items in the vocabulary
188
+ to account for a specific offset or adjustment in token indexing.
189
+
190
+ :return: The adjusted length of the vocabulary.
191
+ :rtype: int
192
+ """
193
+ return len(self.vocab)
194
+
195
+
196
+
197
+ def tokenize(self, text: str, **kwargs) -> List[str]:
198
+ """
199
+ Tokenizes the input text using LCA tokenization.
200
+
201
+ Args:
202
+ text (str): The input DNA sequence to tokenize.
203
+ kwargs: Additional arguments, including:
204
+ - offset (int): The starting position for tokenization. Default is 0.
205
+
206
+ Returns:
207
+ List[str]: A list of tokens generated from the input text.
208
+ """
209
+ return self._tokenize(text, **kwargs)
210
+
211
+ def encode(self, text: str, **kwargs) -> List[int]:
212
+ """
213
+ Extends the base `encode` method to support an `offset` parameter for custom tokenization logic.
214
+
215
+ Args:
216
+ text (str): Input text (DNA sequence).
217
+ offset (int): Offset parameter for the LCA tokenization. Defaults to 0.
218
+ kwargs: Additional arguments passed to the base `encode` method.
219
+
220
+ Returns:
221
+ List[int]: Encoded token IDs.
222
+ """
223
+ # Inject the offset into kwargs for the tokenizer
224
+ offset = kwargs.get("offset", 0)
225
+ kwargs["offset"] = offset
226
+ return super().encode(text, **kwargs)
227
+
228
+ def build_inputs_with_special_tokens(
229
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
230
+ ) -> List[int]:
231
+ """
232
+ Builds inputs by adding special tokens to a sequence or pair of sequences.
233
+
234
+ Args:
235
+ token_ids_0 (List[int]): List of token IDs for the first sequence.
236
+ token_ids_1 (List[int], optional): List of token IDs for the second sequence.
237
+
238
+ Returns:
239
+ List[int]: Input IDs with special tokens.
240
+ """
241
+ if token_ids_1 is None:
242
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
243
+
244
+ input_ids = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
245
+ #token_type_ids = [0 for i in range(len(input_ids))]
246
+ return input_ids
247
+
248
+ def create_token_type_ids_from_sequences(
249
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
250
+ ) -> List[int]:
251
+ """
252
+ Create the token type IDs corresponding to the sequences passed. [What are token type
253
+ IDs?](../glossary#token-type-ids)
254
+
255
+ Should be overridden in a subclass if the model has a special way of building those.
256
+
257
+ Args:
258
+ token_ids_0 (`List[int]`): The first tokenized sequence.
259
+ token_ids_1 (`List[int]`, *optional*): The second tokenized sequence.
260
+
261
+ Returns:
262
+ `List[int]`: The token type ids.
263
+ """
264
+ if token_ids_1 is None:
265
+ return (len(token_ids_0)+2) * [0]
266
+ return [0] * len(token_ids_0) + [1] * len(token_ids_1)
267
+
268
+ def batch_encode_plus(self, *args, **kwargs):
269
+ """
270
+ Extends the base `batch_encode_plus` method to add custom functionality if needed.
271
+
272
+ Args:
273
+ *args: Positional arguments passed to the base method.
274
+ **kwargs: Keyword arguments passed to the base method.
275
+
276
+ Returns:
277
+ dict: A dictionary containing the results of batch encoding.
278
+ """
279
+ # Call the parent method to handle the batch encoding
280
+ #print('Running batch encoding with ids')
281
+ act_outputs = super().batch_encode_plus(*args, **kwargs)
282
+ return act_outputs
283
+
284
+
285
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
286
+ """
287
+ Saves the tokenizer's vocabulary to a file.
288
+
289
+ Args:
290
+ save_directory (str): Directory to save the vocabulary file.
291
+ filename_prefix (str, optional): Prefix for the filename. Default is None.
292
+
293
+ Returns:
294
+ Tuple[str]: Path to the saved vocabulary file.
295
+ """
296
+ if filename_prefix is None:
297
+ filename_prefix = ""
298
+ vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.txt")
299
+ with open(vocab_file_path, "w") as f:
300
+ for token in self.vocab:
301
+ f.write(token + "\n")
302
+ return (vocab_file_path,)
303
+
304
+ def save_pretrained(self, save_directory: str, **kwargs):
305
+ """
306
+ Saves the tokenizer configuration and vocabulary to a directory.
307
+
308
+ Args:
309
+ save_directory (str): Directory to save the tokenizer files.
310
+ """
311
+ if not os.path.exists(save_directory):
312
+ os.makedirs(save_directory)
313
+ super().save_pretrained(save_directory, **kwargs)
314
+
315
+ tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
316
+ if os.path.exists(tokenizer_config_path):
317
+ with open(tokenizer_config_path, "r") as f:
318
+ tokenizer_config = json.load(f)
319
+ else:
320
+ tokenizer_config = {}
321
+
322
+ tokenizer_config.update({
323
+ "kmer": self.config.get("kmer", 6),
324
+ "shift": self.config.get("shift", 1),
325
+ })
326
+
327
+ with open(tokenizer_config_path, "w") as f:
328
+ json.dump(tokenizer_config, f, indent=2)
329
+
330
+ @classmethod
331
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
332
+ """
333
+ Loads a tokenizer from the pretrained model directory or Hugging Face Hub.
334
+
335
+ Args:
336
+ pretrained_model_name_or_path (str): Path or model name on Hugging Face Hub.
337
+ kwargs: Additional arguments for initialization.
338
+
339
+ Returns:
340
+ LCATokenizer: The loaded tokenizer instance.
341
+ """
342
+ tokenizer_config_file = hf_hub_url(
343
+ pretrained_model_name_or_path, filename="tokenizer_config.json"
344
+ )
345
+ resolved_tokenizer_config_file = cached_file(
346
+ pretrained_model_name_or_path, filename="tokenizer_config.json"
347
+ )
348
+
349
+ with open(resolved_tokenizer_config_file, "r") as f:
350
+ tokenizer_config = json.load(f)
351
+
352
+ kmer = tokenizer_config.pop("kmer", 6)
353
+ shift = tokenizer_config.pop("shift", 1)
354
+ base_tokenization_config = {'kmer': kmer, 'shift': shift}
355
+ defconfig = SeqConfig()
356
+ config = defconfig.get_and_set_tokenization_parameters(base_tokenization_config)
357
+
358
+ tokenizer = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
359
+ tokenizer.config = config
360
+
361
+ return tokenizer
362
+
363
+
tokenizer_config.json CHANGED
@@ -1,4 +1,10 @@
1
  {
 
 
 
 
 
 
2
  "clean_up_tokenization_spaces": true,
3
  "cls_token": "[CLS]",
4
  "mask_token": "[MASK]",
 
1
  {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenizer.LCATokenizer",
5
+ null
6
+ ]
7
+ },
8
  "clean_up_tokenization_spaces": true,
9
  "cls_token": "[CLS]",
10
  "mask_token": "[MASK]",