Chiquitin commited on
Commit
482fd8d
·
1 Parent(s): 3cca845

upload source code and train configurations

Browse files
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==2.3.5
2
+ torch==2.5.1+cu121
3
+ torchaudio==2.5.1+cu121
4
+ torchvision==0.20.1+cu121
5
+ tensorboard==2.20.0
6
+ matplotlib==3.10.7
7
+ datasets==4.4.1
8
+ psutil==7.1.3
9
+ spacy==3.8.11
10
+ tqdm==4.67.1
src/dataset/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ from .tokenizer import SegmentationTokenizer, SentenceSegmenter
8
+ from .dataset import SegmentationDataset
9
+ from .tokenized_dataset import TokenizedSegmentationDataset
10
+ from .config import DatasetConfig
11
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
12
+ # END OF FILE #
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class DatasetConfig:
13
+ # Paths:
14
+ train_data_path: str = None
15
+ val_data_path: str = None
16
+ test_data_path: str = None
17
+ # Percentages:
18
+ train_percentage: float = 1.0
19
+ val_percentage: float = 1.0
20
+ test_percentage: float = 1.0
21
+ # Other parameters:
22
+ num_workers: int = 0
23
+ shuffle_train: bool = True
24
+ shuffle_val: bool = True
25
+
26
+
27
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
28
+ # END OF FILE #
29
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/dataset.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from datasets import Dataset as HfDataset
11
+ from datasets import load_from_disk
12
+ from .tokenizer import SegmentationTokenizer, SentenceSegmenter
13
+
14
+
15
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
+ # #
17
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
18
+ class SegmentationDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ huggingface_dataset: str | HfDataset,
22
+ tokenizer: SegmentationTokenizer,
23
+ segmenter: SentenceSegmenter,
24
+ logger: logging.Logger = None,
25
+ percentage: float = 1.0,
26
+ return_type: type = dict
27
+ ):
28
+ """
29
+ A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
30
+ wikipedia-segmentation format. It loads the dataset and prepares it for training.
31
+
32
+ Wikipedia-segmentation format:
33
+ - The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
34
+ - The dataset should contain the following fields:
35
+ >>> sample = {
36
+ >>> 'text': ['Article 1', 'Article 2', ...],
37
+ >>> 'titles': ['Title 1', 'Title 2', ...],
38
+ >>> 'id': str,
39
+ >>> 'words': int
40
+ >>> 'paragraphs': int
41
+ >>> 'sentences': int
42
+ >>> }
43
+ - The dataset should be a list of dictionaries, where each dictionary contains the fields above.
44
+
45
+ Parameters
46
+ ----------
47
+ huggingface_dataset : str | HfDataset
48
+ A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
49
+
50
+ tokenizer : callable
51
+ A tokenizer function that takes a string and returns a list of tokens.
52
+
53
+ logger : logging.Logger, optional
54
+ Logger instance. If not provided, a null logger will be used.
55
+
56
+ percentage : float
57
+ Percentage of the dataset to use. Default is 1.0 (100%).
58
+
59
+ return_type : type
60
+ The return type of __getitem__, either dict or tuple. Default is dict.
61
+
62
+ Raises
63
+ ------
64
+ ValueError
65
+ If the huggingface_dataset is not a string or a HfDataset.
66
+ ValueError
67
+ If the tokenizer is not a callable function or class.
68
+ ValueError
69
+ If the sentence_tokenizer is not a callable function or class.
70
+ ValueError
71
+ If the dtype is not a type.
72
+
73
+ """
74
+ # Null logging:
75
+ if not isinstance(logger, logging.Logger):
76
+ self.logger = logging.getLogger("null")
77
+ self.logger.addHandler(logging.NullHandler())
78
+ else:
79
+ self.logger = logger
80
+
81
+ # Loading:
82
+ if isinstance(huggingface_dataset, HfDataset):
83
+ self.huggingface_dataset = huggingface_dataset
84
+ elif isinstance(huggingface_dataset, str):
85
+ self.huggingface_dataset = load_from_disk(huggingface_dataset)
86
+ else:
87
+ self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
88
+ raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
89
+ self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
90
+ self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
91
+
92
+ # Tokenizer:
93
+ if callable(tokenizer):
94
+ self.tokenizer = tokenizer
95
+ else:
96
+ self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
97
+ raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
98
+
99
+ # Segmenter:
100
+ if not isinstance(segmenter, SentenceSegmenter):
101
+ self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
102
+ raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
103
+ else:
104
+ self.segmenter = segmenter
105
+
106
+ # Percentage:
107
+ if not (0.0 < percentage <= 1.0):
108
+ self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
109
+ raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
110
+ else:
111
+ self.percentage = percentage
112
+
113
+ # Return type:
114
+ if not isinstance(return_type, type):
115
+ self.logger.error(f'[SegmentationDataset] return_type must be a type.')
116
+ raise ValueError(f'[SegmentationDataset] return_type must be a type.')
117
+ elif return_type not in [dict, tuple]:
118
+ self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
119
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
120
+ else:
121
+ self.return_type = return_type
122
+
123
+ def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
124
+ """
125
+ Returns a PyTorch DataLoader for this dataset.
126
+
127
+ Parameters
128
+ ----------
129
+ batch_size : int
130
+ Number of samples per batch.
131
+ shuffle : bool
132
+ Whether to shuffle the dataset.
133
+ num_workers : int
134
+ Number of worker processes.
135
+ **kwargs
136
+ Additional arguments for DataLoader.
137
+
138
+ Returns
139
+ -------
140
+ [torch.utils.data.DataLoader
141
+ Configured DataLoader.
142
+ """
143
+ # Size handling:
144
+ return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
145
+ pin_memory=True, **kwargs)
146
+
147
+ def __len__(self) -> int:
148
+ """
149
+ Returns the number of samples in the dataset.
150
+
151
+ Returns
152
+ -------
153
+ int
154
+ Total number of samples.
155
+ """
156
+ return int(self.huggingface_dataset.num_rows * self.percentage)
157
+
158
+ def __getitem__(self, idx) -> dict | tuple:
159
+ """
160
+ Retrieves a single sample and generates segmentation labels.
161
+
162
+ Parameters
163
+ ----------
164
+ idx : int
165
+ Index of the sample.
166
+
167
+ Returns
168
+ -------
169
+ tuple
170
+ A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
171
+ """
172
+ sample = self.huggingface_dataset[idx]['text']
173
+ sentences = self.segmenter(sample)
174
+ tokenized = self.tokenizer(sentences['sentences'])
175
+
176
+ if self.return_type == tuple:
177
+ return (
178
+ tokenized['input_ids'], # x
179
+ sentences['sentence_boundaries'], # y
180
+ tokenized['attention_mask'], # x_mask
181
+ sentences['sentence_mask'], # y_mask
182
+ sentences['sentence_candidates'], # y_prime_mask
183
+ )
184
+ elif self.return_type == dict:
185
+ return_value = {
186
+ 'input': tokenized['input_ids'],
187
+ 'input_mask': tokenized['attention_mask'],
188
+ 'labels': sentences['sentence_boundaries'],
189
+ 'output_mask': sentences['sentence_mask'],
190
+ 'candidate_mask': sentences['sentence_candidates']
191
+ }
192
+ else:
193
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
194
+ return return_value
195
+
196
+
197
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
198
+ # END OF FILE #
199
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/tokenized_dataset.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ import json
10
+ import os
11
+ import numpy as np
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+
15
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
+ # #
17
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
18
+ class TokenizedSegmentationDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ tokenized_dataset: str,
22
+ logger: logging.Logger = None,
23
+ percentage: float = 1.0,
24
+ return_type: type = dict
25
+ ):
26
+ """
27
+ A tokoenized segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
28
+ wikipedia-segmentation format. It loads the dataset and prepares it for training.
29
+
30
+ Wikipedia-segmentation format:
31
+ - The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
32
+ - The dataset should contain the following fields:
33
+ >>> sample = {
34
+ >>> 'text': ['Article 1', 'Article 2', ...],
35
+ >>> 'titles': ['Title 1', 'Title 2', ...],
36
+ >>> 'id': str,
37
+ >>> 'words': int
38
+ >>> 'paragraphs': int
39
+ >>> 'sentences': int
40
+ >>> }
41
+ - The dataset should be a list of dictionaries, where each dictionary contains the fields above.
42
+
43
+ Parameters
44
+ ----------
45
+ tokenized_dataset : str
46
+ A path to a tokenized dataset on disk with the wikipedia-segmentation format.
47
+
48
+ logger : logging.Logger, optional
49
+ Logger instance. If not provided, a null logger will be used.
50
+
51
+ percentage : float
52
+ Percentage of the dataset to use. Default is 1.0 (100%).
53
+
54
+ return_type : type
55
+ The return type of __getitem__, either dict or tuple. Default is dict.
56
+
57
+ Raises
58
+ ------
59
+ ValueError
60
+ If the huggingface_dataset is not a string or a HfDataset.
61
+ ValueError
62
+ If the tokenizer is not a callable function or class.
63
+ ValueError
64
+ If the sentence_tokenizer is not a callable function or class.
65
+ ValueError
66
+ If the dtype is not a type.
67
+
68
+ """
69
+ # Null logging:
70
+ if not isinstance(logger, logging.Logger):
71
+ self.logger = logging.getLogger("null")
72
+ self.logger.addHandler(logging.NullHandler())
73
+ else:
74
+ self.logger = logger
75
+
76
+ # Loading:
77
+ if isinstance(tokenized_dataset, str):
78
+ self.metadata_path = os.path.join(tokenized_dataset, 'info.json')
79
+ if not os.path.exists(self.metadata_path):
80
+ self.logger.error(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
81
+ raise FileNotFoundError(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
82
+ else:
83
+ with open(self.metadata_path, 'r', encoding='utf-8') as f:
84
+ self.metadata = json.load(f)
85
+ if 'fingerprint' not in self.metadata or not self.metadata['fingerprint']:
86
+ raise ValueError(f'[SegmentationDataset] Dataset metadata file is missing fingerprint information.')
87
+ else:
88
+ self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
89
+ raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
90
+ self.logger.info(f'[SegmentationDataset] Loaded dataset: {tokenized_dataset}')
91
+ self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.metadata["samples"]}')
92
+
93
+ # Percentage:
94
+ if not (0.0 < percentage <= 1.0):
95
+ self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
96
+ raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
97
+ else:
98
+ self.percentage = percentage
99
+
100
+ # Return type:
101
+ if not isinstance(return_type, type):
102
+ self.logger.error(f'[SegmentationDataset] return_type must be a type.')
103
+ raise ValueError(f'[SegmentationDataset] return_type must be a type.')
104
+ elif return_type not in [dict, tuple]:
105
+ self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
106
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
107
+ else:
108
+ self.return_type = return_type
109
+
110
+ self.metadata['max_sentences'] = self.metadata['x']['element_shape'][0]
111
+ self.metadata['max_tokens'] = self.metadata['x']['element_shape'][1]
112
+
113
+ # Build maps:
114
+ read_mode = 'r'
115
+ self.x_map = np.memmap(
116
+ os.path.join(tokenized_dataset, self.metadata['x']['name'] + self.metadata['x']['extension']),
117
+ dtype=self.metadata['x']['dtype'],
118
+ mode=read_mode,
119
+ shape=(self.metadata['x']['samples'], *self.metadata['x']['element_shape'])
120
+ )
121
+ self.y_map = np.memmap(
122
+ os.path.join(tokenized_dataset, self.metadata['y']['name'] + self.metadata['y']['extension']),
123
+ dtype=self.metadata['y']['dtype'],
124
+ mode=read_mode,
125
+ shape=(self.metadata['y']['samples'], *self.metadata['y']['element_shape'])
126
+ )
127
+ self.x_mask_map = np.memmap(
128
+ os.path.join(tokenized_dataset, self.metadata['x_mask']['name'] + self.metadata['x_mask']['extension']),
129
+ dtype=self.metadata['x_mask']['dtype'],
130
+ mode=read_mode,
131
+ shape=(self.metadata['x_mask']['samples'], *self.metadata['x_mask']['element_shape'])
132
+ )
133
+ self.y_mask_map = np.memmap(
134
+ os.path.join(tokenized_dataset, self.metadata['y_mask']['name'] + self.metadata['y_mask']['extension']),
135
+ dtype=self.metadata['y_mask']['dtype'],
136
+ mode=read_mode,
137
+ shape=(self.metadata['y_mask']['samples'], *self.metadata['y_mask']['element_shape'])
138
+ )
139
+ self.y_cand_map = np.memmap(
140
+ os.path.join(tokenized_dataset, self.metadata['y_cand']['name'] + self.metadata['y_cand']['extension']),
141
+ dtype=self.metadata['y_cand']['dtype'],
142
+ mode=read_mode,
143
+ shape=(self.metadata['y_cand']['samples'], *self.metadata['y_cand']['element_shape'])
144
+ )
145
+
146
+ def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
147
+ """
148
+ Returns a PyTorch DataLoader for this dataset.
149
+
150
+ Parameters
151
+ ----------
152
+ batch_size : int
153
+ Number of samples per batch.
154
+ shuffle : bool
155
+ Whether to shuffle the dataset.
156
+ num_workers : int
157
+ Number of worker processes.
158
+ **kwargs
159
+ Additional arguments for DataLoader.
160
+
161
+ Returns
162
+ -------
163
+ [torch.utils.data.DataLoader
164
+ Configured DataLoader.
165
+ """
166
+ # Size handling:
167
+ return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True,
168
+ **kwargs)
169
+
170
+ def __len__(self) -> int:
171
+ """
172
+ Returns the number of samples in the dataset.
173
+
174
+ Returns
175
+ -------
176
+ int
177
+ Total number of samples.
178
+ """
179
+ return int(self.metadata['samples'] * self.percentage)
180
+
181
+ def __getitem__(self, idx) -> dict | tuple:
182
+ """
183
+ Retrieves a single sample and generates segmentation labels.
184
+
185
+ Parameters
186
+ ----------
187
+ idx : int
188
+ Index of the sample.
189
+
190
+ Returns
191
+ -------
192
+ tuple
193
+ A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
194
+ """
195
+ if self.return_type == tuple:
196
+ return (
197
+ np.array(self.x_map[idx]), # ← copia
198
+ np.array(self.y_map[idx]),
199
+ np.array(self.x_mask_map[idx]),
200
+ np.array(self.y_mask_map[idx]),
201
+ np.array(self.y_cand_map[idx]),
202
+ )
203
+ elif self.return_type == dict:
204
+ return {
205
+ 'input': np.array(self.x_map[idx]),
206
+ 'input_mask': np.array(self.x_mask_map[idx]),
207
+ 'labels': np.array(self.y_map[idx]),
208
+ 'output_mask': np.array(self.y_mask_map[idx]),
209
+ 'candidate_mask': np.array(self.y_cand_map[idx]),
210
+ }
211
+ else:
212
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
213
+
214
+
215
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
216
+ # END OF FILE #
217
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/tokenizer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import tokenizers
9
+ import sys
10
+ import subprocess
11
+ import logging
12
+ import spacy
13
+ import numpy as np
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from tokenizers.normalizers import NFKC
18
+ from transformers import PreTrainedTokenizerFast
19
+
20
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
21
+
22
+
23
+ class SegmentationTokenizer:
24
+ def __init__(
25
+ self,
26
+ vocab_size=32_768,
27
+ min_frequency=2,
28
+ max_length=1024
29
+ ):
30
+ self.max_length = max_length
31
+
32
+ # Raw tokenizer (training)
33
+ self.raw_tokenizer = tokenizers.Tokenizer(
34
+ BPE(unk_token="[UNK]")
35
+ )
36
+ self.raw_tokenizer.normalizer = NFKC()
37
+ self.raw_tokenizer.pre_tokenizer = Whitespace()
38
+
39
+ self.trainer = BpeTrainer(
40
+ vocab_size=vocab_size,
41
+ min_frequency=min_frequency,
42
+ special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
43
+ )
44
+
45
+ self._hf_tokenizer = None # created after training
46
+
47
+ # ---------- TRAINING ----------
48
+ def build_iterator(self, dataset, batch_size=1024):
49
+ batch = []
50
+ for item in dataset:
51
+ batch.append("\n".join(item["text"]).replace("\n\n", "\n"))
52
+ if len(batch) == batch_size:
53
+ yield batch
54
+ batch = []
55
+ if batch:
56
+ yield batch
57
+
58
+ def train_from_iterator(self, iterator):
59
+ self.raw_tokenizer.train_from_iterator(
60
+ iterator, trainer=self.trainer
61
+ )
62
+
63
+ # ---------- IO ----------
64
+ def save(self, path):
65
+ self.raw_tokenizer.save(path)
66
+
67
+ def load(self, tokenizer_path):
68
+ self._hf_tokenizer = PreTrainedTokenizerFast(
69
+ tokenizer_file=tokenizer_path,
70
+ unk_token="[UNK]",
71
+ pad_token="[PAD]",
72
+ cls_token="[CLS]",
73
+ sep_token="[SEP]",
74
+ mask_token="[MASK]"
75
+ )
76
+ return self
77
+
78
+ # ---------- TOKENIZATION ----------
79
+ def compute_unk_rate(self, corpus):
80
+ unk_id = self._hf_tokenizer.convert_tokens_to_ids("[UNK]")
81
+
82
+ total_tokens = 0
83
+ unk_tokens = 0
84
+
85
+ for text in corpus:
86
+ enc = self._hf_tokenizer(
87
+ text,
88
+ add_special_tokens=False
89
+ )["input_ids"]
90
+
91
+ total_tokens += len(enc)
92
+ unk_tokens += sum(1 for t in enc if t == unk_id)
93
+
94
+ return unk_tokens / total_tokens if total_tokens > 0 else 0.0
95
+
96
+ def __call__(
97
+ self,
98
+ text,
99
+ return_tensors="pt",
100
+ padding=True,
101
+ truncation=True
102
+ ):
103
+ """
104
+ text: str or List[str]
105
+ returns: dict with input_ids and attention_mask (torch.long)
106
+ """
107
+ if self._hf_tokenizer is None:
108
+ raise RuntimeError("Tokenizer not loaded. Call .load() first.")
109
+
110
+ enc = self._hf_tokenizer(
111
+ text,
112
+ padding="max_length" if padding else False,
113
+ truncation=truncation,
114
+ max_length=self.max_length,
115
+ return_tensors=return_tensors
116
+ )
117
+
118
+ return {
119
+ "input_ids": enc["input_ids"], # torch.LongTensor
120
+ "attention_mask": enc["attention_mask"] # torch.LongTensor
121
+ }
122
+
123
+ @property
124
+ def vocab_size(self):
125
+ if self._hf_tokenizer is None:
126
+ raise RuntimeError("Tokenizer not loaded.")
127
+ return self._hf_tokenizer.vocab_size
128
+
129
+ def __repr__(self):
130
+ return f"<SegmentationTokenizer vocab_size={self.trainer.vocab_size}>"
131
+
132
+
133
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
134
+ # SENTENCE SEG #
135
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
136
+ class SentenceSegmenter:
137
+ def __init__(
138
+ self,
139
+ max_sentences: int,
140
+ spacy_model: str = "es_core_news_sm",
141
+ logger: logging.Logger | None = None
142
+ ):
143
+ self.max_sentences = max_sentences
144
+ self.logger = self._get_logger(logger)
145
+ self.nlp = self.__build_model__(spacy_model, logger=self.logger)
146
+
147
+ @staticmethod
148
+ def __build_model__(sentence_tokenizer_model: str, logger: logging.Logger) -> spacy.language.Language:
149
+ """
150
+ Download the pre-trained sentence tokenizer model.
151
+ :param sentence_tokenizer_model: The sentence tokenizer model to download.
152
+ :return: The spacy language model.
153
+ """
154
+ try:
155
+ spacy_model = spacy.load(sentence_tokenizer_model)
156
+ except OSError:
157
+ result = subprocess.run(
158
+ [sys.executable, "-m", "spacy", "download", sentence_tokenizer_model],
159
+ capture_output=True,
160
+ text=True
161
+ )
162
+
163
+ if result.returncode != 0:
164
+ logger.error(f'[BEAST-Tokenizer]: Loading {sentence_tokenizer_model} failed.')
165
+ raise RuntimeError(f"[BEAST-Tokenizer]: Error while downloading '{sentence_tokenizer_model}'")
166
+
167
+ spacy_model = spacy.load(sentence_tokenizer_model)
168
+ logger.info('[BEAST-Tokenizer]: Successfully downloaded the pre-trained sentence tokenizer model.')
169
+
170
+ if 'parser' not in spacy_model.pipe_names:
171
+ logger.error(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
172
+ raise RuntimeError(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
173
+ else:
174
+ spacy_model.add_pipe("newline_segmenter_keep_exact", before="parser")
175
+
176
+ return spacy_model
177
+
178
+ @staticmethod
179
+ def _get_logger(logger):
180
+ if logger is None:
181
+ logger = logging.getLogger(__name__)
182
+ logger.addHandler(logging.NullHandler())
183
+ return logger
184
+
185
+ def __call__(self, texts: list[str]) -> dict:
186
+ sentences = list()
187
+ sentence_candidates = list()
188
+ sentence_boundaries = list()
189
+ sentence_masking = list()
190
+
191
+ for article in texts:
192
+ doc = self.nlp(article)
193
+ for idx, sent in enumerate(doc.sents):
194
+
195
+ if idx == 0:
196
+ # Article opener
197
+ sentence_candidates.append(1)
198
+ sentence_boundaries.append(1)
199
+ elif sent.text.endswith("\n"):
200
+ # Paragraph break candidate
201
+ sentence_candidates.append(1)
202
+ sentence_boundaries.append(0)
203
+ else:
204
+ sentence_candidates.append(0)
205
+ sentence_boundaries.append(0)
206
+
207
+ sentences.append(sent.text.replace('\n', '').strip())
208
+ sentence_masking.append(1)
209
+
210
+ if len(sentences) >= self.max_sentences:
211
+ self.logger.warning(f"Maximum number of sentences reached: {self.max_sentences}")
212
+ break
213
+
214
+ if len(sentences) >= self.max_sentences:
215
+ break
216
+
217
+ # Pad with zeros:
218
+ while len(sentences) < self.max_sentences:
219
+ sentences.append("")
220
+ sentence_candidates.append(0)
221
+ sentence_boundaries.append(0)
222
+ sentence_masking.append(0)
223
+
224
+ return {
225
+ "sentences": sentences,
226
+ "sentence_candidates": np.array(sentence_candidates, dtype=np.int8),
227
+ "sentence_boundaries": np.array(sentence_boundaries, dtype=np.int8),
228
+ "sentence_mask": np.array(sentence_masking, dtype=np.int8)
229
+ }
230
+
231
+
232
+ @spacy.Language.component("newline_segmenter_keep_exact")
233
+ def newline_segmenter_keep_exact(doc):
234
+ for token in doc[:-1]:
235
+ if token.text == "\n":
236
+ doc[token.i + 1].is_sent_start = True
237
+ return doc
238
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
239
+ # END OF FILE #
240
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ from .setup import Setup
8
+ from .steps import train_step, validation_step
9
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
10
+ # END OF FILE #
11
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from .full_setup import Setup
9
+ from .hooks import HookMonitor
10
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
11
+ # END OF FILE #
12
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/clear.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import os
9
+ import shutil
10
+ import logging
11
+
12
+
13
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
14
+ def clear_logs(log_path: str):
15
+ """
16
+ Clears all the files inside log_path path.
17
+
18
+ Args:
19
+ log_path (str): The file path to be clean.
20
+
21
+ Raises:
22
+ ValueError: If the log_path is not valid.
23
+ """
24
+ # Close all loggers:
25
+ logging.getLogger().handlers.clear()
26
+ if os.path.exists(log_path):
27
+ # Clear the directory if it exists
28
+ shutil.rmtree(log_path)
29
+ else:
30
+ raise ValueError(f'Path {log_path} does not exist.')
31
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
32
+ # END OF FILE #
33
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/device.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import logging
10
+
11
+
12
+ def get_device(number: int, logger: logging.Logger = None):
13
+ """
14
+ Configures PyTorch to use a specified GPU by its index number,
15
+ or falls back to CPU if CUDA is not available.
16
+
17
+ Args:
18
+ number (int): The index number of the GPU to use.
19
+ logger (logging.Logger, optional): Logger for logging GPU info.
20
+
21
+ Returns:
22
+ torch.device: The selected torch device (GPU or CPU).
23
+ """
24
+ # Fallback to CPU if CUDA is not available
25
+ if not torch.cuda.is_available():
26
+ if logger:
27
+ logger.warning("CUDA is not available. Falling back to CPU.")
28
+ return torch.device('cpu')
29
+
30
+ # Check if the specified GPU number is valid
31
+ if number >= torch.cuda.device_count() or number < 0:
32
+ raise ValueError(
33
+ f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.")
34
+
35
+ # Clean up memory and stats
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.reset_peak_memory_stats()
38
+ torch.cuda.reset_accumulated_memory_stats()
39
+
40
+ # Set and log device
41
+ torch.cuda.set_device(number)
42
+ if logger:
43
+ logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}")
44
+
45
+ device_name = torch.cuda.get_device_name(number)
46
+ total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2
47
+ mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2
48
+ mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2
49
+ max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2
50
+ max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2
51
+
52
+ logger.info(f"[GPU {number} - {device_name}] Memory Stats:")
53
+ logger.info(f" Total Memory : {total_mem:.2f} MB")
54
+ logger.info(f" Currently Allocated : {mem_allocated:.2f} MB")
55
+ logger.info(f" Currently Reserved : {mem_reserved:.2f} MB")
56
+ logger.info(f" Max Allocated : {max_allocated:.2f} MB")
57
+ logger.info(f" Max Reserved : {max_reserved:.2f} MB")
58
+
59
+ return torch.device(f'cuda:{number}')
60
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
61
+ # END OF FILE #
62
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/full_setup.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ import torch
10
+ import os
11
+ import glob
12
+ import json
13
+ import matplotlib.pyplot as plt
14
+ from .logger import get_logger
15
+ from .tensorboard import get_writer
16
+ from .seeds import get_seed
17
+ from .device import get_device
18
+ from .clear import clear_logs
19
+ from .marker import register_replay, register
20
+ from .watchers import DEFAULT_WATCHER, S_WATCHER, A_WATCHER, B_WATCHER, C_WATCHER, CNN_WATCHER, AEN_WATCHER, TRA_WATCHER
21
+ from dataclasses import asdict
22
+
23
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
24
+ # #
25
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
26
+ class Setup:
27
+ def __init__(
28
+ self,
29
+ path: str,
30
+ device: int = 0,
31
+ seed: int = None,
32
+ save_each: int = 1,
33
+ reload_state: bool = False,
34
+ tensorboard: int | bool = 6006,
35
+ autoscaler: bool = True,
36
+ replay_element: tuple = (-1, None)
37
+ ):
38
+ """
39
+ This class is used to set up the environment for an AI experiment. It saves
40
+ the model checkpoints, logs, and tensorboard files. It also sets the device
41
+ and seed for reproducibility.
42
+
43
+ Usage:
44
+
45
+ >>> from *** import Setup
46
+ >>> setup = Setup(path='logs', device=0, seed=42, save_each=10)
47
+
48
+ Inside the train loop:
49
+
50
+ >>> model: torch.Model
51
+ >>> loss_value: torch.Tensor
52
+ >>> y: torch.Tensor
53
+ >>> y_hat: torch.Tensor
54
+
55
+ >>> setup.check(model)
56
+ >>> setup.register('loss', loss_value)
57
+ >>> setup.register_replay(y, y_hat)
58
+
59
+ In case you want to reload latest checkpoint:
60
+
61
+ >>> setup.reload(model)
62
+
63
+
64
+ :param path: The path to the logs.
65
+ :param device: The device to use.
66
+ :param seed: The seed to use.
67
+ :param save_each: The number of epochs to save the model.
68
+ :param reload_state: Whether to reload the latest checkpoint.
69
+ :param tensorboard: Whether to use tensorboard.
70
+ :param autoscaler: Whether to use autoscaler for training.
71
+ :param replay_element: The element to replay.
72
+ """
73
+ # Clear logs:
74
+ self.path = path
75
+ self.save_each = save_each
76
+ self.tensorboard_required = tensorboard
77
+ self.replay_id = replay_element
78
+ self.__epoch_count = 0
79
+
80
+ if not reload_state:
81
+ self.clear(path)
82
+
83
+ self.logger = self.set_logger(path)
84
+ self.writer, self.ch_path = self.set_writer(path, tensorboard) if tensorboard else (None, os.path.join(path, 'checkpoints'))
85
+ self.seed = self.set_seed(seed)
86
+ self.device = self.set_device(device)
87
+ self.log_setup_info()
88
+
89
+ self.watcher = DEFAULT_WATCHER
90
+ self.autoscaler = torch.amp.GradScaler(enabled=self.device.type == 'cuda') if autoscaler else None
91
+
92
+ def log_setup_info(self):
93
+ """
94
+ Log the setup information.
95
+ """
96
+ self.logger.info("Setup information:")
97
+ self.logger.info(f"- Setup path: {self.path}")
98
+ self.logger.info(f"- Setup checkpoints path: {self.ch_path}")
99
+ self.logger.info(f"- Setup device: {self.device}")
100
+ self.logger.info(f"- Setup seed: {self.seed}")
101
+ self.logger.info(f"- Setup logger: {self.logger}")
102
+ self.logger.info(f"- Setup writer: {self.writer}")
103
+ self.logger.info(f"- Setup save each: {self.save_each}")
104
+
105
+ def check(
106
+ self,
107
+ model: torch.nn.Module,
108
+ optimizer: torch.optim.Optimizer | None = None,
109
+ learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
110
+ ) -> bool:
111
+ """
112
+ Check the model and save it if the epoch count is a multiple of save_each.
113
+ :param model: The model to checkpoint and save.
114
+ :param optimizer: The optimizer to save.
115
+ :param learning_rate: The learning rate scheduler to save.
116
+ :return: If the model is checkpointed.
117
+ """
118
+ self.__epoch_count += 1
119
+ if self.save_each is not None and self.__epoch_count % self.save_each == 0:
120
+ self.logger.info(f"Checkpointing model at epoch {self.__epoch_count}")
121
+ self.save_model(
122
+ model=model,
123
+ optimizer=optimizer,
124
+ learning_rate=learning_rate
125
+ )
126
+ self.logger.info(f"Model checkpointed at epoch {self.__epoch_count}")
127
+ return True
128
+ return False
129
+
130
+ def save_model(
131
+ self,
132
+ model: torch.nn.Module,
133
+ optimizer: torch.optim.Optimizer | None = None,
134
+ learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
135
+ ):
136
+ """
137
+ Saves the model.
138
+ :param model: The model to save.
139
+ :param optimizer: The optimizer to save.
140
+ :param learning_rate: The learning rate scheduler to save.
141
+ :return: Nothing.
142
+ """
143
+ torch_state = {
144
+ 'epoch': self.__epoch_count,
145
+ 'model_state_dict': model.state_dict(),
146
+ 'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
147
+ 'scheduler_state_dict': learning_rate.state_dict() if learning_rate else None,
148
+ 'seed': self.seed
149
+ }
150
+ torch.save(torch_state, self.ch_path + f'/model_epoch_{self.__epoch_count}.pt')
151
+
152
+ def reload(
153
+ self,
154
+ model: torch.nn.Module,
155
+ optimizer: torch.optim.Optimizer | None = None,
156
+ learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
157
+ ) -> None:
158
+ """
159
+ Reloads the latest checkpoint into the given model.
160
+
161
+ :param model: The PyTorch model to reload the state into.
162
+ :param optimizer: The optimizer to reload the state into.
163
+ :param learning_rate: The learning rate scheduler to reload the state into.
164
+ """
165
+ # Find all matching checkpoints
166
+ checkpoints = glob.glob(os.path.join(self.ch_path, 'model_epoch_*.pt'))
167
+ if not checkpoints:
168
+ self.logger.warning("No checkpoint files found.")
169
+ else:
170
+ # Sort by modification time and get the latest
171
+ checkpoints.sort(key=os.path.getmtime)
172
+ latest_checkpoint = checkpoints[-1]
173
+
174
+ try:
175
+ state_dict = torch.load(latest_checkpoint, map_location=self.device)
176
+ # Load model and info:
177
+ model.load_state_dict(state_dict['model_state_dict'])
178
+ model.to(self.device)
179
+ self.__epoch_count = state_dict['epoch']
180
+ self.seed = state_dict['seed']
181
+ self.logger.info(f"Model reloaded from {latest_checkpoint} at epoch {self.__epoch_count} and "
182
+ f"seed {self.seed}")
183
+
184
+ # Load optimizer and learning rate scheduler if provided
185
+ if optimizer and state_dict['optimizer_state_dict'] is not None:
186
+ optimizer.load_state_dict(state_dict['optimizer_state_dict'])
187
+ self.logger.info(f"Optimizer state_dict loaded from {latest_checkpoint}")
188
+ if learning_rate and state_dict['scheduler_state_dict'] is not None:
189
+ learning_rate.load_state_dict(state_dict['scheduler_state_dict'])
190
+ self.logger.info(f"Scheduler state_dict loaded from {latest_checkpoint}")
191
+
192
+ except Exception as e:
193
+ self.logger.error(f"Failed to reload model from {latest_checkpoint}: {e}")
194
+ raise RuntimeError(f"Failed to reload model from {latest_checkpoint}: {e}")
195
+
196
+ def set_watcher(self, flag_names: str | list[tuple], deactivate: bool = False) -> None:
197
+ """
198
+ Sets up the parameter watcher to the tensorboard.
199
+ :param flag_names: The names of the flags to watch as a tuple of strings.
200
+ :param deactivate: Whether to deactivate the watcher.
201
+ :return: Nothing
202
+ """
203
+ if isinstance(flag_names, str):
204
+ if flag_names == 'S':
205
+ flag_names = S_WATCHER
206
+ elif flag_names == 'A':
207
+ flag_names = A_WATCHER + S_WATCHER
208
+ elif flag_names == 'B':
209
+ flag_names = S_WATCHER + A_WATCHER + B_WATCHER
210
+ elif flag_names == 'C':
211
+ flag_names = S_WATCHER + A_WATCHER + B_WATCHER + C_WATCHER
212
+ elif flag_names == 'cnn':
213
+ flag_names = CNN_WATCHER
214
+ elif flag_names == 'transformer':
215
+ flag_names = TRA_WATCHER
216
+ elif flag_names == 'ae':
217
+ flag_names = AEN_WATCHER
218
+ else:
219
+ self.logger.error(f"[WATCHER] Unknown flag name '{flag_names}'")
220
+ raise ValueError(f"[WATCHER] Unknown flag tier '{flag_names}'")
221
+
222
+ for top_name, low_name in flag_names:
223
+ if top_name not in self.watcher:
224
+ self.logger.error(f"Watcher {top_name} not found in watcher.")
225
+ raise ValueError(f"Watcher {top_name} not found in watcher.")
226
+ elif low_name not in self.watcher[top_name]:
227
+ self.logger.error(f"Watcher {low_name} not found in {top_name}.")
228
+ raise ValueError(f"Watcher {low_name} not found in {top_name}.")
229
+ else:
230
+ self.watcher[top_name][low_name] = not deactivate
231
+
232
+ def register_replay(self, predicted: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None) -> plt.Figure:
233
+ """
234
+ Visualizes predicted vs. target outputs with an optional mask.
235
+ Only positions where mask == True are shown. Each cell displays its value with two decimal places.
236
+
237
+ :param predicted: Tensor of shape (S) or (S, Y) representing the model's output.
238
+ :param target: Tensor of same shape as predicted.
239
+ :param mask: Optional boolean tensor of same shape. False positions are ignored (valid mask).
240
+ """
241
+ return register_replay(
242
+ predicted=predicted,
243
+ target=target,
244
+ valid_mask=mask,
245
+ element=self.replay_id[1],
246
+ epoch=self.__epoch_count,
247
+ writer=self.writer,
248
+ logger=self.logger,
249
+ tensorboard_required=self.tensorboard_required,
250
+ )
251
+
252
+ def register(self, name: str, parameter: float | torch.Tensor, mask: torch.Tensor = Ellipsis) -> None:
253
+ """
254
+ Registers a named parameter into the tensorboard.
255
+ :param name: The name of the parameter.
256
+ :param parameter: The parameter to register.
257
+ :param mask: The optional boolean tensor of same shape as parameter.
258
+ :return: Nothing.
259
+ """
260
+ if isinstance(parameter, torch.Tensor) and mask is Ellipsis:
261
+ mask = torch.ones_like(parameter).bool()
262
+ elif isinstance(parameter, float):
263
+ mask = Ellipsis
264
+
265
+ register(
266
+ flags=self.watcher,
267
+ tensor=parameter,
268
+ valid_mask=mask,
269
+ epoch=self.__epoch_count,
270
+ writer=self.writer,
271
+ logger=self.logger,
272
+ tensorboard_required=self.tensorboard_required,
273
+ parameter_name=name
274
+ )
275
+
276
+ def save_config(self, configuration):
277
+ """
278
+ Saves the configuration to a file.
279
+ :param configuration: A dataclasses configuration object.
280
+ :return: Nothing.
281
+ """
282
+ config_path = os.path.join(self.path, "config.json")
283
+ with open(config_path, "w") as f:
284
+ json.dump(asdict(configuration), f, indent=4)
285
+
286
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
287
+ # #
288
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
289
+ @staticmethod
290
+ def clear(path: str) -> None:
291
+ """
292
+ Clear the logs.
293
+ :param path: The path to the logs.
294
+ """
295
+ clear_logs(path)
296
+
297
+ @staticmethod
298
+ def set_logger(path: str) -> logging.Logger:
299
+ """
300
+ Set the logger.
301
+ :param path: The path to the logs.
302
+ :return: The logger.
303
+ """
304
+ return get_logger(path)
305
+
306
+ def set_writer(self, path: str, tensorboard_port: int | bool) -> tuple:
307
+ """
308
+ Get the writer.
309
+ :param path: The path to the logs.
310
+ :param tensorboard_port: The port to use for tensorboard.
311
+ :return: The writer.
312
+ """
313
+ return get_writer(path, tensorboard_port, self.logger)
314
+
315
+ def set_device(self, device: int) -> torch.device:
316
+ """
317
+ Get the device.
318
+ :param device: The device to use.
319
+ :return: The device.
320
+ """
321
+ return get_device(device, self.logger)
322
+
323
+ def set_seed(self, seed: int) -> int:
324
+ """
325
+ Get the seed.
326
+ :param seed: The seed to use.
327
+ :return: The seed.
328
+ """
329
+ return get_seed(seed, self.logger)
330
+
331
+ @property
332
+ def epoch(self):
333
+ """
334
+ Get the current epoch.
335
+ :return: The current epoch.
336
+ """
337
+ return self.__epoch_count
338
+
339
+ def __enter__(self):
340
+ return self
341
+
342
+ def __exit__(self, *exc):
343
+ if self.writer:
344
+ self.writer.close()
345
+
346
+ # Do not kill Tensor boards - We usually want the process up to analyze the train variables:
347
+ # for proc in psutil.process_iter(['pid', 'name']):
348
+ # if 'tensorboard' in proc.info['name'].lower():
349
+ # proc.terminate()
350
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
351
+ # END OF FILE #
352
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/functions.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ EPS = 1e-12
10
+
11
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
12
+ # REGISTER #
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
14
+ def watch_max(
15
+ tensor: torch.Tensor,
16
+ mask: torch.Tensor,
17
+ grad: bool = False,
18
+ ) -> float:
19
+ if grad:
20
+ return float(tensor.grad[mask].abs().max())
21
+ elif hasattr(tensor, 'data'):
22
+ return float(tensor.data[mask].abs().max())
23
+ else:
24
+ return float(tensor[mask].abs().max())
25
+
26
+ def watch_min(
27
+ tensor: torch.Tensor,
28
+ mask: torch.Tensor,
29
+ grad: bool = False,
30
+ ) -> float:
31
+ if grad:
32
+ return float(tensor.grad[mask].abs().min())
33
+ elif hasattr(tensor, 'data'):
34
+ return float(tensor.data[mask].abs().min())
35
+ else:
36
+ return float(tensor[mask].abs().min())
37
+
38
+ def watch_mean(
39
+ tensor: torch.Tensor,
40
+ mask: torch.Tensor,
41
+ grad: bool = False,
42
+ ) -> float:
43
+ if grad:
44
+ return float(tensor.grad[mask].mean())
45
+ elif hasattr(tensor, 'data'):
46
+ return float(tensor.data[mask].mean())
47
+ else:
48
+ return float(tensor[mask].mean())
49
+
50
+ def watch_var(
51
+ tensor: torch.Tensor,
52
+ mask: torch.Tensor,
53
+ grad: bool = False,
54
+ ) -> float:
55
+ if grad:
56
+ return float(tensor.grad[mask].var())
57
+ elif hasattr(tensor, 'data'):
58
+ return float(tensor.data[mask].var())
59
+ else:
60
+ return float(tensor[mask].var())
61
+
62
+ def watch_std(
63
+ tensor: torch.Tensor,
64
+ mask: torch.Tensor,
65
+ grad: bool = False,
66
+ ) -> float:
67
+ if grad:
68
+ return float(tensor.grad[mask].std())
69
+ elif hasattr(tensor, 'data'):
70
+ return float(tensor.data[mask].std())
71
+ else:
72
+ return float(tensor[mask].std())
73
+
74
+ def watch_sparsity(
75
+ tensor: torch.Tensor,
76
+ mask: torch.Tensor,
77
+ grad: bool = False,
78
+ sparsity_threshold: float = 1e-6,
79
+ ) -> float:
80
+ if grad:
81
+ return float((tensor.grad[mask].abs() <= sparsity_threshold).float().mean())
82
+ elif hasattr(tensor, 'data'):
83
+ return float((tensor.data[mask].abs() <= sparsity_threshold).float().mean())
84
+ else:
85
+ return float((tensor[mask].abs() <= sparsity_threshold).float().mean())
86
+
87
+ def watch_l1(
88
+ tensor: torch.Tensor,
89
+ mask: torch.Tensor,
90
+ grad: bool = False,
91
+ ) -> float:
92
+ if grad:
93
+ return float(tensor.grad[mask].norm(p=1))
94
+ elif hasattr(tensor, 'data'):
95
+ return float(tensor.data[mask].norm(p=1))
96
+ else:
97
+ return float(tensor[mask].norm(p=1))
98
+
99
+ def watch_l2(
100
+ tensor: torch.Tensor,
101
+ mask: torch.Tensor,
102
+ grad: bool = False,
103
+ ) -> float:
104
+ if grad:
105
+ return float(tensor.grad[mask].norm(p=2))
106
+ elif hasattr(tensor, 'data'):
107
+ return float(tensor.data[mask].norm(p=2))
108
+ else:
109
+ return float(tensor[mask].norm(p=2))
110
+
111
+ def watch_snr(
112
+ tensor: torch.Tensor,
113
+ mask: torch.Tensor,
114
+ grad: bool = False,
115
+ ) -> None | float:
116
+ std = watch_std(tensor, mask, grad=grad)
117
+ if std <= 0:
118
+ return None
119
+ elif grad:
120
+ val = float(torch.log10((tensor.grad[mask].mean()).abs() / (std + EPS)))
121
+ elif hasattr(tensor, 'data'):
122
+ val = float(torch.log10((tensor.data[mask].mean()).abs() / (std + EPS)))
123
+ else:
124
+ val = float(torch.log10((tensor[mask].mean()).abs() / (std + EPS)))
125
+ return 20 * val if val != float("-inf") else None # Check for NaN
126
+
127
+ def watch_hist(
128
+ tensor: torch.Tensor,
129
+ mask: torch.Tensor,
130
+ grad: bool = False,
131
+ ) -> torch.Tensor:
132
+ if grad:
133
+ return tensor.grad[mask]
134
+ elif hasattr(tensor, 'data'):
135
+ return tensor.data[mask]
136
+ else:
137
+ return tensor[mask]
138
+
139
+ def watch_rank(
140
+ tensor: torch.Tensor,
141
+ mask: torch.Tensor,
142
+ grad: bool = False,
143
+ threshold: float = 0.92,
144
+ ) -> None | float | int:
145
+ if grad:
146
+ work_tensor = tensor.grad
147
+ elif hasattr(tensor, 'data'):
148
+ work_tensor = tensor.data
149
+ else:
150
+ work_tensor = tensor
151
+ work_tensor = torch.multiply(work_tensor, mask.float())
152
+
153
+ if work_tensor.ndim < 2:
154
+ return None
155
+ else:
156
+ # Compute SVD and sort it:
157
+ work_tensor = torch.linalg.svdvals(work_tensor)
158
+ work_tensor = torch.sort(work_tensor, descending=True).values
159
+ # Cumulative energy:
160
+ work_tensor = torch.cumsum(work_tensor**2, dim=0) / (torch.sum(work_tensor**2) + EPS)
161
+ # Effective rank:
162
+ return float(torch.sum(work_tensor < threshold).item() + 1)
163
+
164
+ def watch_any(
165
+ tensor: torch.Tensor,
166
+ mask: torch.Tensor,
167
+ grad: bool = False,
168
+ ) -> float:
169
+ if grad:
170
+ return float(tensor.grad[mask])
171
+ elif hasattr(tensor, 'data'):
172
+ return float(tensor.data[mask])
173
+ else:
174
+ return float(tensor[mask])
175
+
176
+ def watch_power(
177
+ tensor: torch.Tensor,
178
+ mask: torch.Tensor,
179
+ grad: bool = False,
180
+ ) -> float:
181
+ if grad:
182
+ return float(10 * torch.log10((tensor.grad[mask] ** 2).mean() + EPS))
183
+ elif hasattr(tensor, 'data'):
184
+ return float(10 * torch.log10((tensor.data[mask] ** 2).mean() + EPS))
185
+ else:
186
+ return float(10 * torch.log10((tensor[mask] ** 2).mean() + EPS))
187
+
188
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
189
+ # FUNC. MAP #
190
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
191
+ REG_FUNCTION_MAP = {
192
+ # Function mapping:
193
+ 'max': watch_max,
194
+ 'min': watch_min,
195
+ 'mean': watch_mean,
196
+ 'std': watch_std,
197
+ 'var': watch_var,
198
+ 'l2': watch_l2,
199
+ 'l1': watch_l1,
200
+ 'sparsity': watch_sparsity,
201
+ 'snr': watch_snr,
202
+ 'hist': watch_hist,
203
+ 'rank': watch_rank,
204
+ 'power': watch_power,
205
+
206
+ # Gradient mapping:
207
+ 'grad_max': lambda x, y: watch_max(x, y, grad=True),
208
+ 'grad_min': lambda x, y: watch_min(x, y, grad=True),
209
+ 'grad_mean': lambda x, y: watch_mean(x, y, grad=True),
210
+ 'grad_std': lambda x, y: watch_std(x, y, grad=True),
211
+ 'grad_var': lambda x, y: watch_var(x, y, grad=True),
212
+ 'grad_l1': lambda x, y: watch_l1(x, y, grad=True),
213
+ 'grad_l2': lambda x, y: watch_l2(x, y, grad=True),
214
+ 'grad_sparsity': lambda x, y: watch_sparsity(x, y, grad=True),
215
+ 'grad_snr': lambda x, y: watch_snr(x, y, grad=True),
216
+ 'grad_hist': lambda x, y: watch_hist(x, y, grad=True),
217
+ 'grad_rank': lambda x, y: watch_rank(x, y, grad=True),
218
+ 'grad_power': lambda x, y: watch_power(x, y, grad=True),
219
+
220
+ # Loss:
221
+ 'loss': watch_any,
222
+ 'val_loss': watch_any,
223
+ 'lr': watch_any
224
+ }
225
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
226
+ # END OF FILE #
227
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/hooks.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
2
+ # START OF FILE #
3
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
4
+ import logging
5
+ import torch
6
+ from .functions import REG_FUNCTION_MAP
7
+
8
+
9
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
10
+ # #
11
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
12
+ class HookMonitor:
13
+ """
14
+ Monitors forward activations and backward gradients of a PyTorch model by
15
+ registering hooks on all its submodules. The monitor computes per-layer
16
+ statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward
17
+ and backward passes, and provides normalized results at the end.
18
+
19
+ This class is designed to be lightweight, safe (uses no_grad for activation
20
+ hooks), and usable as a context manager to automate attachment and cleanup
21
+ of hooks.
22
+
23
+ ----------------------------------------
24
+ Core Behavior
25
+ ----------------------------------------
26
+ - During the forward pass:
27
+ • A forward hook receives (module, input, output).
28
+ • The activation tensor is detached and cast to float.
29
+ • For each registered metric in REG_FUNCTION_MAP, if its watcher flag
30
+ is enabled, the metric is computed and accumulated.
31
+ • A gradient hook is registered on the output tensor so that gradient
32
+ statistics can also be collected during backpropagation.
33
+
34
+ - During backpropagation:
35
+ • The gradient hook receives the gradient tensor for the activation.
36
+ • Any metric marked as `grad_<metric>` in the watcher dictionary will be
37
+ applied to the gradient tensor and accumulated.
38
+
39
+ - Statistics:
40
+ • For each metric, the class tracks both the accumulated value and a
41
+ "/valid/" counter.
42
+ • `get_stats()` returns normalized statistics (sum / valid_count) for
43
+ each metric per layer.
44
+
45
+ ----------------------------------------
46
+ Parameters
47
+ ----------------------------------------
48
+ model : torch.nn.Module
49
+ The model whose modules will be monitored. All submodules returned by
50
+ `model.named_modules()` will receive a forward hook.
51
+
52
+ watcher : dict
53
+ A dictionary mapping metric names to boolean flags. Keys must match the
54
+ names used in `REG_FUNCTION_MAP`. Example:
55
+ {
56
+ "mean": True,
57
+ "std": True,
58
+ "grad_mean": True
59
+ }
60
+
61
+ Metrics not enabled here will not be computed.
62
+
63
+ logger : logging.Logger
64
+ A Logger used to report errors, debugging information, and warnings.
65
+
66
+ ----------------------------------------
67
+ Attributes
68
+ ----------------------------------------
69
+ stats : dict
70
+ Nested dictionary storing accumulated statistics per layer. Normalized
71
+ results are returned by `get_stats()`.
72
+
73
+ handles : list
74
+ A List of hook handles returned by `register_forward_hook`. These are
75
+ stored to later remove all hooks safely.
76
+
77
+ ----------------------------------------
78
+ Usage Example
79
+ ----------------------------------------
80
+ >>> model: torch.nn.Module
81
+ >>> watcher: dict[str, bool]
82
+ >>> logger: logging.Logger
83
+ >>> x: torch.Tensor
84
+ >>> loss: torch.nn.Module # Loss
85
+
86
+ >>> monitor = HookMonitor(model, watcher, logger)
87
+ >>> monitor.attach()
88
+ >>> output = model(x)
89
+ >>> loss.backward()
90
+ >>> stats = monitor.get_stats()
91
+ >>> monitor.remove()
92
+
93
+ Or using a context manager:
94
+
95
+ >>> with HookMonitor(model, watcher, logger) as monitor:
96
+ ... output = model(x)
97
+ ... loss.backward()
98
+ >>> stats = monitor.get_stats()
99
+
100
+ ----------------------------------------
101
+ Notes
102
+ ----------------------------------------
103
+ - The gradient hook is attached to the activation tensor (module output),
104
+ not to model parameters.
105
+ - No gradients are tracked during forward hooks thanks to @torch.no_grad().
106
+ - The monitor does not interfere with the training process: it only reads
107
+ activations and gradients.
108
+ - Missing '/valid/' counters trigger an error log and skip normalization for
109
+ that metric.
110
+
111
+ """
112
+ def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger):
113
+ """
114
+ Initialize a HookMonitor instance to track activation and gradient
115
+ statistics across all modules of a PyTorch model.
116
+
117
+ This constructor does not attach any hooks yet; it simply stores the
118
+ monitoring configuration. Hooks are registered only when `attach()` or
119
+ the context manager (`with HookMonitor(...)`) is used.
120
+
121
+ Parameters
122
+ ----------
123
+ model : torch.nn.Module
124
+ The model whose internal modules will be monitored. Every submodule
125
+ returned by `model.named_modules()` will receive a forward hook.
126
+
127
+ watcher : dict
128
+ Dictionary of boolean flags controlling which statistics should be
129
+ computed. Keys must match the names in `REG_FUNCTION_MAP`.
130
+ Example:
131
+ {
132
+ "mean": True,
133
+ "std": False,
134
+ "grad_mean": True
135
+ }
136
+
137
+ Any metric not enabled here will not be computed during execution.
138
+
139
+ logger : logging.Logger
140
+ Logging instance used for reporting errors, debug messages and
141
+ warnings during monitoring operations.
142
+
143
+ Attributes Initialized
144
+ ----------------------
145
+ model : torch.nn.Module
146
+ Stored reference to the monitored model.
147
+
148
+ watcher : dict
149
+ The watcher configuration controlling metric activation.
150
+
151
+ stats : dict
152
+ Internal dictionary used to accumulate statistics across all layers.
153
+
154
+ handles : list
155
+ A List of hook handles created when calling `.attach()`. Each handle
156
+ is later used to safely remove hooks with `.remove()`.
157
+
158
+ Notes
159
+ -----
160
+ - No hooks are installed at construction time.
161
+ - The monitor becomes active only after calling `.attach()` or entering
162
+ a `with` block.
163
+ """
164
+ self.logger: logging.Logger = logger
165
+ self.model: torch.nn.Module = model
166
+ self.watcher: dict = watcher
167
+ self.stats: dict = dict()
168
+ self.handles: list = list()
169
+
170
+ def _build_hook(self, name):
171
+
172
+ @torch.no_grad()
173
+ def hook(*args):
174
+ _, _, act = args
175
+
176
+ if torch.is_tensor(act):
177
+ act_detached = act.detach().float()
178
+ s = self.stats.setdefault(name, {})
179
+
180
+ # Call functions:
181
+ for function_name, compute_function in REG_FUNCTION_MAP.items():
182
+ if self.watcher.get(function_name, False) and not function_name.startswith('grad_'):
183
+ value = compute_function(act_detached, ...)
184
+ if value is not None:
185
+ s[function_name] = s.get(function_name, 0.0) + value
186
+ s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1
187
+
188
+ # Grad hook:
189
+ def grad_hook(grad):
190
+ gd = grad.detach().float()
191
+ # Call functions:
192
+ for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items():
193
+ if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'):
194
+ gd_function_name = 'grad_' + gd_function_name
195
+ gd_value = gd_compute_function(gd, ...)
196
+ if gd_value is not None:
197
+ s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value
198
+ s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1
199
+
200
+ if act.requires_grad:
201
+ act.register_hook(grad_hook)
202
+
203
+ return hook
204
+
205
+ def get_stats(self) -> dict:
206
+ """
207
+ Get the statistics of the hooks.
208
+ :return: A dictionary with the statistics.
209
+ """
210
+ stats = dict()
211
+ for layer_name, layer_stats in self.stats.items():
212
+ sub_stats = dict()
213
+ for key, item in layer_stats.items():
214
+ if '/valid/' not in key:
215
+ if key + '/valid/' in layer_stats:
216
+ sub_stats[key] = item / layer_stats[key + '/valid/']
217
+ else:
218
+ self.logger.error(f"Key {key} has no valid count, skipping normalization.")
219
+ sub_stats[key] = item
220
+ stats[layer_name] = sub_stats
221
+ return stats
222
+
223
+ def attach(self):
224
+ """
225
+ Registers all the hooks in the model.
226
+ :return: The object.
227
+ """
228
+ for name, module in self.model.named_modules():
229
+ h = module.register_forward_hook(self._build_hook(name))
230
+ self.handles.append(h)
231
+ return self
232
+
233
+ def clear(self):
234
+ """
235
+ Clear stats' dictionary.
236
+ :return: Nothing
237
+ """
238
+ self.stats.clear()
239
+
240
+ def remove(self):
241
+ """
242
+ Remove all the hooks from the model.
243
+ :return: Nothing.
244
+ """
245
+ for h in self.handles:
246
+ h.remove()
247
+ self.handles.clear()
248
+
249
+ def __enter__(self):
250
+ self.logger.debug("[Hooks] Attaching HookMonitor...")
251
+ return self.attach()
252
+
253
+ def __exit__(self, exc_type, exc_val, exc_tb):
254
+ self.logger.debug("[Hooks] Removing HookMonitor...")
255
+ self.remove()
256
+
257
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
258
+ # END OF FILE #
259
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
src/dlutils/setup/logger.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ import os
10
+
11
+
12
+ def get_logger(log_path: str, level: int | str = logging.INFO) -> logging.Logger:
13
+ """
14
+ Sets up a logger for debugging with colored output to the console and output to a specified log file.
15
+ Creates the directory if it does not exist.
16
+
17
+ Args:
18
+ log_path (str): The file path where the log file 'logfile.log' will be stored.
19
+ level (int | str): The logging level to be printed on the logger.
20
+
21
+ Raises:
22
+ ValueError: If the log_path is not valid.
23
+ """
24
+ # Check if log_path exists, create it if not
25
+ if not os.path.exists(log_path):
26
+ os.makedirs(log_path, exist_ok=True)
27
+ elif not os.path.isdir(log_path):
28
+ raise ValueError(f"Provided path '{log_path}' is not a directory.")
29
+
30
+ full_log_path = os.path.join(log_path, 'logfile.log')
31
+
32
+ # Transform level:
33
+ if isinstance(level, str):
34
+ level = level.upper()
35
+ if hasattr(logging, level):
36
+ level = getattr(logging, level)
37
+ else:
38
+ raise ValueError(f'The provided level for the logger <<{level}>> is not a valid level for logging.')
39
+ elif not isinstance(level, int):
40
+ raise ValueError(f'The provided level for the logger <<{level}>> is not a string or int, '
41
+ f'the given type is <<{type(level)}>>.')
42
+
43
+ # Create a logger object
44
+ logger = logging.getLogger(__name__)
45
+ logger.handlers.clear() # Avoid duplicates
46
+ logger.setLevel(level) # Set the logging level to the given level
47
+ logger.propagate = False # Prevent duplication in logging output
48
+
49
+ # Create file handler which logs even debug messages
50
+ fh = logging.FileHandler(full_log_path)
51
+ fh.setLevel(level)
52
+ fh.setFormatter(logging.Formatter('%(asctime)s: [%(levelname)s] %(message)s'))
53
+
54
+ # Create console handler with a colored formatter
55
+ ch = logging.StreamHandler()
56
+ ch.setLevel(level)
57
+ ch.setFormatter(ColoredFormatter())
58
+
59
+ # Add handlers to the logger
60
+ logger.addHandler(fh)
61
+ logger.addHandler(ch)
62
+
63
+ logger.info(f'Logger initialized with writer handler at: {full_log_path}')
64
+
65
+ return logger
66
+
67
+
68
+ class ColoredFormatter(logging.Formatter):
69
+ grey = "\x1b[38;20m"
70
+ blue = "\x1b[34;20m"
71
+ cyan = "\x1b[36;20m"
72
+ orange = "\x1b[33;20m"
73
+ red = "\x1b[31;20m"
74
+ reset = "\x1b[0m"
75
+ format = '%(asctime)s: [%(levelname)s] %(message)s'
76
+
77
+ FORMATS = {
78
+ logging.DEBUG: blue + format + reset,
79
+ logging.INFO: cyan + format + reset,
80
+ logging.WARNING: orange + format + reset,
81
+ logging.ERROR: red + format + reset,
82
+ logging.CRITICAL: red + format + reset
83
+ }
84
+
85
+ def format(self, record):
86
+ log_fmt = self.FORMATS.get(record.levelno)
87
+ formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S")
88
+ return formatter.format(record)
89
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
90
+ # END OF FILE #
91
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/marker.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import logging
10
+ import numpy as np
11
+ import io
12
+ import math
13
+ import random
14
+ from PIL import Image
15
+ from matplotlib import pyplot as plt
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from torchvision import transforms
18
+ from .functions import REG_FUNCTION_MAP
19
+
20
+
21
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
22
+ # REGISTER #
23
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
24
+ @torch.no_grad()
25
+ def register(
26
+ flags: dict,
27
+ tensor: float | torch.Tensor,
28
+ valid_mask: torch.Tensor,
29
+ epoch: int,
30
+ writer: SummaryWriter,
31
+ logger: logging.Logger,
32
+ tensorboard_required: bool,
33
+ parameter_name: str = ''
34
+ ):
35
+ """
36
+ Registers a parameter according to the register flags (DEFAULT_WATCHER style).
37
+
38
+ :param flags: A specific watch flag.
39
+ :param tensor: The tensor to register.
40
+ :param valid_mask: The valid mask to apply.
41
+ :param epoch: The current epoch.
42
+ :param writer: The tensorboard writer.
43
+ :param logger: The logger.
44
+ :param tensorboard_required: Whether the tensorboard writer is required.
45
+ :param parameter_name: The name of the parameter.
46
+ :return:
47
+ """
48
+ # 1. Detect tensor type:
49
+ if isinstance(tensor, torch.nn.Parameter):
50
+ flag_type = 'parameters'
51
+ elif isinstance(tensor, torch.Tensor):
52
+ # Intermediate activation:
53
+ flag_type = 'activations'
54
+ elif isinstance(tensor, float):
55
+ flag_type = 'train'
56
+ else:
57
+ raise ValueError(f"{type(tensor)} is not a torch.nn.Parameter or torch.Tensor.")
58
+
59
+ # 2. Build the tensor names:
60
+ safe_names = list()
61
+ # Check if the group is active:
62
+ if flag_type == 'parameters':
63
+ for flag_key, flag_value in flags['parameters'].items():
64
+ # Add if active:
65
+ if flag_value:
66
+ safe_names.append((f'{flag_type}/{flag_key}/{parameter_name}/', flag_key))
67
+ else:
68
+ safe_names.append((f'{flag_type}/{parameter_name}/', ''))
69
+
70
+
71
+ # 3. Write and compute each required variable:
72
+ for name, flag_key in safe_names:
73
+ # Compute the value:
74
+ transformation = None
75
+ if isinstance(tensor, torch.nn.Parameter):
76
+ if tensor.grad is not None and 'grad' in flag_key:
77
+ transformation = REG_FUNCTION_MAP[flag_key](tensor, valid_mask)
78
+ else:
79
+ transformation = float(tensor) if tensor is not None else None
80
+ # Write the value in tensorboard:
81
+ if transformation is not None:
82
+ write_tensorboard(
83
+ name=name,
84
+ value=transformation,
85
+ epoch=epoch,
86
+ writer=writer,
87
+ logger=logger,
88
+ tensorboard_required=tensorboard_required,
89
+ )
90
+
91
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
92
+ # REPLAY #
93
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
94
+ @torch.no_grad()
95
+ def register_replay(
96
+ predicted: torch.Tensor,
97
+ target: torch.Tensor,
98
+ epoch: int,
99
+ writer: SummaryWriter,
100
+ logger: logging.Logger,
101
+ valid_mask: torch.Tensor = Ellipsis,
102
+ element: int = None,
103
+ tensorboard_required: bool = True,
104
+ ) -> plt.Figure:
105
+ """
106
+ Registers a replay as an image.
107
+ :param predicted: The predicted value (prediction).
108
+ :param target: The expected value (labels).
109
+ :param epoch: The current epoch.
110
+ :param writer: The tensorboard writer.
111
+ :param logger: The logger.
112
+ :param valid_mask: A valid mask tensor of same shape. False positions are ignored (valid mask).
113
+ :param element: The element to register, None chooses a random batch element.
114
+ :param tensorboard_required: Whether the tensorboard writer is required.
115
+ :return: A matplotlib figure.
116
+ """
117
+ # Choose random element:
118
+ if element is None:
119
+ element = random.randint(0, len(predicted) - 1)
120
+ else:
121
+ element = min(len(predicted) - 1, max(0, element))
122
+
123
+ # Convert the chosen to numpy:
124
+ predicted_np = predicted[element].detach().cpu().numpy()
125
+ target_np = target[element].detach().cpu().numpy()
126
+
127
+ # Categorical to vector:
128
+ if not target_np.shape:
129
+ target_np_aux = np.zeros_like(predicted_np)
130
+ target_np_aux[target_np] = 1.
131
+ target_np = target_np_aux
132
+ del target_np_aux
133
+
134
+ # Mask the valid positions:
135
+ if valid_mask is not None:
136
+ mask_np = valid_mask[element].detach().cpu().numpy().astype(bool)
137
+ else:
138
+ mask_np = np.ones_like(predicted_np, dtype=bool)
139
+
140
+ # Apply mask and flatten:
141
+ predicted_flat = predicted_np[mask_np].flatten()
142
+ target_flat = target_np[mask_np].flatten()
143
+
144
+ # Compute square size B:
145
+ s = predicted_flat.shape[0]
146
+ b = math.ceil(math.sqrt(s))
147
+ total = b * b
148
+ pad = total - s
149
+
150
+ # Pad with zeros:
151
+ predicted_padded = np.pad(predicted_flat, (0, pad), constant_values=0.0).reshape(b, b)
152
+ target_padded = np.pad(target_flat, (0, pad), constant_values=0.0).reshape(b, b)
153
+
154
+ # Build figure:
155
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
156
+ plot_with_values(axs[0], predicted_padded, "Predicted (y_hat)")
157
+ plot_with_values(axs[1], target_padded, "Target (y)")
158
+ plt.tight_layout()
159
+ write_tensorboard(
160
+ 'replay/',
161
+ fig,
162
+ epoch=epoch,
163
+ writer=writer,
164
+ logger=logger,
165
+ tensorboard_required=tensorboard_required,
166
+ )
167
+ return fig
168
+
169
+ def plot_with_values(ax, data, title):
170
+ """
171
+ Plots data with values and title.
172
+ :param ax: A matplotlib axes.
173
+ :param data: A numpy array.
174
+ :param title: The title of the plot.
175
+ :return:
176
+ """
177
+ ax.imshow(data, cmap='viridis', interpolation='nearest')
178
+ ax.set_title(title)
179
+ ax.axis('off')
180
+ for i in range(data.shape[0]):
181
+ for j in range(data.shape[1]):
182
+ text_color = "white" if data[i, j] < 0.5 else "black"
183
+ ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=8)
184
+
185
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
186
+ # WRITE ON BASE #
187
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
188
+ def write_tensorboard(
189
+ name: str,
190
+ value: int | float | plt.Figure | np.ndarray | torch.Tensor,
191
+ epoch: int,
192
+ writer: SummaryWriter,
193
+ logger: logging.Logger,
194
+ tensorboard_required: bool = True,
195
+ ) -> None:
196
+ """
197
+ Write to tensorboard.
198
+ :param name: The name of the tensorboard.
199
+ :param value: The value to write.
200
+ :param epoch: The current epoch.
201
+ :param writer: The tensorboard writer.
202
+ :param logger: The logger.
203
+ :param tensorboard_required: Whether the tensorboard writer is required.
204
+ """
205
+ # Check if the writer is None
206
+ if writer is None:
207
+ if tensorboard_required:
208
+ logger.warning("Writer is None. Please set the writer first.")
209
+ return
210
+ # Check if the value is None
211
+ if value is None:
212
+ logger.warning("Value is None. Please set the value first.")
213
+ return
214
+ # Check if the name is None
215
+ if name is None:
216
+ logger.warning("Name is None. Please set the name first.")
217
+ return
218
+
219
+ # Type check:
220
+ if isinstance(value, int):
221
+ writer.add_scalar(name, float(value), epoch)
222
+ elif isinstance(value, float):
223
+ writer.add_scalar(name, value, epoch)
224
+ elif isinstance(value, torch.Tensor):
225
+ value = value.detach().cpu().numpy()
226
+ writer.add_histogram(name, value, epoch)
227
+ elif isinstance(value, list):
228
+ value = np.array(value)
229
+ writer.add_histogram(name, value, epoch)
230
+ elif isinstance(value, np.ndarray):
231
+ writer.add_histogram(name, value, epoch)
232
+ elif isinstance(value, str):
233
+ writer.add_text(name, value, epoch)
234
+ elif isinstance(value, bytes):
235
+ image = Image.open(io.BytesIO(value))
236
+ transform = transforms.ToTensor()
237
+ value = transform(image)
238
+ writer.add_image(name, value, epoch)
239
+ elif isinstance(value, plt.Figure):
240
+ buf = io.BytesIO()
241
+ value.savefig(buf, format='png')
242
+ buf.seek(0)
243
+ image = Image.open(buf)
244
+ image = transforms.ToTensor()(image)
245
+ writer.add_image(name, image, epoch)
246
+ plt.close()
247
+ else:
248
+ raise ValueError(f"Type {type(value)} not supported.")
249
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
250
+ # END OF FILE #
251
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/seeds.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ import torch
10
+ import os
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
15
+ CUBLAS_ALLOCATION = 4096
16
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
17
+
18
+
19
+ def get_seed(seed: int = None, logger: logging.Logger = None) -> int:
20
+ """
21
+ Sets the seed for generating random numbers to ensure reproducibility across numpy, random, and PyTorch operations.
22
+ If no seed is provided, a new seed is generated based on the current time.
23
+
24
+ This function also configures PyTorch to ensure deterministic behavior when running on a GPU, including the setting
25
+ of environment variables to influence the behavior of CUDA's cuBLAS library.
26
+
27
+ Args:
28
+ seed (int, optional): The seed for the random number generators. If None, the seed will be generated based on
29
+ the current system time.
30
+ logger (logging.Logger): The logger that traces the logging information.
31
+
32
+ Returns:
33
+ int: The seed used to initialize the random number generators.
34
+
35
+ Example:
36
+ >>> experiment_seed = get_seed()
37
+ Sets a random seed based on the current time and ensures that all subsequent random operations are reproducible.
38
+
39
+ >>> experiment_seed = get_seed(42)
40
+ >>> # experiment_seed == 42
41
+ Uses 42 as the seed for all random number generators to ensure reproducibility.
42
+ """
43
+ # Set environment variable for deterministic behavior on CUDA >= 10.2
44
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = f":{CUBLAS_ALLOCATION}:8"
45
+
46
+ # Create a new seed if not provided:
47
+ seed = seed if seed is not None else int(time.time())
48
+
49
+ # Set seed for numpy and random
50
+ np.random.seed(seed)
51
+ random.seed(seed)
52
+
53
+ # Set seed and deterministic algorithms for torch
54
+ torch.manual_seed(seed)
55
+ torch.backends.cudnn.allow_tf32 = False
56
+ torch.use_deterministic_algorithms(True, warn_only=True)
57
+
58
+ # Ensure all operations are deterministic on GPU (if available)
59
+ if torch.cuda.is_available():
60
+ torch.cuda.manual_seed(seed)
61
+ torch.cuda.manual_seed_all(seed)
62
+ torch.backends.cudnn.deterministic = True
63
+ torch.backends.cudnn.benchmark = False
64
+
65
+ # Return the generated or bypassed seed:
66
+ if logger is not None:
67
+ logger.info(f"Initializer set up seed: {seed}")
68
+ return seed
69
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
70
+ # END OF FILE #
71
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/tensorboard.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ import os
10
+ import psutil
11
+ import time
12
+ import subprocess
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
15
+ DEFAULT_TENSORBOARD_PORT = 6006
16
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
17
+
18
+
19
+ def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None):
20
+ """
21
+ Sets up a TensorBoard logging and checkpoint directory for PyTorch.
22
+
23
+ This function clears the specified directory, creates subdirectories for TensorBoard logs
24
+ and model checkpoints, ensuring a clean environment for running new training sessions.
25
+
26
+ Args:
27
+ path (str): The root directory where TensorBoard logs and checkpoints will be stored.
28
+ tensorboard_port (int): The port on which to run the TensorBoard.
29
+ logger (logging.Logger): The logger that traces the logging information.
30
+
31
+ Returns:
32
+ tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints.
33
+
34
+ Example:
35
+ >>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/')
36
+ """
37
+ # Check tensorboard port:
38
+ if tensorboard_port is True:
39
+ tensorboard_port = DEFAULT_TENSORBOARD_PORT
40
+ elif tensorboard_port is False:
41
+ return None, os.path.join(path, 'checkpoints')
42
+
43
+ # Create subdirectories for logs and checkpoints
44
+ logs_path = os.path.join(path, 'logs')
45
+ checkpoints_path = os.path.join(path, 'checkpoints')
46
+ os.makedirs(logs_path, exist_ok=True)
47
+ os.makedirs(checkpoints_path, exist_ok=True)
48
+
49
+ # Set up TensorBoard logging
50
+ writer = SummaryWriter(log_dir=logs_path)
51
+
52
+ # Print paths where logs and checkpoints will be stored
53
+ if logger is not None:
54
+ logger.info(f"TensorBoard logs will be stored in: {logs_path}")
55
+ logger.info(f"Model checkpoints will be stored in: {checkpoints_path}")
56
+
57
+ # Launch tensorboard:
58
+ for conn in psutil.net_connections(kind='inet'):
59
+ if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN:
60
+ if logger is not None:
61
+ logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}")
62
+ p = psutil.Process(conn.pid)
63
+ p.terminate()
64
+ p.wait(timeout=3)
65
+ time.sleep(5)
66
+ process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}',
67
+ shell=True)
68
+ if logger is not None:
69
+ logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})')
70
+
71
+ return writer, checkpoints_path
72
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
73
+ # END OF FILE #
74
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dlutils/setup/watchers.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
2
+ # DEFAULT WATCH #
3
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
4
+ DEFAULT_WATCHER = {
5
+ 'train': {
6
+ 'loss': True,
7
+ 'lr': False,
8
+ 'val_loss': True
9
+ },
10
+ 'parameters': {
11
+ 'max': False,
12
+ 'min': False,
13
+ 'mean': False,
14
+ 'std': False,
15
+ 'var': False,
16
+ 'hist': False,
17
+ 'l2': False,
18
+ 'l1': False,
19
+ 'sparsity': False,
20
+ 'snr': False,
21
+ 'rank': False,
22
+ 'power': False,
23
+
24
+ # Gradients:
25
+ 'grad_max': False,
26
+ 'grad_min': False,
27
+ 'grad_mean': False,
28
+ 'grad_std': False,
29
+ 'grad_var': False,
30
+ 'grad_hist': False,
31
+ 'grad_l2': False,
32
+ 'grad_l1': False,
33
+ 'grad_sparsity': False,
34
+ 'grad_snr': False,
35
+ 'grad_rank': False,
36
+ 'grad_power': False
37
+ },
38
+ 'activations': {
39
+ 'max': False,
40
+ 'min': False,
41
+ 'mean': False,
42
+ 'std': False,
43
+ 'var': False,
44
+ 'hist': False,
45
+ 'l2': False,
46
+ 'l1': False,
47
+ 'sparsity': False,
48
+ 'snr': False,
49
+ 'rank': False,
50
+ 'power': False,
51
+
52
+ # Gradients:
53
+ 'grad_max': False,
54
+ 'grad_min': False,
55
+ 'grad_mean': False,
56
+ 'grad_std': False,
57
+ 'grad_var': False,
58
+ 'grad_hist': False,
59
+ 'grad_l2': False,
60
+ 'grad_l1': False,
61
+ 'grad_sparsity': False,
62
+ 'grad_snr': False,
63
+ 'grad_rank': False,
64
+ 'grad_power': False
65
+ }
66
+ }
67
+ # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
68
+ # SPECIFIC WATCH #
69
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
70
+ # [PA] Performance analysis.
71
+ # [GF] Gradient flow.
72
+ # [AD] Activation death.
73
+ # [NT] Network topology.
74
+
75
+ S_WATCHER = [
76
+ ('train', 'loss'), # [TOP] [PA] Evolución del entrenamiento.
77
+ ('train', 'val_loss'), # [TOP] [PA] Generalización / overfitting.
78
+ ('parameters', 'grad_power'), # [TOP] [GF] Flujo de gradiente, explosión/vanishing global.
79
+ ('parameters', 'grad_mean'), # [TOP] [NT] Capas muertas / inútiles (mean grad ~ 0).
80
+ ('parameters', 'grad_max'), # [TOP] [GF] Picos de grad -> clipping / LR.
81
+ ('activations', 'grad_power'), # [TOP] [GF] Flujo de grad por capa (muy informativa).
82
+ ('activations', 'sparsity'), # [TOP] [AD] ReLU death / atención colapsada.
83
+ ]
84
+
85
+ A_WATCHER = [
86
+ ('train', 'lr'), # [USEFUL] [PA] Seguir el scheduler / warmup.
87
+ ('parameters', 'l2'), # [USEFUL] [PA] Norm de pesos, regularización / weight decay.
88
+ ('parameters', 'power'), # [USEFUL] [PA] Escala de pesos / posibles explosiones.
89
+ ('parameters', 'grad_snr'), # [USEFUL] [GF] Coherencia señal/ruido del grad.
90
+ ('parameters', 'rank'), # [USEFUL] [NT] Capacidad efectiva / colapso de parámetros.
91
+ ('activations', 'mean'), # [USEFUL] [NT] Shift de activaciones / mala init.
92
+ ('activations', 'std'), # [USEFUL] [NT] Propagación de señal entre capas.
93
+ ('activations', 'snr'), # [USEFUL] [NT] Coherencia de señal entre capas.
94
+ ('activations', 'grad_snr'), # [USEFUL] [GF] Coherencia del grad por capa.
95
+ ]
96
+
97
+ B_WATCHER = [
98
+ ('activations', 'hist'), # [UTILITY] [AD] Visualizar colas raras / saturaciones.
99
+ ('parameters', 'snr'), # [UTILITY] [NT] Coherencia global de pesos (rank suele ser mejor).
100
+ ('parameters', 'grad_l2'), # [UTILITY] [GF] Similar a grad_power pero menos intuitiva.
101
+ ('parameters', 'hist'), # [UTILITY] [PA] Ver distribución de pesos (debug puntual).
102
+ ('activations', 'l2'), # [UTILITY] [NT] Magnitud de activaciones (redundante con std/power).
103
+ ('activations', 'l1'), # [UTILITY] [NT] Similar a l2; a veces útil en AEs sparsos.
104
+ ]
105
+
106
+ C_WATCHER = [
107
+ ('parameters', 'max'), # [LOW] [PA] Útil sólo para detectar NaNs / inf puntuales.
108
+ ('parameters', 'min'), # [LOW] [PA] Igual que max, poco signal.
109
+ ('parameters', 'mean'), # [LOW] [PA] Poco interpretable sin más contexto.
110
+ ('parameters', 'std'), # [LOW] [PA] Redundante con power / l2.
111
+ ('parameters', 'var'), # [LOW] [PA] Redundante con std.
112
+ ('parameters', 'grad_var'), # [LOW] [GF] Redundante con grad_std.
113
+ ('parameters', 'grad_hist'), # [LOW] [GF] Visualización puntual, no para logging continuo.
114
+ ('activations', 'min'), # [LOW] [NT] Rara vez dice algo que std/mean no digan.
115
+ ('activations', 'max'), # [LOW] [NT] Sólo útil para comprobar clamps/NaNs.
116
+ ('activations', 'var'), # [LOW] [NT] Redundante con std.
117
+ ('activations', 'grad_hist'), # [LOW] [GF] Igual que grad_hist de parámetros, solo visual.
118
+ ('activations', 'grad_var'), # [LOW] [GF] Redundante con grad_std/grad_power.
119
+ ]
120
+
121
+ CNN_WATCHER = [
122
+ ('train', 'loss'), # [TOP] [PA] Fit de entrenamiento.
123
+ ('train', 'val_loss'), # [TOP] [PA] Generalización (Imagenette/ImageNet).
124
+ ('parameters', 'grad_power'), # [TOP] [GF] Explosión/vanishing global del grad.
125
+ ('parameters', 'grad_max'), # [TOP] [GF] Picos por capa -> clipping.
126
+ ('activations', 'grad_power'), # [TOP] [GF] Grad por bloque conv / head.
127
+ ('activations', 'sparsity'), # [TOP] [AD] Dead ReLU / capas muertas.
128
+ ('activations', 'std'), # [USEFUL] [NT] Propagación de señal (init, BN).
129
+ ('parameters', 'l2'), # [USEFUL] [PA] Control de norm de pesos / decay.
130
+ ]
131
+
132
+ TRA_WATCHER = [
133
+ ('train', 'loss'), # [TOP] [PA] Fit del modelo (LM / seq2seq / cls).
134
+ ('train', 'val_loss'), # [TOP] [PA] Generalización / overfitting.
135
+ ('train', 'lr'), # [USEFUL] [PA] Warmup, cosine, etc.
136
+ ('parameters', 'grad_power'), # [TOP] [GF] Explosión/vanishing en profundidad.
137
+ ('parameters', 'grad_snr'), # [USEFUL] [GF] SNR de grad en bloques de atención/MLP.
138
+ ('activations', 'grad_power'), # [TOP] [GF] Flujo de grad por layer encoder/decoder.
139
+ ('activations', 'mean'), # [USEFUL] [NT] Drift en LayerNorm / RMSNorm.
140
+ ('activations', 'std'), # [USEFUL] [NT] Propagación en profundidad (residuals).
141
+ ('parameters', 'l2'), # [USEFUL] [PA] Tamaño de pesos en attention/MLP.
142
+ ]
143
+
144
+ AEN_WATCHER = [
145
+ ('train', 'loss'), # [TOP] [PA] Reconstr / contrastive / VAE loss.
146
+ ('train', 'val_loss'), # [TOP] [PA] Generalización del AE.
147
+ ('parameters', 'grad_power'), # [TOP] [GF] Flujo de grad encoder/decoder.
148
+ ('activations', 'sparsity'), # [TOP] [AD] Codificadores sparsos / muerte de neuronas.
149
+ ('activations', 'rank'), # [USEFUL] [NT] Colapso de representación / baja dimensión efectiva.
150
+ ('parameters', 'power'), # [USEFUL] [PA] Pesos del decoder explotando o colapsando.
151
+ ('activations', 'grad_power'), # [TOP] [GF] Grad por capa en encoder/decoder.
152
+ ('parameters', 'l2'), # [USEFUL] [PA] Norm de pesos, sobretodo en AEs profundos.
153
+ ]
src/dlutils/steps.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import numpy as np
10
+ import tqdm
11
+ from .setup import Setup, HookMonitor
12
+
13
+
14
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
15
+ # #
16
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
17
+ def train_step(
18
+ # Always granted:
19
+ model: torch.nn.Module,
20
+ data: torch.utils.data.DataLoader,
21
+ loss: torch.nn.Module,
22
+ optimizer: torch.optim.Optimizer,
23
+ controller: Setup,
24
+ # Not always granted:
25
+ scheduler: torch.optim.lr_scheduler.LRScheduler = None,
26
+ ) -> float:
27
+ """
28
+ Performs a single training step including forward pass, loss calculation, backward pass,
29
+ and optimization step.
30
+
31
+ Parameters:
32
+ model (torch.nn.Module): The model to be trained.
33
+ data (torch.utils.data.DataLoader): DataLoader providing the training data.
34
+ loss (torch.nn.Module): Loss function to be used.
35
+ optimizer (torch.optim.Optimizer): Optimizer used for gradient updates.
36
+ controller (Setup): The setup object containing configuration and state.
37
+ scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler to adjust the learning rate.
38
+ Returns:
39
+ float: The mean loss value for this training step.
40
+ """
41
+ # Train mode:
42
+ model.to(controller.device)
43
+ model.train()
44
+
45
+ # Train the model for dataloaders or iterators:
46
+ losses = list()
47
+
48
+ with HookMonitor(model, controller.watcher['activations'], controller.logger) as hooks:
49
+ with tqdm.tqdm(data, desc=f'\rTraining epoch {controller.epoch}', leave=True) as pbar:
50
+ pbar: torch.DataLoader
51
+ hooks: HookMonitor
52
+
53
+ for i, element in enumerate(pbar):
54
+
55
+ # 1. Gather elements:
56
+ args = tuple()
57
+ if len(element) == 2:
58
+ # Prediction:
59
+ x, y = element
60
+ x_m, y_m = None, None
61
+ elif len(element) == 3:
62
+ # Prediction with x_mask:
63
+ x, y, x_m = element
64
+ y_m = None
65
+ elif len(element) == 4:
66
+ # Prediction with x_mask and y_mask:
67
+ x, y, x_m, y_m = element
68
+ elif len(element) > 4:
69
+ # More input arguments:
70
+ x, y = element[0], element[1]
71
+ x_m, y_m = element[2], element[3]
72
+ args = element[4:]
73
+ else:
74
+ raise ValueError("DataLoader elements must have at least two elements.")
75
+
76
+ # 2. Load data to device:
77
+ x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
78
+ optimizer.zero_grad()
79
+ if x_m is not None:
80
+ x_m = x_m.to(controller.device, non_blocking=True)
81
+ if y_m is not None:
82
+ y_m = y_m.to(controller.device, non_blocking=True)
83
+
84
+ # 3. TRAIN - Control autocast (mem-speed):
85
+ if controller.autoscaler is not None:
86
+ with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type):
87
+ # Forward:
88
+ y_hat = model(x, x_m, *args) if x_m is not None else model(x)
89
+ loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
90
+ # Backward:
91
+ controller.autoscaler.scale(loss_metric).backward()
92
+ controller.autoscaler.step(optimizer)
93
+ controller.autoscaler.update()
94
+ else:
95
+ # Forward:
96
+ y_hat = model(x, x_m, *args) if x_m is not None else model(x)
97
+ loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
98
+ # Backward:
99
+ loss_metric.backward()
100
+ optimizer.step()
101
+
102
+ # 4. Append to metrics:
103
+ losses.append(loss_metric.item())
104
+
105
+ # 5. Monitor hooks:
106
+ if controller.replay_id[0] == i:
107
+ controller.register_replay(predicted=y_hat, target=y, mask=y_m)
108
+
109
+ # Write in summary writer (per epoch):
110
+ losses = np.array(losses)
111
+ mean_loss = float(np.mean(losses))
112
+
113
+ # ================ WATCH ================
114
+ # Register parameters:
115
+ for name, parameter in model.named_parameters():
116
+ controller.register(name, parameter)
117
+
118
+ # Register train:
119
+ controller.register('loss', mean_loss)
120
+
121
+ # Register hooks:
122
+ for layer_name, layer_stats in hooks.get_stats().items():
123
+ for func_name, item in layer_stats.items():
124
+ controller.register(f'{func_name}/{layer_name}', torch.Tensor([item])[0])
125
+
126
+ # ================ CONTROL ================
127
+ # Scheduler step:
128
+ if scheduler is not None:
129
+ controller.register('lr', scheduler.get_last_lr()[0])
130
+ scheduler.step()
131
+
132
+ # Write for logger:
133
+ controller.logger.info(f"Epoch [{controller.epoch}]: loss = {mean_loss:.8f}")
134
+
135
+ # Checkpointing:
136
+ controller.check(model, optimizer, scheduler)
137
+
138
+ return mean_loss
139
+
140
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
141
+ # #
142
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
143
+ def validation_step(
144
+ # Always granted:
145
+ model: torch.nn.Module,
146
+ data: torch.utils.data.DataLoader,
147
+ loss: torch.nn.Module,
148
+ controller: Setup,
149
+ additional_metrics: dict = (),
150
+ ) -> dict:
151
+ """
152
+ Performs a single validation step including forward pass and loss calculation.
153
+
154
+ Parameters:
155
+ model (torch.nn.Module): The model to be validated.
156
+ data (torch.utils.data.DataLoader): DataLoader providing the validation data.
157
+ loss (torch.nn.Module): Loss function to be used.
158
+ controller (Setup): The setup object containing configuration and state.
159
+ additional_metrics (dict): Additional metrics to calculate for each epoch.
160
+ Returns:
161
+ float: The mean loss value for this validation step.
162
+ """
163
+ # Validation mode:
164
+ model.to(controller.device)
165
+ model.eval()
166
+
167
+ # Validation the model for dataloaders or iterators:
168
+ losses = list()
169
+ metrics: dict[str, list | float] = {name: list() for name in additional_metrics}
170
+
171
+ with torch.no_grad():
172
+ with tqdm.tqdm(data, desc=f'\rValidation epoch {controller.epoch}', leave=True) as pbar:
173
+ pbar: torch.DataLoader
174
+ for element in pbar:
175
+ # Gather elements:
176
+ if len(element) == 2:
177
+ # Prediction:
178
+ x, y = element
179
+ x_m, y_m = None, None
180
+ args = tuple()
181
+ elif len(element) == 3:
182
+ # Prediction with x_mask:
183
+ x, y, x_m = element
184
+ y_m = None
185
+ args = tuple()
186
+ elif len(element) == 4:
187
+ # Prediction with x_mask and y_mask:
188
+ x, y, x_m, y_m = element
189
+ elif len(element) > 4:
190
+ # More input arguments:
191
+ x, y = element[0], element[1]
192
+ x_m, y_m = element[2], element[3]
193
+ args = element[4:]
194
+ else:
195
+ raise ValueError("DataLoader elements must have at least two elements.")
196
+
197
+ # Load data to device:
198
+ x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
199
+ if x_m is not None:
200
+ x_m = x_m.to(controller.device, non_blocking=True)
201
+ if y_m is not None:
202
+ y_m = y_m.to(controller.device, non_blocking=True)
203
+
204
+ # Control autocast (mem-speed):
205
+ if controller.autoscaler is not None:
206
+ with torch.amp.autocast(enabled=(controller.device.type == 'cuda'),
207
+ device_type=controller.device.type):
208
+ # Forward:
209
+ y_hat = model(x, x_m, *args) if x_m is not None else model(x)
210
+ loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
211
+
212
+ # Compute additional metrics:
213
+ if additional_metrics:
214
+ for name, additional_metric in additional_metrics.items():
215
+ metrics[name].append(additional_metric(y_hat, y, y_m).item())
216
+ else:
217
+ # Forward:
218
+ y_hat = model(x, x_m, *args) if x_m is not None else model(x)
219
+ loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
220
+
221
+ # Compute additional metrics:
222
+ if additional_metrics:
223
+ for name, additional_metric in additional_metrics.items():
224
+ metrics[name].append(additional_metric(y_hat, y, y_m).item())
225
+
226
+ # Append to metrics:
227
+ losses.append(loss_metric.item())
228
+
229
+ # Convert:
230
+ losses = np.array(losses)
231
+ mean_loss = float(np.mean(losses))
232
+
233
+ # Additional metrics:
234
+ for name, variable in metrics.items():
235
+ metrics[name] = float(np.mean(variable))
236
+ metrics['loss'] = mean_loss
237
+
238
+ # Write to register:
239
+ controller.register("val_loss", mean_loss)
240
+ # Write for logger:
241
+ controller.logger.info(f"Epoch [{controller.epoch}]: val_loss = {mean_loss:.8f}")
242
+
243
+ return metrics
244
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
245
+ # END OF FILE #
246
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from .config import ModelConfig, TransformerConfig, CoSeNetConfig
9
+ from .segmentation import SegmentationNetwork
10
+ from .loss import MaskedBCELoss
11
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
12
+ # END OF FILE #
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from typing import List
9
+ from dataclasses import dataclass, field
10
+
11
+
12
+ @dataclass
13
+ class CoSeNetConfig:
14
+ trainable: bool = True
15
+ init_scale: float = 5.0
16
+
17
+
18
+ @dataclass
19
+ class TransformerConfig:
20
+ attention_heads: int = 8
21
+ feed_forward_multiplier: float = 4
22
+ dropout: float = 0.0
23
+ pre_normalize: bool = True
24
+
25
+
26
+ @dataclass
27
+ class ModelConfig:
28
+ vocab_size: int = 2 ** 15
29
+ model_dim: int = 256
30
+ max_tokens: int = 382
31
+ max_sentences: int = 384
32
+ valid_padding: bool = True
33
+ cosenet: CoSeNetConfig = field(default_factory=CoSeNetConfig)
34
+ transformers: List[TransformerConfig] = field(default_factory=lambda: [TransformerConfig()])
35
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
36
+ # END OF FILE #
37
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/cosenet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ from .cosenet import CoSeNet
8
+ from .cosine_distance import CosineDistanceLayer
9
+ from .trainable_sigmoid import TrainableSigmoid
10
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
11
+ # END OF FILE #
12
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/cosenet/cosenet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import os
10
+ import numpy as np
11
+ from .cosenet_layer import CoSeNetLayer
12
+ from .trainable_sigmoid import TrainableSigmoid
13
+
14
+
15
+ class CoSeNet(torch.nn.Module):
16
+ """
17
+ PyTorch's implementation of the CoSeNet architecture.
18
+
19
+ This module loads pre-trained CoSeNet weights and applies a structured
20
+ unfolding–linear–folding pipeline to the input tensor. An optional
21
+ trainable sigmoid adaptation is applied to the input prior to the
22
+ CoSeNet transformation.
23
+
24
+ The architecture assumes that the input data represent structured
25
+ matrices (e.g., similarity or distance matrices) and performs
26
+ diagonal-based unfolding with overlapping windows.
27
+ """
28
+
29
+ def __init__(self, trainable: bool = False, init_scale: float = 5.0, **kwargs):
30
+ """
31
+ Initialize the CoSeNet model.
32
+
33
+ Pre-trained weights and biases are loaded from disk and used to
34
+ construct the internal CoSeNet layer. Optionally, the parameters
35
+ can be set as trainable.
36
+
37
+ Args:
38
+ trainable (bool, optional): Whether the CoSeNet linear layer
39
+ parameters should be trainable. Defaults to False.
40
+ init_scale (float, optional): Initial scale for the trainable
41
+ sigmoid adaptation module. Defaults to 5.0.
42
+ **kwargs: Additional keyword arguments forwarded to
43
+ `torch.nn.Module`.
44
+
45
+ Raises:
46
+ FileNotFoundError: If the weight or bias files cannot be found.
47
+ """
48
+ super().__init__(**kwargs)
49
+
50
+ # Load weights:
51
+ this_file_name = os.path.dirname(os.path.abspath(__file__))
52
+ w_path = os.path.join(this_file_name, 'weights', 'w.npy')
53
+ b_path = os.path.join(this_file_name, 'weights', 'b.npy')
54
+
55
+ if not os.path.exists(w_path):
56
+ raise FileNotFoundError(f'CoSeNet weight file {w_path} does not exist.')
57
+ if not os.path.exists(b_path):
58
+ raise FileNotFoundError(f'CoSeNet bias file {b_path} does not exist.')
59
+
60
+ w, b = np.load(w_path), np.load(b_path)
61
+
62
+ # Build layers:
63
+ self.matrix_shape = int(np.sqrt(w.shape[-1]))
64
+ self.layer = CoSeNetLayer(w, b, trainable=trainable)
65
+ self.adaptation = TrainableSigmoid(init_scale=init_scale)
66
+
67
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
68
+ """
69
+ Forward pass of the CoSeNet model.
70
+
71
+ The input is first adapted using a trainable sigmoid, then padded,
72
+ unfolded along the diagonal, processed by the CoSeNet linear layer,
73
+ and finally folded back into its original structure. An optional
74
+ external mask can be applied to the output.
75
+
76
+ Args:
77
+ x (torch.Tensor): Input tensor containing structured matrix data.
78
+ mask (torch.Tensor, optional): Optional mask tensor applied
79
+ element-wise to the output. Defaults to None.
80
+
81
+ Returns:
82
+ torch.Tensor: Output tensor with the same spatial structure as
83
+ the input.
84
+ """
85
+ # check dimension:
86
+ if x.dim() < 2:
87
+ raise ValueError(f'CoSeNet input: at least 2 dimensions required. (got {x.dim()})')
88
+ # Check perfect square:
89
+ if x.shape[-1] != x.shape[-2]:
90
+ raise ValueError(f'CoSeNet input: last two dimensions must be equal. ({x.shape[-2]} != {x.shape[-1]})')
91
+
92
+ adapted_x = self.adaptation(x)
93
+ pad_x, pad_mask = self.__cosenet_padding(adapted_x)
94
+ unfold_x = self.__unfold(pad_x)
95
+ unfold_y = self.layer(unfold_x)
96
+ y = self.__fold(unfold_y, pad_mask)
97
+
98
+ if mask is not None:
99
+ y = torch.multiply(y, mask)
100
+
101
+ return y
102
+
103
+ def __unfold(self, x: torch.Tensor) -> torch.Tensor:
104
+ """
105
+ Unfold the input tensor into overlapping diagonal blocks.
106
+
107
+ The unfolding is performed using a sliding window over the last
108
+ two dimensions, followed by diagonal extraction. The stride is
109
+ determined by half of the matrix size.
110
+
111
+ Args:
112
+ x (torch.Tensor): Padded input tensor.
113
+
114
+ Returns:
115
+ torch.Tensor: Tensor containing unfolded diagonal blocks with
116
+ shape [..., K, L, L], where K is the number of extracted blocks.
117
+ """
118
+ step = max(1, self.matrix_shape // 2)
119
+ u = x.unfold(-2, self.matrix_shape, step).unfold(-2, self.matrix_shape, step)
120
+ y = u.diagonal(offset=0, dim1=-4, dim2=-3).movedim(-1, 1)
121
+ return y
122
+
123
+ @staticmethod
124
+ def __fold(x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Fold unfolded CoSeNet outputs back into a full matrix.
127
+
128
+ Overlapping regions are combined using an averaging strategy to
129
+ account for multiple contributions to the same spatial location.
130
+
131
+ Args:
132
+ x (torch.Tensor): Tensor containing unfolded CoSeNet outputs.
133
+ pad_mask (torch.Tensor): Boolean mask indicating valid (non-padded)
134
+ positions.
135
+
136
+ Returns:
137
+ torch.Tensor: Folded tensor with padding removed and original
138
+ structure restored.
139
+ """
140
+ if x.shape[-2] > 1:
141
+ y = torch.zeros(
142
+ list(x.shape[:-2]) + [x.shape[-1] * (x.shape[-2] + 1) // 2],
143
+ device=x.device,
144
+ )
145
+ t = x.shape[-1] // 2
146
+
147
+ for i in range(x.shape[-2]):
148
+ y[..., i * t + 1: t * (i + 2)] += 0.5 * x[..., i, 1:]
149
+ y[..., i * t] *= 2
150
+
151
+ y[..., :t] *= 2
152
+ y[..., -t:] *= 2
153
+ y[..., 0] = 1
154
+ else:
155
+ y = x[..., 0, :]
156
+
157
+ return y[pad_mask].view(pad_mask.shape)
158
+
159
+ def __cosenet_padding(self, x: torch.Tensor) -> tuple:
160
+ """
161
+ Pad the input tensor to match the required matrix shape.
162
+
163
+ Padding is applied along the last two dimensions to ensure that
164
+ their sizes are multiples of the CoSeNet matrix shape. A diagonal
165
+ mask is generated to distinguish padded elements.
166
+
167
+ Args:
168
+ x (torch.Tensor): Original input tensor.
169
+
170
+ Returns:
171
+ tuple:
172
+ - torch.Tensor: Padded tensor with diagonal correction.
173
+ - torch.Tensor: Boolean mask indicating valid entries.
174
+ """
175
+ pad_w = (self.matrix_shape - (x.shape[-1] % self.matrix_shape)) % self.matrix_shape
176
+ pad_h = (self.matrix_shape - (x.shape[-2] % self.matrix_shape)) % self.matrix_shape
177
+
178
+ x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
179
+
180
+ diag = x.diagonal(dim1=-2, dim2=-1)
181
+ mask_bool = (diag == 0)
182
+ mask01 = mask_bool.to(x.dtype)
183
+
184
+ x = x + torch.diag_embed(mask01)
185
+
186
+ return x, torch.logical_not(mask_bool)
187
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
188
+ # END OF FILE #
189
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/cosenet/cosenet_layer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ class CoSeNetLayer(torch.nn.Module):
13
+ """
14
+ Linear layer for CoSeNet with optional trainable parameters.
15
+
16
+ This module implements a single linear transformation used within
17
+ CoSeNet, assuming the input has already been padded and segmented.
18
+ The layer supports fixed (non-trainable) or learnable weights and
19
+ biases, enabling its use in both frozen and fine-tuning scenarios.
20
+ """
21
+
22
+ def __init__(self, coef: np.ndarray, intercept: np.ndarray, trainable: bool = False, **kwargs):
23
+ """
24
+ Initialize the CoSeNet layer.
25
+
26
+ Args:
27
+ coef (np.ndarray): Weight matrix used for the linear transformation.
28
+ intercept (np.ndarray): Bias vector added to the linear output.
29
+ trainable (bool, optional): Whether the weights and bias should be
30
+ optimized during training. Defaults to False.
31
+ **kwargs: Additional keyword arguments forwarded to
32
+ `torch.nn.Module`.
33
+ """
34
+ super().__init__(**kwargs)
35
+ self.weight = torch.nn.Parameter(torch.tensor(coef, dtype=torch.float32), requires_grad=trainable)
36
+ self.bias = torch.nn.Parameter(torch.tensor(intercept, dtype=torch.float32), requires_grad=trainable)
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Apply the linear transformation to the input tensor.
41
+
42
+ The input tensor is flattened across the last two dimensions
43
+ before applying the linear operation.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor with shape [..., *, *], where the
47
+ last two dimensions are flattened prior to the linear mapping.
48
+
49
+ Returns:
50
+ torch.Tensor: Output tensor resulting from the linear transformation.
51
+ """
52
+ return torch.nn.functional.linear(x.flatten(-2), self.weight, self.bias)
53
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
54
+ # END OF FILE #
55
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/cosenet/cosine_distance.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class CosineDistanceLayer(torch.nn.Module):
12
+ """
13
+ Pairwise cosine distance computation layer.
14
+
15
+ This module computes pairwise cosine-based distances between embedding
16
+ vectors within the same input tensor. The operation is performed along
17
+ the last dimension, producing a square similarity (or distance) matrix
18
+ for each leading batch dimension.
19
+ """
20
+
21
+ def __init__(self, **kwargs):
22
+ """
23
+ Initialize the cosine distance layer.
24
+
25
+ Args:
26
+ **kwargs: Additional keyword arguments forwarded to
27
+ `torch.nn.Module`.
28
+ """
29
+ super().__init__(**kwargs)
30
+
31
+ @staticmethod
32
+ def forward(x: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Compute pairwise cosine similarity between embeddings.
35
+
36
+ The input embeddings are L2-normalized along the last dimension
37
+ before computing the cosine similarity matrix. The absolute value
38
+ of the similarity is returned, treating opposite directions as
39
+ related.
40
+
41
+ Args:
42
+ x (torch.Tensor): Input tensor of shape [..., S, D], where
43
+ `S` is the number of embeddings and `D` is the embedding
44
+ dimensionality.
45
+
46
+ Returns:
47
+ torch.Tensor: Tensor of shape [..., S, S] containing the
48
+ pair-wise cosine similarities.
49
+ """
50
+ # Normalize for last dim:
51
+ x_norm = torch.nn.functional.normalize(x, p=2, dim=-1) # [..., S, D]
52
+ # Cosine similarity
53
+ sim = torch.matmul(x_norm, x_norm.transpose(-2, -1)) # [..., S, S]
54
+ return torch.abs(sim)
55
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
56
+ # END OF FILE #
57
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/cosenet/trainable_sigmoid.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class TrainableSigmoid(torch.nn.Module):
12
+ """
13
+ Trainable sigmoid activation module with learnable scaling.
14
+
15
+ This module implements a sigmoid function whose slope is controlled by
16
+ a trainable parameter. It is designed to adaptively rescale input values
17
+ (e.g., distances or similarity scores) around a fixed midpoint (0.5),
18
+ allowing the model to learn the appropriate sharpness of the transition
19
+ during training.
20
+ """
21
+
22
+ def __init__(self, init_scale: float = 5.0, **kwargs):
23
+ """
24
+ Initialize the trainable sigmoid module.
25
+
26
+ Args:
27
+ init_scale (float, optional): Initial magnitude of the sigmoid
28
+ scaling factor. Internally, the learnable parameter is
29
+ initialized as the negative of this value. Defaults to 5.0.
30
+ **kwargs: Additional keyword arguments forwarded to
31
+ `torch.nn.Module`.
32
+ """
33
+ super().__init__(**kwargs)
34
+ self.alpha = torch.nn.Parameter(torch.tensor(-init_scale, dtype=torch.float32))
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Apply the trainable sigmoid transformation to the input tensor.
39
+
40
+ The transformation is centered at 0.5 and scaled by a learnable
41
+ parameter, enabling adaptive control over the sigmoid steepness.
42
+
43
+ Args:
44
+ x (torch.Tensor): Input tensor containing values to be transformed.
45
+
46
+ Returns:
47
+ torch.Tensor: Tensor of the same shape as `x`, with the trainable
48
+ sigmoid function applied element-wise.
49
+ """
50
+ return 1 / (1 + torch.exp(self.alpha * (x - 0.5)))
51
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
52
+ # END OF FILE #
53
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/loss.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class MaskedBCELoss(torch.nn.Module):
12
+ """
13
+ Binary Cross-Entropy loss with explicit masking support.
14
+
15
+ This loss function computes the binary cross-entropy over valid (non-padded)
16
+ elements only, as indicated by a boolean mask. It supports both logits and
17
+ probability inputs, and provides configurable reduction strategies.
18
+
19
+ Masking semantics can be adapted to match PyTorch-style padding conventions
20
+ or custom masking schemes.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ reduction: str = 'mean',
26
+ valid_pad: bool = True,
27
+ eps: float = 1e-7,
28
+ logits: bool = True
29
+ ):
30
+ """
31
+ Initialize the masked binary cross-entropy loss.
32
+
33
+ Args:
34
+ reduction (str, optional): Reduction method applied over valid
35
+ elements. Must be either `'mean'` or `'sum'`. Defaults to `'mean'`.
36
+ valid_pad (bool, optional): Mask interpretation mode. If True,
37
+ `True` values in the mask indicate valid (non-padded) positions.
38
+ If False, `True` values indicate padded positions, following
39
+ PyTorch-style padding conventions. Defaults to True.
40
+ eps (float, optional): Small numerical constant used to clamp
41
+ probability inputs when `logits=False`. Defaults to 1e-7.
42
+ logits (bool, optional): Whether the input predictions are logits.
43
+ If True, `binary_cross_entropy_with_logits` is used; otherwise,
44
+ standard binary cross-entropy is applied. Defaults to True.
45
+
46
+ Raises:
47
+ ValueError: If an unsupported reduction mode is provided.
48
+ """
49
+ super().__init__()
50
+
51
+ if reduction not in ['mean', 'sum']:
52
+ raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'")
53
+
54
+ self.reduction = reduction
55
+ self.valid_pad = valid_pad
56
+ self.logits = logits
57
+ self.eps = eps
58
+
59
+ if logits:
60
+ self.loss = torch.nn.functional.binary_cross_entropy_with_logits
61
+ else:
62
+ self.loss = torch.nn.functional.binary_cross_entropy
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ y: torch.Tensor,
68
+ mask: torch.Tensor
69
+ ) -> torch.Tensor:
70
+ """
71
+ Compute the masked binary cross-entropy loss.
72
+
73
+ Args:
74
+ x (torch.Tensor): Model predictions with shape (B, S). If
75
+ `logits=True`, values are interpreted as logits; otherwise,
76
+ as probabilities in [0, 1].
77
+ y (torch.Tensor): Ground-truth binary labels with shape (B, S).
78
+ mask (torch.Tensor): Boolean mask tensor with shape (B, S).
79
+ The interpretation of the mask depends on `valid_pad`.
80
+ If `valid_pad=True`, `True` indicates valid positions.
81
+ If `valid_pad=False`, `True` indicates padded positions.
82
+
83
+ Returns:
84
+ torch.Tensor: Scalar tensor containing the reduced loss value.
85
+ """
86
+ # Determine valid positions:
87
+ if self.valid_pad:
88
+ valid_mask = mask
89
+ else:
90
+ valid_mask = torch.logical_not(mask)
91
+
92
+ # Numerical stability for probability inputs:
93
+ if not self.logits:
94
+ x = x.clamp(self.eps, 1.0 - self.eps)
95
+
96
+ # Element-wise BCE:
97
+ loss_per_token = self.loss(
98
+ x.float(),
99
+ y.float(),
100
+ reduction='none'
101
+ )
102
+
103
+ # Mask padded positions:
104
+ masked_loss = loss_per_token * valid_mask.float()
105
+
106
+ if self.reduction == 'mean':
107
+ denom = valid_mask.sum().clamp(min=1)
108
+ return masked_loss.sum() / denom
109
+ elif self.reduction == 'sum':
110
+ return masked_loss.sum()
111
+ else:
112
+ raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'")
113
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
114
+ # END OF FILE #
115
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/segmentation.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ from .config import ModelConfig
10
+ from .cosenet import CosineDistanceLayer, CoSeNet
11
+ from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
12
+
13
+
14
+ class SegmentationNetwork(torch.nn.Module):
15
+ """
16
+ Segmentation network combining Transformer encoders with CoSeNet.
17
+
18
+ This model integrates token embeddings and positional encodings with
19
+ a stack of Transformer encoder blocks to produce contextualized
20
+ representations. These representations are then processed by a
21
+ CoSeNet module to perform structured segmentation, followed by a
22
+ cosine-based distance computation.
23
+
24
+ The final output is a pair-wise distance matrix suitable for
25
+ segmentation or boundary detection tasks.
26
+ """
27
+ def __init__(self, model_config: ModelConfig, **kwargs):
28
+ """
29
+ Initialize the segmentation network.
30
+
31
+ The network is composed of an embedding layer, positional encoding,
32
+ multiple Transformer encoder blocks, a CoSeNet segmentation module,
33
+ and a cosine distance layer.
34
+
35
+ Args:
36
+ model_config (ModelConfig): Configuration object containing all
37
+ hyperparameters required to build the model, including
38
+ vocabulary size, model dimensionality, transformer settings,
39
+ and CoSeNet parameters.
40
+ **kwargs: Additional keyword arguments forwarded to
41
+ `torch.nn.Module`.
42
+ """
43
+ super().__init__(**kwargs)
44
+ self.valid_padding = model_config.valid_padding
45
+
46
+ # Build layers:
47
+ self.embedding = torch.nn.Embedding(
48
+ model_config.vocab_size,
49
+ model_config.model_dim
50
+ )
51
+ self.positional_encoding = PositionalEncoding(
52
+ emb_dim=model_config.model_dim,
53
+ max_len=model_config.max_tokens
54
+ )
55
+ self.cosenet = CoSeNet(
56
+ trainable=model_config.cosenet.trainable,
57
+ init_scale=model_config.cosenet.init_scale
58
+ )
59
+ self.distance_layer = CosineDistanceLayer()
60
+ self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
61
+
62
+ # Build encoder blocks:
63
+ module_list = list()
64
+ for transformer_config in model_config.transformers:
65
+ encoder_block = EncoderBlock(
66
+ feature_dim=model_config.model_dim,
67
+ attention_heads=transformer_config.attention_heads,
68
+ feed_forward_multiplier=transformer_config.feed_forward_multiplier,
69
+ dropout=transformer_config.dropout,
70
+ valid_padding=model_config.valid_padding,
71
+ pre_normalize=transformer_config.pre_normalize
72
+ )
73
+ module_list.append(encoder_block)
74
+
75
+ self.encoder_blocks = torch.nn.ModuleList(module_list)
76
+
77
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
78
+ """
79
+ Forward pass of the segmentation network.
80
+
81
+ The input token indices are embedded and enriched with positional
82
+ information, then processed by a stack of Transformer encoder
83
+ blocks. The resulting representations are segmented using CoSeNet
84
+ and finally transformed into a pair-wise distance representation.
85
+
86
+ Args:
87
+ x (torch.Tensor): Input tensor of token indices with shape
88
+ (batch_size, sequence_length).
89
+ mask (torch.Tensor, optional): Optional mask tensor indicating
90
+ valid or padded positions, depending on the configuration
91
+ of the Transformer blocks. Defaults to None.
92
+
93
+ If `valid_padding` is disabled, the mask is inverted before being
94
+ passed to CoSeNet to match its masking convention.
95
+
96
+ candidate_mask (torch.Tensor, optional): Optional mask tensor for
97
+ candidate positions in CoSeNet. Defaults to None.
98
+
99
+ If `valid_padding` is disabled, the mask is inverted before being
100
+ passed to CoSeNet to match its masking convention.
101
+
102
+ Returns:
103
+ torch.Tensor: Output tensor containing pairwise distance values
104
+ derived from the segmented representations.
105
+ """
106
+ # Convert to type:
107
+ x = x.int()
108
+
109
+ # Embedding and positional encoding:
110
+ x = self.embedding(x)
111
+ x = self.positional_encoding(x)
112
+
113
+ # Reshape x and mask:
114
+ _b, _s, _t, _d = x.shape
115
+ x = x.reshape(_b * _s, _t, _d)
116
+ if mask is not None:
117
+ mask = mask.reshape(_b * _s, _t).bool()
118
+
119
+ # Encode the sequence:
120
+ for encoder in self.encoder_blocks:
121
+ x = encoder(x, mask=mask)
122
+
123
+ # Reshape x and mask:
124
+ x = x.reshape(_b, _s, _t, _d)
125
+ if mask is not None:
126
+ mask = mask.reshape(_b, _s, _t)
127
+ mask = torch.logical_not(mask) if not self.valid_padding else mask
128
+
129
+ # Apply pooling:
130
+ x, mask = self.pooling(x, mask=mask)
131
+
132
+ # Compute distances:
133
+ x = self.distance_layer(x)
134
+
135
+ # Pass through CoSeNet:
136
+ x = self.cosenet(x, mask=mask)
137
+
138
+ # Apply candidate mask if provided:
139
+ if candidate_mask is not None:
140
+ candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
141
+ candidate_mask = candidate_mask.to(device=x.device)
142
+ x = x.masked_fill(candidate_mask, 0)
143
+
144
+ return x
145
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
146
+ # END OF FILE #
147
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/transformers/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from .attention import EncoderBlock
9
+ from .positional_encoding import PositionalEncoding
10
+ from .pooling import MaskedMeanPooling
11
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
12
+ # END OF FILE #
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/transformers/attention.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class EncoderBlock(torch.nn.Module):
12
+ """
13
+ Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm
14
+ architecture.
15
+
16
+ The block consists of a multi-head self-attention sublayer followed by a
17
+ position-wise feed-forward network, each wrapped with a residual connection.
18
+ Layer normalization can be applied either before each sublayer (Pre-LN) or
19
+ after each residual addition (Post-LN).
20
+
21
+ This design allows stable training of deep Transformer stacks while retaining
22
+ compatibility with the original Transformer formulation.
23
+ """
24
+ def __init__(
25
+ self,
26
+ feature_dim: int,
27
+ attention_heads: int = 8,
28
+ feed_forward_multiplier: float = 4,
29
+ dropout: float = 0.0,
30
+ valid_padding: bool = False,
31
+ pre_normalize: bool = True,
32
+ **kwargs
33
+ ):
34
+ """
35
+ Initializes a Transformer encoder block.
36
+
37
+ Parameters
38
+ ----------
39
+ feature_dim : int
40
+ Dimensionality of the input and output feature representations.
41
+ attention_heads : int, optional
42
+ Number of attention heads used in the multi-head self-attention layer.
43
+ Default is 8.
44
+ feed_forward_multiplier : float, optional
45
+ Expansion factor for the hidden dimension of the feed-forward network.
46
+ The intermediate dimension is computed as
47
+ `feed_forward_multiplier * feature_dim`.
48
+ Default is 4.
49
+ dropout : float, optional
50
+ Dropout probability applied to the feed-forward residual connection.
51
+ Default is 0.0.
52
+ valid_padding : bool, optional
53
+ If True, the provided mask marks valid (non-padded) positions.
54
+ If False, the mask marks padded (invalid) positions directly.
55
+ Default is False.
56
+ pre_normalize : bool, optional
57
+ If True, uses the Pre-LayerNorm Transformer variant, applying layer
58
+ normalization before each sublayer (self-attention and feed-forward).
59
+ If False, uses the Post-LayerNorm variant, applying normalization after
60
+ each residual connection.
61
+ Default is True.
62
+ **kwargs
63
+ Additional keyword arguments passed to the parent `torch.nn.Module`.
64
+ """
65
+ # Module init via kwargs:
66
+ super().__init__(**kwargs)
67
+
68
+ # Store params:
69
+ self.valid_padding = valid_padding
70
+ self.pre_normalize = pre_normalize
71
+
72
+ # Norm layers:
73
+ self.norm_in = torch.nn.LayerNorm(feature_dim)
74
+ self.norm_out = torch.nn.LayerNorm(feature_dim)
75
+
76
+ # Dropout layer:
77
+ self.dropout = torch.nn.Dropout(dropout)
78
+
79
+ # Attention layer:
80
+ self.attention = torch.nn.MultiheadAttention(
81
+ embed_dim=feature_dim,
82
+ num_heads=attention_heads,
83
+ dropout=0.0,
84
+ batch_first=True
85
+ )
86
+
87
+ # Feed-forward layer:
88
+ self.feed_forward = torch.nn.Sequential(
89
+ torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)),
90
+ torch.nn.GELU(),
91
+ torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim),
92
+ )
93
+
94
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
95
+ """
96
+ Forward pass of a Transformer encoder block.
97
+
98
+ Parameters
99
+ ----------
100
+ x : torch.Tensor
101
+ Input tensor of shape (batch_size, sequence_length, feature_dim).
102
+ mask : torch.Tensor or None, optional
103
+ Boolean mask indicating valid sequence positions.
104
+ Shape: (batch_size, sequence_length).
105
+ If `valid_padding` is True, True values denote valid tokens.
106
+ Otherwise, True values denote masked (invalid) positions.
107
+
108
+ Returns
109
+ -------
110
+ x : torch.Tensor
111
+ Output tensor of the same shape as the input
112
+ (batch_size, sequence_length, feature_dim).
113
+ """
114
+
115
+ # Convert mask:
116
+ if mask is not None and self.valid_padding:
117
+ key_padding_mask = ~mask.bool() # True = pad
118
+ valid_mask = mask.bool()
119
+ elif mask is not None:
120
+ key_padding_mask = mask.bool()
121
+ valid_mask = ~mask.bool()
122
+ else:
123
+ key_padding_mask = None
124
+ valid_mask = None
125
+
126
+ # Detect fully padded sequences:
127
+ if valid_mask is not None:
128
+ all_pad = ~valid_mask.any(dim=-1) # (B,)
129
+ else:
130
+ all_pad = None
131
+
132
+ # Pre-normalization:
133
+ if self.pre_normalize:
134
+ h = self.norm_in(x)
135
+ else:
136
+ h = x
137
+
138
+ # Attention (guard against fully padded sequences):
139
+ if all_pad is not None and all_pad.any():
140
+ h_attn = h.clone()
141
+ h_attn[all_pad] = 0.0
142
+
143
+ if key_padding_mask is not None:
144
+ key_padding_mask = key_padding_mask.clone()
145
+ key_padding_mask[all_pad] = False
146
+ else:
147
+ h_attn = h
148
+
149
+ attn_out, _ = self.attention(
150
+ h_attn, h_attn, h_attn,
151
+ key_padding_mask=key_padding_mask,
152
+ need_weights=False,
153
+ )
154
+ x = x + attn_out
155
+
156
+ # Post-attention normalization:
157
+ if not self.pre_normalize:
158
+ z = self.norm_in(x)
159
+ else:
160
+ z = self.norm_out(x)
161
+
162
+ # Feed-forward:
163
+ z = self.feed_forward(z)
164
+ x = x + self.dropout(z)
165
+
166
+ if not self.pre_normalize:
167
+ x = self.norm_out(x)
168
+
169
+ # Re-pad fully padded sequences:
170
+ if all_pad is not None:
171
+ x = x.masked_fill(all_pad[:, None, None], 0.0)
172
+
173
+ return x
174
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
175
+ # END OF FILE #
176
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/transformers/pooling.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class MaskedMeanPooling(torch.nn.Module):
12
+ """
13
+ Mean pooling layer with explicit masking support.
14
+
15
+ This layer computes the mean over the sequence dimension while
16
+ ignoring padded elements according to a boolean mask. It supports
17
+ both PyTorch-style padding masks and valid-position masks.
18
+ """
19
+
20
+ def __init__(self, valid_pad: bool = True, eps: float = 1e-6):
21
+ """
22
+ Initialize the masked mean pooling layer.
23
+
24
+ Args:
25
+ valid_pad (bool, optional): Mask interpretation mode. If True,
26
+ `True` values in the mask indicate valid (non-padded) positions.
27
+ If False, `True` values indicate padded positions, following
28
+ PyTorch-style padding conventions. Defaults to True.
29
+ eps (float, optional): Small constant to avoid division by zero
30
+ when all positions are masked. Defaults to 1e-8.
31
+ """
32
+ super().__init__()
33
+ self.valid_pad = valid_pad
34
+ self.eps = eps
35
+
36
+ def forward(
37
+ self,
38
+ x: torch.Tensor,
39
+ mask: torch.Tensor
40
+ ) -> tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Apply masked mean pooling.
43
+
44
+ Args:
45
+ x (torch.Tensor): Input tensor of shape (..., S, D), where
46
+ B is the batch size, S the sequence length, and D the
47
+ feature dimension.
48
+ mask (torch.Tensor): Boolean mask tensor of shape (..., S).
49
+ The interpretation depends on `valid_pad`.
50
+
51
+ Returns:
52
+ tuple:
53
+ torch.Tensor: Pooled tensor of shape (..., D).
54
+ torch.Tensor: Updated valid mask after pooling of shape (..., ).
55
+ """
56
+ # Mask handling:
57
+ if mask is None:
58
+ valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device)
59
+ else:
60
+ valid_mask = mask
61
+
62
+ # Valid:
63
+ if self.valid_pad:
64
+ valid_mask = valid_mask
65
+ else:
66
+ valid_mask = torch.logical_not(valid_mask)
67
+
68
+ valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) # (..., S, 1)
69
+ summed = torch.sum(x * valid_mask, dim=-2) # (..., D)
70
+ denom = valid_mask.sum(dim=-2).clamp(min=self.eps) # (..., 1)
71
+
72
+ # Valid mask pooling (any):
73
+ valid_mask = valid_mask.squeeze(-1).any(dim=-1)
74
+
75
+ return summed / denom, valid_mask
76
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
77
+ # END OF FILE #
78
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/model/transformers/positional_encoding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+
10
+
11
+ class PositionalEncoding(torch.nn.Module):
12
+ """
13
+ Sinusoidal positional encoding module for Transformer models.
14
+
15
+ This module injects information about the relative or absolute position of
16
+ tokens in a sequence by adding fixed sinusoidal embeddings to the input
17
+ embeddings. The positional encodings are non-learnable and follow the
18
+ formulation introduced in the original Transformer architecture.
19
+ """
20
+ def __init__(self, emb_dim: int, max_len: int = 5000, **kwargs):
21
+ """
22
+ Initialize the positional encoding module.
23
+
24
+ Parameters
25
+ ----------
26
+ emb_dim : int
27
+ Dimensionality of the embedding space.
28
+ max_len : int, optional
29
+ Maximum supported sequence length for which positional encodings
30
+ are precomputed.
31
+ """
32
+ super().__init__(**kwargs)
33
+
34
+ # Create positional encodings:
35
+ pe = torch.zeros(max_len, emb_dim)
36
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
37
+ div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / emb_dim))
38
+ pe[:, 0::2] = torch.sin(position * div_term)
39
+ pe[:, 1::2] = torch.cos(position * div_term)
40
+ pe = pe.unsqueeze(0)
41
+
42
+ # Register as a buffer:
43
+ self.register_buffer('positional_encoding', pe)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Add positional encodings to the input embeddings.
48
+
49
+ Parameters
50
+ ----------
51
+ x : torch.Tensor
52
+ Input tensor of shape (batch_size, sequence_length, emb_dim).
53
+
54
+ Returns
55
+ -------
56
+ torch.Tensor
57
+ Tensor of the same shape as the input with positional encodings added.
58
+ """
59
+ return x + self.positional_encoding[:, :x.size(1), :]
60
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
61
+ # END OF FILE #
62
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
train/config.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import os
9
+ from dataclasses import dataclass
10
+ from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
11
+ from src.dataset import DatasetConfig
12
+
13
+
14
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
15
+ # SETUP CONFIGURATION #
16
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
17
+ @dataclass
18
+ class SetupConfig:
19
+ """
20
+ Configuration parameters related to the execution environment and logging.
21
+
22
+ This configuration controls device selection, checkpointing behavior,
23
+ reproducibility settings, and logging paths for an experiment.
24
+ """
25
+ device_number: int = 0
26
+ save_model_each: int = 0
27
+ seed: int = None
28
+ logging_path: str = None
29
+ reload_checkpoint: bool = False
30
+
31
+
32
+ def overwrite_setup_config() -> SetupConfig:
33
+ """
34
+ Create and override the default setup configuration.
35
+
36
+ This function customizes execution-level parameters such as logging
37
+ paths, checkpoint reloading, and model saving frequency.
38
+
39
+ Returns:
40
+ SetupConfig: The configured setup configuration object.
41
+ """
42
+ config = SetupConfig()
43
+ config.logging_path = r'/workspace/logs'
44
+ config.reload_checkpoint = True
45
+ config.save_model_each = 1
46
+ return config
47
+
48
+
49
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
50
+ # TRAINING CONFIGURATION #
51
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
52
+ @dataclass
53
+ class TrainConfig:
54
+ """
55
+ Training configuration container.
56
+
57
+ This dataclass aggregates model, dataset, and setup configurations,
58
+ together with optimization and training hyperparameters.
59
+ """
60
+ # Linked configurations:
61
+ model_config: ModelConfig | None = None
62
+ dataset_config: DatasetConfig | None = None
63
+ setup_config: SetupConfig | None = None
64
+
65
+ # Training parameters:
66
+ batch_size: int = 32
67
+ num_epochs: int = 100
68
+
69
+ # Optimizer parameters:
70
+ learning_rate: float = 1e-4
71
+ learning_rate_min: float = 1e-5
72
+ weight_decay: float = 1e-8
73
+ betas: tuple[float, float] = (0.5, 0.999)
74
+
75
+
76
+ def overwrite_train_config() -> TrainConfig:
77
+ """
78
+ Create and override the default training configuration.
79
+
80
+ This function customizes batch size, number of epochs, and optimizer
81
+ hyperparameters for the training process.
82
+
83
+ Returns:
84
+ TrainConfig: The configured training configuration object.
85
+ """
86
+ config = TrainConfig()
87
+ config.batch_size = 4
88
+ config.num_epochs = 200
89
+ config.learning_rate = 5e-4
90
+ config.learning_rate_min = 5e-5
91
+ config.weight_decay = 1e-6
92
+ return config
93
+
94
+
95
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
96
+ # DATASET CONFIGURATION #
97
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
98
+ def overwrite_dataset_config() -> DatasetConfig:
99
+ """
100
+ Create and override the dataset configuration.
101
+
102
+ This function sets the file paths and usage percentages for training,
103
+ validation, and test datasets.
104
+
105
+ Returns:
106
+ DatasetConfig: The configured dataset configuration object.
107
+ """
108
+ config = DatasetConfig()
109
+ config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
110
+ config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
111
+ config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
112
+ config.train_percentage = 1.0
113
+ config.val_percentage = 1.0
114
+ config.test_percentage = 1.0
115
+ return config
116
+
117
+
118
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
119
+ # MODEL CONFIGURATION #
120
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
121
+ def overwrite_model_config() -> ModelConfig:
122
+ """
123
+ Create and override the model configuration.
124
+
125
+ This function defines the architecture-level parameters, including
126
+ vocabulary size, embedding dimensionality, CoSeNet settings, and
127
+ the stack of Transformer encoder configurations.
128
+
129
+ Returns:
130
+ ModelConfig: The configured model configuration object.
131
+ """
132
+ config = ModelConfig()
133
+
134
+ # High-level params:
135
+ config.vocab_size = 32_768
136
+ config.model_dim = 256
137
+ config.valid_padding = True
138
+
139
+ # CoSeNet params:
140
+ config.cosenet = CoSeNetConfig(
141
+ trainable=True,
142
+ init_scale=5.0
143
+ )
144
+
145
+ # Transformer params:
146
+ config.transformers = [
147
+ TransformerConfig(**cfg)
148
+ for cfg in [
149
+ {
150
+ "attention_heads": 16,
151
+ "feed_forward_multiplier": 8,
152
+ "dropout": 0.0,
153
+ "pre_normalize": True
154
+ },
155
+ {
156
+ "attention_heads": 16,
157
+ "feed_forward_multiplier": 8,
158
+ "dropout": 0.0,
159
+ "pre_normalize": True
160
+ }
161
+ ]
162
+ ]
163
+
164
+ return config
165
+
166
+
167
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
168
+ # WHOLE CONFIGURATION #
169
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
170
+ def configuration() -> TrainConfig:
171
+ """
172
+ Create the experiment configuration
173
+ :return: A TrainConfig configuration object
174
+ """
175
+ config = overwrite_train_config()
176
+ config.setup_config = overwrite_setup_config()
177
+ config.model_config = overwrite_model_config()
178
+ config.dataset_config = overwrite_dataset_config()
179
+
180
+ # Assert:
181
+ if not os.path.exists(config.dataset_config.train_data_path):
182
+ raise FileNotFoundError(f"Train data path does not exist: {config.dataset_config.train_data_path}")
183
+ if not os.path.exists(config.dataset_config.val_data_path):
184
+ raise FileNotFoundError(f"Validation data path does not exist: {config.dataset_config.val_data_path}")
185
+ if not 0.0 < config.dataset_config.train_percentage <= 1.0:
186
+ raise ValueError("Train percentage must be in (0.0, 1.0]")
187
+ if not 0.0 < config.dataset_config.val_percentage <= 1.0:
188
+ raise ValueError("Validation percentage must be in (0.0, 1.0]")
189
+
190
+ return config
191
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
192
+ # END OF FILE #
193
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
train/train_logs/config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_config": {
3
+ "vocab_size": 32768,
4
+ "model_dim": 256,
5
+ "max_tokens": 382,
6
+ "max_sentences": 384,
7
+ "valid_padding": true,
8
+ "cosenet": {
9
+ "trainable": true,
10
+ "init_scale": 5.0
11
+ },
12
+ "transformers": [
13
+ {
14
+ "attention_heads": 16,
15
+ "feed_forward_multiplier": 8,
16
+ "dropout": 0.0,
17
+ "pre_normalize": true
18
+ },
19
+ {
20
+ "attention_heads": 16,
21
+ "feed_forward_multiplier": 8,
22
+ "dropout": 0.0,
23
+ "pre_normalize": true
24
+ }
25
+ ]
26
+ },
27
+ "dataset_config": {
28
+ "train_data_path": "/workspace/data/tokens-A000-segmentation",
29
+ "val_data_path": "/workspace/data/tokens-A001-segmentation",
30
+ "test_data_path": "/workspace/data/tokens-A002-segmentation",
31
+ "train_percentage": 1.0,
32
+ "val_percentage": 1.0,
33
+ "test_percentage": 1.0,
34
+ "num_workers": 0,
35
+ "shuffle_train": true,
36
+ "shuffle_val": true
37
+ },
38
+ "setup_config": {
39
+ "device_number": 0,
40
+ "save_model_each": 1,
41
+ "seed": null,
42
+ "logging_path": "/workspace/logs",
43
+ "reload_checkpoint": true
44
+ },
45
+ "batch_size": 4,
46
+ "num_epochs": 200,
47
+ "learning_rate": 0.0005,
48
+ "learning_rate_min": 5e-05,
49
+ "weight_decay": 1e-06,
50
+ "betas": [
51
+ 0.5,
52
+ 0.999
53
+ ]
54
+ }
train/train_logs/logfile.log ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-12-26 14:45:56,651: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
2
+ 2025-12-26 14:45:56,659: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
3
+ 2025-12-26 14:45:56,659: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
4
+ 2025-12-26 14:45:56,672: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=76392)
5
+ 2025-12-26 14:45:56,680: [INFO] Initializer set up seed: 1766760356
6
+ 2025-12-26 14:45:56,728: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
7
+ 2025-12-26 14:45:56,729: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
8
+ 2025-12-26 14:45:56,729: [INFO] Total Memory : 45498.00 MB
9
+ 2025-12-26 14:45:56,730: [INFO] Currently Allocated : 0.00 MB
10
+ 2025-12-26 14:45:56,730: [INFO] Currently Reserved : 0.00 MB
11
+ 2025-12-26 14:45:56,730: [INFO] Max Allocated : 0.00 MB
12
+ 2025-12-26 14:45:56,731: [INFO] Max Reserved : 0.00 MB
13
+ 2025-12-26 14:45:56,731: [INFO] Setup information:
14
+ 2025-12-26 14:45:56,732: [INFO] - Setup path: /workspace/logs
15
+ 2025-12-26 14:45:56,732: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
16
+ 2025-12-26 14:45:56,732: [INFO] - Setup device: cuda:0
17
+ 2025-12-26 14:45:56,733: [INFO] - Setup seed: 1766760356
18
+ 2025-12-26 14:45:56,733: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
19
+ 2025-12-26 14:45:56,734: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x76e7ade77910>
20
+ 2025-12-26 14:45:56,734: [INFO] - Setup save each: 20
21
+ 2025-12-26 14:45:56,737: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
22
+ 2025-12-26 14:45:56,737: [INFO] [SegmentationDataset] Loaded dataset length: 26510
23
+ 2025-12-26 14:45:56,745: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
24
+ 2025-12-26 14:45:56,745: [INFO] [SegmentationDataset] Loaded dataset length: 3336
25
+ 2025-12-26 14:45:57,294: [INFO] [TRAIN] Model Configuration:
26
+ {'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
27
+ 2025-12-26 14:45:57,294: [INFO] [TRAIN] Model parameters: 11.022865 M
28
+ 2025-12-26 14:45:57,295: [INFO] [TRAIN] Trainable parameters: 11.022865 M
29
+ 2025-12-26 14:45:57,295: [INFO] [TRAIN] Training batches: 67
30
+ 2025-12-26 14:47:48,849: [INFO] Epoch [0]: loss = 0.53930415
31
+ 2025-12-26 14:47:50,411: [INFO] Epoch [1]: val_loss = 0.50874352
32
+ 2025-12-26 14:49:37,680: [INFO] Epoch [1]: loss = 0.52674737
33
+ 2025-12-26 14:49:39,116: [INFO] Epoch [2]: val_loss = 0.51872101
34
+ 2025-12-26 14:51:27,172: [INFO] Epoch [2]: loss = 0.52592351
35
+ 2025-12-26 14:51:28,612: [INFO] Epoch [3]: val_loss = 0.51301319
36
+ 2025-12-26 14:53:16,691: [INFO] Epoch [3]: loss = 0.52935326
37
+ 2025-12-26 14:53:18,212: [INFO] Epoch [4]: val_loss = 0.51744863
38
+ 2025-12-26 14:55:05,752: [INFO] Epoch [4]: loss = 0.52446729
39
+ 2025-12-26 14:55:07,327: [INFO] Epoch [5]: val_loss = 0.51929819
40
+ 2025-12-26 14:56:57,434: [INFO] Epoch [5]: loss = 0.52781746
41
+ 2025-12-26 14:56:58,912: [INFO] Epoch [6]: val_loss = 0.52006621
42
+ 2025-12-26 14:58:46,224: [INFO] Epoch [6]: loss = 0.52644637
43
+ 2025-12-26 14:58:47,712: [INFO] Epoch [7]: val_loss = 0.51545532
44
+ 2025-12-26 15:00:34,974: [INFO] Epoch [7]: loss = 0.52535941
45
+ 2025-12-26 15:00:36,412: [INFO] Epoch [8]: val_loss = 0.52077476
46
+ 2025-12-26 15:02:24,083: [INFO] Epoch [8]: loss = 0.52521282
47
+ 2025-12-26 15:02:25,525: [INFO] Epoch [9]: val_loss = 0.51527728
48
+ 2025-12-26 15:04:13,376: [INFO] Epoch [9]: loss = 0.52329010
49
+ 2025-12-26 15:04:14,816: [INFO] Epoch [10]: val_loss = 0.51563372
50
+ 2025-12-26 15:06:03,934: [INFO] Epoch [10]: loss = 0.52397644
51
+ 2025-12-26 15:06:05,412: [INFO] Epoch [11]: val_loss = 0.51372376
52
+ 2025-12-26 15:07:54,323: [INFO] Epoch [11]: loss = 0.52039668
53
+ 2025-12-26 15:07:55,813: [INFO] Epoch [12]: val_loss = 0.51369372
54
+ 2025-12-26 15:09:43,559: [INFO] Epoch [12]: loss = 0.51899378
55
+ 2025-12-26 15:09:45,012: [INFO] Epoch [13]: val_loss = 0.52238202
56
+ 2025-12-26 15:11:32,423: [INFO] Epoch [13]: loss = 0.51784248
57
+ 2025-12-26 15:11:33,912: [INFO] Epoch [14]: val_loss = 0.51489054
58
+ 2025-12-26 15:13:22,761: [INFO] Epoch [14]: loss = 0.50914923
59
+ 2025-12-26 15:13:24,212: [INFO] Epoch [15]: val_loss = 0.50278137
60
+ 2025-12-26 15:15:11,956: [INFO] Epoch [15]: loss = 0.50427987
61
+ 2025-12-26 15:15:13,412: [INFO] Epoch [16]: val_loss = 0.50158396
62
+ 2025-12-26 15:17:01,228: [INFO] Epoch [16]: loss = 0.50178539
63
+ 2025-12-26 15:17:02,711: [INFO] Epoch [17]: val_loss = 0.50242173
64
+ 2025-12-26 15:18:51,266: [INFO] Epoch [17]: loss = 0.49650285
65
+ 2025-12-26 15:18:52,716: [INFO] Epoch [18]: val_loss = 0.50932210
66
+ 2025-12-26 15:20:40,343: [INFO] Epoch [18]: loss = 0.49234502
67
+ 2025-12-26 15:20:41,912: [INFO] Epoch [19]: val_loss = 0.50311281
68
+ 2025-12-26 15:22:29,693: [INFO] Epoch [19]: loss = 0.48797671
69
+ 2025-12-26 15:22:29,695: [INFO] Checkpointing model at epoch 20
70
+ 2025-12-26 15:22:30,454: [INFO] Model checkpointed at epoch 20
71
+ 2025-12-26 15:22:31,912: [INFO] Epoch [20]: val_loss = 0.53549688
72
+ 2025-12-26 15:24:19,843: [INFO] Epoch [20]: loss = 0.48723968
73
+ 2025-12-26 15:24:21,312: [INFO] Epoch [21]: val_loss = 0.49818926
74
+ 2025-12-26 15:26:08,715: [INFO] Epoch [21]: loss = 0.48037165
75
+ 2025-12-26 15:26:10,212: [INFO] Epoch [22]: val_loss = 0.48961075
76
+ 2025-12-26 15:27:59,123: [INFO] Epoch [22]: loss = 0.47390062
77
+ 2025-12-26 15:28:00,911: [INFO] Epoch [23]: val_loss = 0.48781847
78
+ 2025-12-26 15:29:49,056: [INFO] Epoch [23]: loss = 0.46711668
79
+ 2025-12-26 15:29:50,511: [INFO] Epoch [24]: val_loss = 0.47708375
80
+ 2025-12-26 15:31:37,663: [INFO] Epoch [24]: loss = 0.46234217
81
+ 2025-12-26 15:31:39,112: [INFO] Epoch [25]: val_loss = 0.46084376
82
+ 2025-12-26 15:33:26,345: [INFO] Epoch [25]: loss = 0.45538114
83
+ 2025-12-26 15:33:27,812: [INFO] Epoch [26]: val_loss = 0.47136071
84
+ 2025-12-26 15:35:15,250: [INFO] Epoch [26]: loss = 0.45225392
85
+ 2025-12-26 15:35:16,711: [INFO] Epoch [27]: val_loss = 0.47011130
86
+ 2025-12-26 15:37:04,599: [INFO] Epoch [27]: loss = 0.44760030
87
+ 2025-12-26 15:37:06,112: [INFO] Epoch [28]: val_loss = 0.46140307
88
+ 2025-12-26 15:38:54,426: [INFO] Epoch [28]: loss = 0.44472487
89
+ 2025-12-26 15:38:55,912: [INFO] Epoch [29]: val_loss = 0.47098119
90
+ 2025-12-26 15:40:43,445: [INFO] Epoch [29]: loss = 0.43989357
91
+ 2025-12-26 15:40:44,911: [INFO] Epoch [30]: val_loss = 0.45539117
92
+ 2025-12-26 15:42:32,383: [INFO] Epoch [30]: loss = 0.43657149
93
+ 2025-12-26 15:42:33,816: [INFO] Epoch [31]: val_loss = 0.46862131
94
+ 2025-12-26 15:44:21,074: [INFO] Epoch [31]: loss = 0.43649050
95
+ 2025-12-26 15:44:22,511: [INFO] Epoch [32]: val_loss = 0.45548641
96
+ 2025-12-26 15:46:09,812: [INFO] Epoch [32]: loss = 0.43346542
97
+ 2025-12-26 15:46:11,312: [INFO] Epoch [33]: val_loss = 0.45997839
98
+ 2025-12-26 15:47:59,053: [INFO] Epoch [33]: loss = 0.43235683
99
+ 2025-12-26 15:48:00,511: [INFO] Epoch [34]: val_loss = 0.47154692
100
+ 2025-12-26 15:49:47,991: [INFO] Epoch [34]: loss = 0.42891757
101
+ 2025-12-26 15:49:49,416: [INFO] Epoch [35]: val_loss = 0.46223042
102
+ 2025-12-26 15:51:36,793: [INFO] Epoch [35]: loss = 0.42735399
103
+ 2025-12-26 15:51:38,216: [INFO] Epoch [36]: val_loss = 0.46173553
104
+ 2025-12-26 15:53:25,570: [INFO] Epoch [36]: loss = 0.42965186
105
+ 2025-12-26 15:53:27,016: [INFO] Epoch [37]: val_loss = 0.46098506
106
+ 2025-12-26 15:55:14,511: [INFO] Epoch [37]: loss = 0.42778122
107
+ 2025-12-26 15:55:16,012: [INFO] Epoch [38]: val_loss = 0.46018566
108
+ 2025-12-26 15:57:06,234: [INFO] Epoch [38]: loss = 0.42445267
109
+ 2025-12-26 15:57:07,711: [INFO] Epoch [39]: val_loss = 0.46550667
110
+ 2025-12-26 15:58:59,230: [INFO] Epoch [39]: loss = 0.42354161
111
+ 2025-12-26 15:58:59,232: [INFO] Checkpointing model at epoch 40
112
+ 2025-12-26 15:58:59,945: [INFO] Model checkpointed at epoch 40
113
+ 2025-12-26 15:59:01,511: [INFO] Epoch [40]: val_loss = 0.47303247
114
+ 2025-12-26 16:00:49,480: [INFO] Epoch [40]: loss = 0.42338467
115
+ 2025-12-26 16:00:50,911: [INFO] Epoch [41]: val_loss = 0.45826835
116
+ 2025-12-26 16:02:38,743: [INFO] Epoch [41]: loss = 0.41971716
117
+ 2025-12-26 16:02:40,212: [INFO] Epoch [42]: val_loss = 0.45490133
118
+ 2025-12-26 16:04:28,045: [INFO] Epoch [42]: loss = 0.41987514
119
+ 2025-12-26 16:04:29,512: [INFO] Epoch [43]: val_loss = 0.45860666
120
+ 2025-12-26 16:06:16,948: [INFO] Epoch [43]: loss = 0.41933024
121
+ 2025-12-26 16:06:18,411: [INFO] Epoch [44]: val_loss = 0.45629129
122
+ 2025-12-26 16:08:06,282: [INFO] Epoch [44]: loss = 0.41593552
123
+ 2025-12-26 16:08:07,716: [INFO] Epoch [45]: val_loss = 0.46409211
124
+ 2025-12-26 16:09:55,161: [INFO] Epoch [45]: loss = 0.41721227
125
+ 2025-12-26 16:09:56,612: [INFO] Epoch [46]: val_loss = 0.46598683
126
+ 2025-12-26 16:11:43,939: [INFO] Epoch [46]: loss = 0.41726764
127
+ 2025-12-26 16:11:45,411: [INFO] Epoch [47]: val_loss = 0.45663830
128
+ 2025-12-26 16:13:32,862: [INFO] Epoch [47]: loss = 0.41537570
129
+ 2025-12-26 16:13:34,315: [INFO] Epoch [48]: val_loss = 0.46740513
130
+ 2025-12-26 16:15:21,546: [INFO] Epoch [48]: loss = 0.41457776
131
+ 2025-12-26 16:15:23,112: [INFO] Epoch [49]: val_loss = 0.44048135
132
+ 2025-12-26 16:17:10,274: [INFO] Epoch [49]: loss = 0.41388101
133
+ 2025-12-26 16:17:11,715: [INFO] Epoch [50]: val_loss = 0.45519451
134
+ 2025-12-26 16:18:58,990: [INFO] Epoch [50]: loss = 0.41285109
135
+ 2025-12-26 16:19:00,512: [INFO] Epoch [51]: val_loss = 0.45673202
136
+ 2025-12-26 16:20:47,636: [INFO] Epoch [51]: loss = 0.41195465
137
+ 2025-12-26 16:20:49,116: [INFO] Epoch [52]: val_loss = 0.45198240
138
+ 2025-12-26 16:22:36,418: [INFO] Epoch [52]: loss = 0.40953517
139
+ 2025-12-26 16:22:37,893: [INFO] Epoch [53]: val_loss = 0.47122019
140
+ 2025-12-26 16:24:25,331: [INFO] Epoch [53]: loss = 0.40789293
141
+ 2025-12-26 16:24:26,812: [INFO] Epoch [54]: val_loss = 0.44196667
142
+ 2025-12-26 16:26:14,026: [INFO] Epoch [54]: loss = 0.40474147
143
+ 2025-12-26 16:26:15,512: [INFO] Epoch [55]: val_loss = 0.46978565
144
+ 2025-12-26 16:28:03,793: [INFO] Epoch [55]: loss = 0.40504389
145
+ 2025-12-26 16:28:05,272: [INFO] Epoch [56]: val_loss = 0.47313605
146
+ 2025-12-26 16:29:52,585: [INFO] Epoch [56]: loss = 0.40562682
147
+ 2025-12-26 16:29:54,017: [INFO] Epoch [57]: val_loss = 0.46668073
148
+ 2025-12-26 16:31:41,107: [INFO] Epoch [57]: loss = 0.40768713
149
+ 2025-12-26 16:31:42,529: [INFO] Epoch [58]: val_loss = 0.45173921
150
+ 2025-12-26 16:33:29,962: [INFO] Epoch [58]: loss = 0.40458906
151
+ 2025-12-26 16:33:31,411: [INFO] Epoch [59]: val_loss = 0.46093515
152
+ 2025-12-26 16:35:18,548: [INFO] Epoch [59]: loss = 0.40175750
153
+ 2025-12-26 16:35:18,549: [INFO] Checkpointing model at epoch 60
154
+ 2025-12-26 16:35:19,200: [INFO] Model checkpointed at epoch 60
155
+ 2025-12-26 16:35:20,711: [INFO] Epoch [60]: val_loss = 0.46230425
156
+ 2025-12-26 16:37:07,991: [INFO] Epoch [60]: loss = 0.40113024
157
+ 2025-12-26 16:37:09,427: [INFO] Epoch [61]: val_loss = 0.46095682
158
+ 2025-12-26 16:38:57,025: [INFO] Epoch [61]: loss = 0.40212381
159
+ 2025-12-26 16:38:58,604: [INFO] Epoch [62]: val_loss = 0.45801103
160
+ 2025-12-26 16:40:47,358: [INFO] Epoch [62]: loss = 0.40149038
161
+ 2025-12-26 16:40:48,911: [INFO] Epoch [63]: val_loss = 0.45971834
162
+ 2025-12-26 16:42:36,483: [INFO] Epoch [63]: loss = 0.40096813
163
+ 2025-12-26 16:42:37,917: [INFO] Epoch [64]: val_loss = 0.47312803
164
+ 2025-12-26 16:44:25,338: [INFO] Epoch [64]: loss = 0.40213521
165
+ 2025-12-26 16:44:26,895: [INFO] Epoch [65]: val_loss = 0.45463914
166
+ 2025-12-26 16:46:14,170: [INFO] Epoch [65]: loss = 0.39824201
167
+ 2025-12-26 16:46:15,615: [INFO] Epoch [66]: val_loss = 0.47252337
168
+ 2025-12-26 16:48:03,455: [INFO] Epoch [66]: loss = 0.39898236
169
+ 2025-12-26 16:48:04,912: [INFO] Epoch [67]: val_loss = 0.46137960
170
+ 2025-12-26 16:49:52,778: [INFO] Epoch [67]: loss = 0.40269130
171
+ 2025-12-26 16:49:54,216: [INFO] Epoch [68]: val_loss = 0.47056969
172
+ 2025-12-26 16:51:41,521: [INFO] Epoch [68]: loss = 0.39804779
173
+ 2025-12-26 16:51:43,012: [INFO] Epoch [69]: val_loss = 0.46284741
174
+ 2025-12-26 16:53:30,778: [INFO] Epoch [69]: loss = 0.39931213
175
+ 2025-12-26 16:53:32,211: [INFO] Epoch [70]: val_loss = 0.47174325
176
+ 2025-12-26 16:55:19,747: [INFO] Epoch [70]: loss = 0.39947561
177
+ 2025-12-26 16:55:21,211: [INFO] Epoch [71]: val_loss = 0.47359799
178
+ 2025-12-26 16:57:08,420: [INFO] Epoch [71]: loss = 0.39680641
179
+ 2025-12-26 16:57:09,912: [INFO] Epoch [72]: val_loss = 0.45985634
180
+ 2025-12-26 16:58:57,465: [INFO] Epoch [72]: loss = 0.39784966
181
+ 2025-12-26 16:58:58,911: [INFO] Epoch [73]: val_loss = 0.47379973
182
+ 2025-12-26 17:00:46,781: [INFO] Epoch [73]: loss = 0.39575548
183
+ 2025-12-26 17:00:48,217: [INFO] Epoch [74]: val_loss = 0.46827143
184
+ 2025-12-26 17:02:35,956: [INFO] Epoch [74]: loss = 0.39844352
185
+ 2025-12-26 17:02:37,411: [INFO] Epoch [75]: val_loss = 0.48436255
186
+ 2025-12-26 17:04:25,013: [INFO] Epoch [75]: loss = 0.39737436
187
+ 2025-12-26 17:04:26,512: [INFO] Epoch [76]: val_loss = 0.45234020
188
+ 2025-12-26 17:06:13,974: [INFO] Epoch [76]: loss = 0.39371587
189
+ 2025-12-26 17:06:15,415: [INFO] Epoch [77]: val_loss = 0.45753057
190
+ 2025-12-26 17:08:03,455: [INFO] Epoch [77]: loss = 0.39684283
191
+ 2025-12-26 17:08:04,916: [INFO] Epoch [78]: val_loss = 0.46107266
192
+ 2025-12-26 17:09:52,265: [INFO] Epoch [78]: loss = 0.39561052
193
+ 2025-12-26 17:09:53,711: [INFO] Epoch [79]: val_loss = 0.48726222
194
+ 2025-12-26 17:11:40,915: [INFO] Epoch [79]: loss = 0.39534942
195
+ 2025-12-26 17:11:40,917: [INFO] Checkpointing model at epoch 80
196
+ 2025-12-26 17:11:41,448: [INFO] Model checkpointed at epoch 80
197
+ 2025-12-26 17:11:42,912: [INFO] Epoch [80]: val_loss = 0.47510581
198
+ 2025-12-26 17:13:31,165: [INFO] Epoch [80]: loss = 0.39408069
199
+ 2025-12-26 17:13:32,617: [INFO] Epoch [81]: val_loss = 0.46646976
200
+ 2025-12-26 17:15:20,095: [INFO] Epoch [81]: loss = 0.39456047
201
+ 2025-12-26 17:15:21,517: [INFO] Epoch [82]: val_loss = 0.47777673
202
+ 2025-12-26 17:17:10,031: [INFO] Epoch [82]: loss = 0.39687150
203
+ 2025-12-26 17:17:11,512: [INFO] Epoch [83]: val_loss = 0.47680868
204
+ 2025-12-26 17:18:59,688: [INFO] Epoch [83]: loss = 0.39627865
205
+ 2025-12-26 17:19:01,115: [INFO] Epoch [84]: val_loss = 0.47353493
206
+ 2025-12-26 17:20:48,468: [INFO] Epoch [84]: loss = 0.39516608
207
+ 2025-12-26 17:20:49,912: [INFO] Epoch [85]: val_loss = 0.47541119
208
+ 2025-12-26 17:22:38,068: [INFO] Epoch [85]: loss = 0.39570387
209
+ 2025-12-26 17:22:39,517: [INFO] Epoch [86]: val_loss = 0.46904831
210
+ 2025-12-26 17:24:26,979: [INFO] Epoch [86]: loss = 0.39411988
211
+ 2025-12-26 17:24:28,416: [INFO] Epoch [87]: val_loss = 0.47183328
212
+ 2025-12-26 17:26:16,242: [INFO] Epoch [87]: loss = 0.39453237
213
+ 2025-12-26 17:26:17,712: [INFO] Epoch [88]: val_loss = 0.48088008
214
+ 2025-12-26 17:28:05,674: [INFO] Epoch [88]: loss = 0.39265428
215
+ 2025-12-26 17:28:07,111: [INFO] Epoch [89]: val_loss = 0.46431010
216
+ 2025-12-26 17:29:54,253: [INFO] Epoch [89]: loss = 0.39518696
217
+ 2025-12-26 17:29:55,711: [INFO] Epoch [90]: val_loss = 0.47148239
218
+ 2025-12-26 17:31:43,155: [INFO] Epoch [90]: loss = 0.39560070
219
+ 2025-12-26 17:31:44,711: [INFO] Epoch [91]: val_loss = 0.47001378
220
+ 2025-12-26 17:33:32,048: [INFO] Epoch [91]: loss = 0.39522415
221
+ 2025-12-26 17:33:33,512: [INFO] Epoch [92]: val_loss = 0.47427877
222
+ 2025-12-26 17:35:20,792: [INFO] Epoch [92]: loss = 0.39726472
223
+ 2025-12-26 17:35:22,230: [INFO] Epoch [93]: val_loss = 0.48291658
224
+ 2025-12-26 17:37:09,543: [INFO] Epoch [93]: loss = 0.39664398
225
+ 2025-12-26 17:37:11,012: [INFO] Epoch [94]: val_loss = 0.49081665
226
+ 2025-12-26 17:38:58,458: [INFO] Epoch [94]: loss = 0.39135196
227
+ 2025-12-26 17:39:00,111: [INFO] Epoch [95]: val_loss = 0.47873766
228
+ 2025-12-26 17:40:47,362: [INFO] Epoch [95]: loss = 0.39417184
229
+ 2025-12-26 17:40:48,811: [INFO] Epoch [96]: val_loss = 0.48776019
230
+ 2025-12-26 17:42:36,318: [INFO] Epoch [96]: loss = 0.39321537
231
+ 2025-12-26 17:42:37,812: [INFO] Epoch [97]: val_loss = 0.46243800
232
+ 2025-12-26 17:44:25,656: [INFO] Epoch [97]: loss = 0.39767619
233
+ 2025-12-26 17:44:27,112: [INFO] Epoch [98]: val_loss = 0.45655080
234
+ 2025-12-26 17:46:14,472: [INFO] Epoch [98]: loss = 0.39206413
235
+ 2025-12-26 17:46:16,011: [INFO] Epoch [99]: val_loss = 0.46890352
236
+ 2025-12-26 17:48:03,170: [INFO] Epoch [99]: loss = 0.39380527
237
+ 2025-12-26 17:48:03,171: [INFO] Checkpointing model at epoch 100
238
+ 2025-12-26 17:48:03,725: [INFO] Model checkpointed at epoch 100
239
+ 2025-12-26 17:48:05,212: [INFO] Epoch [100]: val_loss = 0.49273304
240
+ 2025-12-26 17:49:52,333: [INFO] Epoch [100]: loss = 0.39218957
241
+ 2025-12-26 17:49:53,811: [INFO] Epoch [101]: val_loss = 0.47090062
242
+ 2025-12-26 17:51:41,096: [INFO] Epoch [101]: loss = 0.39274643
243
+ 2025-12-26 17:51:42,528: [INFO] Epoch [102]: val_loss = 0.46902628
244
+ 2025-12-26 17:53:29,975: [INFO] Epoch [102]: loss = 0.39481238
245
+ 2025-12-26 17:53:31,411: [INFO] Epoch [103]: val_loss = 0.48112577
246
+ 2025-12-26 17:55:18,420: [INFO] Epoch [103]: loss = 0.39440405
247
+ 2025-12-26 17:55:19,912: [INFO] Epoch [104]: val_loss = 0.49355557
248
+ 2025-12-26 17:57:07,125: [INFO] Epoch [104]: loss = 0.39165780
249
+ 2025-12-26 17:57:08,611: [INFO] Epoch [105]: val_loss = 0.48409717
250
+ 2025-12-26 17:58:56,554: [INFO] Epoch [105]: loss = 0.39554418
251
+ 2025-12-26 17:58:58,011: [INFO] Epoch [106]: val_loss = 0.48656076
252
+ 2025-12-26 18:00:45,347: [INFO] Epoch [106]: loss = 0.39228787
253
+ 2025-12-26 18:00:46,812: [INFO] Epoch [107]: val_loss = 0.48810028
254
+ 2025-12-26 18:02:33,925: [INFO] Epoch [107]: loss = 0.39156697
255
+ 2025-12-26 18:02:35,412: [INFO] Epoch [108]: val_loss = 0.47222325
256
+ 2025-12-26 18:04:23,740: [INFO] Epoch [108]: loss = 0.39423798
257
+ 2025-12-26 18:04:25,212: [INFO] Epoch [109]: val_loss = 0.47254576
258
+ 2025-12-26 18:06:12,535: [INFO] Epoch [109]: loss = 0.39252056
259
+ 2025-12-26 18:06:14,012: [INFO] Epoch [110]: val_loss = 0.48817008
260
+ 2025-12-26 18:08:01,245: [INFO] Epoch [110]: loss = 0.39401426
261
+ 2025-12-26 18:08:02,815: [INFO] Epoch [111]: val_loss = 0.48511344
262
+ 2025-12-26 18:09:50,465: [INFO] Epoch [111]: loss = 0.39587875
263
+ 2025-12-26 18:09:51,912: [INFO] Epoch [112]: val_loss = 0.48533930
264
+ 2025-12-26 18:11:39,151: [INFO] Epoch [112]: loss = 0.39013777
265
+ 2025-12-26 18:11:40,611: [INFO] Epoch [113]: val_loss = 0.48621008
266
+ 2025-12-26 18:13:28,463: [INFO] Epoch [113]: loss = 0.38981328
267
+ 2025-12-26 18:13:29,992: [INFO] Epoch [114]: val_loss = 0.47124515
268
+ 2025-12-26 18:15:17,549: [INFO] Epoch [114]: loss = 0.39461992
269
+ 2025-12-26 18:15:19,012: [INFO] Epoch [115]: val_loss = 0.48522179
270
+ 2025-12-26 18:17:06,234: [INFO] Epoch [115]: loss = 0.39355375
271
+ 2025-12-26 18:17:07,711: [INFO] Epoch [116]: val_loss = 0.49023107
272
+ 2025-12-26 18:18:54,822: [INFO] Epoch [116]: loss = 0.39458753
273
+ 2025-12-26 18:18:56,312: [INFO] Epoch [117]: val_loss = 0.48466966
274
+ 2025-12-26 18:20:43,478: [INFO] Epoch [117]: loss = 0.39362156
275
+ 2025-12-26 18:20:44,916: [INFO] Epoch [118]: val_loss = 0.50641123
276
+ 2025-12-26 18:22:32,316: [INFO] Epoch [118]: loss = 0.39327146
277
+ 2025-12-26 18:22:33,812: [INFO] Epoch [119]: val_loss = 0.47998404
278
+ 2025-12-26 18:24:21,387: [INFO] Epoch [119]: loss = 0.39500003
279
+ 2025-12-26 18:24:21,388: [INFO] Checkpointing model at epoch 120
280
+ 2025-12-26 18:24:21,956: [INFO] Model checkpointed at epoch 120
281
+ 2025-12-26 18:24:23,411: [INFO] Epoch [120]: val_loss = 0.47686255
282
+ 2025-12-26 18:26:11,350: [INFO] Epoch [120]: loss = 0.39359459
283
+ 2025-12-26 18:26:12,812: [INFO] Epoch [121]: val_loss = 0.47847155
284
+ 2025-12-26 18:28:00,158: [INFO] Epoch [121]: loss = 0.39265604
285
+ 2025-12-26 18:28:01,612: [INFO] Epoch [122]: val_loss = 0.48565416
286
+ 2025-12-26 18:29:49,068: [INFO] Epoch [122]: loss = 0.39327284
287
+ 2025-12-26 18:29:50,512: [INFO] Epoch [123]: val_loss = 0.46322163
288
+ 2025-12-26 18:31:37,854: [INFO] Epoch [123]: loss = 0.39371465
289
+ 2025-12-26 18:31:39,322: [INFO] Epoch [124]: val_loss = 0.49423857
290
+ 2025-12-26 18:33:26,668: [INFO] Epoch [124]: loss = 0.39377629
291
+ 2025-12-26 18:33:28,112: [INFO] Epoch [125]: val_loss = 0.50537621
292
+ 2025-12-26 18:35:15,388: [INFO] Epoch [125]: loss = 0.39129652
293
+ 2025-12-26 18:35:16,916: [INFO] Epoch [126]: val_loss = 0.50789308
294
+ 2025-12-26 18:37:05,150: [INFO] Epoch [126]: loss = 0.39120718
295
+ 2025-12-26 18:37:06,611: [INFO] Epoch [127]: val_loss = 0.49176749
296
+ 2025-12-26 18:38:54,793: [INFO] Epoch [127]: loss = 0.39434199
297
+ 2025-12-26 18:38:56,312: [INFO] Epoch [128]: val_loss = 0.48982497
298
+ 2025-12-26 18:40:43,572: [INFO] Epoch [128]: loss = 0.39388789
299
+ 2025-12-26 18:40:45,012: [INFO] Epoch [129]: val_loss = 0.49437147
300
+ 2025-12-26 18:42:32,476: [INFO] Epoch [129]: loss = 0.39485405
301
+ 2025-12-26 18:42:34,012: [INFO] Epoch [130]: val_loss = 0.49246545
302
+ 2025-12-26 18:44:24,072: [INFO] Epoch [130]: loss = 0.39075325
303
+ 2025-12-26 18:44:25,612: [INFO] Epoch [131]: val_loss = 0.51833930
304
+ 2025-12-26 18:46:15,498: [INFO] Epoch [131]: loss = 0.39447027
305
+ 2025-12-26 18:46:16,931: [INFO] Epoch [132]: val_loss = 0.48003947
306
+ 2025-12-26 18:48:05,343: [INFO] Epoch [132]: loss = 0.39434897
307
+ 2025-12-26 18:48:06,812: [INFO] Epoch [133]: val_loss = 0.49718059
308
+ 2025-12-26 18:49:55,181: [INFO] Epoch [133]: loss = 0.39328515
309
+ 2025-12-26 18:49:56,612: [INFO] Epoch [134]: val_loss = 0.48965228
310
+ 2025-12-26 18:51:44,446: [INFO] Epoch [134]: loss = 0.39368132
311
+ 2025-12-26 18:51:45,910: [INFO] Epoch [135]: val_loss = 0.50781692
312
+ 2025-12-26 18:53:34,066: [INFO] Epoch [135]: loss = 0.39521807
313
+ 2025-12-26 18:53:35,517: [INFO] Epoch [136]: val_loss = 0.49129677
314
+ 2025-12-26 18:55:24,158: [INFO] Epoch [136]: loss = 0.39310845
315
+ 2025-12-26 18:55:25,611: [INFO] Epoch [137]: val_loss = 0.50138287
316
+ 2025-12-26 18:57:13,543: [INFO] Epoch [137]: loss = 0.39277331
317
+ 2025-12-26 18:57:15,011: [INFO] Epoch [138]: val_loss = 0.49667891
318
+ 2025-12-26 18:59:02,557: [INFO] Epoch [138]: loss = 0.39367320
319
+ 2025-12-26 18:59:04,012: [INFO] Epoch [139]: val_loss = 0.49262191
320
+ 2025-12-26 19:00:51,448: [INFO] Epoch [139]: loss = 0.39519035
321
+ 2025-12-26 19:00:51,449: [INFO] Checkpointing model at epoch 140
322
+ 2025-12-26 19:00:52,031: [INFO] Model checkpointed at epoch 140
323
+ 2025-12-26 19:00:53,512: [INFO] Epoch [140]: val_loss = 0.50657800
324
+ 2025-12-26 19:02:40,830: [INFO] Epoch [140]: loss = 0.39009643
325
+ 2025-12-26 19:02:42,311: [INFO] Epoch [141]: val_loss = 0.48729330
326
+ 2025-12-26 19:04:29,914: [INFO] Epoch [141]: loss = 0.39552269
327
+ 2025-12-26 19:04:31,412: [INFO] Epoch [142]: val_loss = 0.49246952
328
+ 2025-12-26 19:06:18,642: [INFO] Epoch [142]: loss = 0.39385797
329
+ 2025-12-26 19:06:20,111: [INFO] Epoch [143]: val_loss = 0.49985805
330
+ 2025-12-26 19:08:07,780: [INFO] Epoch [143]: loss = 0.39398983
331
+ 2025-12-26 19:08:09,216: [INFO] Epoch [144]: val_loss = 0.49885565
332
+ 2025-12-26 19:09:56,876: [INFO] Epoch [144]: loss = 0.39509994
333
+ 2025-12-26 19:09:58,312: [INFO] Epoch [145]: val_loss = 0.50998864
334
+ 2025-12-26 19:11:46,490: [INFO] Epoch [145]: loss = 0.39288819
335
+ 2025-12-26 19:11:48,006: [INFO] Epoch [146]: val_loss = 0.53312138
336
+ 2025-12-26 19:12:12,458: [WARNING] [TRAIN] Training interrupted by user. Saving model...
337
+ 2025-12-26 19:12:12,459: [INFO] [TRAIN] Saving model before exiting...
338
+ 2025-12-26 19:12:13,159: [INFO] [TRAIN] Training process finished.
339
+ 2025-12-26 19:24:52,215: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
340
+ 2025-12-26 19:24:52,223: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
341
+ 2025-12-26 19:24:52,223: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
342
+ 2025-12-26 19:24:52,236: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=114689)
343
+ 2025-12-26 19:24:52,246: [INFO] Initializer set up seed: 1766777092
344
+ 2025-12-26 19:24:52,250: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
345
+ 2025-12-26 19:24:52,250: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
346
+ 2025-12-26 19:24:52,251: [INFO] Total Memory : 45498.00 MB
347
+ 2025-12-26 19:24:52,251: [INFO] Currently Allocated : 0.00 MB
348
+ 2025-12-26 19:24:52,251: [INFO] Currently Reserved : 0.00 MB
349
+ 2025-12-26 19:24:52,252: [INFO] Max Allocated : 0.00 MB
350
+ 2025-12-26 19:24:52,252: [INFO] Max Reserved : 0.00 MB
351
+ 2025-12-26 19:24:52,252: [INFO] Setup information:
352
+ 2025-12-26 19:24:52,253: [INFO] - Setup path: /workspace/logs
353
+ 2025-12-26 19:24:52,253: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
354
+ 2025-12-26 19:24:52,253: [INFO] - Setup device: cuda:0
355
+ 2025-12-26 19:24:52,254: [INFO] - Setup seed: 1766777092
356
+ 2025-12-26 19:24:52,254: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
357
+ 2025-12-26 19:24:52,254: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x742242ae8a10>
358
+ 2025-12-26 19:24:52,254: [INFO] - Setup save each: 20
359
+ 2025-12-26 19:24:52,258: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
360
+ 2025-12-26 19:24:52,259: [INFO] [SegmentationDataset] Loaded dataset length: 26510
361
+ 2025-12-26 19:24:52,282: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
362
+ 2025-12-26 19:24:52,283: [INFO] [SegmentationDataset] Loaded dataset length: 3336
363
+ 2025-12-26 19:24:52,746: [INFO] [TRAIN] Reloading model, optimizer and scheduler states...
364
+ 2025-12-26 19:24:52,988: [INFO] Model reloaded from /workspace/logs/checkpoints/model_epoch_40.pt at epoch 40 and seed 1766760356
365
+ 2025-12-26 19:24:52,989: [INFO] Optimizer state_dict loaded from /workspace/logs/checkpoints/model_epoch_40.pt
366
+ 2025-12-26 19:24:52,989: [INFO] Scheduler state_dict loaded from /workspace/logs/checkpoints/model_epoch_40.pt
367
+ 2025-12-26 19:24:52,990: [INFO] [TRAIN] Model Configuration:
368
+ {'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
369
+ 2025-12-26 19:24:52,991: [INFO] [TRAIN] Model parameters: 11.022865 M
370
+ 2025-12-26 19:24:52,991: [INFO] [TRAIN] Trainable parameters: 11.022865 M
371
+ 2025-12-26 19:24:52,992: [INFO] [TRAIN] Training batches: 2651
372
+ 2025-12-26 20:38:24,396: [INFO] Epoch [40]: loss = 0.44221879
373
+ 2025-12-26 20:39:27,131: [INFO] Epoch [41]: val_loss = 0.43405354
374
+ 2025-12-26 21:50:17,867: [INFO] Epoch [41]: loss = 0.42664148
375
+ 2025-12-26 21:51:15,331: [INFO] Epoch [42]: val_loss = 0.42706228
376
+ 2025-12-26 23:02:34,039: [INFO] Epoch [42]: loss = 0.41818042
377
+ 2025-12-26 23:03:32,932: [INFO] Epoch [43]: val_loss = 0.42457028
378
+ 2025-12-27 00:15:19,536: [INFO] Epoch [43]: loss = 0.41249070
379
+ 2025-12-27 00:16:17,232: [INFO] Epoch [44]: val_loss = 0.42501960
380
+ 2025-12-27 01:27:24,113: [INFO] Epoch [44]: loss = 0.40819168
381
+ 2025-12-27 01:28:21,631: [INFO] Epoch [45]: val_loss = 0.42466300
382
+ 2025-12-27 02:39:08,276: [INFO] Epoch [45]: loss = 0.40515616
383
+ 2025-12-27 02:40:05,744: [INFO] Epoch [46]: val_loss = 0.42623002
384
+ 2025-12-27 03:50:51,982: [INFO] Epoch [46]: loss = 0.40199136
385
+ 2025-12-27 03:51:49,542: [INFO] Epoch [47]: val_loss = 0.42870347
386
+ 2025-12-27 05:02:35,984: [INFO] Epoch [47]: loss = 0.40118613
387
+ 2025-12-27 05:03:33,643: [INFO] Epoch [48]: val_loss = 0.42845377
388
+ 2025-12-27 06:14:22,488: [INFO] Epoch [48]: loss = 0.40013945
389
+ 2025-12-27 06:15:19,931: [INFO] Epoch [49]: val_loss = 0.43143284
390
+ 2025-12-27 07:26:05,041: [INFO] Epoch [49]: loss = 0.39879406
391
+ 2025-12-27 07:27:02,632: [INFO] Epoch [50]: val_loss = 0.43078374
392
+ 2025-12-27 08:37:47,083: [INFO] Epoch [50]: loss = 0.39769130
393
+ 2025-12-27 08:38:44,631: [INFO] Epoch [51]: val_loss = 0.43248296
394
+ 2025-12-27 09:49:30,251: [INFO] Epoch [51]: loss = 0.39770448
395
+ 2025-12-27 09:50:27,831: [INFO] Epoch [52]: val_loss = 0.43522716
396
+ 2025-12-27 11:01:17,000: [INFO] Epoch [52]: loss = 0.39691210
397
+ 2025-12-27 11:02:14,669: [INFO] Epoch [53]: val_loss = 0.43381873
398
+ 2025-12-27 12:13:00,032: [INFO] Epoch [53]: loss = 0.39669390
399
+ 2025-12-27 12:13:57,331: [INFO] Epoch [54]: val_loss = 0.43624624
400
+ 2025-12-27 13:24:41,611: [INFO] Epoch [54]: loss = 0.39618159
401
+ 2025-12-27 13:25:38,932: [INFO] Epoch [55]: val_loss = 0.43923773
402
+ 2025-12-27 14:36:23,346: [INFO] Epoch [55]: loss = 0.39546988
403
+ 2025-12-27 14:37:20,669: [INFO] Epoch [56]: val_loss = 0.44194769
404
+ 2025-12-27 15:48:09,320: [INFO] Epoch [56]: loss = 0.39544456
405
+ 2025-12-27 15:49:06,843: [INFO] Epoch [57]: val_loss = 0.43803932
406
+ 2025-12-27 16:59:54,398: [INFO] Epoch [57]: loss = 0.39482789
407
+ 2025-12-27 17:00:52,031: [INFO] Epoch [58]: val_loss = 0.43835160
408
+ 2025-12-27 18:11:44,408: [INFO] Epoch [58]: loss = 0.39461407
409
+ 2025-12-27 18:12:41,958: [INFO] Epoch [59]: val_loss = 0.44492985
410
+ 2025-12-27 19:23:29,892: [INFO] Epoch [59]: loss = 0.39447898
411
+ 2025-12-27 19:23:29,894: [INFO] Checkpointing model at epoch 60
412
+ 2025-12-27 19:23:30,517: [INFO] Model checkpointed at epoch 60
413
+ 2025-12-27 19:24:27,831: [INFO] Epoch [60]: val_loss = 0.44300878
414
+ 2025-12-27 19:27:50,666: [WARNING] [TRAIN] Training interrupted by user. Saving model...
415
+ 2025-12-27 19:27:50,668: [INFO] [TRAIN] Saving model before exiting...
416
+ 2025-12-27 19:27:51,180: [INFO] [TRAIN] Training process finished.
417
+ 2025-12-27 19:33:34,471: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
418
+ 2025-12-27 19:33:34,478: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
419
+ 2025-12-27 19:33:34,479: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
420
+ 2025-12-27 19:33:34,491: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=235187)
421
+ 2025-12-27 19:33:34,498: [INFO] Initializer set up seed: 1766864014
422
+ 2025-12-27 19:33:34,501: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
423
+ 2025-12-27 19:33:34,507: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
424
+ 2025-12-27 19:33:34,507: [INFO] Total Memory : 45498.00 MB
425
+ 2025-12-27 19:33:34,507: [INFO] Currently Allocated : 0.00 MB
426
+ 2025-12-27 19:33:34,507: [INFO] Currently Reserved : 0.00 MB
427
+ 2025-12-27 19:33:34,508: [INFO] Max Allocated : 0.00 MB
428
+ 2025-12-27 19:33:34,508: [INFO] Max Reserved : 0.00 MB
429
+ 2025-12-27 19:33:34,508: [INFO] Setup information:
430
+ 2025-12-27 19:33:34,509: [INFO] - Setup path: /workspace/logs
431
+ 2025-12-27 19:33:34,509: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
432
+ 2025-12-27 19:33:34,509: [INFO] - Setup device: cuda:0
433
+ 2025-12-27 19:33:34,510: [INFO] - Setup seed: 1766864014
434
+ 2025-12-27 19:33:34,510: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
435
+ 2025-12-27 19:33:34,510: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x74fc724f4f50>
436
+ 2025-12-27 19:33:34,511: [INFO] - Setup save each: 1
437
+ 2025-12-27 19:33:34,515: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
438
+ 2025-12-27 19:33:34,515: [INFO] [SegmentationDataset] Loaded dataset length: 26510
439
+ 2025-12-27 19:33:35,262: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
440
+ 2025-12-27 19:33:35,262: [INFO] [SegmentationDataset] Loaded dataset length: 3336
441
+ 2025-12-27 19:33:35,796: [INFO] [TRAIN] Reloading model, optimizer and scheduler states...
442
+ 2025-12-27 19:33:35,926: [INFO] Model reloaded from /workspace/logs/checkpoints/model_epoch_60.pt at epoch 60 and seed 1766760356
443
+ 2025-12-27 19:33:35,927: [INFO] Optimizer state_dict loaded from /workspace/logs/checkpoints/model_epoch_60.pt
444
+ 2025-12-27 19:33:35,927: [INFO] Scheduler state_dict loaded from /workspace/logs/checkpoints/model_epoch_60.pt
445
+ 2025-12-27 19:33:35,927: [INFO] [TRAIN] Model Configuration:
446
+ {'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
447
+ 2025-12-27 19:33:35,928: [INFO] [TRAIN] Model parameters: 11.022865 M
448
+ 2025-12-27 19:33:35,928: [INFO] [TRAIN] Trainable parameters: 11.022865 M
449
+ 2025-12-27 19:33:35,928: [INFO] [TRAIN] Training batches: 6628
450
+ 2025-12-27 22:37:53,178: [INFO] Epoch [60]: loss = 0.41291385
451
+ 2025-12-27 22:37:53,181: [INFO] Checkpointing model at epoch 61
452
+ 2025-12-27 22:37:53,822: [INFO] Model checkpointed at epoch 61
453
+ 2025-12-27 22:40:30,975: [INFO] Epoch [61]: val_loss = 0.42101070
454
+ 2025-12-28 01:37:44,628: [INFO] Epoch [61]: loss = 0.40628406
455
+ 2025-12-28 01:37:44,629: [INFO] Checkpointing model at epoch 62
456
+ 2025-12-28 01:37:45,260: [INFO] Model checkpointed at epoch 62
457
+ 2025-12-28 01:40:08,586: [INFO] Epoch [62]: val_loss = 0.41949525
458
+ 2025-12-28 04:37:33,182: [INFO] Epoch [62]: loss = 0.40297545
459
+ 2025-12-28 04:37:33,184: [INFO] Checkpointing model at epoch 63
460
+ 2025-12-28 04:37:33,865: [INFO] Model checkpointed at epoch 63
461
+ 2025-12-28 04:39:58,256: [INFO] Epoch [63]: val_loss = 0.42107245
462
+ 2025-12-28 07:37:53,163: [INFO] Epoch [63]: loss = 0.40075299
463
+ 2025-12-28 07:37:53,165: [INFO] Checkpointing model at epoch 64
464
+ 2025-12-28 07:37:53,812: [INFO] Model checkpointed at epoch 64
465
+ 2025-12-28 07:40:18,271: [INFO] Epoch [64]: val_loss = 0.42276877
466
+ 2025-12-28 10:37:45,887: [INFO] Epoch [64]: loss = 0.39892435
467
+ 2025-12-28 10:37:45,888: [INFO] Checkpointing model at epoch 65
468
+ 2025-12-28 10:37:46,521: [INFO] Model checkpointed at epoch 65
469
+ 2025-12-28 10:40:11,142: [INFO] Epoch [65]: val_loss = 0.42484788
470
+ 2025-12-28 13:37:21,540: [INFO] Epoch [65]: loss = 0.39751294
471
+ 2025-12-28 13:37:21,541: [INFO] Checkpointing model at epoch 66
472
+ 2025-12-28 13:37:22,128: [INFO] Model checkpointed at epoch 66
473
+ 2025-12-28 13:39:45,583: [INFO] Epoch [66]: val_loss = 0.42598702
474
+ 2025-12-28 16:36:57,870: [INFO] Epoch [66]: loss = 0.39654398
475
+ 2025-12-28 16:36:57,872: [INFO] Checkpointing model at epoch 67
476
+ 2025-12-28 16:36:58,476: [INFO] Model checkpointed at epoch 67
477
+ 2025-12-28 16:39:23,196: [INFO] Epoch [67]: val_loss = 0.42759763
478
+ 2025-12-28 19:37:48,749: [INFO] Epoch [67]: loss = 0.39627669
479
+ 2025-12-28 19:37:48,752: [INFO] Checkpointing model at epoch 68
480
+ 2025-12-28 19:37:49,475: [INFO] Model checkpointed at epoch 68
481
+ 2025-12-28 19:40:15,986: [INFO] Epoch [68]: val_loss = 0.42856705
482
+ 2025-12-28 22:39:07,980: [INFO] Epoch [68]: loss = 0.39574215
483
+ 2025-12-28 22:39:07,982: [INFO] Checkpointing model at epoch 69
484
+ 2025-12-28 22:39:08,617: [INFO] Model checkpointed at epoch 69
485
+ 2025-12-28 22:41:33,142: [INFO] Epoch [69]: val_loss = 0.43066087
486
+ 2025-12-29 01:39:35,127: [INFO] Epoch [69]: loss = 0.39522947
487
+ 2025-12-29 01:39:35,129: [INFO] Checkpointing model at epoch 70
488
+ 2025-12-29 01:39:35,774: [INFO] Model checkpointed at epoch 70
489
+ 2025-12-29 01:42:00,471: [INFO] Epoch [70]: val_loss = 0.43032571
train/train_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ import tqdm
10
+ from train.config import configuration, TrainConfig
11
+ from src.model import SegmentationNetwork, MaskedBCELoss
12
+ from src.dataset import TokenizedSegmentationDataset
13
+ from src.dlutils import Setup, train_step, validation_step
14
+
15
+
16
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
17
+ # #
18
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
19
+ def train(controller: Setup, config: TrainConfig):
20
+ """
21
+ Main training function
22
+ :param controller: A training controller
23
+ :param config: The experiment configuration
24
+ :return: None
25
+ """
26
+
27
+ # 1. Train and val datasets:
28
+ train_dataset = TokenizedSegmentationDataset(
29
+ tokenized_dataset=config.dataset_config.train_data_path,
30
+ logger=controller.logger,
31
+ percentage=config.dataset_config.train_percentage,
32
+ return_type=tuple
33
+ ).get_loader(
34
+ config.batch_size,
35
+ shuffle=config.dataset_config.shuffle_train,
36
+ num_workers=config.dataset_config.num_workers
37
+ )
38
+ val_dataset = TokenizedSegmentationDataset(
39
+ tokenized_dataset=config.dataset_config.val_data_path,
40
+ logger=controller.logger,
41
+ percentage=config.dataset_config.val_percentage,
42
+ return_type=tuple
43
+ ).get_loader(
44
+ config.batch_size,
45
+ shuffle=config.dataset_config.shuffle_val,
46
+ num_workers=config.dataset_config.num_workers
47
+ )
48
+
49
+ # 2. Model, loss, optimizer:
50
+ model = SegmentationNetwork(config.model_config).to(controller.device)
51
+ loss_fn = MaskedBCELoss(valid_pad=config.model_config.valid_padding)
52
+ optimizer = torch.optim.AdamW(
53
+ params=model.parameters(),
54
+ lr=config.learning_rate,
55
+ weight_decay=config.weight_decay,
56
+ betas=config.betas
57
+ )
58
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
59
+ optimizer=optimizer,
60
+ T_max=config.num_epochs,
61
+ eta_min=config.learning_rate_min
62
+ )
63
+
64
+ # 3. Reload checkpoint if needed:
65
+ if config.setup_config.reload_checkpoint:
66
+ controller.logger.info("[TRAIN] Reloading model, optimizer and scheduler states...")
67
+ controller.reload(model, optimizer, lr_scheduler)
68
+
69
+ # 4. Log info:
70
+ controller.logger.info(f"[TRAIN] Model Configuration:\n{config.model_config.__dict__}")
71
+ controller.logger.info(f"[TRAIN] Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6} M")
72
+ controller.logger.info(f"[TRAIN] Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6} M")
73
+ controller.logger.info(f"[TRAIN] Training batches: {len(train_dataset)}")
74
+ controller.save_config(config)
75
+
76
+ # 5. Set watchers:
77
+ controller.set_watcher('A')
78
+ controller.set_watcher('transformer')
79
+
80
+ # 6. Train loop:
81
+ try:
82
+ for _ in tqdm.tqdm(range(controller.epoch, config.num_epochs), desc="Epochs", unit="epoch"):
83
+ # Train step:
84
+ train_step(
85
+ model=model,
86
+ data=train_dataset,
87
+ loss=loss_fn,
88
+ optimizer=optimizer,
89
+ controller=controller,
90
+ scheduler=lr_scheduler
91
+ )
92
+
93
+ validation_step(
94
+ model=model,
95
+ data=val_dataset,
96
+ loss=loss_fn,
97
+ controller=controller
98
+ )
99
+ except KeyboardInterrupt:
100
+ controller.logger.warning("[TRAIN] Training interrupted by user. Saving model...")
101
+ except Exception as e:
102
+ controller.logger.error(f"[TRAIN] An error has occurred during training: {e}")
103
+ raise e
104
+ finally:
105
+ # 7. End of training:
106
+ controller.logger.info("[TRAIN] Saving model before exiting...")
107
+ controller.save_model(model, optimizer, lr_scheduler)
108
+ controller.logger.info("[TRAIN] Training process finished.")
109
+ input("[TRAIN] Training finished. Press any key to exit...")
110
+
111
+
112
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
113
+ # #
114
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
115
+ if __name__ == "__main__":
116
+ conf = configuration()
117
+ with Setup(
118
+ path=conf.setup_config.logging_path,
119
+ device=conf.setup_config.device_number,
120
+ seed=conf.setup_config.seed,
121
+ save_each=conf.setup_config.save_model_each,
122
+ reload_state=conf.setup_config.reload_checkpoint,
123
+ replay_element=(0, None)
124
+ ) as setup:
125
+ train(setup, conf)
126
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
127
+ # END OF FILE #
128
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #