Chiquitin
commited on
Commit
·
482fd8d
1
Parent(s):
3cca845
upload source code and train configurations
Browse files- requirements.txt +10 -0
- src/dataset/__init__.py +13 -0
- src/dataset/config.py +29 -0
- src/dataset/dataset.py +199 -0
- src/dataset/tokenized_dataset.py +217 -0
- src/dataset/tokenizer.py +240 -0
- src/dlutils/__init__.py +11 -0
- src/dlutils/setup/__init__.py +12 -0
- src/dlutils/setup/clear.py +33 -0
- src/dlutils/setup/device.py +62 -0
- src/dlutils/setup/full_setup.py +352 -0
- src/dlutils/setup/functions.py +227 -0
- src/dlutils/setup/hooks.py +259 -0
- src/dlutils/setup/logger.py +91 -0
- src/dlutils/setup/marker.py +251 -0
- src/dlutils/setup/seeds.py +71 -0
- src/dlutils/setup/tensorboard.py +74 -0
- src/dlutils/setup/watchers.py +153 -0
- src/dlutils/steps.py +246 -0
- src/model/__init__.py +13 -0
- src/model/config.py +37 -0
- src/model/cosenet/__init__.py +12 -0
- src/model/cosenet/cosenet.py +189 -0
- src/model/cosenet/cosenet_layer.py +55 -0
- src/model/cosenet/cosine_distance.py +57 -0
- src/model/cosenet/trainable_sigmoid.py +53 -0
- src/model/loss.py +115 -0
- src/model/segmentation.py +147 -0
- src/model/transformers/__init__.py +13 -0
- src/model/transformers/attention.py +176 -0
- src/model/transformers/pooling.py +78 -0
- src/model/transformers/positional_encoding.py +62 -0
- train/config.py +193 -0
- train/train_logs/config.json +54 -0
- train/train_logs/logfile.log +489 -0
- train/train_model.py +128 -0
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==2.3.5
|
| 2 |
+
torch==2.5.1+cu121
|
| 3 |
+
torchaudio==2.5.1+cu121
|
| 4 |
+
torchvision==0.20.1+cu121
|
| 5 |
+
tensorboard==2.20.0
|
| 6 |
+
matplotlib==3.10.7
|
| 7 |
+
datasets==4.4.1
|
| 8 |
+
psutil==7.1.3
|
| 9 |
+
spacy==3.8.11
|
| 10 |
+
tqdm==4.67.1
|
src/dataset/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 8 |
+
from .dataset import SegmentationDataset
|
| 9 |
+
from .tokenized_dataset import TokenizedSegmentationDataset
|
| 10 |
+
from .config import DatasetConfig
|
| 11 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 12 |
+
# END OF FILE #
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class DatasetConfig:
|
| 13 |
+
# Paths:
|
| 14 |
+
train_data_path: str = None
|
| 15 |
+
val_data_path: str = None
|
| 16 |
+
test_data_path: str = None
|
| 17 |
+
# Percentages:
|
| 18 |
+
train_percentage: float = 1.0
|
| 19 |
+
val_percentage: float = 1.0
|
| 20 |
+
test_percentage: float = 1.0
|
| 21 |
+
# Other parameters:
|
| 22 |
+
num_workers: int = 0
|
| 23 |
+
shuffle_train: bool = True
|
| 24 |
+
shuffle_val: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 28 |
+
# END OF FILE #
|
| 29 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/dataset.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from datasets import Dataset as HfDataset
|
| 11 |
+
from datasets import load_from_disk
|
| 12 |
+
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
+
# #
|
| 17 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 18 |
+
class SegmentationDataset(Dataset):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
huggingface_dataset: str | HfDataset,
|
| 22 |
+
tokenizer: SegmentationTokenizer,
|
| 23 |
+
segmenter: SentenceSegmenter,
|
| 24 |
+
logger: logging.Logger = None,
|
| 25 |
+
percentage: float = 1.0,
|
| 26 |
+
return_type: type = dict
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
|
| 30 |
+
wikipedia-segmentation format. It loads the dataset and prepares it for training.
|
| 31 |
+
|
| 32 |
+
Wikipedia-segmentation format:
|
| 33 |
+
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
|
| 34 |
+
- The dataset should contain the following fields:
|
| 35 |
+
>>> sample = {
|
| 36 |
+
>>> 'text': ['Article 1', 'Article 2', ...],
|
| 37 |
+
>>> 'titles': ['Title 1', 'Title 2', ...],
|
| 38 |
+
>>> 'id': str,
|
| 39 |
+
>>> 'words': int
|
| 40 |
+
>>> 'paragraphs': int
|
| 41 |
+
>>> 'sentences': int
|
| 42 |
+
>>> }
|
| 43 |
+
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
huggingface_dataset : str | HfDataset
|
| 48 |
+
A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
|
| 49 |
+
|
| 50 |
+
tokenizer : callable
|
| 51 |
+
A tokenizer function that takes a string and returns a list of tokens.
|
| 52 |
+
|
| 53 |
+
logger : logging.Logger, optional
|
| 54 |
+
Logger instance. If not provided, a null logger will be used.
|
| 55 |
+
|
| 56 |
+
percentage : float
|
| 57 |
+
Percentage of the dataset to use. Default is 1.0 (100%).
|
| 58 |
+
|
| 59 |
+
return_type : type
|
| 60 |
+
The return type of __getitem__, either dict or tuple. Default is dict.
|
| 61 |
+
|
| 62 |
+
Raises
|
| 63 |
+
------
|
| 64 |
+
ValueError
|
| 65 |
+
If the huggingface_dataset is not a string or a HfDataset.
|
| 66 |
+
ValueError
|
| 67 |
+
If the tokenizer is not a callable function or class.
|
| 68 |
+
ValueError
|
| 69 |
+
If the sentence_tokenizer is not a callable function or class.
|
| 70 |
+
ValueError
|
| 71 |
+
If the dtype is not a type.
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
# Null logging:
|
| 75 |
+
if not isinstance(logger, logging.Logger):
|
| 76 |
+
self.logger = logging.getLogger("null")
|
| 77 |
+
self.logger.addHandler(logging.NullHandler())
|
| 78 |
+
else:
|
| 79 |
+
self.logger = logger
|
| 80 |
+
|
| 81 |
+
# Loading:
|
| 82 |
+
if isinstance(huggingface_dataset, HfDataset):
|
| 83 |
+
self.huggingface_dataset = huggingface_dataset
|
| 84 |
+
elif isinstance(huggingface_dataset, str):
|
| 85 |
+
self.huggingface_dataset = load_from_disk(huggingface_dataset)
|
| 86 |
+
else:
|
| 87 |
+
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 88 |
+
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 89 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
|
| 90 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
|
| 91 |
+
|
| 92 |
+
# Tokenizer:
|
| 93 |
+
if callable(tokenizer):
|
| 94 |
+
self.tokenizer = tokenizer
|
| 95 |
+
else:
|
| 96 |
+
self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 97 |
+
raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 98 |
+
|
| 99 |
+
# Segmenter:
|
| 100 |
+
if not isinstance(segmenter, SentenceSegmenter):
|
| 101 |
+
self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 102 |
+
raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 103 |
+
else:
|
| 104 |
+
self.segmenter = segmenter
|
| 105 |
+
|
| 106 |
+
# Percentage:
|
| 107 |
+
if not (0.0 < percentage <= 1.0):
|
| 108 |
+
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 109 |
+
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 110 |
+
else:
|
| 111 |
+
self.percentage = percentage
|
| 112 |
+
|
| 113 |
+
# Return type:
|
| 114 |
+
if not isinstance(return_type, type):
|
| 115 |
+
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
|
| 116 |
+
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
|
| 117 |
+
elif return_type not in [dict, tuple]:
|
| 118 |
+
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 119 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 120 |
+
else:
|
| 121 |
+
self.return_type = return_type
|
| 122 |
+
|
| 123 |
+
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
|
| 124 |
+
"""
|
| 125 |
+
Returns a PyTorch DataLoader for this dataset.
|
| 126 |
+
|
| 127 |
+
Parameters
|
| 128 |
+
----------
|
| 129 |
+
batch_size : int
|
| 130 |
+
Number of samples per batch.
|
| 131 |
+
shuffle : bool
|
| 132 |
+
Whether to shuffle the dataset.
|
| 133 |
+
num_workers : int
|
| 134 |
+
Number of worker processes.
|
| 135 |
+
**kwargs
|
| 136 |
+
Additional arguments for DataLoader.
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
[torch.utils.data.DataLoader
|
| 141 |
+
Configured DataLoader.
|
| 142 |
+
"""
|
| 143 |
+
# Size handling:
|
| 144 |
+
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
| 145 |
+
pin_memory=True, **kwargs)
|
| 146 |
+
|
| 147 |
+
def __len__(self) -> int:
|
| 148 |
+
"""
|
| 149 |
+
Returns the number of samples in the dataset.
|
| 150 |
+
|
| 151 |
+
Returns
|
| 152 |
+
-------
|
| 153 |
+
int
|
| 154 |
+
Total number of samples.
|
| 155 |
+
"""
|
| 156 |
+
return int(self.huggingface_dataset.num_rows * self.percentage)
|
| 157 |
+
|
| 158 |
+
def __getitem__(self, idx) -> dict | tuple:
|
| 159 |
+
"""
|
| 160 |
+
Retrieves a single sample and generates segmentation labels.
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
idx : int
|
| 165 |
+
Index of the sample.
|
| 166 |
+
|
| 167 |
+
Returns
|
| 168 |
+
-------
|
| 169 |
+
tuple
|
| 170 |
+
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
|
| 171 |
+
"""
|
| 172 |
+
sample = self.huggingface_dataset[idx]['text']
|
| 173 |
+
sentences = self.segmenter(sample)
|
| 174 |
+
tokenized = self.tokenizer(sentences['sentences'])
|
| 175 |
+
|
| 176 |
+
if self.return_type == tuple:
|
| 177 |
+
return (
|
| 178 |
+
tokenized['input_ids'], # x
|
| 179 |
+
sentences['sentence_boundaries'], # y
|
| 180 |
+
tokenized['attention_mask'], # x_mask
|
| 181 |
+
sentences['sentence_mask'], # y_mask
|
| 182 |
+
sentences['sentence_candidates'], # y_prime_mask
|
| 183 |
+
)
|
| 184 |
+
elif self.return_type == dict:
|
| 185 |
+
return_value = {
|
| 186 |
+
'input': tokenized['input_ids'],
|
| 187 |
+
'input_mask': tokenized['attention_mask'],
|
| 188 |
+
'labels': sentences['sentence_boundaries'],
|
| 189 |
+
'output_mask': sentences['sentence_mask'],
|
| 190 |
+
'candidate_mask': sentences['sentence_candidates']
|
| 191 |
+
}
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 194 |
+
return return_value
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 198 |
+
# END OF FILE #
|
| 199 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/tokenized_dataset.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
+
# #
|
| 17 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 18 |
+
class TokenizedSegmentationDataset(Dataset):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
tokenized_dataset: str,
|
| 22 |
+
logger: logging.Logger = None,
|
| 23 |
+
percentage: float = 1.0,
|
| 24 |
+
return_type: type = dict
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
A tokoenized segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
|
| 28 |
+
wikipedia-segmentation format. It loads the dataset and prepares it for training.
|
| 29 |
+
|
| 30 |
+
Wikipedia-segmentation format:
|
| 31 |
+
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
|
| 32 |
+
- The dataset should contain the following fields:
|
| 33 |
+
>>> sample = {
|
| 34 |
+
>>> 'text': ['Article 1', 'Article 2', ...],
|
| 35 |
+
>>> 'titles': ['Title 1', 'Title 2', ...],
|
| 36 |
+
>>> 'id': str,
|
| 37 |
+
>>> 'words': int
|
| 38 |
+
>>> 'paragraphs': int
|
| 39 |
+
>>> 'sentences': int
|
| 40 |
+
>>> }
|
| 41 |
+
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
|
| 42 |
+
|
| 43 |
+
Parameters
|
| 44 |
+
----------
|
| 45 |
+
tokenized_dataset : str
|
| 46 |
+
A path to a tokenized dataset on disk with the wikipedia-segmentation format.
|
| 47 |
+
|
| 48 |
+
logger : logging.Logger, optional
|
| 49 |
+
Logger instance. If not provided, a null logger will be used.
|
| 50 |
+
|
| 51 |
+
percentage : float
|
| 52 |
+
Percentage of the dataset to use. Default is 1.0 (100%).
|
| 53 |
+
|
| 54 |
+
return_type : type
|
| 55 |
+
The return type of __getitem__, either dict or tuple. Default is dict.
|
| 56 |
+
|
| 57 |
+
Raises
|
| 58 |
+
------
|
| 59 |
+
ValueError
|
| 60 |
+
If the huggingface_dataset is not a string or a HfDataset.
|
| 61 |
+
ValueError
|
| 62 |
+
If the tokenizer is not a callable function or class.
|
| 63 |
+
ValueError
|
| 64 |
+
If the sentence_tokenizer is not a callable function or class.
|
| 65 |
+
ValueError
|
| 66 |
+
If the dtype is not a type.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
# Null logging:
|
| 70 |
+
if not isinstance(logger, logging.Logger):
|
| 71 |
+
self.logger = logging.getLogger("null")
|
| 72 |
+
self.logger.addHandler(logging.NullHandler())
|
| 73 |
+
else:
|
| 74 |
+
self.logger = logger
|
| 75 |
+
|
| 76 |
+
# Loading:
|
| 77 |
+
if isinstance(tokenized_dataset, str):
|
| 78 |
+
self.metadata_path = os.path.join(tokenized_dataset, 'info.json')
|
| 79 |
+
if not os.path.exists(self.metadata_path):
|
| 80 |
+
self.logger.error(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
|
| 81 |
+
raise FileNotFoundError(f'[SegmentationDataset] Dataset metadata file not found at {self.metadata_path}.')
|
| 82 |
+
else:
|
| 83 |
+
with open(self.metadata_path, 'r', encoding='utf-8') as f:
|
| 84 |
+
self.metadata = json.load(f)
|
| 85 |
+
if 'fingerprint' not in self.metadata or not self.metadata['fingerprint']:
|
| 86 |
+
raise ValueError(f'[SegmentationDataset] Dataset metadata file is missing fingerprint information.')
|
| 87 |
+
else:
|
| 88 |
+
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 89 |
+
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 90 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset: {tokenized_dataset}')
|
| 91 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.metadata["samples"]}')
|
| 92 |
+
|
| 93 |
+
# Percentage:
|
| 94 |
+
if not (0.0 < percentage <= 1.0):
|
| 95 |
+
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 96 |
+
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 97 |
+
else:
|
| 98 |
+
self.percentage = percentage
|
| 99 |
+
|
| 100 |
+
# Return type:
|
| 101 |
+
if not isinstance(return_type, type):
|
| 102 |
+
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
|
| 103 |
+
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
|
| 104 |
+
elif return_type not in [dict, tuple]:
|
| 105 |
+
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 106 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 107 |
+
else:
|
| 108 |
+
self.return_type = return_type
|
| 109 |
+
|
| 110 |
+
self.metadata['max_sentences'] = self.metadata['x']['element_shape'][0]
|
| 111 |
+
self.metadata['max_tokens'] = self.metadata['x']['element_shape'][1]
|
| 112 |
+
|
| 113 |
+
# Build maps:
|
| 114 |
+
read_mode = 'r'
|
| 115 |
+
self.x_map = np.memmap(
|
| 116 |
+
os.path.join(tokenized_dataset, self.metadata['x']['name'] + self.metadata['x']['extension']),
|
| 117 |
+
dtype=self.metadata['x']['dtype'],
|
| 118 |
+
mode=read_mode,
|
| 119 |
+
shape=(self.metadata['x']['samples'], *self.metadata['x']['element_shape'])
|
| 120 |
+
)
|
| 121 |
+
self.y_map = np.memmap(
|
| 122 |
+
os.path.join(tokenized_dataset, self.metadata['y']['name'] + self.metadata['y']['extension']),
|
| 123 |
+
dtype=self.metadata['y']['dtype'],
|
| 124 |
+
mode=read_mode,
|
| 125 |
+
shape=(self.metadata['y']['samples'], *self.metadata['y']['element_shape'])
|
| 126 |
+
)
|
| 127 |
+
self.x_mask_map = np.memmap(
|
| 128 |
+
os.path.join(tokenized_dataset, self.metadata['x_mask']['name'] + self.metadata['x_mask']['extension']),
|
| 129 |
+
dtype=self.metadata['x_mask']['dtype'],
|
| 130 |
+
mode=read_mode,
|
| 131 |
+
shape=(self.metadata['x_mask']['samples'], *self.metadata['x_mask']['element_shape'])
|
| 132 |
+
)
|
| 133 |
+
self.y_mask_map = np.memmap(
|
| 134 |
+
os.path.join(tokenized_dataset, self.metadata['y_mask']['name'] + self.metadata['y_mask']['extension']),
|
| 135 |
+
dtype=self.metadata['y_mask']['dtype'],
|
| 136 |
+
mode=read_mode,
|
| 137 |
+
shape=(self.metadata['y_mask']['samples'], *self.metadata['y_mask']['element_shape'])
|
| 138 |
+
)
|
| 139 |
+
self.y_cand_map = np.memmap(
|
| 140 |
+
os.path.join(tokenized_dataset, self.metadata['y_cand']['name'] + self.metadata['y_cand']['extension']),
|
| 141 |
+
dtype=self.metadata['y_cand']['dtype'],
|
| 142 |
+
mode=read_mode,
|
| 143 |
+
shape=(self.metadata['y_cand']['samples'], *self.metadata['y_cand']['element_shape'])
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
|
| 147 |
+
"""
|
| 148 |
+
Returns a PyTorch DataLoader for this dataset.
|
| 149 |
+
|
| 150 |
+
Parameters
|
| 151 |
+
----------
|
| 152 |
+
batch_size : int
|
| 153 |
+
Number of samples per batch.
|
| 154 |
+
shuffle : bool
|
| 155 |
+
Whether to shuffle the dataset.
|
| 156 |
+
num_workers : int
|
| 157 |
+
Number of worker processes.
|
| 158 |
+
**kwargs
|
| 159 |
+
Additional arguments for DataLoader.
|
| 160 |
+
|
| 161 |
+
Returns
|
| 162 |
+
-------
|
| 163 |
+
[torch.utils.data.DataLoader
|
| 164 |
+
Configured DataLoader.
|
| 165 |
+
"""
|
| 166 |
+
# Size handling:
|
| 167 |
+
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True,
|
| 168 |
+
**kwargs)
|
| 169 |
+
|
| 170 |
+
def __len__(self) -> int:
|
| 171 |
+
"""
|
| 172 |
+
Returns the number of samples in the dataset.
|
| 173 |
+
|
| 174 |
+
Returns
|
| 175 |
+
-------
|
| 176 |
+
int
|
| 177 |
+
Total number of samples.
|
| 178 |
+
"""
|
| 179 |
+
return int(self.metadata['samples'] * self.percentage)
|
| 180 |
+
|
| 181 |
+
def __getitem__(self, idx) -> dict | tuple:
|
| 182 |
+
"""
|
| 183 |
+
Retrieves a single sample and generates segmentation labels.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
idx : int
|
| 188 |
+
Index of the sample.
|
| 189 |
+
|
| 190 |
+
Returns
|
| 191 |
+
-------
|
| 192 |
+
tuple
|
| 193 |
+
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
|
| 194 |
+
"""
|
| 195 |
+
if self.return_type == tuple:
|
| 196 |
+
return (
|
| 197 |
+
np.array(self.x_map[idx]), # ← copia
|
| 198 |
+
np.array(self.y_map[idx]),
|
| 199 |
+
np.array(self.x_mask_map[idx]),
|
| 200 |
+
np.array(self.y_mask_map[idx]),
|
| 201 |
+
np.array(self.y_cand_map[idx]),
|
| 202 |
+
)
|
| 203 |
+
elif self.return_type == dict:
|
| 204 |
+
return {
|
| 205 |
+
'input': np.array(self.x_map[idx]),
|
| 206 |
+
'input_mask': np.array(self.x_mask_map[idx]),
|
| 207 |
+
'labels': np.array(self.y_map[idx]),
|
| 208 |
+
'output_mask': np.array(self.y_mask_map[idx]),
|
| 209 |
+
'candidate_mask': np.array(self.y_cand_map[idx]),
|
| 210 |
+
}
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 216 |
+
# END OF FILE #
|
| 217 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/tokenizer.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import tokenizers
|
| 9 |
+
import sys
|
| 10 |
+
import subprocess
|
| 11 |
+
import logging
|
| 12 |
+
import spacy
|
| 13 |
+
import numpy as np
|
| 14 |
+
from tokenizers.models import BPE
|
| 15 |
+
from tokenizers.trainers import BpeTrainer
|
| 16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
+
from tokenizers.normalizers import NFKC
|
| 18 |
+
from transformers import PreTrainedTokenizerFast
|
| 19 |
+
|
| 20 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SegmentationTokenizer:
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
vocab_size=32_768,
|
| 27 |
+
min_frequency=2,
|
| 28 |
+
max_length=1024
|
| 29 |
+
):
|
| 30 |
+
self.max_length = max_length
|
| 31 |
+
|
| 32 |
+
# Raw tokenizer (training)
|
| 33 |
+
self.raw_tokenizer = tokenizers.Tokenizer(
|
| 34 |
+
BPE(unk_token="[UNK]")
|
| 35 |
+
)
|
| 36 |
+
self.raw_tokenizer.normalizer = NFKC()
|
| 37 |
+
self.raw_tokenizer.pre_tokenizer = Whitespace()
|
| 38 |
+
|
| 39 |
+
self.trainer = BpeTrainer(
|
| 40 |
+
vocab_size=vocab_size,
|
| 41 |
+
min_frequency=min_frequency,
|
| 42 |
+
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self._hf_tokenizer = None # created after training
|
| 46 |
+
|
| 47 |
+
# ---------- TRAINING ----------
|
| 48 |
+
def build_iterator(self, dataset, batch_size=1024):
|
| 49 |
+
batch = []
|
| 50 |
+
for item in dataset:
|
| 51 |
+
batch.append("\n".join(item["text"]).replace("\n\n", "\n"))
|
| 52 |
+
if len(batch) == batch_size:
|
| 53 |
+
yield batch
|
| 54 |
+
batch = []
|
| 55 |
+
if batch:
|
| 56 |
+
yield batch
|
| 57 |
+
|
| 58 |
+
def train_from_iterator(self, iterator):
|
| 59 |
+
self.raw_tokenizer.train_from_iterator(
|
| 60 |
+
iterator, trainer=self.trainer
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# ---------- IO ----------
|
| 64 |
+
def save(self, path):
|
| 65 |
+
self.raw_tokenizer.save(path)
|
| 66 |
+
|
| 67 |
+
def load(self, tokenizer_path):
|
| 68 |
+
self._hf_tokenizer = PreTrainedTokenizerFast(
|
| 69 |
+
tokenizer_file=tokenizer_path,
|
| 70 |
+
unk_token="[UNK]",
|
| 71 |
+
pad_token="[PAD]",
|
| 72 |
+
cls_token="[CLS]",
|
| 73 |
+
sep_token="[SEP]",
|
| 74 |
+
mask_token="[MASK]"
|
| 75 |
+
)
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
# ---------- TOKENIZATION ----------
|
| 79 |
+
def compute_unk_rate(self, corpus):
|
| 80 |
+
unk_id = self._hf_tokenizer.convert_tokens_to_ids("[UNK]")
|
| 81 |
+
|
| 82 |
+
total_tokens = 0
|
| 83 |
+
unk_tokens = 0
|
| 84 |
+
|
| 85 |
+
for text in corpus:
|
| 86 |
+
enc = self._hf_tokenizer(
|
| 87 |
+
text,
|
| 88 |
+
add_special_tokens=False
|
| 89 |
+
)["input_ids"]
|
| 90 |
+
|
| 91 |
+
total_tokens += len(enc)
|
| 92 |
+
unk_tokens += sum(1 for t in enc if t == unk_id)
|
| 93 |
+
|
| 94 |
+
return unk_tokens / total_tokens if total_tokens > 0 else 0.0
|
| 95 |
+
|
| 96 |
+
def __call__(
|
| 97 |
+
self,
|
| 98 |
+
text,
|
| 99 |
+
return_tensors="pt",
|
| 100 |
+
padding=True,
|
| 101 |
+
truncation=True
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
text: str or List[str]
|
| 105 |
+
returns: dict with input_ids and attention_mask (torch.long)
|
| 106 |
+
"""
|
| 107 |
+
if self._hf_tokenizer is None:
|
| 108 |
+
raise RuntimeError("Tokenizer not loaded. Call .load() first.")
|
| 109 |
+
|
| 110 |
+
enc = self._hf_tokenizer(
|
| 111 |
+
text,
|
| 112 |
+
padding="max_length" if padding else False,
|
| 113 |
+
truncation=truncation,
|
| 114 |
+
max_length=self.max_length,
|
| 115 |
+
return_tensors=return_tensors
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"input_ids": enc["input_ids"], # torch.LongTensor
|
| 120 |
+
"attention_mask": enc["attention_mask"] # torch.LongTensor
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def vocab_size(self):
|
| 125 |
+
if self._hf_tokenizer is None:
|
| 126 |
+
raise RuntimeError("Tokenizer not loaded.")
|
| 127 |
+
return self._hf_tokenizer.vocab_size
|
| 128 |
+
|
| 129 |
+
def __repr__(self):
|
| 130 |
+
return f"<SegmentationTokenizer vocab_size={self.trainer.vocab_size}>"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 134 |
+
# SENTENCE SEG #
|
| 135 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 136 |
+
class SentenceSegmenter:
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
max_sentences: int,
|
| 140 |
+
spacy_model: str = "es_core_news_sm",
|
| 141 |
+
logger: logging.Logger | None = None
|
| 142 |
+
):
|
| 143 |
+
self.max_sentences = max_sentences
|
| 144 |
+
self.logger = self._get_logger(logger)
|
| 145 |
+
self.nlp = self.__build_model__(spacy_model, logger=self.logger)
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def __build_model__(sentence_tokenizer_model: str, logger: logging.Logger) -> spacy.language.Language:
|
| 149 |
+
"""
|
| 150 |
+
Download the pre-trained sentence tokenizer model.
|
| 151 |
+
:param sentence_tokenizer_model: The sentence tokenizer model to download.
|
| 152 |
+
:return: The spacy language model.
|
| 153 |
+
"""
|
| 154 |
+
try:
|
| 155 |
+
spacy_model = spacy.load(sentence_tokenizer_model)
|
| 156 |
+
except OSError:
|
| 157 |
+
result = subprocess.run(
|
| 158 |
+
[sys.executable, "-m", "spacy", "download", sentence_tokenizer_model],
|
| 159 |
+
capture_output=True,
|
| 160 |
+
text=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if result.returncode != 0:
|
| 164 |
+
logger.error(f'[BEAST-Tokenizer]: Loading {sentence_tokenizer_model} failed.')
|
| 165 |
+
raise RuntimeError(f"[BEAST-Tokenizer]: Error while downloading '{sentence_tokenizer_model}'")
|
| 166 |
+
|
| 167 |
+
spacy_model = spacy.load(sentence_tokenizer_model)
|
| 168 |
+
logger.info('[BEAST-Tokenizer]: Successfully downloaded the pre-trained sentence tokenizer model.')
|
| 169 |
+
|
| 170 |
+
if 'parser' not in spacy_model.pipe_names:
|
| 171 |
+
logger.error(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
|
| 172 |
+
raise RuntimeError(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
|
| 173 |
+
else:
|
| 174 |
+
spacy_model.add_pipe("newline_segmenter_keep_exact", before="parser")
|
| 175 |
+
|
| 176 |
+
return spacy_model
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def _get_logger(logger):
|
| 180 |
+
if logger is None:
|
| 181 |
+
logger = logging.getLogger(__name__)
|
| 182 |
+
logger.addHandler(logging.NullHandler())
|
| 183 |
+
return logger
|
| 184 |
+
|
| 185 |
+
def __call__(self, texts: list[str]) -> dict:
|
| 186 |
+
sentences = list()
|
| 187 |
+
sentence_candidates = list()
|
| 188 |
+
sentence_boundaries = list()
|
| 189 |
+
sentence_masking = list()
|
| 190 |
+
|
| 191 |
+
for article in texts:
|
| 192 |
+
doc = self.nlp(article)
|
| 193 |
+
for idx, sent in enumerate(doc.sents):
|
| 194 |
+
|
| 195 |
+
if idx == 0:
|
| 196 |
+
# Article opener
|
| 197 |
+
sentence_candidates.append(1)
|
| 198 |
+
sentence_boundaries.append(1)
|
| 199 |
+
elif sent.text.endswith("\n"):
|
| 200 |
+
# Paragraph break candidate
|
| 201 |
+
sentence_candidates.append(1)
|
| 202 |
+
sentence_boundaries.append(0)
|
| 203 |
+
else:
|
| 204 |
+
sentence_candidates.append(0)
|
| 205 |
+
sentence_boundaries.append(0)
|
| 206 |
+
|
| 207 |
+
sentences.append(sent.text.replace('\n', '').strip())
|
| 208 |
+
sentence_masking.append(1)
|
| 209 |
+
|
| 210 |
+
if len(sentences) >= self.max_sentences:
|
| 211 |
+
self.logger.warning(f"Maximum number of sentences reached: {self.max_sentences}")
|
| 212 |
+
break
|
| 213 |
+
|
| 214 |
+
if len(sentences) >= self.max_sentences:
|
| 215 |
+
break
|
| 216 |
+
|
| 217 |
+
# Pad with zeros:
|
| 218 |
+
while len(sentences) < self.max_sentences:
|
| 219 |
+
sentences.append("")
|
| 220 |
+
sentence_candidates.append(0)
|
| 221 |
+
sentence_boundaries.append(0)
|
| 222 |
+
sentence_masking.append(0)
|
| 223 |
+
|
| 224 |
+
return {
|
| 225 |
+
"sentences": sentences,
|
| 226 |
+
"sentence_candidates": np.array(sentence_candidates, dtype=np.int8),
|
| 227 |
+
"sentence_boundaries": np.array(sentence_boundaries, dtype=np.int8),
|
| 228 |
+
"sentence_mask": np.array(sentence_masking, dtype=np.int8)
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@spacy.Language.component("newline_segmenter_keep_exact")
|
| 233 |
+
def newline_segmenter_keep_exact(doc):
|
| 234 |
+
for token in doc[:-1]:
|
| 235 |
+
if token.text == "\n":
|
| 236 |
+
doc[token.i + 1].is_sent_start = True
|
| 237 |
+
return doc
|
| 238 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 239 |
+
# END OF FILE #
|
| 240 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
from .setup import Setup
|
| 8 |
+
from .steps import train_step, validation_step
|
| 9 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 10 |
+
# END OF FILE #
|
| 11 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from .full_setup import Setup
|
| 9 |
+
from .hooks import HookMonitor
|
| 10 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 11 |
+
# END OF FILE #
|
| 12 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/clear.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 14 |
+
def clear_logs(log_path: str):
|
| 15 |
+
"""
|
| 16 |
+
Clears all the files inside log_path path.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
log_path (str): The file path to be clean.
|
| 20 |
+
|
| 21 |
+
Raises:
|
| 22 |
+
ValueError: If the log_path is not valid.
|
| 23 |
+
"""
|
| 24 |
+
# Close all loggers:
|
| 25 |
+
logging.getLogger().handlers.clear()
|
| 26 |
+
if os.path.exists(log_path):
|
| 27 |
+
# Clear the directory if it exists
|
| 28 |
+
shutil.rmtree(log_path)
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f'Path {log_path} does not exist.')
|
| 31 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 32 |
+
# END OF FILE #
|
| 33 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/device.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_device(number: int, logger: logging.Logger = None):
|
| 13 |
+
"""
|
| 14 |
+
Configures PyTorch to use a specified GPU by its index number,
|
| 15 |
+
or falls back to CPU if CUDA is not available.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
number (int): The index number of the GPU to use.
|
| 19 |
+
logger (logging.Logger, optional): Logger for logging GPU info.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.device: The selected torch device (GPU or CPU).
|
| 23 |
+
"""
|
| 24 |
+
# Fallback to CPU if CUDA is not available
|
| 25 |
+
if not torch.cuda.is_available():
|
| 26 |
+
if logger:
|
| 27 |
+
logger.warning("CUDA is not available. Falling back to CPU.")
|
| 28 |
+
return torch.device('cpu')
|
| 29 |
+
|
| 30 |
+
# Check if the specified GPU number is valid
|
| 31 |
+
if number >= torch.cuda.device_count() or number < 0:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.")
|
| 34 |
+
|
| 35 |
+
# Clean up memory and stats
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
torch.cuda.reset_peak_memory_stats()
|
| 38 |
+
torch.cuda.reset_accumulated_memory_stats()
|
| 39 |
+
|
| 40 |
+
# Set and log device
|
| 41 |
+
torch.cuda.set_device(number)
|
| 42 |
+
if logger:
|
| 43 |
+
logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}")
|
| 44 |
+
|
| 45 |
+
device_name = torch.cuda.get_device_name(number)
|
| 46 |
+
total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2
|
| 47 |
+
mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2
|
| 48 |
+
mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2
|
| 49 |
+
max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2
|
| 50 |
+
max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2
|
| 51 |
+
|
| 52 |
+
logger.info(f"[GPU {number} - {device_name}] Memory Stats:")
|
| 53 |
+
logger.info(f" Total Memory : {total_mem:.2f} MB")
|
| 54 |
+
logger.info(f" Currently Allocated : {mem_allocated:.2f} MB")
|
| 55 |
+
logger.info(f" Currently Reserved : {mem_reserved:.2f} MB")
|
| 56 |
+
logger.info(f" Max Allocated : {max_allocated:.2f} MB")
|
| 57 |
+
logger.info(f" Max Reserved : {max_reserved:.2f} MB")
|
| 58 |
+
|
| 59 |
+
return torch.device(f'cuda:{number}')
|
| 60 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 61 |
+
# END OF FILE #
|
| 62 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/full_setup.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
import glob
|
| 12 |
+
import json
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from .logger import get_logger
|
| 15 |
+
from .tensorboard import get_writer
|
| 16 |
+
from .seeds import get_seed
|
| 17 |
+
from .device import get_device
|
| 18 |
+
from .clear import clear_logs
|
| 19 |
+
from .marker import register_replay, register
|
| 20 |
+
from .watchers import DEFAULT_WATCHER, S_WATCHER, A_WATCHER, B_WATCHER, C_WATCHER, CNN_WATCHER, AEN_WATCHER, TRA_WATCHER
|
| 21 |
+
from dataclasses import asdict
|
| 22 |
+
|
| 23 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 24 |
+
# #
|
| 25 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 26 |
+
class Setup:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
path: str,
|
| 30 |
+
device: int = 0,
|
| 31 |
+
seed: int = None,
|
| 32 |
+
save_each: int = 1,
|
| 33 |
+
reload_state: bool = False,
|
| 34 |
+
tensorboard: int | bool = 6006,
|
| 35 |
+
autoscaler: bool = True,
|
| 36 |
+
replay_element: tuple = (-1, None)
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
This class is used to set up the environment for an AI experiment. It saves
|
| 40 |
+
the model checkpoints, logs, and tensorboard files. It also sets the device
|
| 41 |
+
and seed for reproducibility.
|
| 42 |
+
|
| 43 |
+
Usage:
|
| 44 |
+
|
| 45 |
+
>>> from *** import Setup
|
| 46 |
+
>>> setup = Setup(path='logs', device=0, seed=42, save_each=10)
|
| 47 |
+
|
| 48 |
+
Inside the train loop:
|
| 49 |
+
|
| 50 |
+
>>> model: torch.Model
|
| 51 |
+
>>> loss_value: torch.Tensor
|
| 52 |
+
>>> y: torch.Tensor
|
| 53 |
+
>>> y_hat: torch.Tensor
|
| 54 |
+
|
| 55 |
+
>>> setup.check(model)
|
| 56 |
+
>>> setup.register('loss', loss_value)
|
| 57 |
+
>>> setup.register_replay(y, y_hat)
|
| 58 |
+
|
| 59 |
+
In case you want to reload latest checkpoint:
|
| 60 |
+
|
| 61 |
+
>>> setup.reload(model)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
:param path: The path to the logs.
|
| 65 |
+
:param device: The device to use.
|
| 66 |
+
:param seed: The seed to use.
|
| 67 |
+
:param save_each: The number of epochs to save the model.
|
| 68 |
+
:param reload_state: Whether to reload the latest checkpoint.
|
| 69 |
+
:param tensorboard: Whether to use tensorboard.
|
| 70 |
+
:param autoscaler: Whether to use autoscaler for training.
|
| 71 |
+
:param replay_element: The element to replay.
|
| 72 |
+
"""
|
| 73 |
+
# Clear logs:
|
| 74 |
+
self.path = path
|
| 75 |
+
self.save_each = save_each
|
| 76 |
+
self.tensorboard_required = tensorboard
|
| 77 |
+
self.replay_id = replay_element
|
| 78 |
+
self.__epoch_count = 0
|
| 79 |
+
|
| 80 |
+
if not reload_state:
|
| 81 |
+
self.clear(path)
|
| 82 |
+
|
| 83 |
+
self.logger = self.set_logger(path)
|
| 84 |
+
self.writer, self.ch_path = self.set_writer(path, tensorboard) if tensorboard else (None, os.path.join(path, 'checkpoints'))
|
| 85 |
+
self.seed = self.set_seed(seed)
|
| 86 |
+
self.device = self.set_device(device)
|
| 87 |
+
self.log_setup_info()
|
| 88 |
+
|
| 89 |
+
self.watcher = DEFAULT_WATCHER
|
| 90 |
+
self.autoscaler = torch.amp.GradScaler(enabled=self.device.type == 'cuda') if autoscaler else None
|
| 91 |
+
|
| 92 |
+
def log_setup_info(self):
|
| 93 |
+
"""
|
| 94 |
+
Log the setup information.
|
| 95 |
+
"""
|
| 96 |
+
self.logger.info("Setup information:")
|
| 97 |
+
self.logger.info(f"- Setup path: {self.path}")
|
| 98 |
+
self.logger.info(f"- Setup checkpoints path: {self.ch_path}")
|
| 99 |
+
self.logger.info(f"- Setup device: {self.device}")
|
| 100 |
+
self.logger.info(f"- Setup seed: {self.seed}")
|
| 101 |
+
self.logger.info(f"- Setup logger: {self.logger}")
|
| 102 |
+
self.logger.info(f"- Setup writer: {self.writer}")
|
| 103 |
+
self.logger.info(f"- Setup save each: {self.save_each}")
|
| 104 |
+
|
| 105 |
+
def check(
|
| 106 |
+
self,
|
| 107 |
+
model: torch.nn.Module,
|
| 108 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 109 |
+
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
|
| 110 |
+
) -> bool:
|
| 111 |
+
"""
|
| 112 |
+
Check the model and save it if the epoch count is a multiple of save_each.
|
| 113 |
+
:param model: The model to checkpoint and save.
|
| 114 |
+
:param optimizer: The optimizer to save.
|
| 115 |
+
:param learning_rate: The learning rate scheduler to save.
|
| 116 |
+
:return: If the model is checkpointed.
|
| 117 |
+
"""
|
| 118 |
+
self.__epoch_count += 1
|
| 119 |
+
if self.save_each is not None and self.__epoch_count % self.save_each == 0:
|
| 120 |
+
self.logger.info(f"Checkpointing model at epoch {self.__epoch_count}")
|
| 121 |
+
self.save_model(
|
| 122 |
+
model=model,
|
| 123 |
+
optimizer=optimizer,
|
| 124 |
+
learning_rate=learning_rate
|
| 125 |
+
)
|
| 126 |
+
self.logger.info(f"Model checkpointed at epoch {self.__epoch_count}")
|
| 127 |
+
return True
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
def save_model(
|
| 131 |
+
self,
|
| 132 |
+
model: torch.nn.Module,
|
| 133 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 134 |
+
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Saves the model.
|
| 138 |
+
:param model: The model to save.
|
| 139 |
+
:param optimizer: The optimizer to save.
|
| 140 |
+
:param learning_rate: The learning rate scheduler to save.
|
| 141 |
+
:return: Nothing.
|
| 142 |
+
"""
|
| 143 |
+
torch_state = {
|
| 144 |
+
'epoch': self.__epoch_count,
|
| 145 |
+
'model_state_dict': model.state_dict(),
|
| 146 |
+
'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
|
| 147 |
+
'scheduler_state_dict': learning_rate.state_dict() if learning_rate else None,
|
| 148 |
+
'seed': self.seed
|
| 149 |
+
}
|
| 150 |
+
torch.save(torch_state, self.ch_path + f'/model_epoch_{self.__epoch_count}.pt')
|
| 151 |
+
|
| 152 |
+
def reload(
|
| 153 |
+
self,
|
| 154 |
+
model: torch.nn.Module,
|
| 155 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 156 |
+
learning_rate: torch.optim.lr_scheduler.LRScheduler | None = None
|
| 157 |
+
) -> None:
|
| 158 |
+
"""
|
| 159 |
+
Reloads the latest checkpoint into the given model.
|
| 160 |
+
|
| 161 |
+
:param model: The PyTorch model to reload the state into.
|
| 162 |
+
:param optimizer: The optimizer to reload the state into.
|
| 163 |
+
:param learning_rate: The learning rate scheduler to reload the state into.
|
| 164 |
+
"""
|
| 165 |
+
# Find all matching checkpoints
|
| 166 |
+
checkpoints = glob.glob(os.path.join(self.ch_path, 'model_epoch_*.pt'))
|
| 167 |
+
if not checkpoints:
|
| 168 |
+
self.logger.warning("No checkpoint files found.")
|
| 169 |
+
else:
|
| 170 |
+
# Sort by modification time and get the latest
|
| 171 |
+
checkpoints.sort(key=os.path.getmtime)
|
| 172 |
+
latest_checkpoint = checkpoints[-1]
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
state_dict = torch.load(latest_checkpoint, map_location=self.device)
|
| 176 |
+
# Load model and info:
|
| 177 |
+
model.load_state_dict(state_dict['model_state_dict'])
|
| 178 |
+
model.to(self.device)
|
| 179 |
+
self.__epoch_count = state_dict['epoch']
|
| 180 |
+
self.seed = state_dict['seed']
|
| 181 |
+
self.logger.info(f"Model reloaded from {latest_checkpoint} at epoch {self.__epoch_count} and "
|
| 182 |
+
f"seed {self.seed}")
|
| 183 |
+
|
| 184 |
+
# Load optimizer and learning rate scheduler if provided
|
| 185 |
+
if optimizer and state_dict['optimizer_state_dict'] is not None:
|
| 186 |
+
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
| 187 |
+
self.logger.info(f"Optimizer state_dict loaded from {latest_checkpoint}")
|
| 188 |
+
if learning_rate and state_dict['scheduler_state_dict'] is not None:
|
| 189 |
+
learning_rate.load_state_dict(state_dict['scheduler_state_dict'])
|
| 190 |
+
self.logger.info(f"Scheduler state_dict loaded from {latest_checkpoint}")
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
self.logger.error(f"Failed to reload model from {latest_checkpoint}: {e}")
|
| 194 |
+
raise RuntimeError(f"Failed to reload model from {latest_checkpoint}: {e}")
|
| 195 |
+
|
| 196 |
+
def set_watcher(self, flag_names: str | list[tuple], deactivate: bool = False) -> None:
|
| 197 |
+
"""
|
| 198 |
+
Sets up the parameter watcher to the tensorboard.
|
| 199 |
+
:param flag_names: The names of the flags to watch as a tuple of strings.
|
| 200 |
+
:param deactivate: Whether to deactivate the watcher.
|
| 201 |
+
:return: Nothing
|
| 202 |
+
"""
|
| 203 |
+
if isinstance(flag_names, str):
|
| 204 |
+
if flag_names == 'S':
|
| 205 |
+
flag_names = S_WATCHER
|
| 206 |
+
elif flag_names == 'A':
|
| 207 |
+
flag_names = A_WATCHER + S_WATCHER
|
| 208 |
+
elif flag_names == 'B':
|
| 209 |
+
flag_names = S_WATCHER + A_WATCHER + B_WATCHER
|
| 210 |
+
elif flag_names == 'C':
|
| 211 |
+
flag_names = S_WATCHER + A_WATCHER + B_WATCHER + C_WATCHER
|
| 212 |
+
elif flag_names == 'cnn':
|
| 213 |
+
flag_names = CNN_WATCHER
|
| 214 |
+
elif flag_names == 'transformer':
|
| 215 |
+
flag_names = TRA_WATCHER
|
| 216 |
+
elif flag_names == 'ae':
|
| 217 |
+
flag_names = AEN_WATCHER
|
| 218 |
+
else:
|
| 219 |
+
self.logger.error(f"[WATCHER] Unknown flag name '{flag_names}'")
|
| 220 |
+
raise ValueError(f"[WATCHER] Unknown flag tier '{flag_names}'")
|
| 221 |
+
|
| 222 |
+
for top_name, low_name in flag_names:
|
| 223 |
+
if top_name not in self.watcher:
|
| 224 |
+
self.logger.error(f"Watcher {top_name} not found in watcher.")
|
| 225 |
+
raise ValueError(f"Watcher {top_name} not found in watcher.")
|
| 226 |
+
elif low_name not in self.watcher[top_name]:
|
| 227 |
+
self.logger.error(f"Watcher {low_name} not found in {top_name}.")
|
| 228 |
+
raise ValueError(f"Watcher {low_name} not found in {top_name}.")
|
| 229 |
+
else:
|
| 230 |
+
self.watcher[top_name][low_name] = not deactivate
|
| 231 |
+
|
| 232 |
+
def register_replay(self, predicted: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None) -> plt.Figure:
|
| 233 |
+
"""
|
| 234 |
+
Visualizes predicted vs. target outputs with an optional mask.
|
| 235 |
+
Only positions where mask == True are shown. Each cell displays its value with two decimal places.
|
| 236 |
+
|
| 237 |
+
:param predicted: Tensor of shape (S) or (S, Y) representing the model's output.
|
| 238 |
+
:param target: Tensor of same shape as predicted.
|
| 239 |
+
:param mask: Optional boolean tensor of same shape. False positions are ignored (valid mask).
|
| 240 |
+
"""
|
| 241 |
+
return register_replay(
|
| 242 |
+
predicted=predicted,
|
| 243 |
+
target=target,
|
| 244 |
+
valid_mask=mask,
|
| 245 |
+
element=self.replay_id[1],
|
| 246 |
+
epoch=self.__epoch_count,
|
| 247 |
+
writer=self.writer,
|
| 248 |
+
logger=self.logger,
|
| 249 |
+
tensorboard_required=self.tensorboard_required,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def register(self, name: str, parameter: float | torch.Tensor, mask: torch.Tensor = Ellipsis) -> None:
|
| 253 |
+
"""
|
| 254 |
+
Registers a named parameter into the tensorboard.
|
| 255 |
+
:param name: The name of the parameter.
|
| 256 |
+
:param parameter: The parameter to register.
|
| 257 |
+
:param mask: The optional boolean tensor of same shape as parameter.
|
| 258 |
+
:return: Nothing.
|
| 259 |
+
"""
|
| 260 |
+
if isinstance(parameter, torch.Tensor) and mask is Ellipsis:
|
| 261 |
+
mask = torch.ones_like(parameter).bool()
|
| 262 |
+
elif isinstance(parameter, float):
|
| 263 |
+
mask = Ellipsis
|
| 264 |
+
|
| 265 |
+
register(
|
| 266 |
+
flags=self.watcher,
|
| 267 |
+
tensor=parameter,
|
| 268 |
+
valid_mask=mask,
|
| 269 |
+
epoch=self.__epoch_count,
|
| 270 |
+
writer=self.writer,
|
| 271 |
+
logger=self.logger,
|
| 272 |
+
tensorboard_required=self.tensorboard_required,
|
| 273 |
+
parameter_name=name
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def save_config(self, configuration):
|
| 277 |
+
"""
|
| 278 |
+
Saves the configuration to a file.
|
| 279 |
+
:param configuration: A dataclasses configuration object.
|
| 280 |
+
:return: Nothing.
|
| 281 |
+
"""
|
| 282 |
+
config_path = os.path.join(self.path, "config.json")
|
| 283 |
+
with open(config_path, "w") as f:
|
| 284 |
+
json.dump(asdict(configuration), f, indent=4)
|
| 285 |
+
|
| 286 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 287 |
+
# #
|
| 288 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 289 |
+
@staticmethod
|
| 290 |
+
def clear(path: str) -> None:
|
| 291 |
+
"""
|
| 292 |
+
Clear the logs.
|
| 293 |
+
:param path: The path to the logs.
|
| 294 |
+
"""
|
| 295 |
+
clear_logs(path)
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def set_logger(path: str) -> logging.Logger:
|
| 299 |
+
"""
|
| 300 |
+
Set the logger.
|
| 301 |
+
:param path: The path to the logs.
|
| 302 |
+
:return: The logger.
|
| 303 |
+
"""
|
| 304 |
+
return get_logger(path)
|
| 305 |
+
|
| 306 |
+
def set_writer(self, path: str, tensorboard_port: int | bool) -> tuple:
|
| 307 |
+
"""
|
| 308 |
+
Get the writer.
|
| 309 |
+
:param path: The path to the logs.
|
| 310 |
+
:param tensorboard_port: The port to use for tensorboard.
|
| 311 |
+
:return: The writer.
|
| 312 |
+
"""
|
| 313 |
+
return get_writer(path, tensorboard_port, self.logger)
|
| 314 |
+
|
| 315 |
+
def set_device(self, device: int) -> torch.device:
|
| 316 |
+
"""
|
| 317 |
+
Get the device.
|
| 318 |
+
:param device: The device to use.
|
| 319 |
+
:return: The device.
|
| 320 |
+
"""
|
| 321 |
+
return get_device(device, self.logger)
|
| 322 |
+
|
| 323 |
+
def set_seed(self, seed: int) -> int:
|
| 324 |
+
"""
|
| 325 |
+
Get the seed.
|
| 326 |
+
:param seed: The seed to use.
|
| 327 |
+
:return: The seed.
|
| 328 |
+
"""
|
| 329 |
+
return get_seed(seed, self.logger)
|
| 330 |
+
|
| 331 |
+
@property
|
| 332 |
+
def epoch(self):
|
| 333 |
+
"""
|
| 334 |
+
Get the current epoch.
|
| 335 |
+
:return: The current epoch.
|
| 336 |
+
"""
|
| 337 |
+
return self.__epoch_count
|
| 338 |
+
|
| 339 |
+
def __enter__(self):
|
| 340 |
+
return self
|
| 341 |
+
|
| 342 |
+
def __exit__(self, *exc):
|
| 343 |
+
if self.writer:
|
| 344 |
+
self.writer.close()
|
| 345 |
+
|
| 346 |
+
# Do not kill Tensor boards - We usually want the process up to analyze the train variables:
|
| 347 |
+
# for proc in psutil.process_iter(['pid', 'name']):
|
| 348 |
+
# if 'tensorboard' in proc.info['name'].lower():
|
| 349 |
+
# proc.terminate()
|
| 350 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 351 |
+
# END OF FILE #
|
| 352 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/functions.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
EPS = 1e-12
|
| 10 |
+
|
| 11 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 12 |
+
# REGISTER #
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 14 |
+
def watch_max(
|
| 15 |
+
tensor: torch.Tensor,
|
| 16 |
+
mask: torch.Tensor,
|
| 17 |
+
grad: bool = False,
|
| 18 |
+
) -> float:
|
| 19 |
+
if grad:
|
| 20 |
+
return float(tensor.grad[mask].abs().max())
|
| 21 |
+
elif hasattr(tensor, 'data'):
|
| 22 |
+
return float(tensor.data[mask].abs().max())
|
| 23 |
+
else:
|
| 24 |
+
return float(tensor[mask].abs().max())
|
| 25 |
+
|
| 26 |
+
def watch_min(
|
| 27 |
+
tensor: torch.Tensor,
|
| 28 |
+
mask: torch.Tensor,
|
| 29 |
+
grad: bool = False,
|
| 30 |
+
) -> float:
|
| 31 |
+
if grad:
|
| 32 |
+
return float(tensor.grad[mask].abs().min())
|
| 33 |
+
elif hasattr(tensor, 'data'):
|
| 34 |
+
return float(tensor.data[mask].abs().min())
|
| 35 |
+
else:
|
| 36 |
+
return float(tensor[mask].abs().min())
|
| 37 |
+
|
| 38 |
+
def watch_mean(
|
| 39 |
+
tensor: torch.Tensor,
|
| 40 |
+
mask: torch.Tensor,
|
| 41 |
+
grad: bool = False,
|
| 42 |
+
) -> float:
|
| 43 |
+
if grad:
|
| 44 |
+
return float(tensor.grad[mask].mean())
|
| 45 |
+
elif hasattr(tensor, 'data'):
|
| 46 |
+
return float(tensor.data[mask].mean())
|
| 47 |
+
else:
|
| 48 |
+
return float(tensor[mask].mean())
|
| 49 |
+
|
| 50 |
+
def watch_var(
|
| 51 |
+
tensor: torch.Tensor,
|
| 52 |
+
mask: torch.Tensor,
|
| 53 |
+
grad: bool = False,
|
| 54 |
+
) -> float:
|
| 55 |
+
if grad:
|
| 56 |
+
return float(tensor.grad[mask].var())
|
| 57 |
+
elif hasattr(tensor, 'data'):
|
| 58 |
+
return float(tensor.data[mask].var())
|
| 59 |
+
else:
|
| 60 |
+
return float(tensor[mask].var())
|
| 61 |
+
|
| 62 |
+
def watch_std(
|
| 63 |
+
tensor: torch.Tensor,
|
| 64 |
+
mask: torch.Tensor,
|
| 65 |
+
grad: bool = False,
|
| 66 |
+
) -> float:
|
| 67 |
+
if grad:
|
| 68 |
+
return float(tensor.grad[mask].std())
|
| 69 |
+
elif hasattr(tensor, 'data'):
|
| 70 |
+
return float(tensor.data[mask].std())
|
| 71 |
+
else:
|
| 72 |
+
return float(tensor[mask].std())
|
| 73 |
+
|
| 74 |
+
def watch_sparsity(
|
| 75 |
+
tensor: torch.Tensor,
|
| 76 |
+
mask: torch.Tensor,
|
| 77 |
+
grad: bool = False,
|
| 78 |
+
sparsity_threshold: float = 1e-6,
|
| 79 |
+
) -> float:
|
| 80 |
+
if grad:
|
| 81 |
+
return float((tensor.grad[mask].abs() <= sparsity_threshold).float().mean())
|
| 82 |
+
elif hasattr(tensor, 'data'):
|
| 83 |
+
return float((tensor.data[mask].abs() <= sparsity_threshold).float().mean())
|
| 84 |
+
else:
|
| 85 |
+
return float((tensor[mask].abs() <= sparsity_threshold).float().mean())
|
| 86 |
+
|
| 87 |
+
def watch_l1(
|
| 88 |
+
tensor: torch.Tensor,
|
| 89 |
+
mask: torch.Tensor,
|
| 90 |
+
grad: bool = False,
|
| 91 |
+
) -> float:
|
| 92 |
+
if grad:
|
| 93 |
+
return float(tensor.grad[mask].norm(p=1))
|
| 94 |
+
elif hasattr(tensor, 'data'):
|
| 95 |
+
return float(tensor.data[mask].norm(p=1))
|
| 96 |
+
else:
|
| 97 |
+
return float(tensor[mask].norm(p=1))
|
| 98 |
+
|
| 99 |
+
def watch_l2(
|
| 100 |
+
tensor: torch.Tensor,
|
| 101 |
+
mask: torch.Tensor,
|
| 102 |
+
grad: bool = False,
|
| 103 |
+
) -> float:
|
| 104 |
+
if grad:
|
| 105 |
+
return float(tensor.grad[mask].norm(p=2))
|
| 106 |
+
elif hasattr(tensor, 'data'):
|
| 107 |
+
return float(tensor.data[mask].norm(p=2))
|
| 108 |
+
else:
|
| 109 |
+
return float(tensor[mask].norm(p=2))
|
| 110 |
+
|
| 111 |
+
def watch_snr(
|
| 112 |
+
tensor: torch.Tensor,
|
| 113 |
+
mask: torch.Tensor,
|
| 114 |
+
grad: bool = False,
|
| 115 |
+
) -> None | float:
|
| 116 |
+
std = watch_std(tensor, mask, grad=grad)
|
| 117 |
+
if std <= 0:
|
| 118 |
+
return None
|
| 119 |
+
elif grad:
|
| 120 |
+
val = float(torch.log10((tensor.grad[mask].mean()).abs() / (std + EPS)))
|
| 121 |
+
elif hasattr(tensor, 'data'):
|
| 122 |
+
val = float(torch.log10((tensor.data[mask].mean()).abs() / (std + EPS)))
|
| 123 |
+
else:
|
| 124 |
+
val = float(torch.log10((tensor[mask].mean()).abs() / (std + EPS)))
|
| 125 |
+
return 20 * val if val != float("-inf") else None # Check for NaN
|
| 126 |
+
|
| 127 |
+
def watch_hist(
|
| 128 |
+
tensor: torch.Tensor,
|
| 129 |
+
mask: torch.Tensor,
|
| 130 |
+
grad: bool = False,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
if grad:
|
| 133 |
+
return tensor.grad[mask]
|
| 134 |
+
elif hasattr(tensor, 'data'):
|
| 135 |
+
return tensor.data[mask]
|
| 136 |
+
else:
|
| 137 |
+
return tensor[mask]
|
| 138 |
+
|
| 139 |
+
def watch_rank(
|
| 140 |
+
tensor: torch.Tensor,
|
| 141 |
+
mask: torch.Tensor,
|
| 142 |
+
grad: bool = False,
|
| 143 |
+
threshold: float = 0.92,
|
| 144 |
+
) -> None | float | int:
|
| 145 |
+
if grad:
|
| 146 |
+
work_tensor = tensor.grad
|
| 147 |
+
elif hasattr(tensor, 'data'):
|
| 148 |
+
work_tensor = tensor.data
|
| 149 |
+
else:
|
| 150 |
+
work_tensor = tensor
|
| 151 |
+
work_tensor = torch.multiply(work_tensor, mask.float())
|
| 152 |
+
|
| 153 |
+
if work_tensor.ndim < 2:
|
| 154 |
+
return None
|
| 155 |
+
else:
|
| 156 |
+
# Compute SVD and sort it:
|
| 157 |
+
work_tensor = torch.linalg.svdvals(work_tensor)
|
| 158 |
+
work_tensor = torch.sort(work_tensor, descending=True).values
|
| 159 |
+
# Cumulative energy:
|
| 160 |
+
work_tensor = torch.cumsum(work_tensor**2, dim=0) / (torch.sum(work_tensor**2) + EPS)
|
| 161 |
+
# Effective rank:
|
| 162 |
+
return float(torch.sum(work_tensor < threshold).item() + 1)
|
| 163 |
+
|
| 164 |
+
def watch_any(
|
| 165 |
+
tensor: torch.Tensor,
|
| 166 |
+
mask: torch.Tensor,
|
| 167 |
+
grad: bool = False,
|
| 168 |
+
) -> float:
|
| 169 |
+
if grad:
|
| 170 |
+
return float(tensor.grad[mask])
|
| 171 |
+
elif hasattr(tensor, 'data'):
|
| 172 |
+
return float(tensor.data[mask])
|
| 173 |
+
else:
|
| 174 |
+
return float(tensor[mask])
|
| 175 |
+
|
| 176 |
+
def watch_power(
|
| 177 |
+
tensor: torch.Tensor,
|
| 178 |
+
mask: torch.Tensor,
|
| 179 |
+
grad: bool = False,
|
| 180 |
+
) -> float:
|
| 181 |
+
if grad:
|
| 182 |
+
return float(10 * torch.log10((tensor.grad[mask] ** 2).mean() + EPS))
|
| 183 |
+
elif hasattr(tensor, 'data'):
|
| 184 |
+
return float(10 * torch.log10((tensor.data[mask] ** 2).mean() + EPS))
|
| 185 |
+
else:
|
| 186 |
+
return float(10 * torch.log10((tensor[mask] ** 2).mean() + EPS))
|
| 187 |
+
|
| 188 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 189 |
+
# FUNC. MAP #
|
| 190 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 191 |
+
REG_FUNCTION_MAP = {
|
| 192 |
+
# Function mapping:
|
| 193 |
+
'max': watch_max,
|
| 194 |
+
'min': watch_min,
|
| 195 |
+
'mean': watch_mean,
|
| 196 |
+
'std': watch_std,
|
| 197 |
+
'var': watch_var,
|
| 198 |
+
'l2': watch_l2,
|
| 199 |
+
'l1': watch_l1,
|
| 200 |
+
'sparsity': watch_sparsity,
|
| 201 |
+
'snr': watch_snr,
|
| 202 |
+
'hist': watch_hist,
|
| 203 |
+
'rank': watch_rank,
|
| 204 |
+
'power': watch_power,
|
| 205 |
+
|
| 206 |
+
# Gradient mapping:
|
| 207 |
+
'grad_max': lambda x, y: watch_max(x, y, grad=True),
|
| 208 |
+
'grad_min': lambda x, y: watch_min(x, y, grad=True),
|
| 209 |
+
'grad_mean': lambda x, y: watch_mean(x, y, grad=True),
|
| 210 |
+
'grad_std': lambda x, y: watch_std(x, y, grad=True),
|
| 211 |
+
'grad_var': lambda x, y: watch_var(x, y, grad=True),
|
| 212 |
+
'grad_l1': lambda x, y: watch_l1(x, y, grad=True),
|
| 213 |
+
'grad_l2': lambda x, y: watch_l2(x, y, grad=True),
|
| 214 |
+
'grad_sparsity': lambda x, y: watch_sparsity(x, y, grad=True),
|
| 215 |
+
'grad_snr': lambda x, y: watch_snr(x, y, grad=True),
|
| 216 |
+
'grad_hist': lambda x, y: watch_hist(x, y, grad=True),
|
| 217 |
+
'grad_rank': lambda x, y: watch_rank(x, y, grad=True),
|
| 218 |
+
'grad_power': lambda x, y: watch_power(x, y, grad=True),
|
| 219 |
+
|
| 220 |
+
# Loss:
|
| 221 |
+
'loss': watch_any,
|
| 222 |
+
'val_loss': watch_any,
|
| 223 |
+
'lr': watch_any
|
| 224 |
+
}
|
| 225 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 226 |
+
# END OF FILE #
|
| 227 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/hooks.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
| 2 |
+
# START OF FILE #
|
| 3 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
| 4 |
+
import logging
|
| 5 |
+
import torch
|
| 6 |
+
from .functions import REG_FUNCTION_MAP
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
| 10 |
+
# #
|
| 11 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
| 12 |
+
class HookMonitor:
|
| 13 |
+
"""
|
| 14 |
+
Monitors forward activations and backward gradients of a PyTorch model by
|
| 15 |
+
registering hooks on all its submodules. The monitor computes per-layer
|
| 16 |
+
statistics defined in `REG_FUNCTION_MAP`, accumulating them during forward
|
| 17 |
+
and backward passes, and provides normalized results at the end.
|
| 18 |
+
|
| 19 |
+
This class is designed to be lightweight, safe (uses no_grad for activation
|
| 20 |
+
hooks), and usable as a context manager to automate attachment and cleanup
|
| 21 |
+
of hooks.
|
| 22 |
+
|
| 23 |
+
----------------------------------------
|
| 24 |
+
Core Behavior
|
| 25 |
+
----------------------------------------
|
| 26 |
+
- During the forward pass:
|
| 27 |
+
• A forward hook receives (module, input, output).
|
| 28 |
+
• The activation tensor is detached and cast to float.
|
| 29 |
+
• For each registered metric in REG_FUNCTION_MAP, if its watcher flag
|
| 30 |
+
is enabled, the metric is computed and accumulated.
|
| 31 |
+
• A gradient hook is registered on the output tensor so that gradient
|
| 32 |
+
statistics can also be collected during backpropagation.
|
| 33 |
+
|
| 34 |
+
- During backpropagation:
|
| 35 |
+
• The gradient hook receives the gradient tensor for the activation.
|
| 36 |
+
• Any metric marked as `grad_<metric>` in the watcher dictionary will be
|
| 37 |
+
applied to the gradient tensor and accumulated.
|
| 38 |
+
|
| 39 |
+
- Statistics:
|
| 40 |
+
• For each metric, the class tracks both the accumulated value and a
|
| 41 |
+
"/valid/" counter.
|
| 42 |
+
• `get_stats()` returns normalized statistics (sum / valid_count) for
|
| 43 |
+
each metric per layer.
|
| 44 |
+
|
| 45 |
+
----------------------------------------
|
| 46 |
+
Parameters
|
| 47 |
+
----------------------------------------
|
| 48 |
+
model : torch.nn.Module
|
| 49 |
+
The model whose modules will be monitored. All submodules returned by
|
| 50 |
+
`model.named_modules()` will receive a forward hook.
|
| 51 |
+
|
| 52 |
+
watcher : dict
|
| 53 |
+
A dictionary mapping metric names to boolean flags. Keys must match the
|
| 54 |
+
names used in `REG_FUNCTION_MAP`. Example:
|
| 55 |
+
{
|
| 56 |
+
"mean": True,
|
| 57 |
+
"std": True,
|
| 58 |
+
"grad_mean": True
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
Metrics not enabled here will not be computed.
|
| 62 |
+
|
| 63 |
+
logger : logging.Logger
|
| 64 |
+
A Logger used to report errors, debugging information, and warnings.
|
| 65 |
+
|
| 66 |
+
----------------------------------------
|
| 67 |
+
Attributes
|
| 68 |
+
----------------------------------------
|
| 69 |
+
stats : dict
|
| 70 |
+
Nested dictionary storing accumulated statistics per layer. Normalized
|
| 71 |
+
results are returned by `get_stats()`.
|
| 72 |
+
|
| 73 |
+
handles : list
|
| 74 |
+
A List of hook handles returned by `register_forward_hook`. These are
|
| 75 |
+
stored to later remove all hooks safely.
|
| 76 |
+
|
| 77 |
+
----------------------------------------
|
| 78 |
+
Usage Example
|
| 79 |
+
----------------------------------------
|
| 80 |
+
>>> model: torch.nn.Module
|
| 81 |
+
>>> watcher: dict[str, bool]
|
| 82 |
+
>>> logger: logging.Logger
|
| 83 |
+
>>> x: torch.Tensor
|
| 84 |
+
>>> loss: torch.nn.Module # Loss
|
| 85 |
+
|
| 86 |
+
>>> monitor = HookMonitor(model, watcher, logger)
|
| 87 |
+
>>> monitor.attach()
|
| 88 |
+
>>> output = model(x)
|
| 89 |
+
>>> loss.backward()
|
| 90 |
+
>>> stats = monitor.get_stats()
|
| 91 |
+
>>> monitor.remove()
|
| 92 |
+
|
| 93 |
+
Or using a context manager:
|
| 94 |
+
|
| 95 |
+
>>> with HookMonitor(model, watcher, logger) as monitor:
|
| 96 |
+
... output = model(x)
|
| 97 |
+
... loss.backward()
|
| 98 |
+
>>> stats = monitor.get_stats()
|
| 99 |
+
|
| 100 |
+
----------------------------------------
|
| 101 |
+
Notes
|
| 102 |
+
----------------------------------------
|
| 103 |
+
- The gradient hook is attached to the activation tensor (module output),
|
| 104 |
+
not to model parameters.
|
| 105 |
+
- No gradients are tracked during forward hooks thanks to @torch.no_grad().
|
| 106 |
+
- The monitor does not interfere with the training process: it only reads
|
| 107 |
+
activations and gradients.
|
| 108 |
+
- Missing '/valid/' counters trigger an error log and skip normalization for
|
| 109 |
+
that metric.
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, model: torch.nn.Module, watcher: dict, logger: logging.Logger):
|
| 113 |
+
"""
|
| 114 |
+
Initialize a HookMonitor instance to track activation and gradient
|
| 115 |
+
statistics across all modules of a PyTorch model.
|
| 116 |
+
|
| 117 |
+
This constructor does not attach any hooks yet; it simply stores the
|
| 118 |
+
monitoring configuration. Hooks are registered only when `attach()` or
|
| 119 |
+
the context manager (`with HookMonitor(...)`) is used.
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
model : torch.nn.Module
|
| 124 |
+
The model whose internal modules will be monitored. Every submodule
|
| 125 |
+
returned by `model.named_modules()` will receive a forward hook.
|
| 126 |
+
|
| 127 |
+
watcher : dict
|
| 128 |
+
Dictionary of boolean flags controlling which statistics should be
|
| 129 |
+
computed. Keys must match the names in `REG_FUNCTION_MAP`.
|
| 130 |
+
Example:
|
| 131 |
+
{
|
| 132 |
+
"mean": True,
|
| 133 |
+
"std": False,
|
| 134 |
+
"grad_mean": True
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
Any metric not enabled here will not be computed during execution.
|
| 138 |
+
|
| 139 |
+
logger : logging.Logger
|
| 140 |
+
Logging instance used for reporting errors, debug messages and
|
| 141 |
+
warnings during monitoring operations.
|
| 142 |
+
|
| 143 |
+
Attributes Initialized
|
| 144 |
+
----------------------
|
| 145 |
+
model : torch.nn.Module
|
| 146 |
+
Stored reference to the monitored model.
|
| 147 |
+
|
| 148 |
+
watcher : dict
|
| 149 |
+
The watcher configuration controlling metric activation.
|
| 150 |
+
|
| 151 |
+
stats : dict
|
| 152 |
+
Internal dictionary used to accumulate statistics across all layers.
|
| 153 |
+
|
| 154 |
+
handles : list
|
| 155 |
+
A List of hook handles created when calling `.attach()`. Each handle
|
| 156 |
+
is later used to safely remove hooks with `.remove()`.
|
| 157 |
+
|
| 158 |
+
Notes
|
| 159 |
+
-----
|
| 160 |
+
- No hooks are installed at construction time.
|
| 161 |
+
- The monitor becomes active only after calling `.attach()` or entering
|
| 162 |
+
a `with` block.
|
| 163 |
+
"""
|
| 164 |
+
self.logger: logging.Logger = logger
|
| 165 |
+
self.model: torch.nn.Module = model
|
| 166 |
+
self.watcher: dict = watcher
|
| 167 |
+
self.stats: dict = dict()
|
| 168 |
+
self.handles: list = list()
|
| 169 |
+
|
| 170 |
+
def _build_hook(self, name):
|
| 171 |
+
|
| 172 |
+
@torch.no_grad()
|
| 173 |
+
def hook(*args):
|
| 174 |
+
_, _, act = args
|
| 175 |
+
|
| 176 |
+
if torch.is_tensor(act):
|
| 177 |
+
act_detached = act.detach().float()
|
| 178 |
+
s = self.stats.setdefault(name, {})
|
| 179 |
+
|
| 180 |
+
# Call functions:
|
| 181 |
+
for function_name, compute_function in REG_FUNCTION_MAP.items():
|
| 182 |
+
if self.watcher.get(function_name, False) and not function_name.startswith('grad_'):
|
| 183 |
+
value = compute_function(act_detached, ...)
|
| 184 |
+
if value is not None:
|
| 185 |
+
s[function_name] = s.get(function_name, 0.0) + value
|
| 186 |
+
s[function_name + '/valid/'] = s.get(function_name + '/valid/', 0.0) + 1
|
| 187 |
+
|
| 188 |
+
# Grad hook:
|
| 189 |
+
def grad_hook(grad):
|
| 190 |
+
gd = grad.detach().float()
|
| 191 |
+
# Call functions:
|
| 192 |
+
for gd_function_name, gd_compute_function in REG_FUNCTION_MAP.items():
|
| 193 |
+
if self.watcher.get('grad_' + gd_function_name, False) and not gd_function_name.startswith('grad_'):
|
| 194 |
+
gd_function_name = 'grad_' + gd_function_name
|
| 195 |
+
gd_value = gd_compute_function(gd, ...)
|
| 196 |
+
if gd_value is not None:
|
| 197 |
+
s[gd_function_name] = s.get(gd_function_name, 0.0) + gd_value
|
| 198 |
+
s[gd_function_name + '/valid/'] = s.get(gd_function_name + '/valid/', 0.0) + 1
|
| 199 |
+
|
| 200 |
+
if act.requires_grad:
|
| 201 |
+
act.register_hook(grad_hook)
|
| 202 |
+
|
| 203 |
+
return hook
|
| 204 |
+
|
| 205 |
+
def get_stats(self) -> dict:
|
| 206 |
+
"""
|
| 207 |
+
Get the statistics of the hooks.
|
| 208 |
+
:return: A dictionary with the statistics.
|
| 209 |
+
"""
|
| 210 |
+
stats = dict()
|
| 211 |
+
for layer_name, layer_stats in self.stats.items():
|
| 212 |
+
sub_stats = dict()
|
| 213 |
+
for key, item in layer_stats.items():
|
| 214 |
+
if '/valid/' not in key:
|
| 215 |
+
if key + '/valid/' in layer_stats:
|
| 216 |
+
sub_stats[key] = item / layer_stats[key + '/valid/']
|
| 217 |
+
else:
|
| 218 |
+
self.logger.error(f"Key {key} has no valid count, skipping normalization.")
|
| 219 |
+
sub_stats[key] = item
|
| 220 |
+
stats[layer_name] = sub_stats
|
| 221 |
+
return stats
|
| 222 |
+
|
| 223 |
+
def attach(self):
|
| 224 |
+
"""
|
| 225 |
+
Registers all the hooks in the model.
|
| 226 |
+
:return: The object.
|
| 227 |
+
"""
|
| 228 |
+
for name, module in self.model.named_modules():
|
| 229 |
+
h = module.register_forward_hook(self._build_hook(name))
|
| 230 |
+
self.handles.append(h)
|
| 231 |
+
return self
|
| 232 |
+
|
| 233 |
+
def clear(self):
|
| 234 |
+
"""
|
| 235 |
+
Clear stats' dictionary.
|
| 236 |
+
:return: Nothing
|
| 237 |
+
"""
|
| 238 |
+
self.stats.clear()
|
| 239 |
+
|
| 240 |
+
def remove(self):
|
| 241 |
+
"""
|
| 242 |
+
Remove all the hooks from the model.
|
| 243 |
+
:return: Nothing.
|
| 244 |
+
"""
|
| 245 |
+
for h in self.handles:
|
| 246 |
+
h.remove()
|
| 247 |
+
self.handles.clear()
|
| 248 |
+
|
| 249 |
+
def __enter__(self):
|
| 250 |
+
self.logger.debug("[Hooks] Attaching HookMonitor...")
|
| 251 |
+
return self.attach()
|
| 252 |
+
|
| 253 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 254 |
+
self.logger.debug("[Hooks] Removing HookMonitor...")
|
| 255 |
+
self.remove()
|
| 256 |
+
|
| 257 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
| 258 |
+
# END OF FILE #
|
| 259 |
+
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
|
src/dlutils/setup/logger.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_logger(log_path: str, level: int | str = logging.INFO) -> logging.Logger:
|
| 13 |
+
"""
|
| 14 |
+
Sets up a logger for debugging with colored output to the console and output to a specified log file.
|
| 15 |
+
Creates the directory if it does not exist.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
log_path (str): The file path where the log file 'logfile.log' will be stored.
|
| 19 |
+
level (int | str): The logging level to be printed on the logger.
|
| 20 |
+
|
| 21 |
+
Raises:
|
| 22 |
+
ValueError: If the log_path is not valid.
|
| 23 |
+
"""
|
| 24 |
+
# Check if log_path exists, create it if not
|
| 25 |
+
if not os.path.exists(log_path):
|
| 26 |
+
os.makedirs(log_path, exist_ok=True)
|
| 27 |
+
elif not os.path.isdir(log_path):
|
| 28 |
+
raise ValueError(f"Provided path '{log_path}' is not a directory.")
|
| 29 |
+
|
| 30 |
+
full_log_path = os.path.join(log_path, 'logfile.log')
|
| 31 |
+
|
| 32 |
+
# Transform level:
|
| 33 |
+
if isinstance(level, str):
|
| 34 |
+
level = level.upper()
|
| 35 |
+
if hasattr(logging, level):
|
| 36 |
+
level = getattr(logging, level)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f'The provided level for the logger <<{level}>> is not a valid level for logging.')
|
| 39 |
+
elif not isinstance(level, int):
|
| 40 |
+
raise ValueError(f'The provided level for the logger <<{level}>> is not a string or int, '
|
| 41 |
+
f'the given type is <<{type(level)}>>.')
|
| 42 |
+
|
| 43 |
+
# Create a logger object
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
logger.handlers.clear() # Avoid duplicates
|
| 46 |
+
logger.setLevel(level) # Set the logging level to the given level
|
| 47 |
+
logger.propagate = False # Prevent duplication in logging output
|
| 48 |
+
|
| 49 |
+
# Create file handler which logs even debug messages
|
| 50 |
+
fh = logging.FileHandler(full_log_path)
|
| 51 |
+
fh.setLevel(level)
|
| 52 |
+
fh.setFormatter(logging.Formatter('%(asctime)s: [%(levelname)s] %(message)s'))
|
| 53 |
+
|
| 54 |
+
# Create console handler with a colored formatter
|
| 55 |
+
ch = logging.StreamHandler()
|
| 56 |
+
ch.setLevel(level)
|
| 57 |
+
ch.setFormatter(ColoredFormatter())
|
| 58 |
+
|
| 59 |
+
# Add handlers to the logger
|
| 60 |
+
logger.addHandler(fh)
|
| 61 |
+
logger.addHandler(ch)
|
| 62 |
+
|
| 63 |
+
logger.info(f'Logger initialized with writer handler at: {full_log_path}')
|
| 64 |
+
|
| 65 |
+
return logger
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ColoredFormatter(logging.Formatter):
|
| 69 |
+
grey = "\x1b[38;20m"
|
| 70 |
+
blue = "\x1b[34;20m"
|
| 71 |
+
cyan = "\x1b[36;20m"
|
| 72 |
+
orange = "\x1b[33;20m"
|
| 73 |
+
red = "\x1b[31;20m"
|
| 74 |
+
reset = "\x1b[0m"
|
| 75 |
+
format = '%(asctime)s: [%(levelname)s] %(message)s'
|
| 76 |
+
|
| 77 |
+
FORMATS = {
|
| 78 |
+
logging.DEBUG: blue + format + reset,
|
| 79 |
+
logging.INFO: cyan + format + reset,
|
| 80 |
+
logging.WARNING: orange + format + reset,
|
| 81 |
+
logging.ERROR: red + format + reset,
|
| 82 |
+
logging.CRITICAL: red + format + reset
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def format(self, record):
|
| 86 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
| 87 |
+
formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S")
|
| 88 |
+
return formatter.format(record)
|
| 89 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 90 |
+
# END OF FILE #
|
| 91 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/marker.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
import io
|
| 12 |
+
import math
|
| 13 |
+
import random
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from matplotlib import pyplot as plt
|
| 16 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from .functions import REG_FUNCTION_MAP
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 22 |
+
# REGISTER #
|
| 23 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def register(
|
| 26 |
+
flags: dict,
|
| 27 |
+
tensor: float | torch.Tensor,
|
| 28 |
+
valid_mask: torch.Tensor,
|
| 29 |
+
epoch: int,
|
| 30 |
+
writer: SummaryWriter,
|
| 31 |
+
logger: logging.Logger,
|
| 32 |
+
tensorboard_required: bool,
|
| 33 |
+
parameter_name: str = ''
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Registers a parameter according to the register flags (DEFAULT_WATCHER style).
|
| 37 |
+
|
| 38 |
+
:param flags: A specific watch flag.
|
| 39 |
+
:param tensor: The tensor to register.
|
| 40 |
+
:param valid_mask: The valid mask to apply.
|
| 41 |
+
:param epoch: The current epoch.
|
| 42 |
+
:param writer: The tensorboard writer.
|
| 43 |
+
:param logger: The logger.
|
| 44 |
+
:param tensorboard_required: Whether the tensorboard writer is required.
|
| 45 |
+
:param parameter_name: The name of the parameter.
|
| 46 |
+
:return:
|
| 47 |
+
"""
|
| 48 |
+
# 1. Detect tensor type:
|
| 49 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 50 |
+
flag_type = 'parameters'
|
| 51 |
+
elif isinstance(tensor, torch.Tensor):
|
| 52 |
+
# Intermediate activation:
|
| 53 |
+
flag_type = 'activations'
|
| 54 |
+
elif isinstance(tensor, float):
|
| 55 |
+
flag_type = 'train'
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"{type(tensor)} is not a torch.nn.Parameter or torch.Tensor.")
|
| 58 |
+
|
| 59 |
+
# 2. Build the tensor names:
|
| 60 |
+
safe_names = list()
|
| 61 |
+
# Check if the group is active:
|
| 62 |
+
if flag_type == 'parameters':
|
| 63 |
+
for flag_key, flag_value in flags['parameters'].items():
|
| 64 |
+
# Add if active:
|
| 65 |
+
if flag_value:
|
| 66 |
+
safe_names.append((f'{flag_type}/{flag_key}/{parameter_name}/', flag_key))
|
| 67 |
+
else:
|
| 68 |
+
safe_names.append((f'{flag_type}/{parameter_name}/', ''))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# 3. Write and compute each required variable:
|
| 72 |
+
for name, flag_key in safe_names:
|
| 73 |
+
# Compute the value:
|
| 74 |
+
transformation = None
|
| 75 |
+
if isinstance(tensor, torch.nn.Parameter):
|
| 76 |
+
if tensor.grad is not None and 'grad' in flag_key:
|
| 77 |
+
transformation = REG_FUNCTION_MAP[flag_key](tensor, valid_mask)
|
| 78 |
+
else:
|
| 79 |
+
transformation = float(tensor) if tensor is not None else None
|
| 80 |
+
# Write the value in tensorboard:
|
| 81 |
+
if transformation is not None:
|
| 82 |
+
write_tensorboard(
|
| 83 |
+
name=name,
|
| 84 |
+
value=transformation,
|
| 85 |
+
epoch=epoch,
|
| 86 |
+
writer=writer,
|
| 87 |
+
logger=logger,
|
| 88 |
+
tensorboard_required=tensorboard_required,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 92 |
+
# REPLAY #
|
| 93 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 94 |
+
@torch.no_grad()
|
| 95 |
+
def register_replay(
|
| 96 |
+
predicted: torch.Tensor,
|
| 97 |
+
target: torch.Tensor,
|
| 98 |
+
epoch: int,
|
| 99 |
+
writer: SummaryWriter,
|
| 100 |
+
logger: logging.Logger,
|
| 101 |
+
valid_mask: torch.Tensor = Ellipsis,
|
| 102 |
+
element: int = None,
|
| 103 |
+
tensorboard_required: bool = True,
|
| 104 |
+
) -> plt.Figure:
|
| 105 |
+
"""
|
| 106 |
+
Registers a replay as an image.
|
| 107 |
+
:param predicted: The predicted value (prediction).
|
| 108 |
+
:param target: The expected value (labels).
|
| 109 |
+
:param epoch: The current epoch.
|
| 110 |
+
:param writer: The tensorboard writer.
|
| 111 |
+
:param logger: The logger.
|
| 112 |
+
:param valid_mask: A valid mask tensor of same shape. False positions are ignored (valid mask).
|
| 113 |
+
:param element: The element to register, None chooses a random batch element.
|
| 114 |
+
:param tensorboard_required: Whether the tensorboard writer is required.
|
| 115 |
+
:return: A matplotlib figure.
|
| 116 |
+
"""
|
| 117 |
+
# Choose random element:
|
| 118 |
+
if element is None:
|
| 119 |
+
element = random.randint(0, len(predicted) - 1)
|
| 120 |
+
else:
|
| 121 |
+
element = min(len(predicted) - 1, max(0, element))
|
| 122 |
+
|
| 123 |
+
# Convert the chosen to numpy:
|
| 124 |
+
predicted_np = predicted[element].detach().cpu().numpy()
|
| 125 |
+
target_np = target[element].detach().cpu().numpy()
|
| 126 |
+
|
| 127 |
+
# Categorical to vector:
|
| 128 |
+
if not target_np.shape:
|
| 129 |
+
target_np_aux = np.zeros_like(predicted_np)
|
| 130 |
+
target_np_aux[target_np] = 1.
|
| 131 |
+
target_np = target_np_aux
|
| 132 |
+
del target_np_aux
|
| 133 |
+
|
| 134 |
+
# Mask the valid positions:
|
| 135 |
+
if valid_mask is not None:
|
| 136 |
+
mask_np = valid_mask[element].detach().cpu().numpy().astype(bool)
|
| 137 |
+
else:
|
| 138 |
+
mask_np = np.ones_like(predicted_np, dtype=bool)
|
| 139 |
+
|
| 140 |
+
# Apply mask and flatten:
|
| 141 |
+
predicted_flat = predicted_np[mask_np].flatten()
|
| 142 |
+
target_flat = target_np[mask_np].flatten()
|
| 143 |
+
|
| 144 |
+
# Compute square size B:
|
| 145 |
+
s = predicted_flat.shape[0]
|
| 146 |
+
b = math.ceil(math.sqrt(s))
|
| 147 |
+
total = b * b
|
| 148 |
+
pad = total - s
|
| 149 |
+
|
| 150 |
+
# Pad with zeros:
|
| 151 |
+
predicted_padded = np.pad(predicted_flat, (0, pad), constant_values=0.0).reshape(b, b)
|
| 152 |
+
target_padded = np.pad(target_flat, (0, pad), constant_values=0.0).reshape(b, b)
|
| 153 |
+
|
| 154 |
+
# Build figure:
|
| 155 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
|
| 156 |
+
plot_with_values(axs[0], predicted_padded, "Predicted (y_hat)")
|
| 157 |
+
plot_with_values(axs[1], target_padded, "Target (y)")
|
| 158 |
+
plt.tight_layout()
|
| 159 |
+
write_tensorboard(
|
| 160 |
+
'replay/',
|
| 161 |
+
fig,
|
| 162 |
+
epoch=epoch,
|
| 163 |
+
writer=writer,
|
| 164 |
+
logger=logger,
|
| 165 |
+
tensorboard_required=tensorboard_required,
|
| 166 |
+
)
|
| 167 |
+
return fig
|
| 168 |
+
|
| 169 |
+
def plot_with_values(ax, data, title):
|
| 170 |
+
"""
|
| 171 |
+
Plots data with values and title.
|
| 172 |
+
:param ax: A matplotlib axes.
|
| 173 |
+
:param data: A numpy array.
|
| 174 |
+
:param title: The title of the plot.
|
| 175 |
+
:return:
|
| 176 |
+
"""
|
| 177 |
+
ax.imshow(data, cmap='viridis', interpolation='nearest')
|
| 178 |
+
ax.set_title(title)
|
| 179 |
+
ax.axis('off')
|
| 180 |
+
for i in range(data.shape[0]):
|
| 181 |
+
for j in range(data.shape[1]):
|
| 182 |
+
text_color = "white" if data[i, j] < 0.5 else "black"
|
| 183 |
+
ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color=text_color, fontsize=8)
|
| 184 |
+
|
| 185 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 186 |
+
# WRITE ON BASE #
|
| 187 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 188 |
+
def write_tensorboard(
|
| 189 |
+
name: str,
|
| 190 |
+
value: int | float | plt.Figure | np.ndarray | torch.Tensor,
|
| 191 |
+
epoch: int,
|
| 192 |
+
writer: SummaryWriter,
|
| 193 |
+
logger: logging.Logger,
|
| 194 |
+
tensorboard_required: bool = True,
|
| 195 |
+
) -> None:
|
| 196 |
+
"""
|
| 197 |
+
Write to tensorboard.
|
| 198 |
+
:param name: The name of the tensorboard.
|
| 199 |
+
:param value: The value to write.
|
| 200 |
+
:param epoch: The current epoch.
|
| 201 |
+
:param writer: The tensorboard writer.
|
| 202 |
+
:param logger: The logger.
|
| 203 |
+
:param tensorboard_required: Whether the tensorboard writer is required.
|
| 204 |
+
"""
|
| 205 |
+
# Check if the writer is None
|
| 206 |
+
if writer is None:
|
| 207 |
+
if tensorboard_required:
|
| 208 |
+
logger.warning("Writer is None. Please set the writer first.")
|
| 209 |
+
return
|
| 210 |
+
# Check if the value is None
|
| 211 |
+
if value is None:
|
| 212 |
+
logger.warning("Value is None. Please set the value first.")
|
| 213 |
+
return
|
| 214 |
+
# Check if the name is None
|
| 215 |
+
if name is None:
|
| 216 |
+
logger.warning("Name is None. Please set the name first.")
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
# Type check:
|
| 220 |
+
if isinstance(value, int):
|
| 221 |
+
writer.add_scalar(name, float(value), epoch)
|
| 222 |
+
elif isinstance(value, float):
|
| 223 |
+
writer.add_scalar(name, value, epoch)
|
| 224 |
+
elif isinstance(value, torch.Tensor):
|
| 225 |
+
value = value.detach().cpu().numpy()
|
| 226 |
+
writer.add_histogram(name, value, epoch)
|
| 227 |
+
elif isinstance(value, list):
|
| 228 |
+
value = np.array(value)
|
| 229 |
+
writer.add_histogram(name, value, epoch)
|
| 230 |
+
elif isinstance(value, np.ndarray):
|
| 231 |
+
writer.add_histogram(name, value, epoch)
|
| 232 |
+
elif isinstance(value, str):
|
| 233 |
+
writer.add_text(name, value, epoch)
|
| 234 |
+
elif isinstance(value, bytes):
|
| 235 |
+
image = Image.open(io.BytesIO(value))
|
| 236 |
+
transform = transforms.ToTensor()
|
| 237 |
+
value = transform(image)
|
| 238 |
+
writer.add_image(name, value, epoch)
|
| 239 |
+
elif isinstance(value, plt.Figure):
|
| 240 |
+
buf = io.BytesIO()
|
| 241 |
+
value.savefig(buf, format='png')
|
| 242 |
+
buf.seek(0)
|
| 243 |
+
image = Image.open(buf)
|
| 244 |
+
image = transforms.ToTensor()(image)
|
| 245 |
+
writer.add_image(name, image, epoch)
|
| 246 |
+
plt.close()
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError(f"Type {type(value)} not supported.")
|
| 249 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 250 |
+
# END OF FILE #
|
| 251 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/seeds.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import random
|
| 13 |
+
import time
|
| 14 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 15 |
+
CUBLAS_ALLOCATION = 4096
|
| 16 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_seed(seed: int = None, logger: logging.Logger = None) -> int:
|
| 20 |
+
"""
|
| 21 |
+
Sets the seed for generating random numbers to ensure reproducibility across numpy, random, and PyTorch operations.
|
| 22 |
+
If no seed is provided, a new seed is generated based on the current time.
|
| 23 |
+
|
| 24 |
+
This function also configures PyTorch to ensure deterministic behavior when running on a GPU, including the setting
|
| 25 |
+
of environment variables to influence the behavior of CUDA's cuBLAS library.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
seed (int, optional): The seed for the random number generators. If None, the seed will be generated based on
|
| 29 |
+
the current system time.
|
| 30 |
+
logger (logging.Logger): The logger that traces the logging information.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
int: The seed used to initialize the random number generators.
|
| 34 |
+
|
| 35 |
+
Example:
|
| 36 |
+
>>> experiment_seed = get_seed()
|
| 37 |
+
Sets a random seed based on the current time and ensures that all subsequent random operations are reproducible.
|
| 38 |
+
|
| 39 |
+
>>> experiment_seed = get_seed(42)
|
| 40 |
+
>>> # experiment_seed == 42
|
| 41 |
+
Uses 42 as the seed for all random number generators to ensure reproducibility.
|
| 42 |
+
"""
|
| 43 |
+
# Set environment variable for deterministic behavior on CUDA >= 10.2
|
| 44 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = f":{CUBLAS_ALLOCATION}:8"
|
| 45 |
+
|
| 46 |
+
# Create a new seed if not provided:
|
| 47 |
+
seed = seed if seed is not None else int(time.time())
|
| 48 |
+
|
| 49 |
+
# Set seed for numpy and random
|
| 50 |
+
np.random.seed(seed)
|
| 51 |
+
random.seed(seed)
|
| 52 |
+
|
| 53 |
+
# Set seed and deterministic algorithms for torch
|
| 54 |
+
torch.manual_seed(seed)
|
| 55 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 56 |
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
| 57 |
+
|
| 58 |
+
# Ensure all operations are deterministic on GPU (if available)
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
torch.cuda.manual_seed(seed)
|
| 61 |
+
torch.cuda.manual_seed_all(seed)
|
| 62 |
+
torch.backends.cudnn.deterministic = True
|
| 63 |
+
torch.backends.cudnn.benchmark = False
|
| 64 |
+
|
| 65 |
+
# Return the generated or bypassed seed:
|
| 66 |
+
if logger is not None:
|
| 67 |
+
logger.info(f"Initializer set up seed: {seed}")
|
| 68 |
+
return seed
|
| 69 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 70 |
+
# END OF FILE #
|
| 71 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/tensorboard.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import psutil
|
| 11 |
+
import time
|
| 12 |
+
import subprocess
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 15 |
+
DEFAULT_TENSORBOARD_PORT = 6006
|
| 16 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_writer(path: str, tensorboard_port: int | bool, logger: logging.Logger = None):
|
| 20 |
+
"""
|
| 21 |
+
Sets up a TensorBoard logging and checkpoint directory for PyTorch.
|
| 22 |
+
|
| 23 |
+
This function clears the specified directory, creates subdirectories for TensorBoard logs
|
| 24 |
+
and model checkpoints, ensuring a clean environment for running new training sessions.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
path (str): The root directory where TensorBoard logs and checkpoints will be stored.
|
| 28 |
+
tensorboard_port (int): The port on which to run the TensorBoard.
|
| 29 |
+
logger (logging.Logger): The logger that traces the logging information.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
tuple: A tuple containing the TensorBoard SummaryWriter object and the path for checkpoints.
|
| 33 |
+
|
| 34 |
+
Example:
|
| 35 |
+
>>> tensor_writer, checkpoint_dir = get_writer('/path/to/tensorboard/')
|
| 36 |
+
"""
|
| 37 |
+
# Check tensorboard port:
|
| 38 |
+
if tensorboard_port is True:
|
| 39 |
+
tensorboard_port = DEFAULT_TENSORBOARD_PORT
|
| 40 |
+
elif tensorboard_port is False:
|
| 41 |
+
return None, os.path.join(path, 'checkpoints')
|
| 42 |
+
|
| 43 |
+
# Create subdirectories for logs and checkpoints
|
| 44 |
+
logs_path = os.path.join(path, 'logs')
|
| 45 |
+
checkpoints_path = os.path.join(path, 'checkpoints')
|
| 46 |
+
os.makedirs(logs_path, exist_ok=True)
|
| 47 |
+
os.makedirs(checkpoints_path, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Set up TensorBoard logging
|
| 50 |
+
writer = SummaryWriter(log_dir=logs_path)
|
| 51 |
+
|
| 52 |
+
# Print paths where logs and checkpoints will be stored
|
| 53 |
+
if logger is not None:
|
| 54 |
+
logger.info(f"TensorBoard logs will be stored in: {logs_path}")
|
| 55 |
+
logger.info(f"Model checkpoints will be stored in: {checkpoints_path}")
|
| 56 |
+
|
| 57 |
+
# Launch tensorboard:
|
| 58 |
+
for conn in psutil.net_connections(kind='inet'):
|
| 59 |
+
if conn.laddr.port == tensorboard_port and conn.status == psutil.CONN_LISTEN:
|
| 60 |
+
if logger is not None:
|
| 61 |
+
logger.warning(f"Killing already running TensorBoard process with PID {conn.pid}")
|
| 62 |
+
p = psutil.Process(conn.pid)
|
| 63 |
+
p.terminate()
|
| 64 |
+
p.wait(timeout=3)
|
| 65 |
+
time.sleep(5)
|
| 66 |
+
process = subprocess.Popen(f'tensorboard --logdir={logs_path} --host=0.0.0.0 --port={tensorboard_port}',
|
| 67 |
+
shell=True)
|
| 68 |
+
if logger is not None:
|
| 69 |
+
logger.info(f'TensorBoard running at http://0.0.0.0:{tensorboard_port}/ (pid={process.pid})')
|
| 70 |
+
|
| 71 |
+
return writer, checkpoints_path
|
| 72 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 73 |
+
# END OF FILE #
|
| 74 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dlutils/setup/watchers.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 2 |
+
# DEFAULT WATCH #
|
| 3 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 4 |
+
DEFAULT_WATCHER = {
|
| 5 |
+
'train': {
|
| 6 |
+
'loss': True,
|
| 7 |
+
'lr': False,
|
| 8 |
+
'val_loss': True
|
| 9 |
+
},
|
| 10 |
+
'parameters': {
|
| 11 |
+
'max': False,
|
| 12 |
+
'min': False,
|
| 13 |
+
'mean': False,
|
| 14 |
+
'std': False,
|
| 15 |
+
'var': False,
|
| 16 |
+
'hist': False,
|
| 17 |
+
'l2': False,
|
| 18 |
+
'l1': False,
|
| 19 |
+
'sparsity': False,
|
| 20 |
+
'snr': False,
|
| 21 |
+
'rank': False,
|
| 22 |
+
'power': False,
|
| 23 |
+
|
| 24 |
+
# Gradients:
|
| 25 |
+
'grad_max': False,
|
| 26 |
+
'grad_min': False,
|
| 27 |
+
'grad_mean': False,
|
| 28 |
+
'grad_std': False,
|
| 29 |
+
'grad_var': False,
|
| 30 |
+
'grad_hist': False,
|
| 31 |
+
'grad_l2': False,
|
| 32 |
+
'grad_l1': False,
|
| 33 |
+
'grad_sparsity': False,
|
| 34 |
+
'grad_snr': False,
|
| 35 |
+
'grad_rank': False,
|
| 36 |
+
'grad_power': False
|
| 37 |
+
},
|
| 38 |
+
'activations': {
|
| 39 |
+
'max': False,
|
| 40 |
+
'min': False,
|
| 41 |
+
'mean': False,
|
| 42 |
+
'std': False,
|
| 43 |
+
'var': False,
|
| 44 |
+
'hist': False,
|
| 45 |
+
'l2': False,
|
| 46 |
+
'l1': False,
|
| 47 |
+
'sparsity': False,
|
| 48 |
+
'snr': False,
|
| 49 |
+
'rank': False,
|
| 50 |
+
'power': False,
|
| 51 |
+
|
| 52 |
+
# Gradients:
|
| 53 |
+
'grad_max': False,
|
| 54 |
+
'grad_min': False,
|
| 55 |
+
'grad_mean': False,
|
| 56 |
+
'grad_std': False,
|
| 57 |
+
'grad_var': False,
|
| 58 |
+
'grad_hist': False,
|
| 59 |
+
'grad_l2': False,
|
| 60 |
+
'grad_l1': False,
|
| 61 |
+
'grad_sparsity': False,
|
| 62 |
+
'grad_snr': False,
|
| 63 |
+
'grad_rank': False,
|
| 64 |
+
'grad_power': False
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
| 68 |
+
# SPECIFIC WATCH #
|
| 69 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 70 |
+
# [PA] Performance analysis.
|
| 71 |
+
# [GF] Gradient flow.
|
| 72 |
+
# [AD] Activation death.
|
| 73 |
+
# [NT] Network topology.
|
| 74 |
+
|
| 75 |
+
S_WATCHER = [
|
| 76 |
+
('train', 'loss'), # [TOP] [PA] Evolución del entrenamiento.
|
| 77 |
+
('train', 'val_loss'), # [TOP] [PA] Generalización / overfitting.
|
| 78 |
+
('parameters', 'grad_power'), # [TOP] [GF] Flujo de gradiente, explosión/vanishing global.
|
| 79 |
+
('parameters', 'grad_mean'), # [TOP] [NT] Capas muertas / inútiles (mean grad ~ 0).
|
| 80 |
+
('parameters', 'grad_max'), # [TOP] [GF] Picos de grad -> clipping / LR.
|
| 81 |
+
('activations', 'grad_power'), # [TOP] [GF] Flujo de grad por capa (muy informativa).
|
| 82 |
+
('activations', 'sparsity'), # [TOP] [AD] ReLU death / atención colapsada.
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
A_WATCHER = [
|
| 86 |
+
('train', 'lr'), # [USEFUL] [PA] Seguir el scheduler / warmup.
|
| 87 |
+
('parameters', 'l2'), # [USEFUL] [PA] Norm de pesos, regularización / weight decay.
|
| 88 |
+
('parameters', 'power'), # [USEFUL] [PA] Escala de pesos / posibles explosiones.
|
| 89 |
+
('parameters', 'grad_snr'), # [USEFUL] [GF] Coherencia señal/ruido del grad.
|
| 90 |
+
('parameters', 'rank'), # [USEFUL] [NT] Capacidad efectiva / colapso de parámetros.
|
| 91 |
+
('activations', 'mean'), # [USEFUL] [NT] Shift de activaciones / mala init.
|
| 92 |
+
('activations', 'std'), # [USEFUL] [NT] Propagación de señal entre capas.
|
| 93 |
+
('activations', 'snr'), # [USEFUL] [NT] Coherencia de señal entre capas.
|
| 94 |
+
('activations', 'grad_snr'), # [USEFUL] [GF] Coherencia del grad por capa.
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
B_WATCHER = [
|
| 98 |
+
('activations', 'hist'), # [UTILITY] [AD] Visualizar colas raras / saturaciones.
|
| 99 |
+
('parameters', 'snr'), # [UTILITY] [NT] Coherencia global de pesos (rank suele ser mejor).
|
| 100 |
+
('parameters', 'grad_l2'), # [UTILITY] [GF] Similar a grad_power pero menos intuitiva.
|
| 101 |
+
('parameters', 'hist'), # [UTILITY] [PA] Ver distribución de pesos (debug puntual).
|
| 102 |
+
('activations', 'l2'), # [UTILITY] [NT] Magnitud de activaciones (redundante con std/power).
|
| 103 |
+
('activations', 'l1'), # [UTILITY] [NT] Similar a l2; a veces útil en AEs sparsos.
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
C_WATCHER = [
|
| 107 |
+
('parameters', 'max'), # [LOW] [PA] Útil sólo para detectar NaNs / inf puntuales.
|
| 108 |
+
('parameters', 'min'), # [LOW] [PA] Igual que max, poco signal.
|
| 109 |
+
('parameters', 'mean'), # [LOW] [PA] Poco interpretable sin más contexto.
|
| 110 |
+
('parameters', 'std'), # [LOW] [PA] Redundante con power / l2.
|
| 111 |
+
('parameters', 'var'), # [LOW] [PA] Redundante con std.
|
| 112 |
+
('parameters', 'grad_var'), # [LOW] [GF] Redundante con grad_std.
|
| 113 |
+
('parameters', 'grad_hist'), # [LOW] [GF] Visualización puntual, no para logging continuo.
|
| 114 |
+
('activations', 'min'), # [LOW] [NT] Rara vez dice algo que std/mean no digan.
|
| 115 |
+
('activations', 'max'), # [LOW] [NT] Sólo útil para comprobar clamps/NaNs.
|
| 116 |
+
('activations', 'var'), # [LOW] [NT] Redundante con std.
|
| 117 |
+
('activations', 'grad_hist'), # [LOW] [GF] Igual que grad_hist de parámetros, solo visual.
|
| 118 |
+
('activations', 'grad_var'), # [LOW] [GF] Redundante con grad_std/grad_power.
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
CNN_WATCHER = [
|
| 122 |
+
('train', 'loss'), # [TOP] [PA] Fit de entrenamiento.
|
| 123 |
+
('train', 'val_loss'), # [TOP] [PA] Generalización (Imagenette/ImageNet).
|
| 124 |
+
('parameters', 'grad_power'), # [TOP] [GF] Explosión/vanishing global del grad.
|
| 125 |
+
('parameters', 'grad_max'), # [TOP] [GF] Picos por capa -> clipping.
|
| 126 |
+
('activations', 'grad_power'), # [TOP] [GF] Grad por bloque conv / head.
|
| 127 |
+
('activations', 'sparsity'), # [TOP] [AD] Dead ReLU / capas muertas.
|
| 128 |
+
('activations', 'std'), # [USEFUL] [NT] Propagación de señal (init, BN).
|
| 129 |
+
('parameters', 'l2'), # [USEFUL] [PA] Control de norm de pesos / decay.
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
TRA_WATCHER = [
|
| 133 |
+
('train', 'loss'), # [TOP] [PA] Fit del modelo (LM / seq2seq / cls).
|
| 134 |
+
('train', 'val_loss'), # [TOP] [PA] Generalización / overfitting.
|
| 135 |
+
('train', 'lr'), # [USEFUL] [PA] Warmup, cosine, etc.
|
| 136 |
+
('parameters', 'grad_power'), # [TOP] [GF] Explosión/vanishing en profundidad.
|
| 137 |
+
('parameters', 'grad_snr'), # [USEFUL] [GF] SNR de grad en bloques de atención/MLP.
|
| 138 |
+
('activations', 'grad_power'), # [TOP] [GF] Flujo de grad por layer encoder/decoder.
|
| 139 |
+
('activations', 'mean'), # [USEFUL] [NT] Drift en LayerNorm / RMSNorm.
|
| 140 |
+
('activations', 'std'), # [USEFUL] [NT] Propagación en profundidad (residuals).
|
| 141 |
+
('parameters', 'l2'), # [USEFUL] [PA] Tamaño de pesos en attention/MLP.
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
AEN_WATCHER = [
|
| 145 |
+
('train', 'loss'), # [TOP] [PA] Reconstr / contrastive / VAE loss.
|
| 146 |
+
('train', 'val_loss'), # [TOP] [PA] Generalización del AE.
|
| 147 |
+
('parameters', 'grad_power'), # [TOP] [GF] Flujo de grad encoder/decoder.
|
| 148 |
+
('activations', 'sparsity'), # [TOP] [AD] Codificadores sparsos / muerte de neuronas.
|
| 149 |
+
('activations', 'rank'), # [USEFUL] [NT] Colapso de representación / baja dimensión efectiva.
|
| 150 |
+
('parameters', 'power'), # [USEFUL] [PA] Pesos del decoder explotando o colapsando.
|
| 151 |
+
('activations', 'grad_power'), # [TOP] [GF] Grad por capa en encoder/decoder.
|
| 152 |
+
('parameters', 'l2'), # [USEFUL] [PA] Norm de pesos, sobretodo en AEs profundos.
|
| 153 |
+
]
|
src/dlutils/steps.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import tqdm
|
| 11 |
+
from .setup import Setup, HookMonitor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 15 |
+
# #
|
| 16 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 17 |
+
def train_step(
|
| 18 |
+
# Always granted:
|
| 19 |
+
model: torch.nn.Module,
|
| 20 |
+
data: torch.utils.data.DataLoader,
|
| 21 |
+
loss: torch.nn.Module,
|
| 22 |
+
optimizer: torch.optim.Optimizer,
|
| 23 |
+
controller: Setup,
|
| 24 |
+
# Not always granted:
|
| 25 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
|
| 26 |
+
) -> float:
|
| 27 |
+
"""
|
| 28 |
+
Performs a single training step including forward pass, loss calculation, backward pass,
|
| 29 |
+
and optimization step.
|
| 30 |
+
|
| 31 |
+
Parameters:
|
| 32 |
+
model (torch.nn.Module): The model to be trained.
|
| 33 |
+
data (torch.utils.data.DataLoader): DataLoader providing the training data.
|
| 34 |
+
loss (torch.nn.Module): Loss function to be used.
|
| 35 |
+
optimizer (torch.optim.Optimizer): Optimizer used for gradient updates.
|
| 36 |
+
controller (Setup): The setup object containing configuration and state.
|
| 37 |
+
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler to adjust the learning rate.
|
| 38 |
+
Returns:
|
| 39 |
+
float: The mean loss value for this training step.
|
| 40 |
+
"""
|
| 41 |
+
# Train mode:
|
| 42 |
+
model.to(controller.device)
|
| 43 |
+
model.train()
|
| 44 |
+
|
| 45 |
+
# Train the model for dataloaders or iterators:
|
| 46 |
+
losses = list()
|
| 47 |
+
|
| 48 |
+
with HookMonitor(model, controller.watcher['activations'], controller.logger) as hooks:
|
| 49 |
+
with tqdm.tqdm(data, desc=f'\rTraining epoch {controller.epoch}', leave=True) as pbar:
|
| 50 |
+
pbar: torch.DataLoader
|
| 51 |
+
hooks: HookMonitor
|
| 52 |
+
|
| 53 |
+
for i, element in enumerate(pbar):
|
| 54 |
+
|
| 55 |
+
# 1. Gather elements:
|
| 56 |
+
args = tuple()
|
| 57 |
+
if len(element) == 2:
|
| 58 |
+
# Prediction:
|
| 59 |
+
x, y = element
|
| 60 |
+
x_m, y_m = None, None
|
| 61 |
+
elif len(element) == 3:
|
| 62 |
+
# Prediction with x_mask:
|
| 63 |
+
x, y, x_m = element
|
| 64 |
+
y_m = None
|
| 65 |
+
elif len(element) == 4:
|
| 66 |
+
# Prediction with x_mask and y_mask:
|
| 67 |
+
x, y, x_m, y_m = element
|
| 68 |
+
elif len(element) > 4:
|
| 69 |
+
# More input arguments:
|
| 70 |
+
x, y = element[0], element[1]
|
| 71 |
+
x_m, y_m = element[2], element[3]
|
| 72 |
+
args = element[4:]
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("DataLoader elements must have at least two elements.")
|
| 75 |
+
|
| 76 |
+
# 2. Load data to device:
|
| 77 |
+
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
|
| 78 |
+
optimizer.zero_grad()
|
| 79 |
+
if x_m is not None:
|
| 80 |
+
x_m = x_m.to(controller.device, non_blocking=True)
|
| 81 |
+
if y_m is not None:
|
| 82 |
+
y_m = y_m.to(controller.device, non_blocking=True)
|
| 83 |
+
|
| 84 |
+
# 3. TRAIN - Control autocast (mem-speed):
|
| 85 |
+
if controller.autoscaler is not None:
|
| 86 |
+
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type):
|
| 87 |
+
# Forward:
|
| 88 |
+
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
|
| 89 |
+
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
|
| 90 |
+
# Backward:
|
| 91 |
+
controller.autoscaler.scale(loss_metric).backward()
|
| 92 |
+
controller.autoscaler.step(optimizer)
|
| 93 |
+
controller.autoscaler.update()
|
| 94 |
+
else:
|
| 95 |
+
# Forward:
|
| 96 |
+
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
|
| 97 |
+
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
|
| 98 |
+
# Backward:
|
| 99 |
+
loss_metric.backward()
|
| 100 |
+
optimizer.step()
|
| 101 |
+
|
| 102 |
+
# 4. Append to metrics:
|
| 103 |
+
losses.append(loss_metric.item())
|
| 104 |
+
|
| 105 |
+
# 5. Monitor hooks:
|
| 106 |
+
if controller.replay_id[0] == i:
|
| 107 |
+
controller.register_replay(predicted=y_hat, target=y, mask=y_m)
|
| 108 |
+
|
| 109 |
+
# Write in summary writer (per epoch):
|
| 110 |
+
losses = np.array(losses)
|
| 111 |
+
mean_loss = float(np.mean(losses))
|
| 112 |
+
|
| 113 |
+
# ================ WATCH ================
|
| 114 |
+
# Register parameters:
|
| 115 |
+
for name, parameter in model.named_parameters():
|
| 116 |
+
controller.register(name, parameter)
|
| 117 |
+
|
| 118 |
+
# Register train:
|
| 119 |
+
controller.register('loss', mean_loss)
|
| 120 |
+
|
| 121 |
+
# Register hooks:
|
| 122 |
+
for layer_name, layer_stats in hooks.get_stats().items():
|
| 123 |
+
for func_name, item in layer_stats.items():
|
| 124 |
+
controller.register(f'{func_name}/{layer_name}', torch.Tensor([item])[0])
|
| 125 |
+
|
| 126 |
+
# ================ CONTROL ================
|
| 127 |
+
# Scheduler step:
|
| 128 |
+
if scheduler is not None:
|
| 129 |
+
controller.register('lr', scheduler.get_last_lr()[0])
|
| 130 |
+
scheduler.step()
|
| 131 |
+
|
| 132 |
+
# Write for logger:
|
| 133 |
+
controller.logger.info(f"Epoch [{controller.epoch}]: loss = {mean_loss:.8f}")
|
| 134 |
+
|
| 135 |
+
# Checkpointing:
|
| 136 |
+
controller.check(model, optimizer, scheduler)
|
| 137 |
+
|
| 138 |
+
return mean_loss
|
| 139 |
+
|
| 140 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 141 |
+
# #
|
| 142 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 143 |
+
def validation_step(
|
| 144 |
+
# Always granted:
|
| 145 |
+
model: torch.nn.Module,
|
| 146 |
+
data: torch.utils.data.DataLoader,
|
| 147 |
+
loss: torch.nn.Module,
|
| 148 |
+
controller: Setup,
|
| 149 |
+
additional_metrics: dict = (),
|
| 150 |
+
) -> dict:
|
| 151 |
+
"""
|
| 152 |
+
Performs a single validation step including forward pass and loss calculation.
|
| 153 |
+
|
| 154 |
+
Parameters:
|
| 155 |
+
model (torch.nn.Module): The model to be validated.
|
| 156 |
+
data (torch.utils.data.DataLoader): DataLoader providing the validation data.
|
| 157 |
+
loss (torch.nn.Module): Loss function to be used.
|
| 158 |
+
controller (Setup): The setup object containing configuration and state.
|
| 159 |
+
additional_metrics (dict): Additional metrics to calculate for each epoch.
|
| 160 |
+
Returns:
|
| 161 |
+
float: The mean loss value for this validation step.
|
| 162 |
+
"""
|
| 163 |
+
# Validation mode:
|
| 164 |
+
model.to(controller.device)
|
| 165 |
+
model.eval()
|
| 166 |
+
|
| 167 |
+
# Validation the model for dataloaders or iterators:
|
| 168 |
+
losses = list()
|
| 169 |
+
metrics: dict[str, list | float] = {name: list() for name in additional_metrics}
|
| 170 |
+
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
with tqdm.tqdm(data, desc=f'\rValidation epoch {controller.epoch}', leave=True) as pbar:
|
| 173 |
+
pbar: torch.DataLoader
|
| 174 |
+
for element in pbar:
|
| 175 |
+
# Gather elements:
|
| 176 |
+
if len(element) == 2:
|
| 177 |
+
# Prediction:
|
| 178 |
+
x, y = element
|
| 179 |
+
x_m, y_m = None, None
|
| 180 |
+
args = tuple()
|
| 181 |
+
elif len(element) == 3:
|
| 182 |
+
# Prediction with x_mask:
|
| 183 |
+
x, y, x_m = element
|
| 184 |
+
y_m = None
|
| 185 |
+
args = tuple()
|
| 186 |
+
elif len(element) == 4:
|
| 187 |
+
# Prediction with x_mask and y_mask:
|
| 188 |
+
x, y, x_m, y_m = element
|
| 189 |
+
elif len(element) > 4:
|
| 190 |
+
# More input arguments:
|
| 191 |
+
x, y = element[0], element[1]
|
| 192 |
+
x_m, y_m = element[2], element[3]
|
| 193 |
+
args = element[4:]
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError("DataLoader elements must have at least two elements.")
|
| 196 |
+
|
| 197 |
+
# Load data to device:
|
| 198 |
+
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True)
|
| 199 |
+
if x_m is not None:
|
| 200 |
+
x_m = x_m.to(controller.device, non_blocking=True)
|
| 201 |
+
if y_m is not None:
|
| 202 |
+
y_m = y_m.to(controller.device, non_blocking=True)
|
| 203 |
+
|
| 204 |
+
# Control autocast (mem-speed):
|
| 205 |
+
if controller.autoscaler is not None:
|
| 206 |
+
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'),
|
| 207 |
+
device_type=controller.device.type):
|
| 208 |
+
# Forward:
|
| 209 |
+
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
|
| 210 |
+
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
|
| 211 |
+
|
| 212 |
+
# Compute additional metrics:
|
| 213 |
+
if additional_metrics:
|
| 214 |
+
for name, additional_metric in additional_metrics.items():
|
| 215 |
+
metrics[name].append(additional_metric(y_hat, y, y_m).item())
|
| 216 |
+
else:
|
| 217 |
+
# Forward:
|
| 218 |
+
y_hat = model(x, x_m, *args) if x_m is not None else model(x)
|
| 219 |
+
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y)
|
| 220 |
+
|
| 221 |
+
# Compute additional metrics:
|
| 222 |
+
if additional_metrics:
|
| 223 |
+
for name, additional_metric in additional_metrics.items():
|
| 224 |
+
metrics[name].append(additional_metric(y_hat, y, y_m).item())
|
| 225 |
+
|
| 226 |
+
# Append to metrics:
|
| 227 |
+
losses.append(loss_metric.item())
|
| 228 |
+
|
| 229 |
+
# Convert:
|
| 230 |
+
losses = np.array(losses)
|
| 231 |
+
mean_loss = float(np.mean(losses))
|
| 232 |
+
|
| 233 |
+
# Additional metrics:
|
| 234 |
+
for name, variable in metrics.items():
|
| 235 |
+
metrics[name] = float(np.mean(variable))
|
| 236 |
+
metrics['loss'] = mean_loss
|
| 237 |
+
|
| 238 |
+
# Write to register:
|
| 239 |
+
controller.register("val_loss", mean_loss)
|
| 240 |
+
# Write for logger:
|
| 241 |
+
controller.logger.info(f"Epoch [{controller.epoch}]: val_loss = {mean_loss:.8f}")
|
| 242 |
+
|
| 243 |
+
return metrics
|
| 244 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 245 |
+
# END OF FILE #
|
| 246 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from .config import ModelConfig, TransformerConfig, CoSeNetConfig
|
| 9 |
+
from .segmentation import SegmentationNetwork
|
| 10 |
+
from .loss import MaskedBCELoss
|
| 11 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 12 |
+
# END OF FILE #
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from typing import List
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class CoSeNetConfig:
|
| 14 |
+
trainable: bool = True
|
| 15 |
+
init_scale: float = 5.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class TransformerConfig:
|
| 20 |
+
attention_heads: int = 8
|
| 21 |
+
feed_forward_multiplier: float = 4
|
| 22 |
+
dropout: float = 0.0
|
| 23 |
+
pre_normalize: bool = True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ModelConfig:
|
| 28 |
+
vocab_size: int = 2 ** 15
|
| 29 |
+
model_dim: int = 256
|
| 30 |
+
max_tokens: int = 382
|
| 31 |
+
max_sentences: int = 384
|
| 32 |
+
valid_padding: bool = True
|
| 33 |
+
cosenet: CoSeNetConfig = field(default_factory=CoSeNetConfig)
|
| 34 |
+
transformers: List[TransformerConfig] = field(default_factory=lambda: [TransformerConfig()])
|
| 35 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 36 |
+
# END OF FILE #
|
| 37 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/cosenet/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
from .cosenet import CoSeNet
|
| 8 |
+
from .cosine_distance import CosineDistanceLayer
|
| 9 |
+
from .trainable_sigmoid import TrainableSigmoid
|
| 10 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 11 |
+
# END OF FILE #
|
| 12 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/cosenet/cosenet.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from .cosenet_layer import CoSeNetLayer
|
| 12 |
+
from .trainable_sigmoid import TrainableSigmoid
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CoSeNet(torch.nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
PyTorch's implementation of the CoSeNet architecture.
|
| 18 |
+
|
| 19 |
+
This module loads pre-trained CoSeNet weights and applies a structured
|
| 20 |
+
unfolding–linear–folding pipeline to the input tensor. An optional
|
| 21 |
+
trainable sigmoid adaptation is applied to the input prior to the
|
| 22 |
+
CoSeNet transformation.
|
| 23 |
+
|
| 24 |
+
The architecture assumes that the input data represent structured
|
| 25 |
+
matrices (e.g., similarity or distance matrices) and performs
|
| 26 |
+
diagonal-based unfolding with overlapping windows.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, trainable: bool = False, init_scale: float = 5.0, **kwargs):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the CoSeNet model.
|
| 32 |
+
|
| 33 |
+
Pre-trained weights and biases are loaded from disk and used to
|
| 34 |
+
construct the internal CoSeNet layer. Optionally, the parameters
|
| 35 |
+
can be set as trainable.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
trainable (bool, optional): Whether the CoSeNet linear layer
|
| 39 |
+
parameters should be trainable. Defaults to False.
|
| 40 |
+
init_scale (float, optional): Initial scale for the trainable
|
| 41 |
+
sigmoid adaptation module. Defaults to 5.0.
|
| 42 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 43 |
+
`torch.nn.Module`.
|
| 44 |
+
|
| 45 |
+
Raises:
|
| 46 |
+
FileNotFoundError: If the weight or bias files cannot be found.
|
| 47 |
+
"""
|
| 48 |
+
super().__init__(**kwargs)
|
| 49 |
+
|
| 50 |
+
# Load weights:
|
| 51 |
+
this_file_name = os.path.dirname(os.path.abspath(__file__))
|
| 52 |
+
w_path = os.path.join(this_file_name, 'weights', 'w.npy')
|
| 53 |
+
b_path = os.path.join(this_file_name, 'weights', 'b.npy')
|
| 54 |
+
|
| 55 |
+
if not os.path.exists(w_path):
|
| 56 |
+
raise FileNotFoundError(f'CoSeNet weight file {w_path} does not exist.')
|
| 57 |
+
if not os.path.exists(b_path):
|
| 58 |
+
raise FileNotFoundError(f'CoSeNet bias file {b_path} does not exist.')
|
| 59 |
+
|
| 60 |
+
w, b = np.load(w_path), np.load(b_path)
|
| 61 |
+
|
| 62 |
+
# Build layers:
|
| 63 |
+
self.matrix_shape = int(np.sqrt(w.shape[-1]))
|
| 64 |
+
self.layer = CoSeNetLayer(w, b, trainable=trainable)
|
| 65 |
+
self.adaptation = TrainableSigmoid(init_scale=init_scale)
|
| 66 |
+
|
| 67 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass of the CoSeNet model.
|
| 70 |
+
|
| 71 |
+
The input is first adapted using a trainable sigmoid, then padded,
|
| 72 |
+
unfolded along the diagonal, processed by the CoSeNet linear layer,
|
| 73 |
+
and finally folded back into its original structure. An optional
|
| 74 |
+
external mask can be applied to the output.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x (torch.Tensor): Input tensor containing structured matrix data.
|
| 78 |
+
mask (torch.Tensor, optional): Optional mask tensor applied
|
| 79 |
+
element-wise to the output. Defaults to None.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
torch.Tensor: Output tensor with the same spatial structure as
|
| 83 |
+
the input.
|
| 84 |
+
"""
|
| 85 |
+
# check dimension:
|
| 86 |
+
if x.dim() < 2:
|
| 87 |
+
raise ValueError(f'CoSeNet input: at least 2 dimensions required. (got {x.dim()})')
|
| 88 |
+
# Check perfect square:
|
| 89 |
+
if x.shape[-1] != x.shape[-2]:
|
| 90 |
+
raise ValueError(f'CoSeNet input: last two dimensions must be equal. ({x.shape[-2]} != {x.shape[-1]})')
|
| 91 |
+
|
| 92 |
+
adapted_x = self.adaptation(x)
|
| 93 |
+
pad_x, pad_mask = self.__cosenet_padding(adapted_x)
|
| 94 |
+
unfold_x = self.__unfold(pad_x)
|
| 95 |
+
unfold_y = self.layer(unfold_x)
|
| 96 |
+
y = self.__fold(unfold_y, pad_mask)
|
| 97 |
+
|
| 98 |
+
if mask is not None:
|
| 99 |
+
y = torch.multiply(y, mask)
|
| 100 |
+
|
| 101 |
+
return y
|
| 102 |
+
|
| 103 |
+
def __unfold(self, x: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
"""
|
| 105 |
+
Unfold the input tensor into overlapping diagonal blocks.
|
| 106 |
+
|
| 107 |
+
The unfolding is performed using a sliding window over the last
|
| 108 |
+
two dimensions, followed by diagonal extraction. The stride is
|
| 109 |
+
determined by half of the matrix size.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
x (torch.Tensor): Padded input tensor.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
torch.Tensor: Tensor containing unfolded diagonal blocks with
|
| 116 |
+
shape [..., K, L, L], where K is the number of extracted blocks.
|
| 117 |
+
"""
|
| 118 |
+
step = max(1, self.matrix_shape // 2)
|
| 119 |
+
u = x.unfold(-2, self.matrix_shape, step).unfold(-2, self.matrix_shape, step)
|
| 120 |
+
y = u.diagonal(offset=0, dim1=-4, dim2=-3).movedim(-1, 1)
|
| 121 |
+
return y
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def __fold(x: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Fold unfolded CoSeNet outputs back into a full matrix.
|
| 127 |
+
|
| 128 |
+
Overlapping regions are combined using an averaging strategy to
|
| 129 |
+
account for multiple contributions to the same spatial location.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
x (torch.Tensor): Tensor containing unfolded CoSeNet outputs.
|
| 133 |
+
pad_mask (torch.Tensor): Boolean mask indicating valid (non-padded)
|
| 134 |
+
positions.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
torch.Tensor: Folded tensor with padding removed and original
|
| 138 |
+
structure restored.
|
| 139 |
+
"""
|
| 140 |
+
if x.shape[-2] > 1:
|
| 141 |
+
y = torch.zeros(
|
| 142 |
+
list(x.shape[:-2]) + [x.shape[-1] * (x.shape[-2] + 1) // 2],
|
| 143 |
+
device=x.device,
|
| 144 |
+
)
|
| 145 |
+
t = x.shape[-1] // 2
|
| 146 |
+
|
| 147 |
+
for i in range(x.shape[-2]):
|
| 148 |
+
y[..., i * t + 1: t * (i + 2)] += 0.5 * x[..., i, 1:]
|
| 149 |
+
y[..., i * t] *= 2
|
| 150 |
+
|
| 151 |
+
y[..., :t] *= 2
|
| 152 |
+
y[..., -t:] *= 2
|
| 153 |
+
y[..., 0] = 1
|
| 154 |
+
else:
|
| 155 |
+
y = x[..., 0, :]
|
| 156 |
+
|
| 157 |
+
return y[pad_mask].view(pad_mask.shape)
|
| 158 |
+
|
| 159 |
+
def __cosenet_padding(self, x: torch.Tensor) -> tuple:
|
| 160 |
+
"""
|
| 161 |
+
Pad the input tensor to match the required matrix shape.
|
| 162 |
+
|
| 163 |
+
Padding is applied along the last two dimensions to ensure that
|
| 164 |
+
their sizes are multiples of the CoSeNet matrix shape. A diagonal
|
| 165 |
+
mask is generated to distinguish padded elements.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
x (torch.Tensor): Original input tensor.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
tuple:
|
| 172 |
+
- torch.Tensor: Padded tensor with diagonal correction.
|
| 173 |
+
- torch.Tensor: Boolean mask indicating valid entries.
|
| 174 |
+
"""
|
| 175 |
+
pad_w = (self.matrix_shape - (x.shape[-1] % self.matrix_shape)) % self.matrix_shape
|
| 176 |
+
pad_h = (self.matrix_shape - (x.shape[-2] % self.matrix_shape)) % self.matrix_shape
|
| 177 |
+
|
| 178 |
+
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))
|
| 179 |
+
|
| 180 |
+
diag = x.diagonal(dim1=-2, dim2=-1)
|
| 181 |
+
mask_bool = (diag == 0)
|
| 182 |
+
mask01 = mask_bool.to(x.dtype)
|
| 183 |
+
|
| 184 |
+
x = x + torch.diag_embed(mask01)
|
| 185 |
+
|
| 186 |
+
return x, torch.logical_not(mask_bool)
|
| 187 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 188 |
+
# END OF FILE #
|
| 189 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/cosenet/cosenet_layer.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CoSeNetLayer(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Linear layer for CoSeNet with optional trainable parameters.
|
| 15 |
+
|
| 16 |
+
This module implements a single linear transformation used within
|
| 17 |
+
CoSeNet, assuming the input has already been padded and segmented.
|
| 18 |
+
The layer supports fixed (non-trainable) or learnable weights and
|
| 19 |
+
biases, enabling its use in both frozen and fine-tuning scenarios.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, coef: np.ndarray, intercept: np.ndarray, trainable: bool = False, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the CoSeNet layer.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
coef (np.ndarray): Weight matrix used for the linear transformation.
|
| 28 |
+
intercept (np.ndarray): Bias vector added to the linear output.
|
| 29 |
+
trainable (bool, optional): Whether the weights and bias should be
|
| 30 |
+
optimized during training. Defaults to False.
|
| 31 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 32 |
+
`torch.nn.Module`.
|
| 33 |
+
"""
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
self.weight = torch.nn.Parameter(torch.tensor(coef, dtype=torch.float32), requires_grad=trainable)
|
| 36 |
+
self.bias = torch.nn.Parameter(torch.tensor(intercept, dtype=torch.float32), requires_grad=trainable)
|
| 37 |
+
|
| 38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""
|
| 40 |
+
Apply the linear transformation to the input tensor.
|
| 41 |
+
|
| 42 |
+
The input tensor is flattened across the last two dimensions
|
| 43 |
+
before applying the linear operation.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (torch.Tensor): Input tensor with shape [..., *, *], where the
|
| 47 |
+
last two dimensions are flattened prior to the linear mapping.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
torch.Tensor: Output tensor resulting from the linear transformation.
|
| 51 |
+
"""
|
| 52 |
+
return torch.nn.functional.linear(x.flatten(-2), self.weight, self.bias)
|
| 53 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 54 |
+
# END OF FILE #
|
| 55 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/cosenet/cosine_distance.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CosineDistanceLayer(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Pairwise cosine distance computation layer.
|
| 14 |
+
|
| 15 |
+
This module computes pairwise cosine-based distances between embedding
|
| 16 |
+
vectors within the same input tensor. The operation is performed along
|
| 17 |
+
the last dimension, producing a square similarity (or distance) matrix
|
| 18 |
+
for each leading batch dimension.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, **kwargs):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the cosine distance layer.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 27 |
+
`torch.nn.Module`.
|
| 28 |
+
"""
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def forward(x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Compute pairwise cosine similarity between embeddings.
|
| 35 |
+
|
| 36 |
+
The input embeddings are L2-normalized along the last dimension
|
| 37 |
+
before computing the cosine similarity matrix. The absolute value
|
| 38 |
+
of the similarity is returned, treating opposite directions as
|
| 39 |
+
related.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (torch.Tensor): Input tensor of shape [..., S, D], where
|
| 43 |
+
`S` is the number of embeddings and `D` is the embedding
|
| 44 |
+
dimensionality.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
torch.Tensor: Tensor of shape [..., S, S] containing the
|
| 48 |
+
pair-wise cosine similarities.
|
| 49 |
+
"""
|
| 50 |
+
# Normalize for last dim:
|
| 51 |
+
x_norm = torch.nn.functional.normalize(x, p=2, dim=-1) # [..., S, D]
|
| 52 |
+
# Cosine similarity
|
| 53 |
+
sim = torch.matmul(x_norm, x_norm.transpose(-2, -1)) # [..., S, S]
|
| 54 |
+
return torch.abs(sim)
|
| 55 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 56 |
+
# END OF FILE #
|
| 57 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/cosenet/trainable_sigmoid.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TrainableSigmoid(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Trainable sigmoid activation module with learnable scaling.
|
| 14 |
+
|
| 15 |
+
This module implements a sigmoid function whose slope is controlled by
|
| 16 |
+
a trainable parameter. It is designed to adaptively rescale input values
|
| 17 |
+
(e.g., distances or similarity scores) around a fixed midpoint (0.5),
|
| 18 |
+
allowing the model to learn the appropriate sharpness of the transition
|
| 19 |
+
during training.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, init_scale: float = 5.0, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the trainable sigmoid module.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
init_scale (float, optional): Initial magnitude of the sigmoid
|
| 28 |
+
scaling factor. Internally, the learnable parameter is
|
| 29 |
+
initialized as the negative of this value. Defaults to 5.0.
|
| 30 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 31 |
+
`torch.nn.Module`.
|
| 32 |
+
"""
|
| 33 |
+
super().__init__(**kwargs)
|
| 34 |
+
self.alpha = torch.nn.Parameter(torch.tensor(-init_scale, dtype=torch.float32))
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
Apply the trainable sigmoid transformation to the input tensor.
|
| 39 |
+
|
| 40 |
+
The transformation is centered at 0.5 and scaled by a learnable
|
| 41 |
+
parameter, enabling adaptive control over the sigmoid steepness.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
x (torch.Tensor): Input tensor containing values to be transformed.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
torch.Tensor: Tensor of the same shape as `x`, with the trainable
|
| 48 |
+
sigmoid function applied element-wise.
|
| 49 |
+
"""
|
| 50 |
+
return 1 / (1 + torch.exp(self.alpha * (x - 0.5)))
|
| 51 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 52 |
+
# END OF FILE #
|
| 53 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/loss.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MaskedBCELoss(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Binary Cross-Entropy loss with explicit masking support.
|
| 14 |
+
|
| 15 |
+
This loss function computes the binary cross-entropy over valid (non-padded)
|
| 16 |
+
elements only, as indicated by a boolean mask. It supports both logits and
|
| 17 |
+
probability inputs, and provides configurable reduction strategies.
|
| 18 |
+
|
| 19 |
+
Masking semantics can be adapted to match PyTorch-style padding conventions
|
| 20 |
+
or custom masking schemes.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
reduction: str = 'mean',
|
| 26 |
+
valid_pad: bool = True,
|
| 27 |
+
eps: float = 1e-7,
|
| 28 |
+
logits: bool = True
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the masked binary cross-entropy loss.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
reduction (str, optional): Reduction method applied over valid
|
| 35 |
+
elements. Must be either `'mean'` or `'sum'`. Defaults to `'mean'`.
|
| 36 |
+
valid_pad (bool, optional): Mask interpretation mode. If True,
|
| 37 |
+
`True` values in the mask indicate valid (non-padded) positions.
|
| 38 |
+
If False, `True` values indicate padded positions, following
|
| 39 |
+
PyTorch-style padding conventions. Defaults to True.
|
| 40 |
+
eps (float, optional): Small numerical constant used to clamp
|
| 41 |
+
probability inputs when `logits=False`. Defaults to 1e-7.
|
| 42 |
+
logits (bool, optional): Whether the input predictions are logits.
|
| 43 |
+
If True, `binary_cross_entropy_with_logits` is used; otherwise,
|
| 44 |
+
standard binary cross-entropy is applied. Defaults to True.
|
| 45 |
+
|
| 46 |
+
Raises:
|
| 47 |
+
ValueError: If an unsupported reduction mode is provided.
|
| 48 |
+
"""
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
if reduction not in ['mean', 'sum']:
|
| 52 |
+
raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'")
|
| 53 |
+
|
| 54 |
+
self.reduction = reduction
|
| 55 |
+
self.valid_pad = valid_pad
|
| 56 |
+
self.logits = logits
|
| 57 |
+
self.eps = eps
|
| 58 |
+
|
| 59 |
+
if logits:
|
| 60 |
+
self.loss = torch.nn.functional.binary_cross_entropy_with_logits
|
| 61 |
+
else:
|
| 62 |
+
self.loss = torch.nn.functional.binary_cross_entropy
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
x: torch.Tensor,
|
| 67 |
+
y: torch.Tensor,
|
| 68 |
+
mask: torch.Tensor
|
| 69 |
+
) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
Compute the masked binary cross-entropy loss.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x (torch.Tensor): Model predictions with shape (B, S). If
|
| 75 |
+
`logits=True`, values are interpreted as logits; otherwise,
|
| 76 |
+
as probabilities in [0, 1].
|
| 77 |
+
y (torch.Tensor): Ground-truth binary labels with shape (B, S).
|
| 78 |
+
mask (torch.Tensor): Boolean mask tensor with shape (B, S).
|
| 79 |
+
The interpretation of the mask depends on `valid_pad`.
|
| 80 |
+
If `valid_pad=True`, `True` indicates valid positions.
|
| 81 |
+
If `valid_pad=False`, `True` indicates padded positions.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
torch.Tensor: Scalar tensor containing the reduced loss value.
|
| 85 |
+
"""
|
| 86 |
+
# Determine valid positions:
|
| 87 |
+
if self.valid_pad:
|
| 88 |
+
valid_mask = mask
|
| 89 |
+
else:
|
| 90 |
+
valid_mask = torch.logical_not(mask)
|
| 91 |
+
|
| 92 |
+
# Numerical stability for probability inputs:
|
| 93 |
+
if not self.logits:
|
| 94 |
+
x = x.clamp(self.eps, 1.0 - self.eps)
|
| 95 |
+
|
| 96 |
+
# Element-wise BCE:
|
| 97 |
+
loss_per_token = self.loss(
|
| 98 |
+
x.float(),
|
| 99 |
+
y.float(),
|
| 100 |
+
reduction='none'
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Mask padded positions:
|
| 104 |
+
masked_loss = loss_per_token * valid_mask.float()
|
| 105 |
+
|
| 106 |
+
if self.reduction == 'mean':
|
| 107 |
+
denom = valid_mask.sum().clamp(min=1)
|
| 108 |
+
return masked_loss.sum() / denom
|
| 109 |
+
elif self.reduction == 'sum':
|
| 110 |
+
return masked_loss.sum()
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("[MASKED-BCE] Reduction must be 'mean' or 'sum'")
|
| 113 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 114 |
+
# END OF FILE #
|
| 115 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/segmentation.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
from .config import ModelConfig
|
| 10 |
+
from .cosenet import CosineDistanceLayer, CoSeNet
|
| 11 |
+
from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SegmentationNetwork(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Segmentation network combining Transformer encoders with CoSeNet.
|
| 17 |
+
|
| 18 |
+
This model integrates token embeddings and positional encodings with
|
| 19 |
+
a stack of Transformer encoder blocks to produce contextualized
|
| 20 |
+
representations. These representations are then processed by a
|
| 21 |
+
CoSeNet module to perform structured segmentation, followed by a
|
| 22 |
+
cosine-based distance computation.
|
| 23 |
+
|
| 24 |
+
The final output is a pair-wise distance matrix suitable for
|
| 25 |
+
segmentation or boundary detection tasks.
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, model_config: ModelConfig, **kwargs):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the segmentation network.
|
| 30 |
+
|
| 31 |
+
The network is composed of an embedding layer, positional encoding,
|
| 32 |
+
multiple Transformer encoder blocks, a CoSeNet segmentation module,
|
| 33 |
+
and a cosine distance layer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_config (ModelConfig): Configuration object containing all
|
| 37 |
+
hyperparameters required to build the model, including
|
| 38 |
+
vocabulary size, model dimensionality, transformer settings,
|
| 39 |
+
and CoSeNet parameters.
|
| 40 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 41 |
+
`torch.nn.Module`.
|
| 42 |
+
"""
|
| 43 |
+
super().__init__(**kwargs)
|
| 44 |
+
self.valid_padding = model_config.valid_padding
|
| 45 |
+
|
| 46 |
+
# Build layers:
|
| 47 |
+
self.embedding = torch.nn.Embedding(
|
| 48 |
+
model_config.vocab_size,
|
| 49 |
+
model_config.model_dim
|
| 50 |
+
)
|
| 51 |
+
self.positional_encoding = PositionalEncoding(
|
| 52 |
+
emb_dim=model_config.model_dim,
|
| 53 |
+
max_len=model_config.max_tokens
|
| 54 |
+
)
|
| 55 |
+
self.cosenet = CoSeNet(
|
| 56 |
+
trainable=model_config.cosenet.trainable,
|
| 57 |
+
init_scale=model_config.cosenet.init_scale
|
| 58 |
+
)
|
| 59 |
+
self.distance_layer = CosineDistanceLayer()
|
| 60 |
+
self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
|
| 61 |
+
|
| 62 |
+
# Build encoder blocks:
|
| 63 |
+
module_list = list()
|
| 64 |
+
for transformer_config in model_config.transformers:
|
| 65 |
+
encoder_block = EncoderBlock(
|
| 66 |
+
feature_dim=model_config.model_dim,
|
| 67 |
+
attention_heads=transformer_config.attention_heads,
|
| 68 |
+
feed_forward_multiplier=transformer_config.feed_forward_multiplier,
|
| 69 |
+
dropout=transformer_config.dropout,
|
| 70 |
+
valid_padding=model_config.valid_padding,
|
| 71 |
+
pre_normalize=transformer_config.pre_normalize
|
| 72 |
+
)
|
| 73 |
+
module_list.append(encoder_block)
|
| 74 |
+
|
| 75 |
+
self.encoder_blocks = torch.nn.ModuleList(module_list)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Forward pass of the segmentation network.
|
| 80 |
+
|
| 81 |
+
The input token indices are embedded and enriched with positional
|
| 82 |
+
information, then processed by a stack of Transformer encoder
|
| 83 |
+
blocks. The resulting representations are segmented using CoSeNet
|
| 84 |
+
and finally transformed into a pair-wise distance representation.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
x (torch.Tensor): Input tensor of token indices with shape
|
| 88 |
+
(batch_size, sequence_length).
|
| 89 |
+
mask (torch.Tensor, optional): Optional mask tensor indicating
|
| 90 |
+
valid or padded positions, depending on the configuration
|
| 91 |
+
of the Transformer blocks. Defaults to None.
|
| 92 |
+
|
| 93 |
+
If `valid_padding` is disabled, the mask is inverted before being
|
| 94 |
+
passed to CoSeNet to match its masking convention.
|
| 95 |
+
|
| 96 |
+
candidate_mask (torch.Tensor, optional): Optional mask tensor for
|
| 97 |
+
candidate positions in CoSeNet. Defaults to None.
|
| 98 |
+
|
| 99 |
+
If `valid_padding` is disabled, the mask is inverted before being
|
| 100 |
+
passed to CoSeNet to match its masking convention.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
torch.Tensor: Output tensor containing pairwise distance values
|
| 104 |
+
derived from the segmented representations.
|
| 105 |
+
"""
|
| 106 |
+
# Convert to type:
|
| 107 |
+
x = x.int()
|
| 108 |
+
|
| 109 |
+
# Embedding and positional encoding:
|
| 110 |
+
x = self.embedding(x)
|
| 111 |
+
x = self.positional_encoding(x)
|
| 112 |
+
|
| 113 |
+
# Reshape x and mask:
|
| 114 |
+
_b, _s, _t, _d = x.shape
|
| 115 |
+
x = x.reshape(_b * _s, _t, _d)
|
| 116 |
+
if mask is not None:
|
| 117 |
+
mask = mask.reshape(_b * _s, _t).bool()
|
| 118 |
+
|
| 119 |
+
# Encode the sequence:
|
| 120 |
+
for encoder in self.encoder_blocks:
|
| 121 |
+
x = encoder(x, mask=mask)
|
| 122 |
+
|
| 123 |
+
# Reshape x and mask:
|
| 124 |
+
x = x.reshape(_b, _s, _t, _d)
|
| 125 |
+
if mask is not None:
|
| 126 |
+
mask = mask.reshape(_b, _s, _t)
|
| 127 |
+
mask = torch.logical_not(mask) if not self.valid_padding else mask
|
| 128 |
+
|
| 129 |
+
# Apply pooling:
|
| 130 |
+
x, mask = self.pooling(x, mask=mask)
|
| 131 |
+
|
| 132 |
+
# Compute distances:
|
| 133 |
+
x = self.distance_layer(x)
|
| 134 |
+
|
| 135 |
+
# Pass through CoSeNet:
|
| 136 |
+
x = self.cosenet(x, mask=mask)
|
| 137 |
+
|
| 138 |
+
# Apply candidate mask if provided:
|
| 139 |
+
if candidate_mask is not None:
|
| 140 |
+
candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
|
| 141 |
+
candidate_mask = candidate_mask.to(device=x.device)
|
| 142 |
+
x = x.masked_fill(candidate_mask, 0)
|
| 143 |
+
|
| 144 |
+
return x
|
| 145 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 146 |
+
# END OF FILE #
|
| 147 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/transformers/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from .attention import EncoderBlock
|
| 9 |
+
from .positional_encoding import PositionalEncoding
|
| 10 |
+
from .pooling import MaskedMeanPooling
|
| 11 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 12 |
+
# END OF FILE #
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/transformers/attention.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EncoderBlock(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Transformer encoder block with configurable Pre-LayerNorm or Post-LayerNorm
|
| 14 |
+
architecture.
|
| 15 |
+
|
| 16 |
+
The block consists of a multi-head self-attention sublayer followed by a
|
| 17 |
+
position-wise feed-forward network, each wrapped with a residual connection.
|
| 18 |
+
Layer normalization can be applied either before each sublayer (Pre-LN) or
|
| 19 |
+
after each residual addition (Post-LN).
|
| 20 |
+
|
| 21 |
+
This design allows stable training of deep Transformer stacks while retaining
|
| 22 |
+
compatibility with the original Transformer formulation.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_dim: int,
|
| 27 |
+
attention_heads: int = 8,
|
| 28 |
+
feed_forward_multiplier: float = 4,
|
| 29 |
+
dropout: float = 0.0,
|
| 30 |
+
valid_padding: bool = False,
|
| 31 |
+
pre_normalize: bool = True,
|
| 32 |
+
**kwargs
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Initializes a Transformer encoder block.
|
| 36 |
+
|
| 37 |
+
Parameters
|
| 38 |
+
----------
|
| 39 |
+
feature_dim : int
|
| 40 |
+
Dimensionality of the input and output feature representations.
|
| 41 |
+
attention_heads : int, optional
|
| 42 |
+
Number of attention heads used in the multi-head self-attention layer.
|
| 43 |
+
Default is 8.
|
| 44 |
+
feed_forward_multiplier : float, optional
|
| 45 |
+
Expansion factor for the hidden dimension of the feed-forward network.
|
| 46 |
+
The intermediate dimension is computed as
|
| 47 |
+
`feed_forward_multiplier * feature_dim`.
|
| 48 |
+
Default is 4.
|
| 49 |
+
dropout : float, optional
|
| 50 |
+
Dropout probability applied to the feed-forward residual connection.
|
| 51 |
+
Default is 0.0.
|
| 52 |
+
valid_padding : bool, optional
|
| 53 |
+
If True, the provided mask marks valid (non-padded) positions.
|
| 54 |
+
If False, the mask marks padded (invalid) positions directly.
|
| 55 |
+
Default is False.
|
| 56 |
+
pre_normalize : bool, optional
|
| 57 |
+
If True, uses the Pre-LayerNorm Transformer variant, applying layer
|
| 58 |
+
normalization before each sublayer (self-attention and feed-forward).
|
| 59 |
+
If False, uses the Post-LayerNorm variant, applying normalization after
|
| 60 |
+
each residual connection.
|
| 61 |
+
Default is True.
|
| 62 |
+
**kwargs
|
| 63 |
+
Additional keyword arguments passed to the parent `torch.nn.Module`.
|
| 64 |
+
"""
|
| 65 |
+
# Module init via kwargs:
|
| 66 |
+
super().__init__(**kwargs)
|
| 67 |
+
|
| 68 |
+
# Store params:
|
| 69 |
+
self.valid_padding = valid_padding
|
| 70 |
+
self.pre_normalize = pre_normalize
|
| 71 |
+
|
| 72 |
+
# Norm layers:
|
| 73 |
+
self.norm_in = torch.nn.LayerNorm(feature_dim)
|
| 74 |
+
self.norm_out = torch.nn.LayerNorm(feature_dim)
|
| 75 |
+
|
| 76 |
+
# Dropout layer:
|
| 77 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 78 |
+
|
| 79 |
+
# Attention layer:
|
| 80 |
+
self.attention = torch.nn.MultiheadAttention(
|
| 81 |
+
embed_dim=feature_dim,
|
| 82 |
+
num_heads=attention_heads,
|
| 83 |
+
dropout=0.0,
|
| 84 |
+
batch_first=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Feed-forward layer:
|
| 88 |
+
self.feed_forward = torch.nn.Sequential(
|
| 89 |
+
torch.nn.Linear(feature_dim, int(feed_forward_multiplier * feature_dim)),
|
| 90 |
+
torch.nn.GELU(),
|
| 91 |
+
torch.nn.Linear(int(feed_forward_multiplier * feature_dim), feature_dim),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 95 |
+
"""
|
| 96 |
+
Forward pass of a Transformer encoder block.
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
----------
|
| 100 |
+
x : torch.Tensor
|
| 101 |
+
Input tensor of shape (batch_size, sequence_length, feature_dim).
|
| 102 |
+
mask : torch.Tensor or None, optional
|
| 103 |
+
Boolean mask indicating valid sequence positions.
|
| 104 |
+
Shape: (batch_size, sequence_length).
|
| 105 |
+
If `valid_padding` is True, True values denote valid tokens.
|
| 106 |
+
Otherwise, True values denote masked (invalid) positions.
|
| 107 |
+
|
| 108 |
+
Returns
|
| 109 |
+
-------
|
| 110 |
+
x : torch.Tensor
|
| 111 |
+
Output tensor of the same shape as the input
|
| 112 |
+
(batch_size, sequence_length, feature_dim).
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
# Convert mask:
|
| 116 |
+
if mask is not None and self.valid_padding:
|
| 117 |
+
key_padding_mask = ~mask.bool() # True = pad
|
| 118 |
+
valid_mask = mask.bool()
|
| 119 |
+
elif mask is not None:
|
| 120 |
+
key_padding_mask = mask.bool()
|
| 121 |
+
valid_mask = ~mask.bool()
|
| 122 |
+
else:
|
| 123 |
+
key_padding_mask = None
|
| 124 |
+
valid_mask = None
|
| 125 |
+
|
| 126 |
+
# Detect fully padded sequences:
|
| 127 |
+
if valid_mask is not None:
|
| 128 |
+
all_pad = ~valid_mask.any(dim=-1) # (B,)
|
| 129 |
+
else:
|
| 130 |
+
all_pad = None
|
| 131 |
+
|
| 132 |
+
# Pre-normalization:
|
| 133 |
+
if self.pre_normalize:
|
| 134 |
+
h = self.norm_in(x)
|
| 135 |
+
else:
|
| 136 |
+
h = x
|
| 137 |
+
|
| 138 |
+
# Attention (guard against fully padded sequences):
|
| 139 |
+
if all_pad is not None and all_pad.any():
|
| 140 |
+
h_attn = h.clone()
|
| 141 |
+
h_attn[all_pad] = 0.0
|
| 142 |
+
|
| 143 |
+
if key_padding_mask is not None:
|
| 144 |
+
key_padding_mask = key_padding_mask.clone()
|
| 145 |
+
key_padding_mask[all_pad] = False
|
| 146 |
+
else:
|
| 147 |
+
h_attn = h
|
| 148 |
+
|
| 149 |
+
attn_out, _ = self.attention(
|
| 150 |
+
h_attn, h_attn, h_attn,
|
| 151 |
+
key_padding_mask=key_padding_mask,
|
| 152 |
+
need_weights=False,
|
| 153 |
+
)
|
| 154 |
+
x = x + attn_out
|
| 155 |
+
|
| 156 |
+
# Post-attention normalization:
|
| 157 |
+
if not self.pre_normalize:
|
| 158 |
+
z = self.norm_in(x)
|
| 159 |
+
else:
|
| 160 |
+
z = self.norm_out(x)
|
| 161 |
+
|
| 162 |
+
# Feed-forward:
|
| 163 |
+
z = self.feed_forward(z)
|
| 164 |
+
x = x + self.dropout(z)
|
| 165 |
+
|
| 166 |
+
if not self.pre_normalize:
|
| 167 |
+
x = self.norm_out(x)
|
| 168 |
+
|
| 169 |
+
# Re-pad fully padded sequences:
|
| 170 |
+
if all_pad is not None:
|
| 171 |
+
x = x.masked_fill(all_pad[:, None, None], 0.0)
|
| 172 |
+
|
| 173 |
+
return x
|
| 174 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 175 |
+
# END OF FILE #
|
| 176 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/transformers/pooling.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MaskedMeanPooling(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Mean pooling layer with explicit masking support.
|
| 14 |
+
|
| 15 |
+
This layer computes the mean over the sequence dimension while
|
| 16 |
+
ignoring padded elements according to a boolean mask. It supports
|
| 17 |
+
both PyTorch-style padding masks and valid-position masks.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, valid_pad: bool = True, eps: float = 1e-6):
|
| 21 |
+
"""
|
| 22 |
+
Initialize the masked mean pooling layer.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
valid_pad (bool, optional): Mask interpretation mode. If True,
|
| 26 |
+
`True` values in the mask indicate valid (non-padded) positions.
|
| 27 |
+
If False, `True` values indicate padded positions, following
|
| 28 |
+
PyTorch-style padding conventions. Defaults to True.
|
| 29 |
+
eps (float, optional): Small constant to avoid division by zero
|
| 30 |
+
when all positions are masked. Defaults to 1e-8.
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.valid_pad = valid_pad
|
| 34 |
+
self.eps = eps
|
| 35 |
+
|
| 36 |
+
def forward(
|
| 37 |
+
self,
|
| 38 |
+
x: torch.Tensor,
|
| 39 |
+
mask: torch.Tensor
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Apply masked mean pooling.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x (torch.Tensor): Input tensor of shape (..., S, D), where
|
| 46 |
+
B is the batch size, S the sequence length, and D the
|
| 47 |
+
feature dimension.
|
| 48 |
+
mask (torch.Tensor): Boolean mask tensor of shape (..., S).
|
| 49 |
+
The interpretation depends on `valid_pad`.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
tuple:
|
| 53 |
+
torch.Tensor: Pooled tensor of shape (..., D).
|
| 54 |
+
torch.Tensor: Updated valid mask after pooling of shape (..., ).
|
| 55 |
+
"""
|
| 56 |
+
# Mask handling:
|
| 57 |
+
if mask is None:
|
| 58 |
+
valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device)
|
| 59 |
+
else:
|
| 60 |
+
valid_mask = mask
|
| 61 |
+
|
| 62 |
+
# Valid:
|
| 63 |
+
if self.valid_pad:
|
| 64 |
+
valid_mask = valid_mask
|
| 65 |
+
else:
|
| 66 |
+
valid_mask = torch.logical_not(valid_mask)
|
| 67 |
+
|
| 68 |
+
valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) # (..., S, 1)
|
| 69 |
+
summed = torch.sum(x * valid_mask, dim=-2) # (..., D)
|
| 70 |
+
denom = valid_mask.sum(dim=-2).clamp(min=self.eps) # (..., 1)
|
| 71 |
+
|
| 72 |
+
# Valid mask pooling (any):
|
| 73 |
+
valid_mask = valid_mask.squeeze(-1).any(dim=-1)
|
| 74 |
+
|
| 75 |
+
return summed / denom, valid_mask
|
| 76 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 77 |
+
# END OF FILE #
|
| 78 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/model/transformers/positional_encoding.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PositionalEncoding(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Sinusoidal positional encoding module for Transformer models.
|
| 14 |
+
|
| 15 |
+
This module injects information about the relative or absolute position of
|
| 16 |
+
tokens in a sequence by adding fixed sinusoidal embeddings to the input
|
| 17 |
+
embeddings. The positional encodings are non-learnable and follow the
|
| 18 |
+
formulation introduced in the original Transformer architecture.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(self, emb_dim: int, max_len: int = 5000, **kwargs):
|
| 21 |
+
"""
|
| 22 |
+
Initialize the positional encoding module.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
emb_dim : int
|
| 27 |
+
Dimensionality of the embedding space.
|
| 28 |
+
max_len : int, optional
|
| 29 |
+
Maximum supported sequence length for which positional encodings
|
| 30 |
+
are precomputed.
|
| 31 |
+
"""
|
| 32 |
+
super().__init__(**kwargs)
|
| 33 |
+
|
| 34 |
+
# Create positional encodings:
|
| 35 |
+
pe = torch.zeros(max_len, emb_dim)
|
| 36 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 37 |
+
div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / emb_dim))
|
| 38 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 39 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 40 |
+
pe = pe.unsqueeze(0)
|
| 41 |
+
|
| 42 |
+
# Register as a buffer:
|
| 43 |
+
self.register_buffer('positional_encoding', pe)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Add positional encodings to the input embeddings.
|
| 48 |
+
|
| 49 |
+
Parameters
|
| 50 |
+
----------
|
| 51 |
+
x : torch.Tensor
|
| 52 |
+
Input tensor of shape (batch_size, sequence_length, emb_dim).
|
| 53 |
+
|
| 54 |
+
Returns
|
| 55 |
+
-------
|
| 56 |
+
torch.Tensor
|
| 57 |
+
Tensor of the same shape as the input with positional encodings added.
|
| 58 |
+
"""
|
| 59 |
+
return x + self.positional_encoding[:, :x.size(1), :]
|
| 60 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 61 |
+
# END OF FILE #
|
| 62 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
train/config.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import os
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
|
| 11 |
+
from src.dataset import DatasetConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 15 |
+
# SETUP CONFIGURATION #
|
| 16 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 17 |
+
@dataclass
|
| 18 |
+
class SetupConfig:
|
| 19 |
+
"""
|
| 20 |
+
Configuration parameters related to the execution environment and logging.
|
| 21 |
+
|
| 22 |
+
This configuration controls device selection, checkpointing behavior,
|
| 23 |
+
reproducibility settings, and logging paths for an experiment.
|
| 24 |
+
"""
|
| 25 |
+
device_number: int = 0
|
| 26 |
+
save_model_each: int = 0
|
| 27 |
+
seed: int = None
|
| 28 |
+
logging_path: str = None
|
| 29 |
+
reload_checkpoint: bool = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def overwrite_setup_config() -> SetupConfig:
|
| 33 |
+
"""
|
| 34 |
+
Create and override the default setup configuration.
|
| 35 |
+
|
| 36 |
+
This function customizes execution-level parameters such as logging
|
| 37 |
+
paths, checkpoint reloading, and model saving frequency.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
SetupConfig: The configured setup configuration object.
|
| 41 |
+
"""
|
| 42 |
+
config = SetupConfig()
|
| 43 |
+
config.logging_path = r'/workspace/logs'
|
| 44 |
+
config.reload_checkpoint = True
|
| 45 |
+
config.save_model_each = 1
|
| 46 |
+
return config
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 50 |
+
# TRAINING CONFIGURATION #
|
| 51 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 52 |
+
@dataclass
|
| 53 |
+
class TrainConfig:
|
| 54 |
+
"""
|
| 55 |
+
Training configuration container.
|
| 56 |
+
|
| 57 |
+
This dataclass aggregates model, dataset, and setup configurations,
|
| 58 |
+
together with optimization and training hyperparameters.
|
| 59 |
+
"""
|
| 60 |
+
# Linked configurations:
|
| 61 |
+
model_config: ModelConfig | None = None
|
| 62 |
+
dataset_config: DatasetConfig | None = None
|
| 63 |
+
setup_config: SetupConfig | None = None
|
| 64 |
+
|
| 65 |
+
# Training parameters:
|
| 66 |
+
batch_size: int = 32
|
| 67 |
+
num_epochs: int = 100
|
| 68 |
+
|
| 69 |
+
# Optimizer parameters:
|
| 70 |
+
learning_rate: float = 1e-4
|
| 71 |
+
learning_rate_min: float = 1e-5
|
| 72 |
+
weight_decay: float = 1e-8
|
| 73 |
+
betas: tuple[float, float] = (0.5, 0.999)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def overwrite_train_config() -> TrainConfig:
|
| 77 |
+
"""
|
| 78 |
+
Create and override the default training configuration.
|
| 79 |
+
|
| 80 |
+
This function customizes batch size, number of epochs, and optimizer
|
| 81 |
+
hyperparameters for the training process.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
TrainConfig: The configured training configuration object.
|
| 85 |
+
"""
|
| 86 |
+
config = TrainConfig()
|
| 87 |
+
config.batch_size = 4
|
| 88 |
+
config.num_epochs = 200
|
| 89 |
+
config.learning_rate = 5e-4
|
| 90 |
+
config.learning_rate_min = 5e-5
|
| 91 |
+
config.weight_decay = 1e-6
|
| 92 |
+
return config
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 96 |
+
# DATASET CONFIGURATION #
|
| 97 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 98 |
+
def overwrite_dataset_config() -> DatasetConfig:
|
| 99 |
+
"""
|
| 100 |
+
Create and override the dataset configuration.
|
| 101 |
+
|
| 102 |
+
This function sets the file paths and usage percentages for training,
|
| 103 |
+
validation, and test datasets.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
DatasetConfig: The configured dataset configuration object.
|
| 107 |
+
"""
|
| 108 |
+
config = DatasetConfig()
|
| 109 |
+
config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
|
| 110 |
+
config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
|
| 111 |
+
config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
|
| 112 |
+
config.train_percentage = 1.0
|
| 113 |
+
config.val_percentage = 1.0
|
| 114 |
+
config.test_percentage = 1.0
|
| 115 |
+
return config
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 119 |
+
# MODEL CONFIGURATION #
|
| 120 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 121 |
+
def overwrite_model_config() -> ModelConfig:
|
| 122 |
+
"""
|
| 123 |
+
Create and override the model configuration.
|
| 124 |
+
|
| 125 |
+
This function defines the architecture-level parameters, including
|
| 126 |
+
vocabulary size, embedding dimensionality, CoSeNet settings, and
|
| 127 |
+
the stack of Transformer encoder configurations.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
ModelConfig: The configured model configuration object.
|
| 131 |
+
"""
|
| 132 |
+
config = ModelConfig()
|
| 133 |
+
|
| 134 |
+
# High-level params:
|
| 135 |
+
config.vocab_size = 32_768
|
| 136 |
+
config.model_dim = 256
|
| 137 |
+
config.valid_padding = True
|
| 138 |
+
|
| 139 |
+
# CoSeNet params:
|
| 140 |
+
config.cosenet = CoSeNetConfig(
|
| 141 |
+
trainable=True,
|
| 142 |
+
init_scale=5.0
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Transformer params:
|
| 146 |
+
config.transformers = [
|
| 147 |
+
TransformerConfig(**cfg)
|
| 148 |
+
for cfg in [
|
| 149 |
+
{
|
| 150 |
+
"attention_heads": 16,
|
| 151 |
+
"feed_forward_multiplier": 8,
|
| 152 |
+
"dropout": 0.0,
|
| 153 |
+
"pre_normalize": True
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"attention_heads": 16,
|
| 157 |
+
"feed_forward_multiplier": 8,
|
| 158 |
+
"dropout": 0.0,
|
| 159 |
+
"pre_normalize": True
|
| 160 |
+
}
|
| 161 |
+
]
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
return config
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 168 |
+
# WHOLE CONFIGURATION #
|
| 169 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 170 |
+
def configuration() -> TrainConfig:
|
| 171 |
+
"""
|
| 172 |
+
Create the experiment configuration
|
| 173 |
+
:return: A TrainConfig configuration object
|
| 174 |
+
"""
|
| 175 |
+
config = overwrite_train_config()
|
| 176 |
+
config.setup_config = overwrite_setup_config()
|
| 177 |
+
config.model_config = overwrite_model_config()
|
| 178 |
+
config.dataset_config = overwrite_dataset_config()
|
| 179 |
+
|
| 180 |
+
# Assert:
|
| 181 |
+
if not os.path.exists(config.dataset_config.train_data_path):
|
| 182 |
+
raise FileNotFoundError(f"Train data path does not exist: {config.dataset_config.train_data_path}")
|
| 183 |
+
if not os.path.exists(config.dataset_config.val_data_path):
|
| 184 |
+
raise FileNotFoundError(f"Validation data path does not exist: {config.dataset_config.val_data_path}")
|
| 185 |
+
if not 0.0 < config.dataset_config.train_percentage <= 1.0:
|
| 186 |
+
raise ValueError("Train percentage must be in (0.0, 1.0]")
|
| 187 |
+
if not 0.0 < config.dataset_config.val_percentage <= 1.0:
|
| 188 |
+
raise ValueError("Validation percentage must be in (0.0, 1.0]")
|
| 189 |
+
|
| 190 |
+
return config
|
| 191 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 192 |
+
# END OF FILE #
|
| 193 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
train/train_logs/config.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_config": {
|
| 3 |
+
"vocab_size": 32768,
|
| 4 |
+
"model_dim": 256,
|
| 5 |
+
"max_tokens": 382,
|
| 6 |
+
"max_sentences": 384,
|
| 7 |
+
"valid_padding": true,
|
| 8 |
+
"cosenet": {
|
| 9 |
+
"trainable": true,
|
| 10 |
+
"init_scale": 5.0
|
| 11 |
+
},
|
| 12 |
+
"transformers": [
|
| 13 |
+
{
|
| 14 |
+
"attention_heads": 16,
|
| 15 |
+
"feed_forward_multiplier": 8,
|
| 16 |
+
"dropout": 0.0,
|
| 17 |
+
"pre_normalize": true
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"attention_heads": 16,
|
| 21 |
+
"feed_forward_multiplier": 8,
|
| 22 |
+
"dropout": 0.0,
|
| 23 |
+
"pre_normalize": true
|
| 24 |
+
}
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
"dataset_config": {
|
| 28 |
+
"train_data_path": "/workspace/data/tokens-A000-segmentation",
|
| 29 |
+
"val_data_path": "/workspace/data/tokens-A001-segmentation",
|
| 30 |
+
"test_data_path": "/workspace/data/tokens-A002-segmentation",
|
| 31 |
+
"train_percentage": 1.0,
|
| 32 |
+
"val_percentage": 1.0,
|
| 33 |
+
"test_percentage": 1.0,
|
| 34 |
+
"num_workers": 0,
|
| 35 |
+
"shuffle_train": true,
|
| 36 |
+
"shuffle_val": true
|
| 37 |
+
},
|
| 38 |
+
"setup_config": {
|
| 39 |
+
"device_number": 0,
|
| 40 |
+
"save_model_each": 1,
|
| 41 |
+
"seed": null,
|
| 42 |
+
"logging_path": "/workspace/logs",
|
| 43 |
+
"reload_checkpoint": true
|
| 44 |
+
},
|
| 45 |
+
"batch_size": 4,
|
| 46 |
+
"num_epochs": 200,
|
| 47 |
+
"learning_rate": 0.0005,
|
| 48 |
+
"learning_rate_min": 5e-05,
|
| 49 |
+
"weight_decay": 1e-06,
|
| 50 |
+
"betas": [
|
| 51 |
+
0.5,
|
| 52 |
+
0.999
|
| 53 |
+
]
|
| 54 |
+
}
|
train/train_logs/logfile.log
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-12-26 14:45:56,651: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
|
| 2 |
+
2025-12-26 14:45:56,659: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
|
| 3 |
+
2025-12-26 14:45:56,659: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
|
| 4 |
+
2025-12-26 14:45:56,672: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=76392)
|
| 5 |
+
2025-12-26 14:45:56,680: [INFO] Initializer set up seed: 1766760356
|
| 6 |
+
2025-12-26 14:45:56,728: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
|
| 7 |
+
2025-12-26 14:45:56,729: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
|
| 8 |
+
2025-12-26 14:45:56,729: [INFO] Total Memory : 45498.00 MB
|
| 9 |
+
2025-12-26 14:45:56,730: [INFO] Currently Allocated : 0.00 MB
|
| 10 |
+
2025-12-26 14:45:56,730: [INFO] Currently Reserved : 0.00 MB
|
| 11 |
+
2025-12-26 14:45:56,730: [INFO] Max Allocated : 0.00 MB
|
| 12 |
+
2025-12-26 14:45:56,731: [INFO] Max Reserved : 0.00 MB
|
| 13 |
+
2025-12-26 14:45:56,731: [INFO] Setup information:
|
| 14 |
+
2025-12-26 14:45:56,732: [INFO] - Setup path: /workspace/logs
|
| 15 |
+
2025-12-26 14:45:56,732: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
|
| 16 |
+
2025-12-26 14:45:56,732: [INFO] - Setup device: cuda:0
|
| 17 |
+
2025-12-26 14:45:56,733: [INFO] - Setup seed: 1766760356
|
| 18 |
+
2025-12-26 14:45:56,733: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
|
| 19 |
+
2025-12-26 14:45:56,734: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x76e7ade77910>
|
| 20 |
+
2025-12-26 14:45:56,734: [INFO] - Setup save each: 20
|
| 21 |
+
2025-12-26 14:45:56,737: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
|
| 22 |
+
2025-12-26 14:45:56,737: [INFO] [SegmentationDataset] Loaded dataset length: 26510
|
| 23 |
+
2025-12-26 14:45:56,745: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
|
| 24 |
+
2025-12-26 14:45:56,745: [INFO] [SegmentationDataset] Loaded dataset length: 3336
|
| 25 |
+
2025-12-26 14:45:57,294: [INFO] [TRAIN] Model Configuration:
|
| 26 |
+
{'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
|
| 27 |
+
2025-12-26 14:45:57,294: [INFO] [TRAIN] Model parameters: 11.022865 M
|
| 28 |
+
2025-12-26 14:45:57,295: [INFO] [TRAIN] Trainable parameters: 11.022865 M
|
| 29 |
+
2025-12-26 14:45:57,295: [INFO] [TRAIN] Training batches: 67
|
| 30 |
+
2025-12-26 14:47:48,849: [INFO] Epoch [0]: loss = 0.53930415
|
| 31 |
+
2025-12-26 14:47:50,411: [INFO] Epoch [1]: val_loss = 0.50874352
|
| 32 |
+
2025-12-26 14:49:37,680: [INFO] Epoch [1]: loss = 0.52674737
|
| 33 |
+
2025-12-26 14:49:39,116: [INFO] Epoch [2]: val_loss = 0.51872101
|
| 34 |
+
2025-12-26 14:51:27,172: [INFO] Epoch [2]: loss = 0.52592351
|
| 35 |
+
2025-12-26 14:51:28,612: [INFO] Epoch [3]: val_loss = 0.51301319
|
| 36 |
+
2025-12-26 14:53:16,691: [INFO] Epoch [3]: loss = 0.52935326
|
| 37 |
+
2025-12-26 14:53:18,212: [INFO] Epoch [4]: val_loss = 0.51744863
|
| 38 |
+
2025-12-26 14:55:05,752: [INFO] Epoch [4]: loss = 0.52446729
|
| 39 |
+
2025-12-26 14:55:07,327: [INFO] Epoch [5]: val_loss = 0.51929819
|
| 40 |
+
2025-12-26 14:56:57,434: [INFO] Epoch [5]: loss = 0.52781746
|
| 41 |
+
2025-12-26 14:56:58,912: [INFO] Epoch [6]: val_loss = 0.52006621
|
| 42 |
+
2025-12-26 14:58:46,224: [INFO] Epoch [6]: loss = 0.52644637
|
| 43 |
+
2025-12-26 14:58:47,712: [INFO] Epoch [7]: val_loss = 0.51545532
|
| 44 |
+
2025-12-26 15:00:34,974: [INFO] Epoch [7]: loss = 0.52535941
|
| 45 |
+
2025-12-26 15:00:36,412: [INFO] Epoch [8]: val_loss = 0.52077476
|
| 46 |
+
2025-12-26 15:02:24,083: [INFO] Epoch [8]: loss = 0.52521282
|
| 47 |
+
2025-12-26 15:02:25,525: [INFO] Epoch [9]: val_loss = 0.51527728
|
| 48 |
+
2025-12-26 15:04:13,376: [INFO] Epoch [9]: loss = 0.52329010
|
| 49 |
+
2025-12-26 15:04:14,816: [INFO] Epoch [10]: val_loss = 0.51563372
|
| 50 |
+
2025-12-26 15:06:03,934: [INFO] Epoch [10]: loss = 0.52397644
|
| 51 |
+
2025-12-26 15:06:05,412: [INFO] Epoch [11]: val_loss = 0.51372376
|
| 52 |
+
2025-12-26 15:07:54,323: [INFO] Epoch [11]: loss = 0.52039668
|
| 53 |
+
2025-12-26 15:07:55,813: [INFO] Epoch [12]: val_loss = 0.51369372
|
| 54 |
+
2025-12-26 15:09:43,559: [INFO] Epoch [12]: loss = 0.51899378
|
| 55 |
+
2025-12-26 15:09:45,012: [INFO] Epoch [13]: val_loss = 0.52238202
|
| 56 |
+
2025-12-26 15:11:32,423: [INFO] Epoch [13]: loss = 0.51784248
|
| 57 |
+
2025-12-26 15:11:33,912: [INFO] Epoch [14]: val_loss = 0.51489054
|
| 58 |
+
2025-12-26 15:13:22,761: [INFO] Epoch [14]: loss = 0.50914923
|
| 59 |
+
2025-12-26 15:13:24,212: [INFO] Epoch [15]: val_loss = 0.50278137
|
| 60 |
+
2025-12-26 15:15:11,956: [INFO] Epoch [15]: loss = 0.50427987
|
| 61 |
+
2025-12-26 15:15:13,412: [INFO] Epoch [16]: val_loss = 0.50158396
|
| 62 |
+
2025-12-26 15:17:01,228: [INFO] Epoch [16]: loss = 0.50178539
|
| 63 |
+
2025-12-26 15:17:02,711: [INFO] Epoch [17]: val_loss = 0.50242173
|
| 64 |
+
2025-12-26 15:18:51,266: [INFO] Epoch [17]: loss = 0.49650285
|
| 65 |
+
2025-12-26 15:18:52,716: [INFO] Epoch [18]: val_loss = 0.50932210
|
| 66 |
+
2025-12-26 15:20:40,343: [INFO] Epoch [18]: loss = 0.49234502
|
| 67 |
+
2025-12-26 15:20:41,912: [INFO] Epoch [19]: val_loss = 0.50311281
|
| 68 |
+
2025-12-26 15:22:29,693: [INFO] Epoch [19]: loss = 0.48797671
|
| 69 |
+
2025-12-26 15:22:29,695: [INFO] Checkpointing model at epoch 20
|
| 70 |
+
2025-12-26 15:22:30,454: [INFO] Model checkpointed at epoch 20
|
| 71 |
+
2025-12-26 15:22:31,912: [INFO] Epoch [20]: val_loss = 0.53549688
|
| 72 |
+
2025-12-26 15:24:19,843: [INFO] Epoch [20]: loss = 0.48723968
|
| 73 |
+
2025-12-26 15:24:21,312: [INFO] Epoch [21]: val_loss = 0.49818926
|
| 74 |
+
2025-12-26 15:26:08,715: [INFO] Epoch [21]: loss = 0.48037165
|
| 75 |
+
2025-12-26 15:26:10,212: [INFO] Epoch [22]: val_loss = 0.48961075
|
| 76 |
+
2025-12-26 15:27:59,123: [INFO] Epoch [22]: loss = 0.47390062
|
| 77 |
+
2025-12-26 15:28:00,911: [INFO] Epoch [23]: val_loss = 0.48781847
|
| 78 |
+
2025-12-26 15:29:49,056: [INFO] Epoch [23]: loss = 0.46711668
|
| 79 |
+
2025-12-26 15:29:50,511: [INFO] Epoch [24]: val_loss = 0.47708375
|
| 80 |
+
2025-12-26 15:31:37,663: [INFO] Epoch [24]: loss = 0.46234217
|
| 81 |
+
2025-12-26 15:31:39,112: [INFO] Epoch [25]: val_loss = 0.46084376
|
| 82 |
+
2025-12-26 15:33:26,345: [INFO] Epoch [25]: loss = 0.45538114
|
| 83 |
+
2025-12-26 15:33:27,812: [INFO] Epoch [26]: val_loss = 0.47136071
|
| 84 |
+
2025-12-26 15:35:15,250: [INFO] Epoch [26]: loss = 0.45225392
|
| 85 |
+
2025-12-26 15:35:16,711: [INFO] Epoch [27]: val_loss = 0.47011130
|
| 86 |
+
2025-12-26 15:37:04,599: [INFO] Epoch [27]: loss = 0.44760030
|
| 87 |
+
2025-12-26 15:37:06,112: [INFO] Epoch [28]: val_loss = 0.46140307
|
| 88 |
+
2025-12-26 15:38:54,426: [INFO] Epoch [28]: loss = 0.44472487
|
| 89 |
+
2025-12-26 15:38:55,912: [INFO] Epoch [29]: val_loss = 0.47098119
|
| 90 |
+
2025-12-26 15:40:43,445: [INFO] Epoch [29]: loss = 0.43989357
|
| 91 |
+
2025-12-26 15:40:44,911: [INFO] Epoch [30]: val_loss = 0.45539117
|
| 92 |
+
2025-12-26 15:42:32,383: [INFO] Epoch [30]: loss = 0.43657149
|
| 93 |
+
2025-12-26 15:42:33,816: [INFO] Epoch [31]: val_loss = 0.46862131
|
| 94 |
+
2025-12-26 15:44:21,074: [INFO] Epoch [31]: loss = 0.43649050
|
| 95 |
+
2025-12-26 15:44:22,511: [INFO] Epoch [32]: val_loss = 0.45548641
|
| 96 |
+
2025-12-26 15:46:09,812: [INFO] Epoch [32]: loss = 0.43346542
|
| 97 |
+
2025-12-26 15:46:11,312: [INFO] Epoch [33]: val_loss = 0.45997839
|
| 98 |
+
2025-12-26 15:47:59,053: [INFO] Epoch [33]: loss = 0.43235683
|
| 99 |
+
2025-12-26 15:48:00,511: [INFO] Epoch [34]: val_loss = 0.47154692
|
| 100 |
+
2025-12-26 15:49:47,991: [INFO] Epoch [34]: loss = 0.42891757
|
| 101 |
+
2025-12-26 15:49:49,416: [INFO] Epoch [35]: val_loss = 0.46223042
|
| 102 |
+
2025-12-26 15:51:36,793: [INFO] Epoch [35]: loss = 0.42735399
|
| 103 |
+
2025-12-26 15:51:38,216: [INFO] Epoch [36]: val_loss = 0.46173553
|
| 104 |
+
2025-12-26 15:53:25,570: [INFO] Epoch [36]: loss = 0.42965186
|
| 105 |
+
2025-12-26 15:53:27,016: [INFO] Epoch [37]: val_loss = 0.46098506
|
| 106 |
+
2025-12-26 15:55:14,511: [INFO] Epoch [37]: loss = 0.42778122
|
| 107 |
+
2025-12-26 15:55:16,012: [INFO] Epoch [38]: val_loss = 0.46018566
|
| 108 |
+
2025-12-26 15:57:06,234: [INFO] Epoch [38]: loss = 0.42445267
|
| 109 |
+
2025-12-26 15:57:07,711: [INFO] Epoch [39]: val_loss = 0.46550667
|
| 110 |
+
2025-12-26 15:58:59,230: [INFO] Epoch [39]: loss = 0.42354161
|
| 111 |
+
2025-12-26 15:58:59,232: [INFO] Checkpointing model at epoch 40
|
| 112 |
+
2025-12-26 15:58:59,945: [INFO] Model checkpointed at epoch 40
|
| 113 |
+
2025-12-26 15:59:01,511: [INFO] Epoch [40]: val_loss = 0.47303247
|
| 114 |
+
2025-12-26 16:00:49,480: [INFO] Epoch [40]: loss = 0.42338467
|
| 115 |
+
2025-12-26 16:00:50,911: [INFO] Epoch [41]: val_loss = 0.45826835
|
| 116 |
+
2025-12-26 16:02:38,743: [INFO] Epoch [41]: loss = 0.41971716
|
| 117 |
+
2025-12-26 16:02:40,212: [INFO] Epoch [42]: val_loss = 0.45490133
|
| 118 |
+
2025-12-26 16:04:28,045: [INFO] Epoch [42]: loss = 0.41987514
|
| 119 |
+
2025-12-26 16:04:29,512: [INFO] Epoch [43]: val_loss = 0.45860666
|
| 120 |
+
2025-12-26 16:06:16,948: [INFO] Epoch [43]: loss = 0.41933024
|
| 121 |
+
2025-12-26 16:06:18,411: [INFO] Epoch [44]: val_loss = 0.45629129
|
| 122 |
+
2025-12-26 16:08:06,282: [INFO] Epoch [44]: loss = 0.41593552
|
| 123 |
+
2025-12-26 16:08:07,716: [INFO] Epoch [45]: val_loss = 0.46409211
|
| 124 |
+
2025-12-26 16:09:55,161: [INFO] Epoch [45]: loss = 0.41721227
|
| 125 |
+
2025-12-26 16:09:56,612: [INFO] Epoch [46]: val_loss = 0.46598683
|
| 126 |
+
2025-12-26 16:11:43,939: [INFO] Epoch [46]: loss = 0.41726764
|
| 127 |
+
2025-12-26 16:11:45,411: [INFO] Epoch [47]: val_loss = 0.45663830
|
| 128 |
+
2025-12-26 16:13:32,862: [INFO] Epoch [47]: loss = 0.41537570
|
| 129 |
+
2025-12-26 16:13:34,315: [INFO] Epoch [48]: val_loss = 0.46740513
|
| 130 |
+
2025-12-26 16:15:21,546: [INFO] Epoch [48]: loss = 0.41457776
|
| 131 |
+
2025-12-26 16:15:23,112: [INFO] Epoch [49]: val_loss = 0.44048135
|
| 132 |
+
2025-12-26 16:17:10,274: [INFO] Epoch [49]: loss = 0.41388101
|
| 133 |
+
2025-12-26 16:17:11,715: [INFO] Epoch [50]: val_loss = 0.45519451
|
| 134 |
+
2025-12-26 16:18:58,990: [INFO] Epoch [50]: loss = 0.41285109
|
| 135 |
+
2025-12-26 16:19:00,512: [INFO] Epoch [51]: val_loss = 0.45673202
|
| 136 |
+
2025-12-26 16:20:47,636: [INFO] Epoch [51]: loss = 0.41195465
|
| 137 |
+
2025-12-26 16:20:49,116: [INFO] Epoch [52]: val_loss = 0.45198240
|
| 138 |
+
2025-12-26 16:22:36,418: [INFO] Epoch [52]: loss = 0.40953517
|
| 139 |
+
2025-12-26 16:22:37,893: [INFO] Epoch [53]: val_loss = 0.47122019
|
| 140 |
+
2025-12-26 16:24:25,331: [INFO] Epoch [53]: loss = 0.40789293
|
| 141 |
+
2025-12-26 16:24:26,812: [INFO] Epoch [54]: val_loss = 0.44196667
|
| 142 |
+
2025-12-26 16:26:14,026: [INFO] Epoch [54]: loss = 0.40474147
|
| 143 |
+
2025-12-26 16:26:15,512: [INFO] Epoch [55]: val_loss = 0.46978565
|
| 144 |
+
2025-12-26 16:28:03,793: [INFO] Epoch [55]: loss = 0.40504389
|
| 145 |
+
2025-12-26 16:28:05,272: [INFO] Epoch [56]: val_loss = 0.47313605
|
| 146 |
+
2025-12-26 16:29:52,585: [INFO] Epoch [56]: loss = 0.40562682
|
| 147 |
+
2025-12-26 16:29:54,017: [INFO] Epoch [57]: val_loss = 0.46668073
|
| 148 |
+
2025-12-26 16:31:41,107: [INFO] Epoch [57]: loss = 0.40768713
|
| 149 |
+
2025-12-26 16:31:42,529: [INFO] Epoch [58]: val_loss = 0.45173921
|
| 150 |
+
2025-12-26 16:33:29,962: [INFO] Epoch [58]: loss = 0.40458906
|
| 151 |
+
2025-12-26 16:33:31,411: [INFO] Epoch [59]: val_loss = 0.46093515
|
| 152 |
+
2025-12-26 16:35:18,548: [INFO] Epoch [59]: loss = 0.40175750
|
| 153 |
+
2025-12-26 16:35:18,549: [INFO] Checkpointing model at epoch 60
|
| 154 |
+
2025-12-26 16:35:19,200: [INFO] Model checkpointed at epoch 60
|
| 155 |
+
2025-12-26 16:35:20,711: [INFO] Epoch [60]: val_loss = 0.46230425
|
| 156 |
+
2025-12-26 16:37:07,991: [INFO] Epoch [60]: loss = 0.40113024
|
| 157 |
+
2025-12-26 16:37:09,427: [INFO] Epoch [61]: val_loss = 0.46095682
|
| 158 |
+
2025-12-26 16:38:57,025: [INFO] Epoch [61]: loss = 0.40212381
|
| 159 |
+
2025-12-26 16:38:58,604: [INFO] Epoch [62]: val_loss = 0.45801103
|
| 160 |
+
2025-12-26 16:40:47,358: [INFO] Epoch [62]: loss = 0.40149038
|
| 161 |
+
2025-12-26 16:40:48,911: [INFO] Epoch [63]: val_loss = 0.45971834
|
| 162 |
+
2025-12-26 16:42:36,483: [INFO] Epoch [63]: loss = 0.40096813
|
| 163 |
+
2025-12-26 16:42:37,917: [INFO] Epoch [64]: val_loss = 0.47312803
|
| 164 |
+
2025-12-26 16:44:25,338: [INFO] Epoch [64]: loss = 0.40213521
|
| 165 |
+
2025-12-26 16:44:26,895: [INFO] Epoch [65]: val_loss = 0.45463914
|
| 166 |
+
2025-12-26 16:46:14,170: [INFO] Epoch [65]: loss = 0.39824201
|
| 167 |
+
2025-12-26 16:46:15,615: [INFO] Epoch [66]: val_loss = 0.47252337
|
| 168 |
+
2025-12-26 16:48:03,455: [INFO] Epoch [66]: loss = 0.39898236
|
| 169 |
+
2025-12-26 16:48:04,912: [INFO] Epoch [67]: val_loss = 0.46137960
|
| 170 |
+
2025-12-26 16:49:52,778: [INFO] Epoch [67]: loss = 0.40269130
|
| 171 |
+
2025-12-26 16:49:54,216: [INFO] Epoch [68]: val_loss = 0.47056969
|
| 172 |
+
2025-12-26 16:51:41,521: [INFO] Epoch [68]: loss = 0.39804779
|
| 173 |
+
2025-12-26 16:51:43,012: [INFO] Epoch [69]: val_loss = 0.46284741
|
| 174 |
+
2025-12-26 16:53:30,778: [INFO] Epoch [69]: loss = 0.39931213
|
| 175 |
+
2025-12-26 16:53:32,211: [INFO] Epoch [70]: val_loss = 0.47174325
|
| 176 |
+
2025-12-26 16:55:19,747: [INFO] Epoch [70]: loss = 0.39947561
|
| 177 |
+
2025-12-26 16:55:21,211: [INFO] Epoch [71]: val_loss = 0.47359799
|
| 178 |
+
2025-12-26 16:57:08,420: [INFO] Epoch [71]: loss = 0.39680641
|
| 179 |
+
2025-12-26 16:57:09,912: [INFO] Epoch [72]: val_loss = 0.45985634
|
| 180 |
+
2025-12-26 16:58:57,465: [INFO] Epoch [72]: loss = 0.39784966
|
| 181 |
+
2025-12-26 16:58:58,911: [INFO] Epoch [73]: val_loss = 0.47379973
|
| 182 |
+
2025-12-26 17:00:46,781: [INFO] Epoch [73]: loss = 0.39575548
|
| 183 |
+
2025-12-26 17:00:48,217: [INFO] Epoch [74]: val_loss = 0.46827143
|
| 184 |
+
2025-12-26 17:02:35,956: [INFO] Epoch [74]: loss = 0.39844352
|
| 185 |
+
2025-12-26 17:02:37,411: [INFO] Epoch [75]: val_loss = 0.48436255
|
| 186 |
+
2025-12-26 17:04:25,013: [INFO] Epoch [75]: loss = 0.39737436
|
| 187 |
+
2025-12-26 17:04:26,512: [INFO] Epoch [76]: val_loss = 0.45234020
|
| 188 |
+
2025-12-26 17:06:13,974: [INFO] Epoch [76]: loss = 0.39371587
|
| 189 |
+
2025-12-26 17:06:15,415: [INFO] Epoch [77]: val_loss = 0.45753057
|
| 190 |
+
2025-12-26 17:08:03,455: [INFO] Epoch [77]: loss = 0.39684283
|
| 191 |
+
2025-12-26 17:08:04,916: [INFO] Epoch [78]: val_loss = 0.46107266
|
| 192 |
+
2025-12-26 17:09:52,265: [INFO] Epoch [78]: loss = 0.39561052
|
| 193 |
+
2025-12-26 17:09:53,711: [INFO] Epoch [79]: val_loss = 0.48726222
|
| 194 |
+
2025-12-26 17:11:40,915: [INFO] Epoch [79]: loss = 0.39534942
|
| 195 |
+
2025-12-26 17:11:40,917: [INFO] Checkpointing model at epoch 80
|
| 196 |
+
2025-12-26 17:11:41,448: [INFO] Model checkpointed at epoch 80
|
| 197 |
+
2025-12-26 17:11:42,912: [INFO] Epoch [80]: val_loss = 0.47510581
|
| 198 |
+
2025-12-26 17:13:31,165: [INFO] Epoch [80]: loss = 0.39408069
|
| 199 |
+
2025-12-26 17:13:32,617: [INFO] Epoch [81]: val_loss = 0.46646976
|
| 200 |
+
2025-12-26 17:15:20,095: [INFO] Epoch [81]: loss = 0.39456047
|
| 201 |
+
2025-12-26 17:15:21,517: [INFO] Epoch [82]: val_loss = 0.47777673
|
| 202 |
+
2025-12-26 17:17:10,031: [INFO] Epoch [82]: loss = 0.39687150
|
| 203 |
+
2025-12-26 17:17:11,512: [INFO] Epoch [83]: val_loss = 0.47680868
|
| 204 |
+
2025-12-26 17:18:59,688: [INFO] Epoch [83]: loss = 0.39627865
|
| 205 |
+
2025-12-26 17:19:01,115: [INFO] Epoch [84]: val_loss = 0.47353493
|
| 206 |
+
2025-12-26 17:20:48,468: [INFO] Epoch [84]: loss = 0.39516608
|
| 207 |
+
2025-12-26 17:20:49,912: [INFO] Epoch [85]: val_loss = 0.47541119
|
| 208 |
+
2025-12-26 17:22:38,068: [INFO] Epoch [85]: loss = 0.39570387
|
| 209 |
+
2025-12-26 17:22:39,517: [INFO] Epoch [86]: val_loss = 0.46904831
|
| 210 |
+
2025-12-26 17:24:26,979: [INFO] Epoch [86]: loss = 0.39411988
|
| 211 |
+
2025-12-26 17:24:28,416: [INFO] Epoch [87]: val_loss = 0.47183328
|
| 212 |
+
2025-12-26 17:26:16,242: [INFO] Epoch [87]: loss = 0.39453237
|
| 213 |
+
2025-12-26 17:26:17,712: [INFO] Epoch [88]: val_loss = 0.48088008
|
| 214 |
+
2025-12-26 17:28:05,674: [INFO] Epoch [88]: loss = 0.39265428
|
| 215 |
+
2025-12-26 17:28:07,111: [INFO] Epoch [89]: val_loss = 0.46431010
|
| 216 |
+
2025-12-26 17:29:54,253: [INFO] Epoch [89]: loss = 0.39518696
|
| 217 |
+
2025-12-26 17:29:55,711: [INFO] Epoch [90]: val_loss = 0.47148239
|
| 218 |
+
2025-12-26 17:31:43,155: [INFO] Epoch [90]: loss = 0.39560070
|
| 219 |
+
2025-12-26 17:31:44,711: [INFO] Epoch [91]: val_loss = 0.47001378
|
| 220 |
+
2025-12-26 17:33:32,048: [INFO] Epoch [91]: loss = 0.39522415
|
| 221 |
+
2025-12-26 17:33:33,512: [INFO] Epoch [92]: val_loss = 0.47427877
|
| 222 |
+
2025-12-26 17:35:20,792: [INFO] Epoch [92]: loss = 0.39726472
|
| 223 |
+
2025-12-26 17:35:22,230: [INFO] Epoch [93]: val_loss = 0.48291658
|
| 224 |
+
2025-12-26 17:37:09,543: [INFO] Epoch [93]: loss = 0.39664398
|
| 225 |
+
2025-12-26 17:37:11,012: [INFO] Epoch [94]: val_loss = 0.49081665
|
| 226 |
+
2025-12-26 17:38:58,458: [INFO] Epoch [94]: loss = 0.39135196
|
| 227 |
+
2025-12-26 17:39:00,111: [INFO] Epoch [95]: val_loss = 0.47873766
|
| 228 |
+
2025-12-26 17:40:47,362: [INFO] Epoch [95]: loss = 0.39417184
|
| 229 |
+
2025-12-26 17:40:48,811: [INFO] Epoch [96]: val_loss = 0.48776019
|
| 230 |
+
2025-12-26 17:42:36,318: [INFO] Epoch [96]: loss = 0.39321537
|
| 231 |
+
2025-12-26 17:42:37,812: [INFO] Epoch [97]: val_loss = 0.46243800
|
| 232 |
+
2025-12-26 17:44:25,656: [INFO] Epoch [97]: loss = 0.39767619
|
| 233 |
+
2025-12-26 17:44:27,112: [INFO] Epoch [98]: val_loss = 0.45655080
|
| 234 |
+
2025-12-26 17:46:14,472: [INFO] Epoch [98]: loss = 0.39206413
|
| 235 |
+
2025-12-26 17:46:16,011: [INFO] Epoch [99]: val_loss = 0.46890352
|
| 236 |
+
2025-12-26 17:48:03,170: [INFO] Epoch [99]: loss = 0.39380527
|
| 237 |
+
2025-12-26 17:48:03,171: [INFO] Checkpointing model at epoch 100
|
| 238 |
+
2025-12-26 17:48:03,725: [INFO] Model checkpointed at epoch 100
|
| 239 |
+
2025-12-26 17:48:05,212: [INFO] Epoch [100]: val_loss = 0.49273304
|
| 240 |
+
2025-12-26 17:49:52,333: [INFO] Epoch [100]: loss = 0.39218957
|
| 241 |
+
2025-12-26 17:49:53,811: [INFO] Epoch [101]: val_loss = 0.47090062
|
| 242 |
+
2025-12-26 17:51:41,096: [INFO] Epoch [101]: loss = 0.39274643
|
| 243 |
+
2025-12-26 17:51:42,528: [INFO] Epoch [102]: val_loss = 0.46902628
|
| 244 |
+
2025-12-26 17:53:29,975: [INFO] Epoch [102]: loss = 0.39481238
|
| 245 |
+
2025-12-26 17:53:31,411: [INFO] Epoch [103]: val_loss = 0.48112577
|
| 246 |
+
2025-12-26 17:55:18,420: [INFO] Epoch [103]: loss = 0.39440405
|
| 247 |
+
2025-12-26 17:55:19,912: [INFO] Epoch [104]: val_loss = 0.49355557
|
| 248 |
+
2025-12-26 17:57:07,125: [INFO] Epoch [104]: loss = 0.39165780
|
| 249 |
+
2025-12-26 17:57:08,611: [INFO] Epoch [105]: val_loss = 0.48409717
|
| 250 |
+
2025-12-26 17:58:56,554: [INFO] Epoch [105]: loss = 0.39554418
|
| 251 |
+
2025-12-26 17:58:58,011: [INFO] Epoch [106]: val_loss = 0.48656076
|
| 252 |
+
2025-12-26 18:00:45,347: [INFO] Epoch [106]: loss = 0.39228787
|
| 253 |
+
2025-12-26 18:00:46,812: [INFO] Epoch [107]: val_loss = 0.48810028
|
| 254 |
+
2025-12-26 18:02:33,925: [INFO] Epoch [107]: loss = 0.39156697
|
| 255 |
+
2025-12-26 18:02:35,412: [INFO] Epoch [108]: val_loss = 0.47222325
|
| 256 |
+
2025-12-26 18:04:23,740: [INFO] Epoch [108]: loss = 0.39423798
|
| 257 |
+
2025-12-26 18:04:25,212: [INFO] Epoch [109]: val_loss = 0.47254576
|
| 258 |
+
2025-12-26 18:06:12,535: [INFO] Epoch [109]: loss = 0.39252056
|
| 259 |
+
2025-12-26 18:06:14,012: [INFO] Epoch [110]: val_loss = 0.48817008
|
| 260 |
+
2025-12-26 18:08:01,245: [INFO] Epoch [110]: loss = 0.39401426
|
| 261 |
+
2025-12-26 18:08:02,815: [INFO] Epoch [111]: val_loss = 0.48511344
|
| 262 |
+
2025-12-26 18:09:50,465: [INFO] Epoch [111]: loss = 0.39587875
|
| 263 |
+
2025-12-26 18:09:51,912: [INFO] Epoch [112]: val_loss = 0.48533930
|
| 264 |
+
2025-12-26 18:11:39,151: [INFO] Epoch [112]: loss = 0.39013777
|
| 265 |
+
2025-12-26 18:11:40,611: [INFO] Epoch [113]: val_loss = 0.48621008
|
| 266 |
+
2025-12-26 18:13:28,463: [INFO] Epoch [113]: loss = 0.38981328
|
| 267 |
+
2025-12-26 18:13:29,992: [INFO] Epoch [114]: val_loss = 0.47124515
|
| 268 |
+
2025-12-26 18:15:17,549: [INFO] Epoch [114]: loss = 0.39461992
|
| 269 |
+
2025-12-26 18:15:19,012: [INFO] Epoch [115]: val_loss = 0.48522179
|
| 270 |
+
2025-12-26 18:17:06,234: [INFO] Epoch [115]: loss = 0.39355375
|
| 271 |
+
2025-12-26 18:17:07,711: [INFO] Epoch [116]: val_loss = 0.49023107
|
| 272 |
+
2025-12-26 18:18:54,822: [INFO] Epoch [116]: loss = 0.39458753
|
| 273 |
+
2025-12-26 18:18:56,312: [INFO] Epoch [117]: val_loss = 0.48466966
|
| 274 |
+
2025-12-26 18:20:43,478: [INFO] Epoch [117]: loss = 0.39362156
|
| 275 |
+
2025-12-26 18:20:44,916: [INFO] Epoch [118]: val_loss = 0.50641123
|
| 276 |
+
2025-12-26 18:22:32,316: [INFO] Epoch [118]: loss = 0.39327146
|
| 277 |
+
2025-12-26 18:22:33,812: [INFO] Epoch [119]: val_loss = 0.47998404
|
| 278 |
+
2025-12-26 18:24:21,387: [INFO] Epoch [119]: loss = 0.39500003
|
| 279 |
+
2025-12-26 18:24:21,388: [INFO] Checkpointing model at epoch 120
|
| 280 |
+
2025-12-26 18:24:21,956: [INFO] Model checkpointed at epoch 120
|
| 281 |
+
2025-12-26 18:24:23,411: [INFO] Epoch [120]: val_loss = 0.47686255
|
| 282 |
+
2025-12-26 18:26:11,350: [INFO] Epoch [120]: loss = 0.39359459
|
| 283 |
+
2025-12-26 18:26:12,812: [INFO] Epoch [121]: val_loss = 0.47847155
|
| 284 |
+
2025-12-26 18:28:00,158: [INFO] Epoch [121]: loss = 0.39265604
|
| 285 |
+
2025-12-26 18:28:01,612: [INFO] Epoch [122]: val_loss = 0.48565416
|
| 286 |
+
2025-12-26 18:29:49,068: [INFO] Epoch [122]: loss = 0.39327284
|
| 287 |
+
2025-12-26 18:29:50,512: [INFO] Epoch [123]: val_loss = 0.46322163
|
| 288 |
+
2025-12-26 18:31:37,854: [INFO] Epoch [123]: loss = 0.39371465
|
| 289 |
+
2025-12-26 18:31:39,322: [INFO] Epoch [124]: val_loss = 0.49423857
|
| 290 |
+
2025-12-26 18:33:26,668: [INFO] Epoch [124]: loss = 0.39377629
|
| 291 |
+
2025-12-26 18:33:28,112: [INFO] Epoch [125]: val_loss = 0.50537621
|
| 292 |
+
2025-12-26 18:35:15,388: [INFO] Epoch [125]: loss = 0.39129652
|
| 293 |
+
2025-12-26 18:35:16,916: [INFO] Epoch [126]: val_loss = 0.50789308
|
| 294 |
+
2025-12-26 18:37:05,150: [INFO] Epoch [126]: loss = 0.39120718
|
| 295 |
+
2025-12-26 18:37:06,611: [INFO] Epoch [127]: val_loss = 0.49176749
|
| 296 |
+
2025-12-26 18:38:54,793: [INFO] Epoch [127]: loss = 0.39434199
|
| 297 |
+
2025-12-26 18:38:56,312: [INFO] Epoch [128]: val_loss = 0.48982497
|
| 298 |
+
2025-12-26 18:40:43,572: [INFO] Epoch [128]: loss = 0.39388789
|
| 299 |
+
2025-12-26 18:40:45,012: [INFO] Epoch [129]: val_loss = 0.49437147
|
| 300 |
+
2025-12-26 18:42:32,476: [INFO] Epoch [129]: loss = 0.39485405
|
| 301 |
+
2025-12-26 18:42:34,012: [INFO] Epoch [130]: val_loss = 0.49246545
|
| 302 |
+
2025-12-26 18:44:24,072: [INFO] Epoch [130]: loss = 0.39075325
|
| 303 |
+
2025-12-26 18:44:25,612: [INFO] Epoch [131]: val_loss = 0.51833930
|
| 304 |
+
2025-12-26 18:46:15,498: [INFO] Epoch [131]: loss = 0.39447027
|
| 305 |
+
2025-12-26 18:46:16,931: [INFO] Epoch [132]: val_loss = 0.48003947
|
| 306 |
+
2025-12-26 18:48:05,343: [INFO] Epoch [132]: loss = 0.39434897
|
| 307 |
+
2025-12-26 18:48:06,812: [INFO] Epoch [133]: val_loss = 0.49718059
|
| 308 |
+
2025-12-26 18:49:55,181: [INFO] Epoch [133]: loss = 0.39328515
|
| 309 |
+
2025-12-26 18:49:56,612: [INFO] Epoch [134]: val_loss = 0.48965228
|
| 310 |
+
2025-12-26 18:51:44,446: [INFO] Epoch [134]: loss = 0.39368132
|
| 311 |
+
2025-12-26 18:51:45,910: [INFO] Epoch [135]: val_loss = 0.50781692
|
| 312 |
+
2025-12-26 18:53:34,066: [INFO] Epoch [135]: loss = 0.39521807
|
| 313 |
+
2025-12-26 18:53:35,517: [INFO] Epoch [136]: val_loss = 0.49129677
|
| 314 |
+
2025-12-26 18:55:24,158: [INFO] Epoch [136]: loss = 0.39310845
|
| 315 |
+
2025-12-26 18:55:25,611: [INFO] Epoch [137]: val_loss = 0.50138287
|
| 316 |
+
2025-12-26 18:57:13,543: [INFO] Epoch [137]: loss = 0.39277331
|
| 317 |
+
2025-12-26 18:57:15,011: [INFO] Epoch [138]: val_loss = 0.49667891
|
| 318 |
+
2025-12-26 18:59:02,557: [INFO] Epoch [138]: loss = 0.39367320
|
| 319 |
+
2025-12-26 18:59:04,012: [INFO] Epoch [139]: val_loss = 0.49262191
|
| 320 |
+
2025-12-26 19:00:51,448: [INFO] Epoch [139]: loss = 0.39519035
|
| 321 |
+
2025-12-26 19:00:51,449: [INFO] Checkpointing model at epoch 140
|
| 322 |
+
2025-12-26 19:00:52,031: [INFO] Model checkpointed at epoch 140
|
| 323 |
+
2025-12-26 19:00:53,512: [INFO] Epoch [140]: val_loss = 0.50657800
|
| 324 |
+
2025-12-26 19:02:40,830: [INFO] Epoch [140]: loss = 0.39009643
|
| 325 |
+
2025-12-26 19:02:42,311: [INFO] Epoch [141]: val_loss = 0.48729330
|
| 326 |
+
2025-12-26 19:04:29,914: [INFO] Epoch [141]: loss = 0.39552269
|
| 327 |
+
2025-12-26 19:04:31,412: [INFO] Epoch [142]: val_loss = 0.49246952
|
| 328 |
+
2025-12-26 19:06:18,642: [INFO] Epoch [142]: loss = 0.39385797
|
| 329 |
+
2025-12-26 19:06:20,111: [INFO] Epoch [143]: val_loss = 0.49985805
|
| 330 |
+
2025-12-26 19:08:07,780: [INFO] Epoch [143]: loss = 0.39398983
|
| 331 |
+
2025-12-26 19:08:09,216: [INFO] Epoch [144]: val_loss = 0.49885565
|
| 332 |
+
2025-12-26 19:09:56,876: [INFO] Epoch [144]: loss = 0.39509994
|
| 333 |
+
2025-12-26 19:09:58,312: [INFO] Epoch [145]: val_loss = 0.50998864
|
| 334 |
+
2025-12-26 19:11:46,490: [INFO] Epoch [145]: loss = 0.39288819
|
| 335 |
+
2025-12-26 19:11:48,006: [INFO] Epoch [146]: val_loss = 0.53312138
|
| 336 |
+
2025-12-26 19:12:12,458: [WARNING] [TRAIN] Training interrupted by user. Saving model...
|
| 337 |
+
2025-12-26 19:12:12,459: [INFO] [TRAIN] Saving model before exiting...
|
| 338 |
+
2025-12-26 19:12:13,159: [INFO] [TRAIN] Training process finished.
|
| 339 |
+
2025-12-26 19:24:52,215: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
|
| 340 |
+
2025-12-26 19:24:52,223: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
|
| 341 |
+
2025-12-26 19:24:52,223: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
|
| 342 |
+
2025-12-26 19:24:52,236: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=114689)
|
| 343 |
+
2025-12-26 19:24:52,246: [INFO] Initializer set up seed: 1766777092
|
| 344 |
+
2025-12-26 19:24:52,250: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
|
| 345 |
+
2025-12-26 19:24:52,250: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
|
| 346 |
+
2025-12-26 19:24:52,251: [INFO] Total Memory : 45498.00 MB
|
| 347 |
+
2025-12-26 19:24:52,251: [INFO] Currently Allocated : 0.00 MB
|
| 348 |
+
2025-12-26 19:24:52,251: [INFO] Currently Reserved : 0.00 MB
|
| 349 |
+
2025-12-26 19:24:52,252: [INFO] Max Allocated : 0.00 MB
|
| 350 |
+
2025-12-26 19:24:52,252: [INFO] Max Reserved : 0.00 MB
|
| 351 |
+
2025-12-26 19:24:52,252: [INFO] Setup information:
|
| 352 |
+
2025-12-26 19:24:52,253: [INFO] - Setup path: /workspace/logs
|
| 353 |
+
2025-12-26 19:24:52,253: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
|
| 354 |
+
2025-12-26 19:24:52,253: [INFO] - Setup device: cuda:0
|
| 355 |
+
2025-12-26 19:24:52,254: [INFO] - Setup seed: 1766777092
|
| 356 |
+
2025-12-26 19:24:52,254: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
|
| 357 |
+
2025-12-26 19:24:52,254: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x742242ae8a10>
|
| 358 |
+
2025-12-26 19:24:52,254: [INFO] - Setup save each: 20
|
| 359 |
+
2025-12-26 19:24:52,258: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
|
| 360 |
+
2025-12-26 19:24:52,259: [INFO] [SegmentationDataset] Loaded dataset length: 26510
|
| 361 |
+
2025-12-26 19:24:52,282: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
|
| 362 |
+
2025-12-26 19:24:52,283: [INFO] [SegmentationDataset] Loaded dataset length: 3336
|
| 363 |
+
2025-12-26 19:24:52,746: [INFO] [TRAIN] Reloading model, optimizer and scheduler states...
|
| 364 |
+
2025-12-26 19:24:52,988: [INFO] Model reloaded from /workspace/logs/checkpoints/model_epoch_40.pt at epoch 40 and seed 1766760356
|
| 365 |
+
2025-12-26 19:24:52,989: [INFO] Optimizer state_dict loaded from /workspace/logs/checkpoints/model_epoch_40.pt
|
| 366 |
+
2025-12-26 19:24:52,989: [INFO] Scheduler state_dict loaded from /workspace/logs/checkpoints/model_epoch_40.pt
|
| 367 |
+
2025-12-26 19:24:52,990: [INFO] [TRAIN] Model Configuration:
|
| 368 |
+
{'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
|
| 369 |
+
2025-12-26 19:24:52,991: [INFO] [TRAIN] Model parameters: 11.022865 M
|
| 370 |
+
2025-12-26 19:24:52,991: [INFO] [TRAIN] Trainable parameters: 11.022865 M
|
| 371 |
+
2025-12-26 19:24:52,992: [INFO] [TRAIN] Training batches: 2651
|
| 372 |
+
2025-12-26 20:38:24,396: [INFO] Epoch [40]: loss = 0.44221879
|
| 373 |
+
2025-12-26 20:39:27,131: [INFO] Epoch [41]: val_loss = 0.43405354
|
| 374 |
+
2025-12-26 21:50:17,867: [INFO] Epoch [41]: loss = 0.42664148
|
| 375 |
+
2025-12-26 21:51:15,331: [INFO] Epoch [42]: val_loss = 0.42706228
|
| 376 |
+
2025-12-26 23:02:34,039: [INFO] Epoch [42]: loss = 0.41818042
|
| 377 |
+
2025-12-26 23:03:32,932: [INFO] Epoch [43]: val_loss = 0.42457028
|
| 378 |
+
2025-12-27 00:15:19,536: [INFO] Epoch [43]: loss = 0.41249070
|
| 379 |
+
2025-12-27 00:16:17,232: [INFO] Epoch [44]: val_loss = 0.42501960
|
| 380 |
+
2025-12-27 01:27:24,113: [INFO] Epoch [44]: loss = 0.40819168
|
| 381 |
+
2025-12-27 01:28:21,631: [INFO] Epoch [45]: val_loss = 0.42466300
|
| 382 |
+
2025-12-27 02:39:08,276: [INFO] Epoch [45]: loss = 0.40515616
|
| 383 |
+
2025-12-27 02:40:05,744: [INFO] Epoch [46]: val_loss = 0.42623002
|
| 384 |
+
2025-12-27 03:50:51,982: [INFO] Epoch [46]: loss = 0.40199136
|
| 385 |
+
2025-12-27 03:51:49,542: [INFO] Epoch [47]: val_loss = 0.42870347
|
| 386 |
+
2025-12-27 05:02:35,984: [INFO] Epoch [47]: loss = 0.40118613
|
| 387 |
+
2025-12-27 05:03:33,643: [INFO] Epoch [48]: val_loss = 0.42845377
|
| 388 |
+
2025-12-27 06:14:22,488: [INFO] Epoch [48]: loss = 0.40013945
|
| 389 |
+
2025-12-27 06:15:19,931: [INFO] Epoch [49]: val_loss = 0.43143284
|
| 390 |
+
2025-12-27 07:26:05,041: [INFO] Epoch [49]: loss = 0.39879406
|
| 391 |
+
2025-12-27 07:27:02,632: [INFO] Epoch [50]: val_loss = 0.43078374
|
| 392 |
+
2025-12-27 08:37:47,083: [INFO] Epoch [50]: loss = 0.39769130
|
| 393 |
+
2025-12-27 08:38:44,631: [INFO] Epoch [51]: val_loss = 0.43248296
|
| 394 |
+
2025-12-27 09:49:30,251: [INFO] Epoch [51]: loss = 0.39770448
|
| 395 |
+
2025-12-27 09:50:27,831: [INFO] Epoch [52]: val_loss = 0.43522716
|
| 396 |
+
2025-12-27 11:01:17,000: [INFO] Epoch [52]: loss = 0.39691210
|
| 397 |
+
2025-12-27 11:02:14,669: [INFO] Epoch [53]: val_loss = 0.43381873
|
| 398 |
+
2025-12-27 12:13:00,032: [INFO] Epoch [53]: loss = 0.39669390
|
| 399 |
+
2025-12-27 12:13:57,331: [INFO] Epoch [54]: val_loss = 0.43624624
|
| 400 |
+
2025-12-27 13:24:41,611: [INFO] Epoch [54]: loss = 0.39618159
|
| 401 |
+
2025-12-27 13:25:38,932: [INFO] Epoch [55]: val_loss = 0.43923773
|
| 402 |
+
2025-12-27 14:36:23,346: [INFO] Epoch [55]: loss = 0.39546988
|
| 403 |
+
2025-12-27 14:37:20,669: [INFO] Epoch [56]: val_loss = 0.44194769
|
| 404 |
+
2025-12-27 15:48:09,320: [INFO] Epoch [56]: loss = 0.39544456
|
| 405 |
+
2025-12-27 15:49:06,843: [INFO] Epoch [57]: val_loss = 0.43803932
|
| 406 |
+
2025-12-27 16:59:54,398: [INFO] Epoch [57]: loss = 0.39482789
|
| 407 |
+
2025-12-27 17:00:52,031: [INFO] Epoch [58]: val_loss = 0.43835160
|
| 408 |
+
2025-12-27 18:11:44,408: [INFO] Epoch [58]: loss = 0.39461407
|
| 409 |
+
2025-12-27 18:12:41,958: [INFO] Epoch [59]: val_loss = 0.44492985
|
| 410 |
+
2025-12-27 19:23:29,892: [INFO] Epoch [59]: loss = 0.39447898
|
| 411 |
+
2025-12-27 19:23:29,894: [INFO] Checkpointing model at epoch 60
|
| 412 |
+
2025-12-27 19:23:30,517: [INFO] Model checkpointed at epoch 60
|
| 413 |
+
2025-12-27 19:24:27,831: [INFO] Epoch [60]: val_loss = 0.44300878
|
| 414 |
+
2025-12-27 19:27:50,666: [WARNING] [TRAIN] Training interrupted by user. Saving model...
|
| 415 |
+
2025-12-27 19:27:50,668: [INFO] [TRAIN] Saving model before exiting...
|
| 416 |
+
2025-12-27 19:27:51,180: [INFO] [TRAIN] Training process finished.
|
| 417 |
+
2025-12-27 19:33:34,471: [INFO] Logger initialized with writer handler at: /workspace/logs/logfile.log
|
| 418 |
+
2025-12-27 19:33:34,478: [INFO] TensorBoard logs will be stored in: /workspace/logs/logs
|
| 419 |
+
2025-12-27 19:33:34,479: [INFO] Model checkpoints will be stored in: /workspace/logs/checkpoints
|
| 420 |
+
2025-12-27 19:33:34,491: [INFO] TensorBoard running at http://0.0.0.0:6006/ (pid=235187)
|
| 421 |
+
2025-12-27 19:33:34,498: [INFO] Initializer set up seed: 1766864014
|
| 422 |
+
2025-12-27 19:33:34,501: [INFO] PyTorch is now configured to use GPU 0: NVIDIA A40
|
| 423 |
+
2025-12-27 19:33:34,507: [INFO] [GPU 0 - NVIDIA A40] Memory Stats:
|
| 424 |
+
2025-12-27 19:33:34,507: [INFO] Total Memory : 45498.00 MB
|
| 425 |
+
2025-12-27 19:33:34,507: [INFO] Currently Allocated : 0.00 MB
|
| 426 |
+
2025-12-27 19:33:34,507: [INFO] Currently Reserved : 0.00 MB
|
| 427 |
+
2025-12-27 19:33:34,508: [INFO] Max Allocated : 0.00 MB
|
| 428 |
+
2025-12-27 19:33:34,508: [INFO] Max Reserved : 0.00 MB
|
| 429 |
+
2025-12-27 19:33:34,508: [INFO] Setup information:
|
| 430 |
+
2025-12-27 19:33:34,509: [INFO] - Setup path: /workspace/logs
|
| 431 |
+
2025-12-27 19:33:34,509: [INFO] - Setup checkpoints path: /workspace/logs/checkpoints
|
| 432 |
+
2025-12-27 19:33:34,509: [INFO] - Setup device: cuda:0
|
| 433 |
+
2025-12-27 19:33:34,510: [INFO] - Setup seed: 1766864014
|
| 434 |
+
2025-12-27 19:33:34,510: [INFO] - Setup logger: <Logger src.dlutils.setup.logger (INFO)>
|
| 435 |
+
2025-12-27 19:33:34,510: [INFO] - Setup writer: <torch.utils.tensorboard.writer.SummaryWriter object at 0x74fc724f4f50>
|
| 436 |
+
2025-12-27 19:33:34,511: [INFO] - Setup save each: 1
|
| 437 |
+
2025-12-27 19:33:34,515: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A000-segmentation
|
| 438 |
+
2025-12-27 19:33:34,515: [INFO] [SegmentationDataset] Loaded dataset length: 26510
|
| 439 |
+
2025-12-27 19:33:35,262: [INFO] [SegmentationDataset] Loaded dataset: /workspace/data/tokens-A001-segmentation
|
| 440 |
+
2025-12-27 19:33:35,262: [INFO] [SegmentationDataset] Loaded dataset length: 3336
|
| 441 |
+
2025-12-27 19:33:35,796: [INFO] [TRAIN] Reloading model, optimizer and scheduler states...
|
| 442 |
+
2025-12-27 19:33:35,926: [INFO] Model reloaded from /workspace/logs/checkpoints/model_epoch_60.pt at epoch 60 and seed 1766760356
|
| 443 |
+
2025-12-27 19:33:35,927: [INFO] Optimizer state_dict loaded from /workspace/logs/checkpoints/model_epoch_60.pt
|
| 444 |
+
2025-12-27 19:33:35,927: [INFO] Scheduler state_dict loaded from /workspace/logs/checkpoints/model_epoch_60.pt
|
| 445 |
+
2025-12-27 19:33:35,927: [INFO] [TRAIN] Model Configuration:
|
| 446 |
+
{'vocab_size': 32768, 'model_dim': 256, 'max_tokens': 382, 'max_sentences': 384, 'valid_padding': True, 'cosenet': CoSeNetConfig(trainable=True, init_scale=5.0), 'transformers': [TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True), TransformerConfig(attention_heads=16, feed_forward_multiplier=8, dropout=0.0, pre_normalize=True)]}
|
| 447 |
+
2025-12-27 19:33:35,928: [INFO] [TRAIN] Model parameters: 11.022865 M
|
| 448 |
+
2025-12-27 19:33:35,928: [INFO] [TRAIN] Trainable parameters: 11.022865 M
|
| 449 |
+
2025-12-27 19:33:35,928: [INFO] [TRAIN] Training batches: 6628
|
| 450 |
+
2025-12-27 22:37:53,178: [INFO] Epoch [60]: loss = 0.41291385
|
| 451 |
+
2025-12-27 22:37:53,181: [INFO] Checkpointing model at epoch 61
|
| 452 |
+
2025-12-27 22:37:53,822: [INFO] Model checkpointed at epoch 61
|
| 453 |
+
2025-12-27 22:40:30,975: [INFO] Epoch [61]: val_loss = 0.42101070
|
| 454 |
+
2025-12-28 01:37:44,628: [INFO] Epoch [61]: loss = 0.40628406
|
| 455 |
+
2025-12-28 01:37:44,629: [INFO] Checkpointing model at epoch 62
|
| 456 |
+
2025-12-28 01:37:45,260: [INFO] Model checkpointed at epoch 62
|
| 457 |
+
2025-12-28 01:40:08,586: [INFO] Epoch [62]: val_loss = 0.41949525
|
| 458 |
+
2025-12-28 04:37:33,182: [INFO] Epoch [62]: loss = 0.40297545
|
| 459 |
+
2025-12-28 04:37:33,184: [INFO] Checkpointing model at epoch 63
|
| 460 |
+
2025-12-28 04:37:33,865: [INFO] Model checkpointed at epoch 63
|
| 461 |
+
2025-12-28 04:39:58,256: [INFO] Epoch [63]: val_loss = 0.42107245
|
| 462 |
+
2025-12-28 07:37:53,163: [INFO] Epoch [63]: loss = 0.40075299
|
| 463 |
+
2025-12-28 07:37:53,165: [INFO] Checkpointing model at epoch 64
|
| 464 |
+
2025-12-28 07:37:53,812: [INFO] Model checkpointed at epoch 64
|
| 465 |
+
2025-12-28 07:40:18,271: [INFO] Epoch [64]: val_loss = 0.42276877
|
| 466 |
+
2025-12-28 10:37:45,887: [INFO] Epoch [64]: loss = 0.39892435
|
| 467 |
+
2025-12-28 10:37:45,888: [INFO] Checkpointing model at epoch 65
|
| 468 |
+
2025-12-28 10:37:46,521: [INFO] Model checkpointed at epoch 65
|
| 469 |
+
2025-12-28 10:40:11,142: [INFO] Epoch [65]: val_loss = 0.42484788
|
| 470 |
+
2025-12-28 13:37:21,540: [INFO] Epoch [65]: loss = 0.39751294
|
| 471 |
+
2025-12-28 13:37:21,541: [INFO] Checkpointing model at epoch 66
|
| 472 |
+
2025-12-28 13:37:22,128: [INFO] Model checkpointed at epoch 66
|
| 473 |
+
2025-12-28 13:39:45,583: [INFO] Epoch [66]: val_loss = 0.42598702
|
| 474 |
+
2025-12-28 16:36:57,870: [INFO] Epoch [66]: loss = 0.39654398
|
| 475 |
+
2025-12-28 16:36:57,872: [INFO] Checkpointing model at epoch 67
|
| 476 |
+
2025-12-28 16:36:58,476: [INFO] Model checkpointed at epoch 67
|
| 477 |
+
2025-12-28 16:39:23,196: [INFO] Epoch [67]: val_loss = 0.42759763
|
| 478 |
+
2025-12-28 19:37:48,749: [INFO] Epoch [67]: loss = 0.39627669
|
| 479 |
+
2025-12-28 19:37:48,752: [INFO] Checkpointing model at epoch 68
|
| 480 |
+
2025-12-28 19:37:49,475: [INFO] Model checkpointed at epoch 68
|
| 481 |
+
2025-12-28 19:40:15,986: [INFO] Epoch [68]: val_loss = 0.42856705
|
| 482 |
+
2025-12-28 22:39:07,980: [INFO] Epoch [68]: loss = 0.39574215
|
| 483 |
+
2025-12-28 22:39:07,982: [INFO] Checkpointing model at epoch 69
|
| 484 |
+
2025-12-28 22:39:08,617: [INFO] Model checkpointed at epoch 69
|
| 485 |
+
2025-12-28 22:41:33,142: [INFO] Epoch [69]: val_loss = 0.43066087
|
| 486 |
+
2025-12-29 01:39:35,127: [INFO] Epoch [69]: loss = 0.39522947
|
| 487 |
+
2025-12-29 01:39:35,129: [INFO] Checkpointing model at epoch 70
|
| 488 |
+
2025-12-29 01:39:35,774: [INFO] Model checkpointed at epoch 70
|
| 489 |
+
2025-12-29 01:42:00,471: [INFO] Epoch [70]: val_loss = 0.43032571
|
train/train_model.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
import tqdm
|
| 10 |
+
from train.config import configuration, TrainConfig
|
| 11 |
+
from src.model import SegmentationNetwork, MaskedBCELoss
|
| 12 |
+
from src.dataset import TokenizedSegmentationDataset
|
| 13 |
+
from src.dlutils import Setup, train_step, validation_step
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 17 |
+
# #
|
| 18 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 19 |
+
def train(controller: Setup, config: TrainConfig):
|
| 20 |
+
"""
|
| 21 |
+
Main training function
|
| 22 |
+
:param controller: A training controller
|
| 23 |
+
:param config: The experiment configuration
|
| 24 |
+
:return: None
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# 1. Train and val datasets:
|
| 28 |
+
train_dataset = TokenizedSegmentationDataset(
|
| 29 |
+
tokenized_dataset=config.dataset_config.train_data_path,
|
| 30 |
+
logger=controller.logger,
|
| 31 |
+
percentage=config.dataset_config.train_percentage,
|
| 32 |
+
return_type=tuple
|
| 33 |
+
).get_loader(
|
| 34 |
+
config.batch_size,
|
| 35 |
+
shuffle=config.dataset_config.shuffle_train,
|
| 36 |
+
num_workers=config.dataset_config.num_workers
|
| 37 |
+
)
|
| 38 |
+
val_dataset = TokenizedSegmentationDataset(
|
| 39 |
+
tokenized_dataset=config.dataset_config.val_data_path,
|
| 40 |
+
logger=controller.logger,
|
| 41 |
+
percentage=config.dataset_config.val_percentage,
|
| 42 |
+
return_type=tuple
|
| 43 |
+
).get_loader(
|
| 44 |
+
config.batch_size,
|
| 45 |
+
shuffle=config.dataset_config.shuffle_val,
|
| 46 |
+
num_workers=config.dataset_config.num_workers
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 2. Model, loss, optimizer:
|
| 50 |
+
model = SegmentationNetwork(config.model_config).to(controller.device)
|
| 51 |
+
loss_fn = MaskedBCELoss(valid_pad=config.model_config.valid_padding)
|
| 52 |
+
optimizer = torch.optim.AdamW(
|
| 53 |
+
params=model.parameters(),
|
| 54 |
+
lr=config.learning_rate,
|
| 55 |
+
weight_decay=config.weight_decay,
|
| 56 |
+
betas=config.betas
|
| 57 |
+
)
|
| 58 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 59 |
+
optimizer=optimizer,
|
| 60 |
+
T_max=config.num_epochs,
|
| 61 |
+
eta_min=config.learning_rate_min
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# 3. Reload checkpoint if needed:
|
| 65 |
+
if config.setup_config.reload_checkpoint:
|
| 66 |
+
controller.logger.info("[TRAIN] Reloading model, optimizer and scheduler states...")
|
| 67 |
+
controller.reload(model, optimizer, lr_scheduler)
|
| 68 |
+
|
| 69 |
+
# 4. Log info:
|
| 70 |
+
controller.logger.info(f"[TRAIN] Model Configuration:\n{config.model_config.__dict__}")
|
| 71 |
+
controller.logger.info(f"[TRAIN] Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6} M")
|
| 72 |
+
controller.logger.info(f"[TRAIN] Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6} M")
|
| 73 |
+
controller.logger.info(f"[TRAIN] Training batches: {len(train_dataset)}")
|
| 74 |
+
controller.save_config(config)
|
| 75 |
+
|
| 76 |
+
# 5. Set watchers:
|
| 77 |
+
controller.set_watcher('A')
|
| 78 |
+
controller.set_watcher('transformer')
|
| 79 |
+
|
| 80 |
+
# 6. Train loop:
|
| 81 |
+
try:
|
| 82 |
+
for _ in tqdm.tqdm(range(controller.epoch, config.num_epochs), desc="Epochs", unit="epoch"):
|
| 83 |
+
# Train step:
|
| 84 |
+
train_step(
|
| 85 |
+
model=model,
|
| 86 |
+
data=train_dataset,
|
| 87 |
+
loss=loss_fn,
|
| 88 |
+
optimizer=optimizer,
|
| 89 |
+
controller=controller,
|
| 90 |
+
scheduler=lr_scheduler
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
validation_step(
|
| 94 |
+
model=model,
|
| 95 |
+
data=val_dataset,
|
| 96 |
+
loss=loss_fn,
|
| 97 |
+
controller=controller
|
| 98 |
+
)
|
| 99 |
+
except KeyboardInterrupt:
|
| 100 |
+
controller.logger.warning("[TRAIN] Training interrupted by user. Saving model...")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
controller.logger.error(f"[TRAIN] An error has occurred during training: {e}")
|
| 103 |
+
raise e
|
| 104 |
+
finally:
|
| 105 |
+
# 7. End of training:
|
| 106 |
+
controller.logger.info("[TRAIN] Saving model before exiting...")
|
| 107 |
+
controller.save_model(model, optimizer, lr_scheduler)
|
| 108 |
+
controller.logger.info("[TRAIN] Training process finished.")
|
| 109 |
+
input("[TRAIN] Training finished. Press any key to exit...")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 113 |
+
# #
|
| 114 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
conf = configuration()
|
| 117 |
+
with Setup(
|
| 118 |
+
path=conf.setup_config.logging_path,
|
| 119 |
+
device=conf.setup_config.device_number,
|
| 120 |
+
seed=conf.setup_config.seed,
|
| 121 |
+
save_each=conf.setup_config.save_model_each,
|
| 122 |
+
reload_state=conf.setup_config.reload_checkpoint,
|
| 123 |
+
replay_element=(0, None)
|
| 124 |
+
) as setup:
|
| 125 |
+
train(setup, conf)
|
| 126 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 127 |
+
# END OF FILE #
|
| 128 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|