alverciito commited on
Commit
dbd79bd
·
1 Parent(s): 4c7684b

upload safetensors and refactor research files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. config.json +16 -0
  2. configurations.py +0 -0
  3. model.py +176 -0
  4. model.safetensors +3 -0
  5. requirements.txt +1 -0
  6. bench.py → research_files/bench.py +0 -0
  7. research_files/benchmark/results/binseg_bert-base-multilingual-cased.json +0 -0
  8. {benchmark → research_files/benchmark}/results/binseg_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
  9. {benchmark → research_files/benchmark}/results/binseg_sentence_similarity_spanish_es.json +0 -0
  10. research_files/benchmark/results/csim_bert-base-multilingual-cased.json +0 -0
  11. {benchmark → research_files/benchmark}/results/csim_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
  12. {benchmark → research_files/benchmark}/results/csim_sentence_similarity_spanish_es.json +0 -0
  13. research_files/benchmark/results/pelt_LaBSE.json +0 -0
  14. {benchmark → research_files/benchmark}/results/pelt_bert-base-multilingual-cased.json +0 -0
  15. {benchmark → research_files/benchmark}/results/pelt_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
  16. {benchmark → research_files/benchmark}/results/pelt_sentence_similarity_spanish_es.json +0 -0
  17. {benchmark → research_files/benchmark}/results/proposed_method.json +0 -0
  18. {benchmark → research_files/benchmark}/results/textile_baseline.json +0 -0
  19. {benchmark → research_files/benchmark}/segmentation_benchmark/__init__.py +0 -0
  20. {benchmark → research_files/benchmark}/segmentation_benchmark/heuristic.py +0 -0
  21. {benchmark → research_files/benchmark}/segmentation_benchmark/load_dataset.py +0 -0
  22. {benchmark → research_files/benchmark}/segmentation_benchmark/metrics.py +0 -0
  23. {benchmark → research_files/benchmark}/segmentation_benchmark/proposed.py +2 -2
  24. {benchmark → research_files/benchmark}/segmentation_benchmark/transformers.py +1 -1
  25. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_1.json +0 -0
  26. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_2.json +0 -0
  27. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_3.json +0 -0
  28. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_4.json +0 -0
  29. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_5.json +0 -0
  30. {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_threshold.py +2 -2
  31. {benchmark → research_files/benchmark}/thresholding_benchmark/print_results.py +0 -0
  32. {benchmark → research_files/benchmark}/wikipedia-es-A002/data-00000-of-00001.arrow +0 -0
  33. {benchmark → research_files/benchmark}/wikipedia-es-A002/dataset_info.json +0 -0
  34. {benchmark → research_files/benchmark}/wikipedia-es-A002/state.json +0 -0
  35. {inference → research_files/inference}/__init__.py +0 -0
  36. {inference → research_files/inference}/config.py +181 -181
  37. {inference → research_files/inference}/load.py +0 -0
  38. {inference → research_files/inference}/model_state.pt +0 -0
  39. {inference → research_files/inference}/pipeline.py +1 -1
  40. {inference → research_files/inference}/tokenizer_32768.json +0 -0
  41. research_files/torch_to_hf.py +27 -0
  42. {train → research_files/train}/config.py +2 -2
  43. {train → research_files/train}/train_logs/config.json +0 -0
  44. {train → research_files/train}/train_logs/logfile.log +0 -0
  45. {train → research_files/train}/train_logs/tensorboard_logs.zip +0 -0
  46. {train → research_files/train}/train_model.py +3 -3
  47. special_tokens_map.json +7 -0
  48. src/dataset/__init__.py +13 -13
  49. src/dataset/config.py +29 -29
  50. src/dataset/dataset.py +199 -199
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CoseNetTransformer"
4
+ ],
5
+ "dropout": 0.0,
6
+ "emb_dim": 256,
7
+ "model_type": "sentence_transformer",
8
+ "seq_len": ...,
9
+ "torch_dtype": "float32",
10
+ "transformers_version": "4.57.3",
11
+ "vocab_size": 32768,
12
+ "auto_map": {
13
+ "AutoConfig": "configurations.SentenceCoseNetConfig",
14
+ "AutoModel": "model.SentenceCoseNet"
15
+ }
16
+ }
configurations.py ADDED
File without changes
model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import torch
9
+ from src.model.config import ModelConfig
10
+ from src.model.cosenet import CosineDistanceLayer, CoSeNet
11
+ from src.model.transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
12
+
13
+
14
+ class CoseNetTransformer(torch.nn.Module):
15
+ """
16
+ Segmentation network combining Transformer encoders with CoSeNet.
17
+
18
+ This model integrates token embeddings and positional encodings with
19
+ a stack of Transformer encoder blocks to produce contextualized
20
+ representations. These representations are then processed by a
21
+ CoSeNet module to perform structured segmentation, followed by a
22
+ cosine-based distance computation.
23
+
24
+ The final output is a pair-wise distance matrix suitable for
25
+ segmentation or boundary detection tasks.
26
+ """
27
+ def __init__(self, model_config: ModelConfig, **kwargs):
28
+ """
29
+ Initialize the segmentation network.
30
+
31
+ The network is composed of an embedding layer, positional encoding,
32
+ multiple Transformer encoder blocks, a CoSeNet segmentation module,
33
+ and a cosine distance layer.
34
+
35
+ Args:
36
+ model_config (ModelConfig): Configuration object containing all
37
+ hyperparameters required to build the model, including
38
+ vocabulary size, model dimensionality, transformer settings,
39
+ and CoSeNet parameters.
40
+ **kwargs: Additional keyword arguments forwarded to
41
+ `torch.nn.Module`.
42
+ """
43
+ super().__init__(**kwargs)
44
+ self.valid_padding = model_config.valid_padding
45
+
46
+ # Build layers:
47
+ self.embedding = torch.nn.Embedding(
48
+ model_config.vocab_size,
49
+ model_config.model_dim
50
+ )
51
+ self.positional_encoding = PositionalEncoding(
52
+ emb_dim=model_config.model_dim,
53
+ max_len=model_config.max_tokens
54
+ )
55
+ self.cosenet = CoSeNet(
56
+ trainable=model_config.cosenet.trainable,
57
+ init_scale=model_config.cosenet.init_scale
58
+ )
59
+ self.distance_layer = CosineDistanceLayer()
60
+ self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
61
+
62
+ # Build encoder blocks:
63
+ module_list = list()
64
+ for transformer_config in model_config.transformers:
65
+ encoder_block = EncoderBlock(
66
+ feature_dim=model_config.model_dim,
67
+ attention_heads=transformer_config.attention_heads,
68
+ feed_forward_multiplier=transformer_config.feed_forward_multiplier,
69
+ dropout=transformer_config.dropout,
70
+ valid_padding=model_config.valid_padding,
71
+ pre_normalize=transformer_config.pre_normalize
72
+ )
73
+ module_list.append(encoder_block)
74
+
75
+ self.encoder_blocks = torch.nn.ModuleList(module_list)
76
+
77
+ def encode(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
78
+ """
79
+ Encode input sequences into contextualized representations.
80
+ The input token indices are embedded and enriched with positional
81
+ information, then processed by a stack of Transformer encoder
82
+ blocks.
83
+
84
+ Args:
85
+ x (torch.Tensor): Input tensor of token indices with shape
86
+ (batch_size, max_tokens).
87
+ mask (torch.Tensor, optional): Optional mask tensor indicating
88
+ valid or padded positions, depending on the configuration
89
+ of the Transformer blocks. Defaults to None. Dimensions should be
90
+ (batch_size, max_tokens).
91
+ """
92
+ # Convert to type:
93
+ x = x.int()
94
+ # Embedding and positional encoding:
95
+ x = self.embedding(x)
96
+ x = self.positional_encoding(x)
97
+ # Check mask inversion:
98
+ if mask[0, 0] == 0:
99
+ mask = torch.logical_not(mask)
100
+ # Encode:
101
+ for encoder in self.encoder_blocks:
102
+ x = encoder(x, mask=mask)
103
+ return x
104
+
105
+
106
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
107
+ """
108
+ Forward pass of the segmentation network.
109
+
110
+ The input token indices are embedded and enriched with positional
111
+ information, then processed by a stack of Transformer encoder
112
+ blocks. The resulting representations are segmented using CoSeNet
113
+ and finally transformed into a pair-wise distance representation.
114
+
115
+ Args:
116
+ x (torch.Tensor): Input tensor of token indices with shape
117
+ (batch_size, sequence_length).
118
+ mask (torch.Tensor, optional): Optional mask tensor indicating
119
+ valid or padded positions, depending on the configuration
120
+ of the Transformer blocks. Defaults to None.
121
+
122
+ If `valid_padding` is disabled, the mask is inverted before being
123
+ passed to CoSeNet to match its masking convention.
124
+
125
+ candidate_mask (torch.Tensor, optional): Optional mask tensor for
126
+ candidate positions in CoSeNet. Defaults to None.
127
+
128
+ If `valid_padding` is disabled, the mask is inverted before being
129
+ passed to CoSeNet to match its masking convention.
130
+
131
+ Returns:
132
+ torch.Tensor: Output tensor containing pairwise distance values
133
+ derived from the segmented representations.
134
+ """
135
+ # Convert to type:
136
+ x = x.int()
137
+
138
+ # Embedding and positional encoding:
139
+ x = self.embedding(x)
140
+ x = self.positional_encoding(x)
141
+
142
+ # Reshape x and mask:
143
+ _b, _s, _t, _d = x.shape
144
+ x = x.reshape(_b * _s, _t, _d)
145
+ if mask is not None:
146
+ mask = mask.reshape(_b * _s, _t).bool()
147
+
148
+ # Encode the sequence:
149
+ for encoder in self.encoder_blocks:
150
+ x = encoder(x, mask=mask)
151
+
152
+ # Reshape x and mask:
153
+ x = x.reshape(_b, _s, _t, _d)
154
+ if mask is not None:
155
+ mask = mask.reshape(_b, _s, _t)
156
+ mask = torch.logical_not(mask) if not self.valid_padding else mask
157
+
158
+ # Apply pooling:
159
+ x, mask = self.pooling(x, mask=mask)
160
+
161
+ # Compute distances:
162
+ x = self.distance_layer(x)
163
+
164
+ # Pass through CoSeNet:
165
+ x = self.cosenet(x, mask=mask)
166
+
167
+ # Apply candidate mask if provided:
168
+ if candidate_mask is not None:
169
+ candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
170
+ candidate_mask = candidate_mask.to(device=x.device)
171
+ x = x.masked_fill(candidate_mask, 0)
172
+
173
+ return x
174
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
175
+ # END OF FILE #
176
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6db78280c80f27b94434a1d1e17296ecddc1d21705ec6be3b8bd0bc49991f27f
3
+ size 44485604
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  ruptures
2
  sentence-transformers
3
  numpy==2.3.5
 
1
+ safetensors
2
  ruptures
3
  sentence-transformers
4
  numpy==2.3.5
bench.py → research_files/bench.py RENAMED
File without changes
research_files/benchmark/results/binseg_bert-base-multilingual-cased.json ADDED
The diff for this file is too large to render. See raw diff
 
{benchmark → research_files/benchmark}/results/binseg_paraphrase-multilingual-MiniLM-L12-v2.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/binseg_sentence_similarity_spanish_es.json RENAMED
File without changes
research_files/benchmark/results/csim_bert-base-multilingual-cased.json ADDED
The diff for this file is too large to render. See raw diff
 
{benchmark → research_files/benchmark}/results/csim_paraphrase-multilingual-MiniLM-L12-v2.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/csim_sentence_similarity_spanish_es.json RENAMED
File without changes
research_files/benchmark/results/pelt_LaBSE.json ADDED
The diff for this file is too large to render. See raw diff
 
{benchmark → research_files/benchmark}/results/pelt_bert-base-multilingual-cased.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/pelt_paraphrase-multilingual-MiniLM-L12-v2.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/pelt_sentence_similarity_spanish_es.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/proposed_method.json RENAMED
File without changes
{benchmark → research_files/benchmark}/results/textile_baseline.json RENAMED
File without changes
{benchmark → research_files/benchmark}/segmentation_benchmark/__init__.py RENAMED
File without changes
{benchmark → research_files/benchmark}/segmentation_benchmark/heuristic.py RENAMED
File without changes
{benchmark → research_files/benchmark}/segmentation_benchmark/load_dataset.py RENAMED
File without changes
{benchmark → research_files/benchmark}/segmentation_benchmark/metrics.py RENAMED
File without changes
{benchmark → research_files/benchmark}/segmentation_benchmark/proposed.py RENAMED
@@ -9,10 +9,10 @@ import os
9
  import json
10
  import numpy as np
11
  import torch
12
- from datasets import tqdm
13
  from .metrics import precision_recall_f1_wd
14
  from .load_dataset import load_dataset
15
- from inference import load_model
16
 
17
 
18
  def evaluate_proposed(
 
9
  import json
10
  import numpy as np
11
  import torch
12
+ from tqdm import tqdm
13
  from .metrics import precision_recall_f1_wd
14
  from .load_dataset import load_dataset
15
+ from research_files.inference import load_model
16
 
17
 
18
  def evaluate_proposed(
{benchmark → research_files/benchmark}/segmentation_benchmark/transformers.py RENAMED
@@ -9,7 +9,7 @@ import os
9
  import json
10
  import numpy as np
11
  import torch
12
- from datasets import tqdm
13
  from .metrics import precision_recall_f1_wd
14
  from .load_dataset import load_dataset
15
  from sentence_transformers import SentenceTransformer
 
9
  import json
10
  import numpy as np
11
  import torch
12
+ from tqdm import tqdm
13
  from .metrics import precision_recall_f1_wd
14
  from .load_dataset import load_dataset
15
  from sentence_transformers import SentenceTransformer
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_1.json RENAMED
File without changes
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_2.json RENAMED
File without changes
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_3.json RENAMED
File without changes
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_4.json RENAMED
File without changes
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_5.json RENAMED
File without changes
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_threshold.py RENAMED
@@ -10,8 +10,8 @@ import tqdm
10
  import json
11
  from datasets import load_from_disk
12
  from src.model import SegmentationNetwork, MaskedBCELoss, WindowDiffLoss
13
- from src.dataset import SegmentationTokenizer, SentenceSegmenter, TokenizedSegmentationDataset
14
- from train.config import configuration
15
 
16
 
17
  # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
 
10
  import json
11
  from datasets import load_from_disk
12
  from src.model import SegmentationNetwork, MaskedBCELoss, WindowDiffLoss
13
+ from dataset import SegmentationTokenizer, SentenceSegmenter, TokenizedSegmentationDataset
14
+ from research_files.train.config import configuration
15
 
16
 
17
  # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
{benchmark → research_files/benchmark}/thresholding_benchmark/print_results.py RENAMED
File without changes
{benchmark → research_files/benchmark}/wikipedia-es-A002/data-00000-of-00001.arrow RENAMED
File without changes
{benchmark → research_files/benchmark}/wikipedia-es-A002/dataset_info.json RENAMED
File without changes
{benchmark → research_files/benchmark}/wikipedia-es-A002/state.json RENAMED
File without changes
{inference → research_files/inference}/__init__.py RENAMED
File without changes
{inference → research_files/inference}/config.py RENAMED
@@ -1,181 +1,181 @@
1
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
- # #
3
- # This file was created by: Alberto Palomo Alonso #
4
- # Universidad de Alcalá - Escuela Politécnica Superior #
5
- # #
6
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- # Import statements:
8
- from dataclasses import dataclass
9
- from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
10
- from src.dataset import DatasetConfig
11
-
12
-
13
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
14
- # SETUP CONFIGURATION #
15
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
- @dataclass
17
- class SetupConfig:
18
- """
19
- Configuration parameters related to the execution environment and logging.
20
-
21
- This configuration controls device selection, checkpointing behavior,
22
- reproducibility settings, and logging paths for an experiment.
23
- """
24
- device_number: int = 0
25
- save_model_each: int = 0
26
- seed: int = None
27
- logging_path: str = None
28
- reload_checkpoint: bool = False
29
-
30
-
31
- def overwrite_setup_config() -> SetupConfig:
32
- """
33
- Create and override the default setup configuration.
34
-
35
- This function customizes execution-level parameters such as logging
36
- paths, checkpoint reloading, and model saving frequency.
37
-
38
- Returns:
39
- SetupConfig: The configured setup configuration object.
40
- """
41
- config = SetupConfig()
42
- config.logging_path = r'/workspace/logs'
43
- config.reload_checkpoint = True
44
- config.save_model_each = 1
45
- return config
46
-
47
-
48
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
49
- # TRAINING CONFIGURATION #
50
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
51
- @dataclass
52
- class TrainConfig:
53
- """
54
- Training configuration container.
55
-
56
- This dataclass aggregates model, dataset, and setup configurations,
57
- together with optimization and training hyperparameters.
58
- """
59
- # Linked configurations:
60
- model_config: ModelConfig | None = None
61
- dataset_config: DatasetConfig | None = None
62
- setup_config: SetupConfig | None = None
63
-
64
- # Training parameters:
65
- batch_size: int = 32
66
- num_epochs: int = 100
67
-
68
- # Optimizer parameters:
69
- learning_rate: float = 1e-4
70
- learning_rate_min: float = 1e-5
71
- weight_decay: float = 1e-8
72
- betas: tuple[float, float] = (0.5, 0.999)
73
-
74
-
75
- def overwrite_train_config() -> TrainConfig:
76
- """
77
- Create and override the default training configuration.
78
-
79
- This function customizes batch size, number of epochs, and optimizer
80
- hyperparameters for the training process.
81
-
82
- Returns:
83
- TrainConfig: The configured training configuration object.
84
- """
85
- config = TrainConfig()
86
- config.batch_size = 4
87
- config.num_epochs = 200
88
- config.learning_rate = 5e-4
89
- config.learning_rate_min = 5e-5
90
- config.weight_decay = 1e-6
91
- return config
92
-
93
-
94
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
95
- # DATASET CONFIGURATION #
96
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
97
- def overwrite_dataset_config() -> DatasetConfig:
98
- """
99
- Create and override the dataset configuration.
100
-
101
- This function sets the file paths and usage percentages for training,
102
- validation, and test datasets.
103
-
104
- Returns:
105
- DatasetConfig: The configured dataset configuration object.
106
- """
107
- config = DatasetConfig()
108
- config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
109
- config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
110
- config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
111
- config.train_percentage = 0.4
112
- config.val_percentage = 0.4
113
- config.test_percentage = 1.0
114
- return config
115
-
116
-
117
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
118
- # MODEL CONFIGURATION #
119
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
120
- def overwrite_model_config() -> ModelConfig:
121
- """
122
- Create and override the model configuration.
123
-
124
- This function defines the architecture-level parameters, including
125
- vocabulary size, embedding dimensionality, CoSeNet settings, and
126
- the stack of Transformer encoder configurations.
127
-
128
- Returns:
129
- ModelConfig: The configured model configuration object.
130
- """
131
- config = ModelConfig()
132
-
133
- # High-level params:
134
- config.vocab_size = 32_768
135
- config.model_dim = 256
136
- config.valid_padding = True
137
-
138
- # CoSeNet params:
139
- config.cosenet = CoSeNetConfig(
140
- trainable=True,
141
- init_scale=5.0
142
- )
143
-
144
- # Transformer params:
145
- config.transformers = [
146
- TransformerConfig(**cfg)
147
- for cfg in [
148
- {
149
- "attention_heads": 16,
150
- "feed_forward_multiplier": 8,
151
- "dropout": 0.0,
152
- "pre_normalize": True
153
- },
154
- {
155
- "attention_heads": 16,
156
- "feed_forward_multiplier": 8,
157
- "dropout": 0.0,
158
- "pre_normalize": True
159
- }
160
- ]
161
- ]
162
-
163
- return config
164
-
165
-
166
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
167
- # WHOLE CONFIGURATION #
168
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
169
- def configuration() -> TrainConfig:
170
- """
171
- Create the experiment configuration
172
- :return: A TrainConfig configuration object
173
- """
174
- config = overwrite_train_config()
175
- config.setup_config = overwrite_setup_config()
176
- config.model_config = overwrite_model_config()
177
- config.dataset_config = overwrite_dataset_config()
178
- return config
179
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
180
- # END OF FILE #
181
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from dataclasses import dataclass
9
+ from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
10
+ from src.dataset import DatasetConfig
11
+
12
+
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
14
+ # SETUP CONFIGURATION #
15
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
+ @dataclass
17
+ class SetupConfig:
18
+ """
19
+ Configuration parameters related to the execution environment and logging.
20
+
21
+ This configuration controls device selection, checkpointing behavior,
22
+ reproducibility settings, and logging paths for an experiment.
23
+ """
24
+ device_number: int = 0
25
+ save_model_each: int = 0
26
+ seed: int = None
27
+ logging_path: str = None
28
+ reload_checkpoint: bool = False
29
+
30
+
31
+ def overwrite_setup_config() -> SetupConfig:
32
+ """
33
+ Create and override the default setup configuration.
34
+
35
+ This function customizes execution-level parameters such as logging
36
+ paths, checkpoint reloading, and model saving frequency.
37
+
38
+ Returns:
39
+ SetupConfig: The configured setup configuration object.
40
+ """
41
+ config = SetupConfig()
42
+ config.logging_path = r'/workspace/logs'
43
+ config.reload_checkpoint = True
44
+ config.save_model_each = 1
45
+ return config
46
+
47
+
48
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
49
+ # TRAINING CONFIGURATION #
50
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
51
+ @dataclass
52
+ class TrainConfig:
53
+ """
54
+ Training configuration container.
55
+
56
+ This dataclass aggregates model, dataset, and setup configurations,
57
+ together with optimization and training hyperparameters.
58
+ """
59
+ # Linked configurations:
60
+ model_config: ModelConfig | None = None
61
+ dataset_config: DatasetConfig | None = None
62
+ setup_config: SetupConfig | None = None
63
+
64
+ # Training parameters:
65
+ batch_size: int = 32
66
+ num_epochs: int = 100
67
+
68
+ # Optimizer parameters:
69
+ learning_rate: float = 1e-4
70
+ learning_rate_min: float = 1e-5
71
+ weight_decay: float = 1e-8
72
+ betas: tuple[float, float] = (0.5, 0.999)
73
+
74
+
75
+ def overwrite_train_config() -> TrainConfig:
76
+ """
77
+ Create and override the default training configuration.
78
+
79
+ This function customizes batch size, number of epochs, and optimizer
80
+ hyperparameters for the training process.
81
+
82
+ Returns:
83
+ TrainConfig: The configured training configuration object.
84
+ """
85
+ config = TrainConfig()
86
+ config.batch_size = 4
87
+ config.num_epochs = 200
88
+ config.learning_rate = 5e-4
89
+ config.learning_rate_min = 5e-5
90
+ config.weight_decay = 1e-6
91
+ return config
92
+
93
+
94
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
95
+ # DATASET CONFIGURATION #
96
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
97
+ def overwrite_dataset_config() -> DatasetConfig:
98
+ """
99
+ Create and override the dataset configuration.
100
+
101
+ This function sets the file paths and usage percentages for training,
102
+ validation, and test datasets.
103
+
104
+ Returns:
105
+ DatasetConfig: The configured dataset configuration object.
106
+ """
107
+ config = DatasetConfig()
108
+ config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
109
+ config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
110
+ config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
111
+ config.train_percentage = 0.4
112
+ config.val_percentage = 0.4
113
+ config.test_percentage = 1.0
114
+ return config
115
+
116
+
117
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
118
+ # MODEL CONFIGURATION #
119
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
120
+ def overwrite_model_config() -> ModelConfig:
121
+ """
122
+ Create and override the model configuration.
123
+
124
+ This function defines the architecture-level parameters, including
125
+ vocabulary size, embedding dimensionality, CoSeNet settings, and
126
+ the stack of Transformer encoder configurations.
127
+
128
+ Returns:
129
+ ModelConfig: The configured model configuration object.
130
+ """
131
+ config = ModelConfig()
132
+
133
+ # High-level params:
134
+ config.vocab_size = 32_768
135
+ config.model_dim = 256
136
+ config.valid_padding = True
137
+
138
+ # CoSeNet params:
139
+ config.cosenet = CoSeNetConfig(
140
+ trainable=True,
141
+ init_scale=5.0
142
+ )
143
+
144
+ # Transformer params:
145
+ config.transformers = [
146
+ TransformerConfig(**cfg)
147
+ for cfg in [
148
+ {
149
+ "attention_heads": 16,
150
+ "feed_forward_multiplier": 8,
151
+ "dropout": 0.0,
152
+ "pre_normalize": True
153
+ },
154
+ {
155
+ "attention_heads": 16,
156
+ "feed_forward_multiplier": 8,
157
+ "dropout": 0.0,
158
+ "pre_normalize": True
159
+ }
160
+ ]
161
+ ]
162
+
163
+ return config
164
+
165
+
166
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
167
+ # WHOLE CONFIGURATION #
168
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
169
+ def configuration() -> TrainConfig:
170
+ """
171
+ Create the experiment configuration
172
+ :return: A TrainConfig configuration object
173
+ """
174
+ config = overwrite_train_config()
175
+ config.setup_config = overwrite_setup_config()
176
+ config.model_config = overwrite_model_config()
177
+ config.dataset_config = overwrite_dataset_config()
178
+ return config
179
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
180
+ # END OF FILE #
181
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
{inference → research_files/inference}/load.py RENAMED
File without changes
{inference → research_files/inference}/model_state.pt RENAMED
File without changes
{inference → research_files/inference}/pipeline.py RENAMED
@@ -8,7 +8,7 @@
8
  import numpy as np
9
  import torch
10
  from src.model import SegmentationNetwork
11
- from src.dataset import SegmentationTokenizer, SentenceSegmenter
12
 
13
 
14
  # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
 
8
  import numpy as np
9
  import torch
10
  from src.model import SegmentationNetwork
11
+ from dataset import SegmentationTokenizer, SentenceSegmenter
12
 
13
 
14
  # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
{inference → research_files/inference}/tokenizer_32768.json RENAMED
File without changes
research_files/torch_to_hf.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import os
9
+ from research_files.inference import load_model
10
+ from safetensors.torch import save_file
11
+
12
+
13
+ def convert_model(save_path: str, model_path: str = None, tokenizer_path: str = None):
14
+
15
+ # Load model:
16
+ model, tokenizer, segmenter = load_model(model_path, tokenizer_path)
17
+ state_dict = model.state_dict()
18
+ save_file(state_dict, os.path.join(save_path, "model.safetensors"))
19
+ tokenizer._hf_tokenizer.save_pretrained(os.path.join(save_path))
20
+
21
+
22
+ if __name__ == "__main__":
23
+ # Convert and save:
24
+ convert_model("./")
25
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
26
+ # END OF FILE #
27
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
{train → research_files/train}/config.py RENAMED
@@ -7,8 +7,8 @@
7
  # Import statements:
8
  import os
9
  from dataclasses import dataclass
10
- from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
11
- from src.dataset import DatasetConfig
12
 
13
 
14
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
7
  # Import statements:
8
  import os
9
  from dataclasses import dataclass
10
+ from model import ModelConfig, CoSeNetConfig, TransformerConfig
11
+ from dataset import DatasetConfig
12
 
13
 
14
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
{train → research_files/train}/train_logs/config.json RENAMED
File without changes
{train → research_files/train}/train_logs/logfile.log RENAMED
File without changes
{train → research_files/train}/train_logs/tensorboard_logs.zip RENAMED
File without changes
{train → research_files/train}/train_model.py RENAMED
@@ -7,10 +7,10 @@
7
  # Import statements:
8
  import torch
9
  import tqdm
10
- from train.config import configuration, TrainConfig
11
  from src.model import SegmentationNetwork, MaskedBCELoss
12
- from src.dataset import TokenizedSegmentationDataset
13
- from src.dlutils import Setup, train_step, validation_step
14
 
15
 
16
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
7
  # Import statements:
8
  import torch
9
  import tqdm
10
+ from research_files.train.config import configuration, TrainConfig
11
  from src.model import SegmentationNetwork, MaskedBCELoss
12
+ from dataset import TokenizedSegmentationDataset
13
+ from dlutils import Setup, train_step, validation_step
14
 
15
 
16
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
src/dataset/__init__.py CHANGED
@@ -1,13 +1,13 @@
1
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
- # #
3
- # This file was created by: Alberto Palomo Alonso #
4
- # Universidad de Alcalá - Escuela Politécnica Superior #
5
- # #
6
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- from .tokenizer import SegmentationTokenizer, SentenceSegmenter
8
- from .dataset import SegmentationDataset
9
- from .tokenized_dataset import TokenizedSegmentationDataset
10
- from .config import DatasetConfig
11
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
12
- # END OF FILE #
13
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ from .tokenizer import SegmentationTokenizer, SentenceSegmenter
8
+ from .dataset import SegmentationDataset
9
+ from .tokenized_dataset import TokenizedSegmentationDataset
10
+ from .config import DatasetConfig
11
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
12
+ # END OF FILE #
13
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/config.py CHANGED
@@ -1,29 +1,29 @@
1
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
- # #
3
- # This file was created by: Alberto Palomo Alonso #
4
- # Universidad de Alcalá - Escuela Politécnica Superior #
5
- # #
6
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- # Import statements:
8
- from dataclasses import dataclass
9
-
10
-
11
- @dataclass
12
- class DatasetConfig:
13
- # Paths:
14
- train_data_path: str = None
15
- val_data_path: str = None
16
- test_data_path: str = None
17
- # Percentages:
18
- train_percentage: float = 1.0
19
- val_percentage: float = 1.0
20
- test_percentage: float = 1.0
21
- # Other parameters:
22
- num_workers: int = 0
23
- shuffle_train: bool = True
24
- shuffle_val: bool = True
25
-
26
-
27
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
28
- # END OF FILE #
29
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class DatasetConfig:
13
+ # Paths:
14
+ train_data_path: str = None
15
+ val_data_path: str = None
16
+ test_data_path: str = None
17
+ # Percentages:
18
+ train_percentage: float = 1.0
19
+ val_percentage: float = 1.0
20
+ test_percentage: float = 1.0
21
+ # Other parameters:
22
+ num_workers: int = 0
23
+ shuffle_train: bool = True
24
+ shuffle_val: bool = True
25
+
26
+
27
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
28
+ # END OF FILE #
29
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
src/dataset/dataset.py CHANGED
@@ -1,199 +1,199 @@
1
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
- # #
3
- # This file was created by: Alberto Palomo Alonso #
4
- # Universidad de Alcalá - Escuela Politécnica Superior #
5
- # #
6
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- # Import statements:
8
- import logging
9
- from torch.utils.data import Dataset, DataLoader
10
- from datasets import Dataset as HfDataset
11
- from datasets import load_from_disk
12
- from .tokenizer import SegmentationTokenizer, SentenceSegmenter
13
-
14
-
15
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
- # #
17
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
18
- class SegmentationDataset(Dataset):
19
- def __init__(
20
- self,
21
- huggingface_dataset: str | HfDataset,
22
- tokenizer: SegmentationTokenizer,
23
- segmenter: SentenceSegmenter,
24
- logger: logging.Logger = None,
25
- percentage: float = 1.0,
26
- return_type: type = dict
27
- ):
28
- """
29
- A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
30
- wikipedia-segmentation format. It loads the dataset and prepares it for training.
31
-
32
- Wikipedia-segmentation format:
33
- - The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
34
- - The dataset should contain the following fields:
35
- >>> sample = {
36
- >>> 'text': ['Article 1', 'Article 2', ...],
37
- >>> 'titles': ['Title 1', 'Title 2', ...],
38
- >>> 'id': str,
39
- >>> 'words': int
40
- >>> 'paragraphs': int
41
- >>> 'sentences': int
42
- >>> }
43
- - The dataset should be a list of dictionaries, where each dictionary contains the fields above.
44
-
45
- Parameters
46
- ----------
47
- huggingface_dataset : str | HfDataset
48
- A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
49
-
50
- tokenizer : callable
51
- A tokenizer function that takes a string and returns a list of tokens.
52
-
53
- logger : logging.Logger, optional
54
- Logger instance. If not provided, a null logger will be used.
55
-
56
- percentage : float
57
- Percentage of the dataset to use. Default is 1.0 (100%).
58
-
59
- return_type : type
60
- The return type of __getitem__, either dict or tuple. Default is dict.
61
-
62
- Raises
63
- ------
64
- ValueError
65
- If the huggingface_dataset is not a string or a HfDataset.
66
- ValueError
67
- If the tokenizer is not a callable function or class.
68
- ValueError
69
- If the sentence_tokenizer is not a callable function or class.
70
- ValueError
71
- If the dtype is not a type.
72
-
73
- """
74
- # Null logging:
75
- if not isinstance(logger, logging.Logger):
76
- self.logger = logging.getLogger("null")
77
- self.logger.addHandler(logging.NullHandler())
78
- else:
79
- self.logger = logger
80
-
81
- # Loading:
82
- if isinstance(huggingface_dataset, HfDataset):
83
- self.huggingface_dataset = huggingface_dataset
84
- elif isinstance(huggingface_dataset, str):
85
- self.huggingface_dataset = load_from_disk(huggingface_dataset)
86
- else:
87
- self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
88
- raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
89
- self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
90
- self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
91
-
92
- # Tokenizer:
93
- if callable(tokenizer):
94
- self.tokenizer = tokenizer
95
- else:
96
- self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
97
- raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
98
-
99
- # Segmenter:
100
- if not isinstance(segmenter, SentenceSegmenter):
101
- self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
102
- raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
103
- else:
104
- self.segmenter = segmenter
105
-
106
- # Percentage:
107
- if not (0.0 < percentage <= 1.0):
108
- self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
109
- raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
110
- else:
111
- self.percentage = percentage
112
-
113
- # Return type:
114
- if not isinstance(return_type, type):
115
- self.logger.error(f'[SegmentationDataset] return_type must be a type.')
116
- raise ValueError(f'[SegmentationDataset] return_type must be a type.')
117
- elif return_type not in [dict, tuple]:
118
- self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
119
- raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
120
- else:
121
- self.return_type = return_type
122
-
123
- def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
124
- """
125
- Returns a PyTorch DataLoader for this dataset.
126
-
127
- Parameters
128
- ----------
129
- batch_size : int
130
- Number of samples per batch.
131
- shuffle : bool
132
- Whether to shuffle the dataset.
133
- num_workers : int
134
- Number of worker processes.
135
- **kwargs
136
- Additional arguments for DataLoader.
137
-
138
- Returns
139
- -------
140
- [torch.utils.data.DataLoader
141
- Configured DataLoader.
142
- """
143
- # Size handling:
144
- return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
145
- pin_memory=True, **kwargs)
146
-
147
- def __len__(self) -> int:
148
- """
149
- Returns the number of samples in the dataset.
150
-
151
- Returns
152
- -------
153
- int
154
- Total number of samples.
155
- """
156
- return int(self.huggingface_dataset.num_rows * self.percentage)
157
-
158
- def __getitem__(self, idx) -> dict | tuple:
159
- """
160
- Retrieves a single sample and generates segmentation labels.
161
-
162
- Parameters
163
- ----------
164
- idx : int
165
- Index of the sample.
166
-
167
- Returns
168
- -------
169
- tuple
170
- A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
171
- """
172
- sample = self.huggingface_dataset[idx]['text']
173
- sentences = self.segmenter(sample)
174
- tokenized = self.tokenizer(sentences['sentences'])
175
-
176
- if self.return_type == tuple:
177
- return (
178
- tokenized['input_ids'], # x
179
- sentences['sentence_boundaries'], # y
180
- tokenized['attention_mask'], # x_mask
181
- sentences['sentence_mask'], # y_mask
182
- sentences['sentence_candidates'], # y_prime_mask
183
- )
184
- elif self.return_type == dict:
185
- return_value = {
186
- 'input': tokenized['input_ids'],
187
- 'input_mask': tokenized['attention_mask'],
188
- 'labels': sentences['sentence_boundaries'],
189
- 'output_mask': sentences['sentence_mask'],
190
- 'candidate_mask': sentences['sentence_candidates']
191
- }
192
- else:
193
- raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
194
- return return_value
195
-
196
-
197
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
198
- # END OF FILE #
199
- # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ # Import statements:
8
+ import logging
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from datasets import Dataset as HfDataset
11
+ from datasets import load_from_disk
12
+ from .tokenizer import SegmentationTokenizer, SentenceSegmenter
13
+
14
+
15
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
16
+ # #
17
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
18
+ class SegmentationDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ huggingface_dataset: str | HfDataset,
22
+ tokenizer: SegmentationTokenizer,
23
+ segmenter: SentenceSegmenter,
24
+ logger: logging.Logger = None,
25
+ percentage: float = 1.0,
26
+ return_type: type = dict
27
+ ):
28
+ """
29
+ A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
30
+ wikipedia-segmentation format. It loads the dataset and prepares it for training.
31
+
32
+ Wikipedia-segmentation format:
33
+ - The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
34
+ - The dataset should contain the following fields:
35
+ >>> sample = {
36
+ >>> 'text': ['Article 1', 'Article 2', ...],
37
+ >>> 'titles': ['Title 1', 'Title 2', ...],
38
+ >>> 'id': str,
39
+ >>> 'words': int
40
+ >>> 'paragraphs': int
41
+ >>> 'sentences': int
42
+ >>> }
43
+ - The dataset should be a list of dictionaries, where each dictionary contains the fields above.
44
+
45
+ Parameters
46
+ ----------
47
+ huggingface_dataset : str | HfDataset
48
+ A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
49
+
50
+ tokenizer : callable
51
+ A tokenizer function that takes a string and returns a list of tokens.
52
+
53
+ logger : logging.Logger, optional
54
+ Logger instance. If not provided, a null logger will be used.
55
+
56
+ percentage : float
57
+ Percentage of the dataset to use. Default is 1.0 (100%).
58
+
59
+ return_type : type
60
+ The return type of __getitem__, either dict or tuple. Default is dict.
61
+
62
+ Raises
63
+ ------
64
+ ValueError
65
+ If the huggingface_dataset is not a string or a HfDataset.
66
+ ValueError
67
+ If the tokenizer is not a callable function or class.
68
+ ValueError
69
+ If the sentence_tokenizer is not a callable function or class.
70
+ ValueError
71
+ If the dtype is not a type.
72
+
73
+ """
74
+ # Null logging:
75
+ if not isinstance(logger, logging.Logger):
76
+ self.logger = logging.getLogger("null")
77
+ self.logger.addHandler(logging.NullHandler())
78
+ else:
79
+ self.logger = logger
80
+
81
+ # Loading:
82
+ if isinstance(huggingface_dataset, HfDataset):
83
+ self.huggingface_dataset = huggingface_dataset
84
+ elif isinstance(huggingface_dataset, str):
85
+ self.huggingface_dataset = load_from_disk(huggingface_dataset)
86
+ else:
87
+ self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
88
+ raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
89
+ self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
90
+ self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
91
+
92
+ # Tokenizer:
93
+ if callable(tokenizer):
94
+ self.tokenizer = tokenizer
95
+ else:
96
+ self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
97
+ raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
98
+
99
+ # Segmenter:
100
+ if not isinstance(segmenter, SentenceSegmenter):
101
+ self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
102
+ raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
103
+ else:
104
+ self.segmenter = segmenter
105
+
106
+ # Percentage:
107
+ if not (0.0 < percentage <= 1.0):
108
+ self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
109
+ raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
110
+ else:
111
+ self.percentage = percentage
112
+
113
+ # Return type:
114
+ if not isinstance(return_type, type):
115
+ self.logger.error(f'[SegmentationDataset] return_type must be a type.')
116
+ raise ValueError(f'[SegmentationDataset] return_type must be a type.')
117
+ elif return_type not in [dict, tuple]:
118
+ self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
119
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
120
+ else:
121
+ self.return_type = return_type
122
+
123
+ def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
124
+ """
125
+ Returns a PyTorch DataLoader for this dataset.
126
+
127
+ Parameters
128
+ ----------
129
+ batch_size : int
130
+ Number of samples per batch.
131
+ shuffle : bool
132
+ Whether to shuffle the dataset.
133
+ num_workers : int
134
+ Number of worker processes.
135
+ **kwargs
136
+ Additional arguments for DataLoader.
137
+
138
+ Returns
139
+ -------
140
+ [torch.utils.data.DataLoader
141
+ Configured DataLoader.
142
+ """
143
+ # Size handling:
144
+ return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
145
+ pin_memory=True, **kwargs)
146
+
147
+ def __len__(self) -> int:
148
+ """
149
+ Returns the number of samples in the dataset.
150
+
151
+ Returns
152
+ -------
153
+ int
154
+ Total number of samples.
155
+ """
156
+ return int(self.huggingface_dataset.num_rows * self.percentage)
157
+
158
+ def __getitem__(self, idx) -> dict | tuple:
159
+ """
160
+ Retrieves a single sample and generates segmentation labels.
161
+
162
+ Parameters
163
+ ----------
164
+ idx : int
165
+ Index of the sample.
166
+
167
+ Returns
168
+ -------
169
+ tuple
170
+ A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
171
+ """
172
+ sample = self.huggingface_dataset[idx]['text']
173
+ sentences = self.segmenter(sample)
174
+ tokenized = self.tokenizer(sentences['sentences'])
175
+
176
+ if self.return_type == tuple:
177
+ return (
178
+ tokenized['input_ids'], # x
179
+ sentences['sentence_boundaries'], # y
180
+ tokenized['attention_mask'], # x_mask
181
+ sentences['sentence_mask'], # y_mask
182
+ sentences['sentence_candidates'], # y_prime_mask
183
+ )
184
+ elif self.return_type == dict:
185
+ return_value = {
186
+ 'input': tokenized['input_ids'],
187
+ 'input_mask': tokenized['attention_mask'],
188
+ 'labels': sentences['sentence_boundaries'],
189
+ 'output_mask': sentences['sentence_mask'],
190
+ 'candidate_mask': sentences['sentence_candidates']
191
+ }
192
+ else:
193
+ raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
194
+ return return_value
195
+
196
+
197
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
198
+ # END OF FILE #
199
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #