alverciito
commited on
Commit
·
dbd79bd
1
Parent(s):
4c7684b
upload safetensors and refactor research files
Browse filesThis view is limited to 50 files because it contains too many changes. Â
See raw diff
- config.json +16 -0
- configurations.py +0 -0
- model.py +176 -0
- model.safetensors +3 -0
- requirements.txt +1 -0
- bench.py → research_files/bench.py +0 -0
- research_files/benchmark/results/binseg_bert-base-multilingual-cased.json +0 -0
- {benchmark → research_files/benchmark}/results/binseg_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
- {benchmark → research_files/benchmark}/results/binseg_sentence_similarity_spanish_es.json +0 -0
- research_files/benchmark/results/csim_bert-base-multilingual-cased.json +0 -0
- {benchmark → research_files/benchmark}/results/csim_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
- {benchmark → research_files/benchmark}/results/csim_sentence_similarity_spanish_es.json +0 -0
- research_files/benchmark/results/pelt_LaBSE.json +0 -0
- {benchmark → research_files/benchmark}/results/pelt_bert-base-multilingual-cased.json +0 -0
- {benchmark → research_files/benchmark}/results/pelt_paraphrase-multilingual-MiniLM-L12-v2.json +0 -0
- {benchmark → research_files/benchmark}/results/pelt_sentence_similarity_spanish_es.json +0 -0
- {benchmark → research_files/benchmark}/results/proposed_method.json +0 -0
- {benchmark → research_files/benchmark}/results/textile_baseline.json +0 -0
- {benchmark → research_files/benchmark}/segmentation_benchmark/__init__.py +0 -0
- {benchmark → research_files/benchmark}/segmentation_benchmark/heuristic.py +0 -0
- {benchmark → research_files/benchmark}/segmentation_benchmark/load_dataset.py +0 -0
- {benchmark → research_files/benchmark}/segmentation_benchmark/metrics.py +0 -0
- {benchmark → research_files/benchmark}/segmentation_benchmark/proposed.py +2 -2
- {benchmark → research_files/benchmark}/segmentation_benchmark/transformers.py +1 -1
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_1.json +0 -0
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_2.json +0 -0
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_3.json +0 -0
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_4.json +0 -0
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_5.json +0 -0
- {benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_threshold.py +2 -2
- {benchmark → research_files/benchmark}/thresholding_benchmark/print_results.py +0 -0
- {benchmark → research_files/benchmark}/wikipedia-es-A002/data-00000-of-00001.arrow +0 -0
- {benchmark → research_files/benchmark}/wikipedia-es-A002/dataset_info.json +0 -0
- {benchmark → research_files/benchmark}/wikipedia-es-A002/state.json +0 -0
- {inference → research_files/inference}/__init__.py +0 -0
- {inference → research_files/inference}/config.py +181 -181
- {inference → research_files/inference}/load.py +0 -0
- {inference → research_files/inference}/model_state.pt +0 -0
- {inference → research_files/inference}/pipeline.py +1 -1
- {inference → research_files/inference}/tokenizer_32768.json +0 -0
- research_files/torch_to_hf.py +27 -0
- {train → research_files/train}/config.py +2 -2
- {train → research_files/train}/train_logs/config.json +0 -0
- {train → research_files/train}/train_logs/logfile.log +0 -0
- {train → research_files/train}/train_logs/tensorboard_logs.zip +0 -0
- {train → research_files/train}/train_model.py +3 -3
- special_tokens_map.json +7 -0
- src/dataset/__init__.py +13 -13
- src/dataset/config.py +29 -29
- src/dataset/dataset.py +199 -199
config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CoseNetTransformer"
|
| 4 |
+
],
|
| 5 |
+
"dropout": 0.0,
|
| 6 |
+
"emb_dim": 256,
|
| 7 |
+
"model_type": "sentence_transformer",
|
| 8 |
+
"seq_len": ...,
|
| 9 |
+
"torch_dtype": "float32",
|
| 10 |
+
"transformers_version": "4.57.3",
|
| 11 |
+
"vocab_size": 32768,
|
| 12 |
+
"auto_map": {
|
| 13 |
+
"AutoConfig": "configurations.SentenceCoseNetConfig",
|
| 14 |
+
"AutoModel": "model.SentenceCoseNet"
|
| 15 |
+
}
|
| 16 |
+
}
|
configurations.py
ADDED
|
File without changes
|
model.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import torch
|
| 9 |
+
from src.model.config import ModelConfig
|
| 10 |
+
from src.model.cosenet import CosineDistanceLayer, CoSeNet
|
| 11 |
+
from src.model.transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CoseNetTransformer(torch.nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Segmentation network combining Transformer encoders with CoSeNet.
|
| 17 |
+
|
| 18 |
+
This model integrates token embeddings and positional encodings with
|
| 19 |
+
a stack of Transformer encoder blocks to produce contextualized
|
| 20 |
+
representations. These representations are then processed by a
|
| 21 |
+
CoSeNet module to perform structured segmentation, followed by a
|
| 22 |
+
cosine-based distance computation.
|
| 23 |
+
|
| 24 |
+
The final output is a pair-wise distance matrix suitable for
|
| 25 |
+
segmentation or boundary detection tasks.
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, model_config: ModelConfig, **kwargs):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the segmentation network.
|
| 30 |
+
|
| 31 |
+
The network is composed of an embedding layer, positional encoding,
|
| 32 |
+
multiple Transformer encoder blocks, a CoSeNet segmentation module,
|
| 33 |
+
and a cosine distance layer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_config (ModelConfig): Configuration object containing all
|
| 37 |
+
hyperparameters required to build the model, including
|
| 38 |
+
vocabulary size, model dimensionality, transformer settings,
|
| 39 |
+
and CoSeNet parameters.
|
| 40 |
+
**kwargs: Additional keyword arguments forwarded to
|
| 41 |
+
`torch.nn.Module`.
|
| 42 |
+
"""
|
| 43 |
+
super().__init__(**kwargs)
|
| 44 |
+
self.valid_padding = model_config.valid_padding
|
| 45 |
+
|
| 46 |
+
# Build layers:
|
| 47 |
+
self.embedding = torch.nn.Embedding(
|
| 48 |
+
model_config.vocab_size,
|
| 49 |
+
model_config.model_dim
|
| 50 |
+
)
|
| 51 |
+
self.positional_encoding = PositionalEncoding(
|
| 52 |
+
emb_dim=model_config.model_dim,
|
| 53 |
+
max_len=model_config.max_tokens
|
| 54 |
+
)
|
| 55 |
+
self.cosenet = CoSeNet(
|
| 56 |
+
trainable=model_config.cosenet.trainable,
|
| 57 |
+
init_scale=model_config.cosenet.init_scale
|
| 58 |
+
)
|
| 59 |
+
self.distance_layer = CosineDistanceLayer()
|
| 60 |
+
self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
|
| 61 |
+
|
| 62 |
+
# Build encoder blocks:
|
| 63 |
+
module_list = list()
|
| 64 |
+
for transformer_config in model_config.transformers:
|
| 65 |
+
encoder_block = EncoderBlock(
|
| 66 |
+
feature_dim=model_config.model_dim,
|
| 67 |
+
attention_heads=transformer_config.attention_heads,
|
| 68 |
+
feed_forward_multiplier=transformer_config.feed_forward_multiplier,
|
| 69 |
+
dropout=transformer_config.dropout,
|
| 70 |
+
valid_padding=model_config.valid_padding,
|
| 71 |
+
pre_normalize=transformer_config.pre_normalize
|
| 72 |
+
)
|
| 73 |
+
module_list.append(encoder_block)
|
| 74 |
+
|
| 75 |
+
self.encoder_blocks = torch.nn.ModuleList(module_list)
|
| 76 |
+
|
| 77 |
+
def encode(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Encode input sequences into contextualized representations.
|
| 80 |
+
The input token indices are embedded and enriched with positional
|
| 81 |
+
information, then processed by a stack of Transformer encoder
|
| 82 |
+
blocks.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
x (torch.Tensor): Input tensor of token indices with shape
|
| 86 |
+
(batch_size, max_tokens).
|
| 87 |
+
mask (torch.Tensor, optional): Optional mask tensor indicating
|
| 88 |
+
valid or padded positions, depending on the configuration
|
| 89 |
+
of the Transformer blocks. Defaults to None. Dimensions should be
|
| 90 |
+
(batch_size, max_tokens).
|
| 91 |
+
"""
|
| 92 |
+
# Convert to type:
|
| 93 |
+
x = x.int()
|
| 94 |
+
# Embedding and positional encoding:
|
| 95 |
+
x = self.embedding(x)
|
| 96 |
+
x = self.positional_encoding(x)
|
| 97 |
+
# Check mask inversion:
|
| 98 |
+
if mask[0, 0] == 0:
|
| 99 |
+
mask = torch.logical_not(mask)
|
| 100 |
+
# Encode:
|
| 101 |
+
for encoder in self.encoder_blocks:
|
| 102 |
+
x = encoder(x, mask=mask)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Forward pass of the segmentation network.
|
| 109 |
+
|
| 110 |
+
The input token indices are embedded and enriched with positional
|
| 111 |
+
information, then processed by a stack of Transformer encoder
|
| 112 |
+
blocks. The resulting representations are segmented using CoSeNet
|
| 113 |
+
and finally transformed into a pair-wise distance representation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
x (torch.Tensor): Input tensor of token indices with shape
|
| 117 |
+
(batch_size, sequence_length).
|
| 118 |
+
mask (torch.Tensor, optional): Optional mask tensor indicating
|
| 119 |
+
valid or padded positions, depending on the configuration
|
| 120 |
+
of the Transformer blocks. Defaults to None.
|
| 121 |
+
|
| 122 |
+
If `valid_padding` is disabled, the mask is inverted before being
|
| 123 |
+
passed to CoSeNet to match its masking convention.
|
| 124 |
+
|
| 125 |
+
candidate_mask (torch.Tensor, optional): Optional mask tensor for
|
| 126 |
+
candidate positions in CoSeNet. Defaults to None.
|
| 127 |
+
|
| 128 |
+
If `valid_padding` is disabled, the mask is inverted before being
|
| 129 |
+
passed to CoSeNet to match its masking convention.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
torch.Tensor: Output tensor containing pairwise distance values
|
| 133 |
+
derived from the segmented representations.
|
| 134 |
+
"""
|
| 135 |
+
# Convert to type:
|
| 136 |
+
x = x.int()
|
| 137 |
+
|
| 138 |
+
# Embedding and positional encoding:
|
| 139 |
+
x = self.embedding(x)
|
| 140 |
+
x = self.positional_encoding(x)
|
| 141 |
+
|
| 142 |
+
# Reshape x and mask:
|
| 143 |
+
_b, _s, _t, _d = x.shape
|
| 144 |
+
x = x.reshape(_b * _s, _t, _d)
|
| 145 |
+
if mask is not None:
|
| 146 |
+
mask = mask.reshape(_b * _s, _t).bool()
|
| 147 |
+
|
| 148 |
+
# Encode the sequence:
|
| 149 |
+
for encoder in self.encoder_blocks:
|
| 150 |
+
x = encoder(x, mask=mask)
|
| 151 |
+
|
| 152 |
+
# Reshape x and mask:
|
| 153 |
+
x = x.reshape(_b, _s, _t, _d)
|
| 154 |
+
if mask is not None:
|
| 155 |
+
mask = mask.reshape(_b, _s, _t)
|
| 156 |
+
mask = torch.logical_not(mask) if not self.valid_padding else mask
|
| 157 |
+
|
| 158 |
+
# Apply pooling:
|
| 159 |
+
x, mask = self.pooling(x, mask=mask)
|
| 160 |
+
|
| 161 |
+
# Compute distances:
|
| 162 |
+
x = self.distance_layer(x)
|
| 163 |
+
|
| 164 |
+
# Pass through CoSeNet:
|
| 165 |
+
x = self.cosenet(x, mask=mask)
|
| 166 |
+
|
| 167 |
+
# Apply candidate mask if provided:
|
| 168 |
+
if candidate_mask is not None:
|
| 169 |
+
candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
|
| 170 |
+
candidate_mask = candidate_mask.to(device=x.device)
|
| 171 |
+
x = x.masked_fill(candidate_mask, 0)
|
| 172 |
+
|
| 173 |
+
return x
|
| 174 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 175 |
+
# END OF FILE #
|
| 176 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6db78280c80f27b94434a1d1e17296ecddc1d21705ec6be3b8bd0bc49991f27f
|
| 3 |
+
size 44485604
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
ruptures
|
| 2 |
sentence-transformers
|
| 3 |
numpy==2.3.5
|
|
|
|
| 1 |
+
safetensors
|
| 2 |
ruptures
|
| 3 |
sentence-transformers
|
| 4 |
numpy==2.3.5
|
bench.py → research_files/bench.py
RENAMED
|
File without changes
|
research_files/benchmark/results/binseg_bert-base-multilingual-cased.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{benchmark → research_files/benchmark}/results/binseg_paraphrase-multilingual-MiniLM-L12-v2.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/binseg_sentence_similarity_spanish_es.json
RENAMED
|
File without changes
|
research_files/benchmark/results/csim_bert-base-multilingual-cased.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{benchmark → research_files/benchmark}/results/csim_paraphrase-multilingual-MiniLM-L12-v2.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/csim_sentence_similarity_spanish_es.json
RENAMED
|
File without changes
|
research_files/benchmark/results/pelt_LaBSE.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{benchmark → research_files/benchmark}/results/pelt_bert-base-multilingual-cased.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/pelt_paraphrase-multilingual-MiniLM-L12-v2.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/pelt_sentence_similarity_spanish_es.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/proposed_method.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/results/textile_baseline.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/segmentation_benchmark/__init__.py
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/segmentation_benchmark/heuristic.py
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/segmentation_benchmark/load_dataset.py
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/segmentation_benchmark/metrics.py
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/segmentation_benchmark/proposed.py
RENAMED
|
@@ -9,10 +9,10 @@ import os
|
|
| 9 |
import json
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
-
from
|
| 13 |
from .metrics import precision_recall_f1_wd
|
| 14 |
from .load_dataset import load_dataset
|
| 15 |
-
from inference import load_model
|
| 16 |
|
| 17 |
|
| 18 |
def evaluate_proposed(
|
|
|
|
| 9 |
import json
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
from .metrics import precision_recall_f1_wd
|
| 14 |
from .load_dataset import load_dataset
|
| 15 |
+
from research_files.inference import load_model
|
| 16 |
|
| 17 |
|
| 18 |
def evaluate_proposed(
|
{benchmark → research_files/benchmark}/segmentation_benchmark/transformers.py
RENAMED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
import json
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
-
from
|
| 13 |
from .metrics import precision_recall_f1_wd
|
| 14 |
from .load_dataset import load_dataset
|
| 15 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 9 |
import json
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
from .metrics import precision_recall_f1_wd
|
| 14 |
from .load_dataset import load_dataset
|
| 15 |
from sentence_transformers import SentenceTransformer
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_1.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_2.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_3.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_4.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_result_A001_5.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/thresholding_benchmark/benchmark_threshold.py
RENAMED
|
@@ -10,8 +10,8 @@ import tqdm
|
|
| 10 |
import json
|
| 11 |
from datasets import load_from_disk
|
| 12 |
from src.model import SegmentationNetwork, MaskedBCELoss, WindowDiffLoss
|
| 13 |
-
from
|
| 14 |
-
from train.config import configuration
|
| 15 |
|
| 16 |
|
| 17 |
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
|
|
|
| 10 |
import json
|
| 11 |
from datasets import load_from_disk
|
| 12 |
from src.model import SegmentationNetwork, MaskedBCELoss, WindowDiffLoss
|
| 13 |
+
from dataset import SegmentationTokenizer, SentenceSegmenter, TokenizedSegmentationDataset
|
| 14 |
+
from research_files.train.config import configuration
|
| 15 |
|
| 16 |
|
| 17 |
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
{benchmark → research_files/benchmark}/thresholding_benchmark/print_results.py
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/wikipedia-es-A002/data-00000-of-00001.arrow
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/wikipedia-es-A002/dataset_info.json
RENAMED
|
File without changes
|
{benchmark → research_files/benchmark}/wikipedia-es-A002/state.json
RENAMED
|
File without changes
|
{inference → research_files/inference}/__init__.py
RENAMED
|
File without changes
|
{inference → research_files/inference}/config.py
RENAMED
|
@@ -1,181 +1,181 @@
|
|
| 1 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
-
# #
|
| 3 |
-
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
-
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
-
# #
|
| 6 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
-
# Import statements:
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
|
| 10 |
-
from src.dataset import DatasetConfig
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 14 |
-
# SETUP CONFIGURATION #
|
| 15 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
-
@dataclass
|
| 17 |
-
class SetupConfig:
|
| 18 |
-
"""
|
| 19 |
-
Configuration parameters related to the execution environment and logging.
|
| 20 |
-
|
| 21 |
-
This configuration controls device selection, checkpointing behavior,
|
| 22 |
-
reproducibility settings, and logging paths for an experiment.
|
| 23 |
-
"""
|
| 24 |
-
device_number: int = 0
|
| 25 |
-
save_model_each: int = 0
|
| 26 |
-
seed: int = None
|
| 27 |
-
logging_path: str = None
|
| 28 |
-
reload_checkpoint: bool = False
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def overwrite_setup_config() -> SetupConfig:
|
| 32 |
-
"""
|
| 33 |
-
Create and override the default setup configuration.
|
| 34 |
-
|
| 35 |
-
This function customizes execution-level parameters such as logging
|
| 36 |
-
paths, checkpoint reloading, and model saving frequency.
|
| 37 |
-
|
| 38 |
-
Returns:
|
| 39 |
-
SetupConfig: The configured setup configuration object.
|
| 40 |
-
"""
|
| 41 |
-
config = SetupConfig()
|
| 42 |
-
config.logging_path = r'/workspace/logs'
|
| 43 |
-
config.reload_checkpoint = True
|
| 44 |
-
config.save_model_each = 1
|
| 45 |
-
return config
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 49 |
-
# TRAINING CONFIGURATION #
|
| 50 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 51 |
-
@dataclass
|
| 52 |
-
class TrainConfig:
|
| 53 |
-
"""
|
| 54 |
-
Training configuration container.
|
| 55 |
-
|
| 56 |
-
This dataclass aggregates model, dataset, and setup configurations,
|
| 57 |
-
together with optimization and training hyperparameters.
|
| 58 |
-
"""
|
| 59 |
-
# Linked configurations:
|
| 60 |
-
model_config: ModelConfig | None = None
|
| 61 |
-
dataset_config: DatasetConfig | None = None
|
| 62 |
-
setup_config: SetupConfig | None = None
|
| 63 |
-
|
| 64 |
-
# Training parameters:
|
| 65 |
-
batch_size: int = 32
|
| 66 |
-
num_epochs: int = 100
|
| 67 |
-
|
| 68 |
-
# Optimizer parameters:
|
| 69 |
-
learning_rate: float = 1e-4
|
| 70 |
-
learning_rate_min: float = 1e-5
|
| 71 |
-
weight_decay: float = 1e-8
|
| 72 |
-
betas: tuple[float, float] = (0.5, 0.999)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def overwrite_train_config() -> TrainConfig:
|
| 76 |
-
"""
|
| 77 |
-
Create and override the default training configuration.
|
| 78 |
-
|
| 79 |
-
This function customizes batch size, number of epochs, and optimizer
|
| 80 |
-
hyperparameters for the training process.
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
TrainConfig: The configured training configuration object.
|
| 84 |
-
"""
|
| 85 |
-
config = TrainConfig()
|
| 86 |
-
config.batch_size = 4
|
| 87 |
-
config.num_epochs = 200
|
| 88 |
-
config.learning_rate = 5e-4
|
| 89 |
-
config.learning_rate_min = 5e-5
|
| 90 |
-
config.weight_decay = 1e-6
|
| 91 |
-
return config
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 95 |
-
# DATASET CONFIGURATION #
|
| 96 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 97 |
-
def overwrite_dataset_config() -> DatasetConfig:
|
| 98 |
-
"""
|
| 99 |
-
Create and override the dataset configuration.
|
| 100 |
-
|
| 101 |
-
This function sets the file paths and usage percentages for training,
|
| 102 |
-
validation, and test datasets.
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
DatasetConfig: The configured dataset configuration object.
|
| 106 |
-
"""
|
| 107 |
-
config = DatasetConfig()
|
| 108 |
-
config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
|
| 109 |
-
config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
|
| 110 |
-
config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
|
| 111 |
-
config.train_percentage = 0.4
|
| 112 |
-
config.val_percentage = 0.4
|
| 113 |
-
config.test_percentage = 1.0
|
| 114 |
-
return config
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 118 |
-
# MODEL CONFIGURATION #
|
| 119 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 120 |
-
def overwrite_model_config() -> ModelConfig:
|
| 121 |
-
"""
|
| 122 |
-
Create and override the model configuration.
|
| 123 |
-
|
| 124 |
-
This function defines the architecture-level parameters, including
|
| 125 |
-
vocabulary size, embedding dimensionality, CoSeNet settings, and
|
| 126 |
-
the stack of Transformer encoder configurations.
|
| 127 |
-
|
| 128 |
-
Returns:
|
| 129 |
-
ModelConfig: The configured model configuration object.
|
| 130 |
-
"""
|
| 131 |
-
config = ModelConfig()
|
| 132 |
-
|
| 133 |
-
# High-level params:
|
| 134 |
-
config.vocab_size = 32_768
|
| 135 |
-
config.model_dim = 256
|
| 136 |
-
config.valid_padding = True
|
| 137 |
-
|
| 138 |
-
# CoSeNet params:
|
| 139 |
-
config.cosenet = CoSeNetConfig(
|
| 140 |
-
trainable=True,
|
| 141 |
-
init_scale=5.0
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Transformer params:
|
| 145 |
-
config.transformers = [
|
| 146 |
-
TransformerConfig(**cfg)
|
| 147 |
-
for cfg in [
|
| 148 |
-
{
|
| 149 |
-
"attention_heads": 16,
|
| 150 |
-
"feed_forward_multiplier": 8,
|
| 151 |
-
"dropout": 0.0,
|
| 152 |
-
"pre_normalize": True
|
| 153 |
-
},
|
| 154 |
-
{
|
| 155 |
-
"attention_heads": 16,
|
| 156 |
-
"feed_forward_multiplier": 8,
|
| 157 |
-
"dropout": 0.0,
|
| 158 |
-
"pre_normalize": True
|
| 159 |
-
}
|
| 160 |
-
]
|
| 161 |
-
]
|
| 162 |
-
|
| 163 |
-
return config
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 167 |
-
# WHOLE CONFIGURATION #
|
| 168 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 169 |
-
def configuration() -> TrainConfig:
|
| 170 |
-
"""
|
| 171 |
-
Create the experiment configuration
|
| 172 |
-
:return: A TrainConfig configuration object
|
| 173 |
-
"""
|
| 174 |
-
config = overwrite_train_config()
|
| 175 |
-
config.setup_config = overwrite_setup_config()
|
| 176 |
-
config.model_config = overwrite_model_config()
|
| 177 |
-
config.dataset_config = overwrite_dataset_config()
|
| 178 |
-
return config
|
| 179 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 180 |
-
# END OF FILE #
|
| 181 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from src.model import ModelConfig, CoSeNetConfig, TransformerConfig
|
| 10 |
+
from src.dataset import DatasetConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 14 |
+
# SETUP CONFIGURATION #
|
| 15 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
+
@dataclass
|
| 17 |
+
class SetupConfig:
|
| 18 |
+
"""
|
| 19 |
+
Configuration parameters related to the execution environment and logging.
|
| 20 |
+
|
| 21 |
+
This configuration controls device selection, checkpointing behavior,
|
| 22 |
+
reproducibility settings, and logging paths for an experiment.
|
| 23 |
+
"""
|
| 24 |
+
device_number: int = 0
|
| 25 |
+
save_model_each: int = 0
|
| 26 |
+
seed: int = None
|
| 27 |
+
logging_path: str = None
|
| 28 |
+
reload_checkpoint: bool = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def overwrite_setup_config() -> SetupConfig:
|
| 32 |
+
"""
|
| 33 |
+
Create and override the default setup configuration.
|
| 34 |
+
|
| 35 |
+
This function customizes execution-level parameters such as logging
|
| 36 |
+
paths, checkpoint reloading, and model saving frequency.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
SetupConfig: The configured setup configuration object.
|
| 40 |
+
"""
|
| 41 |
+
config = SetupConfig()
|
| 42 |
+
config.logging_path = r'/workspace/logs'
|
| 43 |
+
config.reload_checkpoint = True
|
| 44 |
+
config.save_model_each = 1
|
| 45 |
+
return config
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 49 |
+
# TRAINING CONFIGURATION #
|
| 50 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 51 |
+
@dataclass
|
| 52 |
+
class TrainConfig:
|
| 53 |
+
"""
|
| 54 |
+
Training configuration container.
|
| 55 |
+
|
| 56 |
+
This dataclass aggregates model, dataset, and setup configurations,
|
| 57 |
+
together with optimization and training hyperparameters.
|
| 58 |
+
"""
|
| 59 |
+
# Linked configurations:
|
| 60 |
+
model_config: ModelConfig | None = None
|
| 61 |
+
dataset_config: DatasetConfig | None = None
|
| 62 |
+
setup_config: SetupConfig | None = None
|
| 63 |
+
|
| 64 |
+
# Training parameters:
|
| 65 |
+
batch_size: int = 32
|
| 66 |
+
num_epochs: int = 100
|
| 67 |
+
|
| 68 |
+
# Optimizer parameters:
|
| 69 |
+
learning_rate: float = 1e-4
|
| 70 |
+
learning_rate_min: float = 1e-5
|
| 71 |
+
weight_decay: float = 1e-8
|
| 72 |
+
betas: tuple[float, float] = (0.5, 0.999)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def overwrite_train_config() -> TrainConfig:
|
| 76 |
+
"""
|
| 77 |
+
Create and override the default training configuration.
|
| 78 |
+
|
| 79 |
+
This function customizes batch size, number of epochs, and optimizer
|
| 80 |
+
hyperparameters for the training process.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
TrainConfig: The configured training configuration object.
|
| 84 |
+
"""
|
| 85 |
+
config = TrainConfig()
|
| 86 |
+
config.batch_size = 4
|
| 87 |
+
config.num_epochs = 200
|
| 88 |
+
config.learning_rate = 5e-4
|
| 89 |
+
config.learning_rate_min = 5e-5
|
| 90 |
+
config.weight_decay = 1e-6
|
| 91 |
+
return config
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 95 |
+
# DATASET CONFIGURATION #
|
| 96 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 97 |
+
def overwrite_dataset_config() -> DatasetConfig:
|
| 98 |
+
"""
|
| 99 |
+
Create and override the dataset configuration.
|
| 100 |
+
|
| 101 |
+
This function sets the file paths and usage percentages for training,
|
| 102 |
+
validation, and test datasets.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
DatasetConfig: The configured dataset configuration object.
|
| 106 |
+
"""
|
| 107 |
+
config = DatasetConfig()
|
| 108 |
+
config.train_data_path = r"/workspace/data/tokens-A000-segmentation"
|
| 109 |
+
config.val_data_path = r"/workspace/data/tokens-A001-segmentation"
|
| 110 |
+
config.test_data_path = r"/workspace/data/tokens-A002-segmentation"
|
| 111 |
+
config.train_percentage = 0.4
|
| 112 |
+
config.val_percentage = 0.4
|
| 113 |
+
config.test_percentage = 1.0
|
| 114 |
+
return config
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 118 |
+
# MODEL CONFIGURATION #
|
| 119 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 120 |
+
def overwrite_model_config() -> ModelConfig:
|
| 121 |
+
"""
|
| 122 |
+
Create and override the model configuration.
|
| 123 |
+
|
| 124 |
+
This function defines the architecture-level parameters, including
|
| 125 |
+
vocabulary size, embedding dimensionality, CoSeNet settings, and
|
| 126 |
+
the stack of Transformer encoder configurations.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
ModelConfig: The configured model configuration object.
|
| 130 |
+
"""
|
| 131 |
+
config = ModelConfig()
|
| 132 |
+
|
| 133 |
+
# High-level params:
|
| 134 |
+
config.vocab_size = 32_768
|
| 135 |
+
config.model_dim = 256
|
| 136 |
+
config.valid_padding = True
|
| 137 |
+
|
| 138 |
+
# CoSeNet params:
|
| 139 |
+
config.cosenet = CoSeNetConfig(
|
| 140 |
+
trainable=True,
|
| 141 |
+
init_scale=5.0
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Transformer params:
|
| 145 |
+
config.transformers = [
|
| 146 |
+
TransformerConfig(**cfg)
|
| 147 |
+
for cfg in [
|
| 148 |
+
{
|
| 149 |
+
"attention_heads": 16,
|
| 150 |
+
"feed_forward_multiplier": 8,
|
| 151 |
+
"dropout": 0.0,
|
| 152 |
+
"pre_normalize": True
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"attention_heads": 16,
|
| 156 |
+
"feed_forward_multiplier": 8,
|
| 157 |
+
"dropout": 0.0,
|
| 158 |
+
"pre_normalize": True
|
| 159 |
+
}
|
| 160 |
+
]
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
return config
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 167 |
+
# WHOLE CONFIGURATION #
|
| 168 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 169 |
+
def configuration() -> TrainConfig:
|
| 170 |
+
"""
|
| 171 |
+
Create the experiment configuration
|
| 172 |
+
:return: A TrainConfig configuration object
|
| 173 |
+
"""
|
| 174 |
+
config = overwrite_train_config()
|
| 175 |
+
config.setup_config = overwrite_setup_config()
|
| 176 |
+
config.model_config = overwrite_model_config()
|
| 177 |
+
config.dataset_config = overwrite_dataset_config()
|
| 178 |
+
return config
|
| 179 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 180 |
+
# END OF FILE #
|
| 181 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
{inference → research_files/inference}/load.py
RENAMED
|
File without changes
|
{inference → research_files/inference}/model_state.pt
RENAMED
|
File without changes
|
{inference → research_files/inference}/pipeline.py
RENAMED
|
@@ -8,7 +8,7 @@
|
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
from src.model import SegmentationNetwork
|
| 11 |
-
from
|
| 12 |
|
| 13 |
|
| 14 |
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
from src.model import SegmentationNetwork
|
| 11 |
+
from dataset import SegmentationTokenizer, SentenceSegmenter
|
| 12 |
|
| 13 |
|
| 14 |
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
|
{inference → research_files/inference}/tokenizer_32768.json
RENAMED
|
File without changes
|
research_files/torch_to_hf.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import os
|
| 9 |
+
from research_files.inference import load_model
|
| 10 |
+
from safetensors.torch import save_file
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def convert_model(save_path: str, model_path: str = None, tokenizer_path: str = None):
|
| 14 |
+
|
| 15 |
+
# Load model:
|
| 16 |
+
model, tokenizer, segmenter = load_model(model_path, tokenizer_path)
|
| 17 |
+
state_dict = model.state_dict()
|
| 18 |
+
save_file(state_dict, os.path.join(save_path, "model.safetensors"))
|
| 19 |
+
tokenizer._hf_tokenizer.save_pretrained(os.path.join(save_path))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
# Convert and save:
|
| 24 |
+
convert_model("./")
|
| 25 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 26 |
+
# END OF FILE #
|
| 27 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
{train → research_files/train}/config.py
RENAMED
|
@@ -7,8 +7,8 @@
|
|
| 7 |
# Import statements:
|
| 8 |
import os
|
| 9 |
from dataclasses import dataclass
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
|
| 13 |
|
| 14 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 7 |
# Import statements:
|
| 8 |
import os
|
| 9 |
from dataclasses import dataclass
|
| 10 |
+
from model import ModelConfig, CoSeNetConfig, TransformerConfig
|
| 11 |
+
from dataset import DatasetConfig
|
| 12 |
|
| 13 |
|
| 14 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
{train → research_files/train}/train_logs/config.json
RENAMED
|
File without changes
|
{train → research_files/train}/train_logs/logfile.log
RENAMED
|
File without changes
|
{train → research_files/train}/train_logs/tensorboard_logs.zip
RENAMED
|
File without changes
|
{train → research_files/train}/train_model.py
RENAMED
|
@@ -7,10 +7,10 @@
|
|
| 7 |
# Import statements:
|
| 8 |
import torch
|
| 9 |
import tqdm
|
| 10 |
-
from train.config import configuration, TrainConfig
|
| 11 |
from src.model import SegmentationNetwork, MaskedBCELoss
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
|
| 15 |
|
| 16 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 7 |
# Import statements:
|
| 8 |
import torch
|
| 9 |
import tqdm
|
| 10 |
+
from research_files.train.config import configuration, TrainConfig
|
| 11 |
from src.model import SegmentationNetwork, MaskedBCELoss
|
| 12 |
+
from dataset import TokenizedSegmentationDataset
|
| 13 |
+
from dlutils import Setup, train_step, validation_step
|
| 14 |
|
| 15 |
|
| 16 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": "[CLS]",
|
| 3 |
+
"mask_token": "[MASK]",
|
| 4 |
+
"pad_token": "[PAD]",
|
| 5 |
+
"sep_token": "[SEP]",
|
| 6 |
+
"unk_token": "[UNK]"
|
| 7 |
+
}
|
src/dataset/__init__.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
-
# #
|
| 3 |
-
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
-
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
-
# #
|
| 6 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
-
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 8 |
-
from .dataset import SegmentationDataset
|
| 9 |
-
from .tokenized_dataset import TokenizedSegmentationDataset
|
| 10 |
-
from .config import DatasetConfig
|
| 11 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 12 |
-
# END OF FILE #
|
| 13 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 8 |
+
from .dataset import SegmentationDataset
|
| 9 |
+
from .tokenized_dataset import TokenizedSegmentationDataset
|
| 10 |
+
from .config import DatasetConfig
|
| 11 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 12 |
+
# END OF FILE #
|
| 13 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/config.py
CHANGED
|
@@ -1,29 +1,29 @@
|
|
| 1 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
-
# #
|
| 3 |
-
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
-
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
-
# #
|
| 6 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
-
# Import statements:
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class DatasetConfig:
|
| 13 |
-
# Paths:
|
| 14 |
-
train_data_path: str = None
|
| 15 |
-
val_data_path: str = None
|
| 16 |
-
test_data_path: str = None
|
| 17 |
-
# Percentages:
|
| 18 |
-
train_percentage: float = 1.0
|
| 19 |
-
val_percentage: float = 1.0
|
| 20 |
-
test_percentage: float = 1.0
|
| 21 |
-
# Other parameters:
|
| 22 |
-
num_workers: int = 0
|
| 23 |
-
shuffle_train: bool = True
|
| 24 |
-
shuffle_val: bool = True
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 28 |
-
# END OF FILE #
|
| 29 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class DatasetConfig:
|
| 13 |
+
# Paths:
|
| 14 |
+
train_data_path: str = None
|
| 15 |
+
val_data_path: str = None
|
| 16 |
+
test_data_path: str = None
|
| 17 |
+
# Percentages:
|
| 18 |
+
train_percentage: float = 1.0
|
| 19 |
+
val_percentage: float = 1.0
|
| 20 |
+
test_percentage: float = 1.0
|
| 21 |
+
# Other parameters:
|
| 22 |
+
num_workers: int = 0
|
| 23 |
+
shuffle_train: bool = True
|
| 24 |
+
shuffle_val: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 28 |
+
# END OF FILE #
|
| 29 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
src/dataset/dataset.py
CHANGED
|
@@ -1,199 +1,199 @@
|
|
| 1 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
-
# #
|
| 3 |
-
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
-
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
-
# #
|
| 6 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
-
# Import statements:
|
| 8 |
-
import logging
|
| 9 |
-
from torch.utils.data import Dataset, DataLoader
|
| 10 |
-
from datasets import Dataset as HfDataset
|
| 11 |
-
from datasets import load_from_disk
|
| 12 |
-
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
-
# #
|
| 17 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 18 |
-
class SegmentationDataset(Dataset):
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
huggingface_dataset: str | HfDataset,
|
| 22 |
-
tokenizer: SegmentationTokenizer,
|
| 23 |
-
segmenter: SentenceSegmenter,
|
| 24 |
-
logger: logging.Logger = None,
|
| 25 |
-
percentage: float = 1.0,
|
| 26 |
-
return_type: type = dict
|
| 27 |
-
):
|
| 28 |
-
"""
|
| 29 |
-
A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
|
| 30 |
-
wikipedia-segmentation format. It loads the dataset and prepares it for training.
|
| 31 |
-
|
| 32 |
-
Wikipedia-segmentation format:
|
| 33 |
-
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
|
| 34 |
-
- The dataset should contain the following fields:
|
| 35 |
-
>>> sample = {
|
| 36 |
-
>>> 'text': ['Article 1', 'Article 2', ...],
|
| 37 |
-
>>> 'titles': ['Title 1', 'Title 2', ...],
|
| 38 |
-
>>> 'id': str,
|
| 39 |
-
>>> 'words': int
|
| 40 |
-
>>> 'paragraphs': int
|
| 41 |
-
>>> 'sentences': int
|
| 42 |
-
>>> }
|
| 43 |
-
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
|
| 44 |
-
|
| 45 |
-
Parameters
|
| 46 |
-
----------
|
| 47 |
-
huggingface_dataset : str | HfDataset
|
| 48 |
-
A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
|
| 49 |
-
|
| 50 |
-
tokenizer : callable
|
| 51 |
-
A tokenizer function that takes a string and returns a list of tokens.
|
| 52 |
-
|
| 53 |
-
logger : logging.Logger, optional
|
| 54 |
-
Logger instance. If not provided, a null logger will be used.
|
| 55 |
-
|
| 56 |
-
percentage : float
|
| 57 |
-
Percentage of the dataset to use. Default is 1.0 (100%).
|
| 58 |
-
|
| 59 |
-
return_type : type
|
| 60 |
-
The return type of __getitem__, either dict or tuple. Default is dict.
|
| 61 |
-
|
| 62 |
-
Raises
|
| 63 |
-
------
|
| 64 |
-
ValueError
|
| 65 |
-
If the huggingface_dataset is not a string or a HfDataset.
|
| 66 |
-
ValueError
|
| 67 |
-
If the tokenizer is not a callable function or class.
|
| 68 |
-
ValueError
|
| 69 |
-
If the sentence_tokenizer is not a callable function or class.
|
| 70 |
-
ValueError
|
| 71 |
-
If the dtype is not a type.
|
| 72 |
-
|
| 73 |
-
"""
|
| 74 |
-
# Null logging:
|
| 75 |
-
if not isinstance(logger, logging.Logger):
|
| 76 |
-
self.logger = logging.getLogger("null")
|
| 77 |
-
self.logger.addHandler(logging.NullHandler())
|
| 78 |
-
else:
|
| 79 |
-
self.logger = logger
|
| 80 |
-
|
| 81 |
-
# Loading:
|
| 82 |
-
if isinstance(huggingface_dataset, HfDataset):
|
| 83 |
-
self.huggingface_dataset = huggingface_dataset
|
| 84 |
-
elif isinstance(huggingface_dataset, str):
|
| 85 |
-
self.huggingface_dataset = load_from_disk(huggingface_dataset)
|
| 86 |
-
else:
|
| 87 |
-
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 88 |
-
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 89 |
-
self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
|
| 90 |
-
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
|
| 91 |
-
|
| 92 |
-
# Tokenizer:
|
| 93 |
-
if callable(tokenizer):
|
| 94 |
-
self.tokenizer = tokenizer
|
| 95 |
-
else:
|
| 96 |
-
self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 97 |
-
raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 98 |
-
|
| 99 |
-
# Segmenter:
|
| 100 |
-
if not isinstance(segmenter, SentenceSegmenter):
|
| 101 |
-
self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 102 |
-
raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 103 |
-
else:
|
| 104 |
-
self.segmenter = segmenter
|
| 105 |
-
|
| 106 |
-
# Percentage:
|
| 107 |
-
if not (0.0 < percentage <= 1.0):
|
| 108 |
-
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 109 |
-
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 110 |
-
else:
|
| 111 |
-
self.percentage = percentage
|
| 112 |
-
|
| 113 |
-
# Return type:
|
| 114 |
-
if not isinstance(return_type, type):
|
| 115 |
-
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
|
| 116 |
-
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
|
| 117 |
-
elif return_type not in [dict, tuple]:
|
| 118 |
-
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 119 |
-
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 120 |
-
else:
|
| 121 |
-
self.return_type = return_type
|
| 122 |
-
|
| 123 |
-
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
|
| 124 |
-
"""
|
| 125 |
-
Returns a PyTorch DataLoader for this dataset.
|
| 126 |
-
|
| 127 |
-
Parameters
|
| 128 |
-
----------
|
| 129 |
-
batch_size : int
|
| 130 |
-
Number of samples per batch.
|
| 131 |
-
shuffle : bool
|
| 132 |
-
Whether to shuffle the dataset.
|
| 133 |
-
num_workers : int
|
| 134 |
-
Number of worker processes.
|
| 135 |
-
**kwargs
|
| 136 |
-
Additional arguments for DataLoader.
|
| 137 |
-
|
| 138 |
-
Returns
|
| 139 |
-
-------
|
| 140 |
-
[torch.utils.data.DataLoader
|
| 141 |
-
Configured DataLoader.
|
| 142 |
-
"""
|
| 143 |
-
# Size handling:
|
| 144 |
-
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
| 145 |
-
pin_memory=True, **kwargs)
|
| 146 |
-
|
| 147 |
-
def __len__(self) -> int:
|
| 148 |
-
"""
|
| 149 |
-
Returns the number of samples in the dataset.
|
| 150 |
-
|
| 151 |
-
Returns
|
| 152 |
-
-------
|
| 153 |
-
int
|
| 154 |
-
Total number of samples.
|
| 155 |
-
"""
|
| 156 |
-
return int(self.huggingface_dataset.num_rows * self.percentage)
|
| 157 |
-
|
| 158 |
-
def __getitem__(self, idx) -> dict | tuple:
|
| 159 |
-
"""
|
| 160 |
-
Retrieves a single sample and generates segmentation labels.
|
| 161 |
-
|
| 162 |
-
Parameters
|
| 163 |
-
----------
|
| 164 |
-
idx : int
|
| 165 |
-
Index of the sample.
|
| 166 |
-
|
| 167 |
-
Returns
|
| 168 |
-
-------
|
| 169 |
-
tuple
|
| 170 |
-
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
|
| 171 |
-
"""
|
| 172 |
-
sample = self.huggingface_dataset[idx]['text']
|
| 173 |
-
sentences = self.segmenter(sample)
|
| 174 |
-
tokenized = self.tokenizer(sentences['sentences'])
|
| 175 |
-
|
| 176 |
-
if self.return_type == tuple:
|
| 177 |
-
return (
|
| 178 |
-
tokenized['input_ids'], # x
|
| 179 |
-
sentences['sentence_boundaries'], # y
|
| 180 |
-
tokenized['attention_mask'], # x_mask
|
| 181 |
-
sentences['sentence_mask'], # y_mask
|
| 182 |
-
sentences['sentence_candidates'], # y_prime_mask
|
| 183 |
-
)
|
| 184 |
-
elif self.return_type == dict:
|
| 185 |
-
return_value = {
|
| 186 |
-
'input': tokenized['input_ids'],
|
| 187 |
-
'input_mask': tokenized['attention_mask'],
|
| 188 |
-
'labels': sentences['sentence_boundaries'],
|
| 189 |
-
'output_mask': sentences['sentence_mask'],
|
| 190 |
-
'candidate_mask': sentences['sentence_candidates']
|
| 191 |
-
}
|
| 192 |
-
else:
|
| 193 |
-
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 194 |
-
return return_value
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 198 |
-
# END OF FILE #
|
| 199 |
-
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
|
|
|
| 1 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 2 |
+
# #
|
| 3 |
+
# This file was created by: Alberto Palomo Alonso #
|
| 4 |
+
# Universidad de Alcalá - Escuela Politécnica Superior #
|
| 5 |
+
# #
|
| 6 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 7 |
+
# Import statements:
|
| 8 |
+
import logging
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from datasets import Dataset as HfDataset
|
| 11 |
+
from datasets import load_from_disk
|
| 12 |
+
from .tokenizer import SegmentationTokenizer, SentenceSegmenter
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 16 |
+
# #
|
| 17 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 18 |
+
class SegmentationDataset(Dataset):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
huggingface_dataset: str | HfDataset,
|
| 22 |
+
tokenizer: SegmentationTokenizer,
|
| 23 |
+
segmenter: SentenceSegmenter,
|
| 24 |
+
logger: logging.Logger = None,
|
| 25 |
+
percentage: float = 1.0,
|
| 26 |
+
return_type: type = dict
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
A segmentation dataset takes a huggingface dataset or a path to a dataset on disk with the
|
| 30 |
+
wikipedia-segmentation format. It loads the dataset and prepares it for training.
|
| 31 |
+
|
| 32 |
+
Wikipedia-segmentation format:
|
| 33 |
+
- The dataset is expected to be a huggingface dataset or a path to a dataset on disk.
|
| 34 |
+
- The dataset should contain the following fields:
|
| 35 |
+
>>> sample = {
|
| 36 |
+
>>> 'text': ['Article 1', 'Article 2', ...],
|
| 37 |
+
>>> 'titles': ['Title 1', 'Title 2', ...],
|
| 38 |
+
>>> 'id': str,
|
| 39 |
+
>>> 'words': int
|
| 40 |
+
>>> 'paragraphs': int
|
| 41 |
+
>>> 'sentences': int
|
| 42 |
+
>>> }
|
| 43 |
+
- The dataset should be a list of dictionaries, where each dictionary contains the fields above.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
huggingface_dataset : str | HfDataset
|
| 48 |
+
A huggingface dataset or a path to a dataset on disk with the wikipedia-segmentation format.
|
| 49 |
+
|
| 50 |
+
tokenizer : callable
|
| 51 |
+
A tokenizer function that takes a string and returns a list of tokens.
|
| 52 |
+
|
| 53 |
+
logger : logging.Logger, optional
|
| 54 |
+
Logger instance. If not provided, a null logger will be used.
|
| 55 |
+
|
| 56 |
+
percentage : float
|
| 57 |
+
Percentage of the dataset to use. Default is 1.0 (100%).
|
| 58 |
+
|
| 59 |
+
return_type : type
|
| 60 |
+
The return type of __getitem__, either dict or tuple. Default is dict.
|
| 61 |
+
|
| 62 |
+
Raises
|
| 63 |
+
------
|
| 64 |
+
ValueError
|
| 65 |
+
If the huggingface_dataset is not a string or a HfDataset.
|
| 66 |
+
ValueError
|
| 67 |
+
If the tokenizer is not a callable function or class.
|
| 68 |
+
ValueError
|
| 69 |
+
If the sentence_tokenizer is not a callable function or class.
|
| 70 |
+
ValueError
|
| 71 |
+
If the dtype is not a type.
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
# Null logging:
|
| 75 |
+
if not isinstance(logger, logging.Logger):
|
| 76 |
+
self.logger = logging.getLogger("null")
|
| 77 |
+
self.logger.addHandler(logging.NullHandler())
|
| 78 |
+
else:
|
| 79 |
+
self.logger = logger
|
| 80 |
+
|
| 81 |
+
# Loading:
|
| 82 |
+
if isinstance(huggingface_dataset, HfDataset):
|
| 83 |
+
self.huggingface_dataset = huggingface_dataset
|
| 84 |
+
elif isinstance(huggingface_dataset, str):
|
| 85 |
+
self.huggingface_dataset = load_from_disk(huggingface_dataset)
|
| 86 |
+
else:
|
| 87 |
+
self.logger.error(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 88 |
+
raise ValueError(f'[SegmentationDataset] huggingface_dataset must be either a string or a HfDataset.')
|
| 89 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset: {self.huggingface_dataset}')
|
| 90 |
+
self.logger.info(f'[SegmentationDataset] Loaded dataset length: {self.huggingface_dataset.num_rows}')
|
| 91 |
+
|
| 92 |
+
# Tokenizer:
|
| 93 |
+
if callable(tokenizer):
|
| 94 |
+
self.tokenizer = tokenizer
|
| 95 |
+
else:
|
| 96 |
+
self.logger.error(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 97 |
+
raise ValueError(f'[SegmentationDataset] Tokenizer must be a callable function.')
|
| 98 |
+
|
| 99 |
+
# Segmenter:
|
| 100 |
+
if not isinstance(segmenter, SentenceSegmenter):
|
| 101 |
+
self.logger.error(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 102 |
+
raise ValueError(f'[SegmentationDataset] Segmenter must be a SentenceSegmenter instance.')
|
| 103 |
+
else:
|
| 104 |
+
self.segmenter = segmenter
|
| 105 |
+
|
| 106 |
+
# Percentage:
|
| 107 |
+
if not (0.0 < percentage <= 1.0):
|
| 108 |
+
self.logger.error(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 109 |
+
raise ValueError(f'[SegmentationDataset] Percentage must be between 0.0 and 1.0.')
|
| 110 |
+
else:
|
| 111 |
+
self.percentage = percentage
|
| 112 |
+
|
| 113 |
+
# Return type:
|
| 114 |
+
if not isinstance(return_type, type):
|
| 115 |
+
self.logger.error(f'[SegmentationDataset] return_type must be a type.')
|
| 116 |
+
raise ValueError(f'[SegmentationDataset] return_type must be a type.')
|
| 117 |
+
elif return_type not in [dict, tuple]:
|
| 118 |
+
self.logger.error(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 119 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 120 |
+
else:
|
| 121 |
+
self.return_type = return_type
|
| 122 |
+
|
| 123 |
+
def get_loader(self, batch_size=8, shuffle=True, num_workers=0, **kwargs) -> DataLoader:
|
| 124 |
+
"""
|
| 125 |
+
Returns a PyTorch DataLoader for this dataset.
|
| 126 |
+
|
| 127 |
+
Parameters
|
| 128 |
+
----------
|
| 129 |
+
batch_size : int
|
| 130 |
+
Number of samples per batch.
|
| 131 |
+
shuffle : bool
|
| 132 |
+
Whether to shuffle the dataset.
|
| 133 |
+
num_workers : int
|
| 134 |
+
Number of worker processes.
|
| 135 |
+
**kwargs
|
| 136 |
+
Additional arguments for DataLoader.
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
[torch.utils.data.DataLoader
|
| 141 |
+
Configured DataLoader.
|
| 142 |
+
"""
|
| 143 |
+
# Size handling:
|
| 144 |
+
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
|
| 145 |
+
pin_memory=True, **kwargs)
|
| 146 |
+
|
| 147 |
+
def __len__(self) -> int:
|
| 148 |
+
"""
|
| 149 |
+
Returns the number of samples in the dataset.
|
| 150 |
+
|
| 151 |
+
Returns
|
| 152 |
+
-------
|
| 153 |
+
int
|
| 154 |
+
Total number of samples.
|
| 155 |
+
"""
|
| 156 |
+
return int(self.huggingface_dataset.num_rows * self.percentage)
|
| 157 |
+
|
| 158 |
+
def __getitem__(self, idx) -> dict | tuple:
|
| 159 |
+
"""
|
| 160 |
+
Retrieves a single sample and generates segmentation labels.
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
idx : int
|
| 165 |
+
Index of the sample.
|
| 166 |
+
|
| 167 |
+
Returns
|
| 168 |
+
-------
|
| 169 |
+
tuple
|
| 170 |
+
A tuple or dict (x_i, y_i, mask_x) with noisy input and corresponding target.
|
| 171 |
+
"""
|
| 172 |
+
sample = self.huggingface_dataset[idx]['text']
|
| 173 |
+
sentences = self.segmenter(sample)
|
| 174 |
+
tokenized = self.tokenizer(sentences['sentences'])
|
| 175 |
+
|
| 176 |
+
if self.return_type == tuple:
|
| 177 |
+
return (
|
| 178 |
+
tokenized['input_ids'], # x
|
| 179 |
+
sentences['sentence_boundaries'], # y
|
| 180 |
+
tokenized['attention_mask'], # x_mask
|
| 181 |
+
sentences['sentence_mask'], # y_mask
|
| 182 |
+
sentences['sentence_candidates'], # y_prime_mask
|
| 183 |
+
)
|
| 184 |
+
elif self.return_type == dict:
|
| 185 |
+
return_value = {
|
| 186 |
+
'input': tokenized['input_ids'],
|
| 187 |
+
'input_mask': tokenized['attention_mask'],
|
| 188 |
+
'labels': sentences['sentence_boundaries'],
|
| 189 |
+
'output_mask': sentences['sentence_mask'],
|
| 190 |
+
'candidate_mask': sentences['sentence_candidates']
|
| 191 |
+
}
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f'[SegmentationDataset] return_type must be either dict or tuple.')
|
| 194 |
+
return return_value
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|
| 198 |
+
# END OF FILE #
|
| 199 |
+
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|