Lexa commited on
Commit
b5a0bec
·
1 Parent(s): bb10ea5

Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files

Browse files
.gitattributes CHANGED
@@ -1 +1,2 @@
1
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -115,4 +115,7 @@ mortimer_env.txt
115
  _LexaLCM_Block0/Datasets/
116
 
117
  # UV
118
- uv.lock
 
 
 
 
115
  _LexaLCM_Block0/Datasets/
116
 
117
  # UV
118
+ uv.lock
119
+
120
+ # Unsafe files
121
+ *.pt
Patches/Patch_TorchLoader.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Patch for fairseq2.utils.file.load_tensors
2
+ #
3
+ # This patch allows for loading safetensors files
4
+ #
5
+ # It is used in the two_tower_diffusion_lcm model loader:
6
+ # ./lcm/models/two_tower_diffusion_lcm/loader.py
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from pathlib import Path
12
+ from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union
13
+ from warnings import catch_warnings
14
+
15
+ import torch
16
+ from torch import Tensor
17
+ from typing_extensions import TypeAlias
18
+
19
+ from fairseq2.typing import Device
20
+
21
+ from safetensors.torch import load_file
22
+
23
+ MapLocation: TypeAlias = Optional[
24
+ Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]]
25
+ ]
26
+
27
+
28
+ class TensorLoader(Protocol):
29
+ """Loads tensors from files."""
30
+
31
+ def __call__(
32
+ self,
33
+ path: Path,
34
+ *,
35
+ map_location: MapLocation = None,
36
+ restrict: bool = False,
37
+ ) -> Dict[str, Any]:
38
+ """
39
+ :param path:
40
+ The path to the file.
41
+ :param map_location:
42
+ Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
43
+ config_loader=load_two_tower_diffusion_lcm_config,
44
+ factory=create_two_tower_diffusion_lcm_model,
45
+ checkpoint_converter=convert_lcm_checkpoint,
46
+ restrict_checkpoints=False,
47
+ )
48
+ """
49
+
50
+
51
+ class TensorDumper(Protocol):
52
+ """Dumps tensors to files."""
53
+
54
+ def __call__(self, data: Mapping[str, Any], path: Path) -> None:
55
+ """
56
+ :param data:
57
+ The dictionary containing tensors and other auxiliary data.
58
+ :param path:
59
+ The path to the file.
60
+ """
61
+
62
+
63
+ def load_tensors(
64
+ path: Path,
65
+ *,
66
+ map_location=None,
67
+ restrict: bool = False,
68
+ ) -> Dict[str, Any]:
69
+ """Load a checkpoint in .pt or .safetensors format."""
70
+ if str(path).endswith(".safetensors"):
71
+ tensors = load_file(str(path), device=str(map_location) if map_location else "cpu")
72
+ return {"model": tensors} # ✅ Wrap it like a .pt file
73
+
74
+
75
+ with warnings.catch_warnings():
76
+ warnings.simplefilter("ignore")
77
+ return torch.load(
78
+ str(path), map_location, weights_only=restrict # type: ignore[arg-type]
79
+ )
80
+
81
+
82
+ def dump_tensors(data: Mapping[str, Any], path: Path) -> None:
83
+ """Dump ``data`` to a PyTorch tensor file under ``path``."""
84
+ with catch_warnings():
85
+ warnings.simplefilter("ignore") # Suppress noisy FSDP warnings.
86
+
87
+ torch.save(data, path)
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{metadata.pt → metadata.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:72a183d6a5d90ff8ae2bd4ceaab9cc107d20c53f4e4d37f1152fbc27b356a5b4
3
- size 5284
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
3
+ size 16
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{model.pt → model.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c587394ef0a4ab818d9e023974d351d70852a2f02847efdbd13ef327a4c6ac33
3
- size 575893434
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f6160840e8a76276b126f4da6ded5568c2dcc777fd40007ccfa5bcfb08d9bce
3
+ size 575804960
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml CHANGED
@@ -1,5 +1,5 @@
1
  __source__: inproc
2
- checkpoint: file:///home/lexa/DevProjects/_Unsorted/LexaLCM_Pre0_288M/_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model.pt
3
  model_arch: arch_lexa_lcm_pre0
4
  model_family: two_tower_diffusion_lcm
5
  name: on_the_fly_lcm
 
1
  __source__: inproc
2
+ checkpoint: file:///home/lexa/DevProjects/_Unsorted/LexaLCM_Pre0_288M/_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model.safetensors
3
  model_arch: arch_lexa_lcm_pre0
4
  model_family: two_tower_diffusion_lcm
5
  name: on_the_fly_lcm
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{rank_0.pt → rank_0.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:766a467589456f9d9e060dc79d5837c8e7f0f9dd8572997cae32c97d66eb74cb
3
- size 2307681830
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
3
+ size 16
lcm/datasets/base.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import logging
7
+ from abc import ABC, abstractmethod
8
+ from typing import Callable, Dict, Generic, Iterator, Optional, Sequence, TypeVar, Union
9
+
10
+ import torch
11
+ from fairseq2.data.data_pipeline import DataPipeline
12
+ from fairseq2.gang import FakeGang, Gang
13
+ from fairseq2.typing import DataType
14
+
15
+ from lcm.datasets.configs import (
16
+ DataLoadingConfig,
17
+ DatasetConfigT,
18
+ create_dataset_config_from_cards,
19
+ )
20
+ from lcm.datasets.dataloading import (
21
+ build_weighted_pipeline_with_renaming as default_build_fn,
22
+ )
23
+ from lcm.utils.common import Batched, set_mkl_num_threads
24
+
25
+ BatchT_co = TypeVar("BatchT_co", bound=Union[Dict, Batched], covariant=True)
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class DataLoader(ABC, Generic[BatchT_co, DatasetConfigT]):
30
+ def __init__(
31
+ self,
32
+ data_config: DataLoadingConfig,
33
+ datasets: Sequence[DatasetConfigT],
34
+ gang: Gang,
35
+ builder_func: Callable[..., DataPipeline] = default_build_fn,
36
+ dtype: DataType = torch.float16,
37
+ ):
38
+ self.data_config = data_config
39
+ self.datasets = list(map(create_dataset_config_from_cards, datasets))
40
+ self.dtype = dtype
41
+ self.gang = gang
42
+ self.builder_func = builder_func
43
+
44
+ self._pipeline: Optional[DataPipeline] = None
45
+
46
+ @property
47
+ def pipeline(self) -> DataPipeline:
48
+ if self._pipeline is None:
49
+ logger.info(f"R{self.gang.rank} self._pipeline is None, building...")
50
+ gang_rank = self.gang.rank if self.gang else 0
51
+ world_size = self.gang.size if self.gang else 1
52
+
53
+ self._pipeline = self.builder_func(
54
+ self.datasets, self.data_config, gang_rank, world_size
55
+ )
56
+ assert self._pipeline, (
57
+ f"Cannot build data pipeline from config {self.data_config}"
58
+ )
59
+ return self._pipeline
60
+
61
+ def destroy(self) -> None:
62
+ """Destroy the pipeline to rebuild it with different shuffling"""
63
+ self._pipeline = None
64
+ # Build again and reset it
65
+ logger.info(f"R{self.gang.rank} resetting the pipeline in DataLoader.destroy")
66
+ self.reset()
67
+
68
+ def reset(self) -> None:
69
+ """
70
+ Applying reset will result in different shuffling for next iterations,
71
+ since pipeline will use modified generator state from previous one.
72
+ This's suitable side effect for `sharding_in_memory=False` (training) scenario.
73
+
74
+ Illustrative example :
75
+ >>> import torch
76
+ >>> from fairseq2.data import read_sequence
77
+
78
+ >>> def get_one_epoch_pipeline():
79
+ ... torch.manual_seed(13)
80
+ ... return read_sequence(list(range(10))).shuffle(5)
81
+
82
+ >>> bb = get_one_epoch_pipeline().and_return()
83
+ >>> list(bb)
84
+ [3, 1, 2, 4, 0, 8, 5, 6, 9, 7]
85
+ >>> bb.reset()
86
+ >>> list(bb)
87
+ [4, 0, 3, 2, 1, 9, 7, 6, 8, 5]
88
+ """
89
+ self.pipeline.reset()
90
+
91
+ @abstractmethod
92
+ def iterate_batches(self) -> Iterator[BatchT_co]: ...
93
+
94
+
95
+ class BaseDataLoader(DataLoader[dict, DatasetConfigT]):
96
+ def __init__(
97
+ self,
98
+ data_config: DataLoadingConfig,
99
+ datasets: Sequence[DatasetConfigT],
100
+ dtype: DataType = torch.float16,
101
+ gang: Gang = None,
102
+ ) -> None:
103
+ gang = gang or FakeGang()
104
+ super().__init__(
105
+ data_config=data_config,
106
+ datasets=datasets,
107
+ builder_func=default_build_fn,
108
+ dtype=dtype,
109
+ gang=gang,
110
+ )
111
+ set_mkl_num_threads()
112
+
113
+ def iterate_batches(self) -> Iterator[dict]:
114
+ yield from iter(self.pipeline)
lcm/datasets/configs.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import logging
7
+ import re
8
+ from dataclasses import asdict, dataclass, fields
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar
12
+
13
+ # XXX: these should be kept for eval of filters expressions
14
+ import pyarrow as pa
15
+ import pyarrow.compute as pc
16
+ import pyarrow.parquet as pq
17
+ from fairseq2.assets import default_asset_store
18
+ from omegaconf import MISSING
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ParquetBatchFormat(Enum):
24
+ pyarrow = 0
25
+ pandas = 1
26
+ torch = 2
27
+
28
+
29
+ class ColumnsNames(Enum):
30
+ source_column = "_source_column"
31
+ source_text_column = "_source_text_column"
32
+ target_column = "_target_column"
33
+ target_text_column = "_target_text_column"
34
+
35
+ dataset_name = "_dataset_name"
36
+
37
+
38
+ @dataclass
39
+ class SonarTextColumn:
40
+ text_value: Optional[str] = None
41
+ """
42
+ Raw text expression that will be used as constant colum after being sententized and sonarized.
43
+ """
44
+ text_column: Optional[str] = None
45
+ sonar_column: Optional[str] = None
46
+ """
47
+ Note `text_column` and `sonar_column` should be aligned (so `sonar_column` should be sonar encoded `text_column`).
48
+ If `sonar_column` is None and `text_column` is provided, we set `sonar_column = f"{text_column}_sonar_emb"` as default processing value!
49
+ """
50
+
51
+
52
+ @dataclass
53
+ class ParquetDatasetLimitOptions:
54
+ fraction_of_files: Optional[float] = None
55
+ nb_files: Optional[int] = None
56
+ nb_fragments: Optional[int] = None
57
+ nb_rows: Optional[int] = None
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class SonarDecoderConfig:
62
+ tokenizer: str = "text_sonar_basic_decoder"
63
+ """ SONAR tokenizer """
64
+
65
+ decoder: str = "text_sonar_basic_decoder"
66
+ """ SONAR decoder"""
67
+
68
+ lang: str = "eng_Latn"
69
+ """ Target language """
70
+
71
+ max_tokens_in_sentence: int = 256
72
+ """Maximum number of tokens generated in the text"""
73
+
74
+ temperature: float = 1.0
75
+ """The decoding logit temperature, where values greater than 1.0 produce more
76
+ uniform logits; values less than 1.0 produce sharper logits."""
77
+
78
+
79
+ @dataclass(frozen=True)
80
+ class SonarEncoderConfig:
81
+ tokenizer: str = "text_sonar_basic_encoder"
82
+ """ SONAR tokenizer """
83
+
84
+ encoder: str = "text_sonar_basic_encoder"
85
+ """ SONAR decoder"""
86
+
87
+ lang: str = "eng_Latn"
88
+ """ Target language """
89
+
90
+
91
+ @dataclass
92
+ class DatasetConfig:
93
+ """
94
+ Generic dataset config
95
+ """
96
+
97
+ columns: Optional[List[str]] = None
98
+ """The list of columns to load.
99
+ Columns such as `source_column`, ..., will be added automatically.
100
+ """
101
+
102
+ source_text_column: Optional[str] = None
103
+ """ Column to load as source raw text"""
104
+
105
+ target_text_column: Optional[str] = None
106
+ """ Column to load as target raw text for paired data"""
107
+
108
+ source_prefix_text: Optional[str] = None
109
+ """ Text to prepend to the content of the source_column"""
110
+
111
+ source_suffix_text: Optional[str] = None
112
+ """ Text to append to the content of the target_column"""
113
+
114
+ target_prefix_text: Optional[str] = None
115
+ """ Text to prepend to the content of the source_column"""
116
+
117
+ target_suffix_text: Optional[str] = None
118
+ """ Text to append to the content of the target_column"""
119
+
120
+ source_sequences: Optional[List[SonarTextColumn]] = None
121
+ """
122
+ Designed to make on-the-fly prompts from existing columns that are more complex than prefix and suffix.
123
+ Each element of source_sequences is a SonarTextColumn, which can be either:
124
+ - constant raw text (with the text_value argument)
125
+ - text column (with the text_column argument)
126
+ - sonar column (with the sonar_column argument)
127
+
128
+ Note that text_value cannot co-exist with text_column or sonar_column, and sonar column cannot be specified
129
+ without a text column. Further behaviour for parquet datasets:
130
+ - If text_value is specified, this will be split to sentences and sonarized
131
+ - If only text_column is specified, a new column named "<text_column>_sonar_emb" will be added as sonar_column.
132
+ - If both (text_column, sonar_column) is specified,
133
+
134
+ All SonarTextColumn elements from source_sequences will be concatenated together to produce new source_column
135
+ and source_text_column (same for target), which will have names as defined in ColumnsNames.
136
+ Using source_sequences is NOT compatible with using source_column or source_text_column, as well as quality filtering.
137
+ """
138
+
139
+ target_sequences: Optional[List[SonarTextColumn]] = None
140
+ """Designed to make on-the-fly prompts / instructions for target column, see `source_sequences` for more details"""
141
+
142
+ silent_freeze: bool = False
143
+ """If set to true, the config value can only be set once, i.e. it will not be able to update after the being set is instantiated.
144
+ This is helpful to avoid side-effect in setting some configs after being specified by the user application (Hydra, CLI)"""
145
+
146
+ def __post_init__(self):
147
+ if self.source_sequences is not None:
148
+ if self.source_text_column is not None:
149
+ logger.warning(
150
+ f"Both `source_sequence` and `source_text_column` is specified. "
151
+ f"Ignore `source_text_column` and use default value `{ColumnsNames.source_text_column.value}`.\n"
152
+ f"(`source_sequences` = {self.source_sequences}, \n"
153
+ f"`source_text_column` = {self.source_text_column} )"
154
+ )
155
+ self.source_text_column = ColumnsNames.source_text_column.value
156
+
157
+ if self.target_sequences is not None:
158
+ if self.target_text_column is not None:
159
+ logger.warning(
160
+ f"Both `target_sequences` and `target_text_column` is specified. "
161
+ f"Ignore `target_text_column` and use default value `{ColumnsNames.target_text_column.value}`.\n"
162
+ f"(`target_sequences` = {self.target_sequences}, \n"
163
+ f"`target_text_column` = {self.target_text_column} )"
164
+ )
165
+ self.target_text_column = ColumnsNames.target_text_column.value
166
+
167
+ for col in (self.source_sequences or []) + (self.target_sequences or []):
168
+ if col.text_value is not None:
169
+ assert col.text_column is None and col.sonar_column is None
170
+ else:
171
+ assert col.text_column is not None
172
+
173
+ self._has_initialized_: bool = True
174
+
175
+ def __setattr__(self, name: str, value: Any) -> None:
176
+ if not getattr(self, "_has_initialized_", False):
177
+ return super().__setattr__(name, value)
178
+ if name == "silent_freeze":
179
+ raise ValueError(
180
+ "Direct change of silent_freeze outside __init__ is forbidden"
181
+ )
182
+ if self.silent_freeze and getattr(self, name) not in ("", None, MISSING):
183
+ logger.debug(
184
+ f"Ignore change of {name} since silent_freeze is set and value is not empty ({getattr(self, name)})"
185
+ )
186
+ return
187
+ super().__setattr__(name, value)
188
+
189
+ def override_attr(self, name: str, value: Any) -> None:
190
+ try:
191
+ self._has_initialized_ = False
192
+ super().__setattr__(name, value)
193
+ finally:
194
+ self._has_initialized_ = True
195
+
196
+ def freeze(self) -> None:
197
+ """Turn the `silent_freeze` flag on"""
198
+ try:
199
+ self._has_initialized_ = False
200
+ self.silent_freeze = True
201
+ finally:
202
+ self._has_initialized_ = True
203
+
204
+
205
+ @dataclass
206
+ class JSONDatasetConfig(DatasetConfig):
207
+ """Config for datasets stored in JsonL format."""
208
+
209
+ file_path: str = str()
210
+ """
211
+ Path to the directory containing the Jsonl dataset.
212
+ Each task will replace this wil a real Json files
213
+ TODO: Add support for remote JsonL file (e.g. with "s3://...")
214
+ """
215
+
216
+ prompt_template: Optional[str] = None
217
+ """
218
+ A jinja-format string to apply for each item in the dataset to transform into a string.
219
+ Useful for example when compiling a dynamic instruction / prompt for training or evaluation.
220
+ Note that when this is specified, it will take precedence over the "affix" option, i.e. the
221
+ columns `source_prefix_text`, `source_suffix_text`,... will be ignored.
222
+ """
223
+
224
+ def __setattr__(self, name: str, value: Any) -> None:
225
+ if not getattr(self, "_has_initialized_", False):
226
+ return super().__setattr__(name, value)
227
+
228
+ if name == "silent_freeze":
229
+ raise ValueError("Direct change of silent_freeze is forbidden")
230
+
231
+ if self.silent_freeze:
232
+ if getattr(self, name) not in ("", None, MISSING):
233
+ logger.debug(
234
+ f"Ignore change of {name} in silent frozen mode when value is not empty ({getattr(self, name)})"
235
+ )
236
+ return
237
+
238
+ # Ensure we cannot set the default `prompt_template` value when the user specifies
239
+ # source_sequences or source_text_column explicitly
240
+ for hi_prior_col, lo_prior_col, lo_prior_value in [
241
+ ("source_sequences", "source_text_column", self.source_text_column),
242
+ ("target_sequences", "target_text_column", self.target_text_column),
243
+ ("prompt_template", "source_sequences", self.source_sequences),
244
+ ("prompt_template", "source_prefix_text", self.source_prefix_text),
245
+ ("prompt_template", "source_suffix_text", self.source_suffix_text),
246
+ ]:
247
+ if name == hi_prior_col and lo_prior_value not in ("", None, MISSING):
248
+ logger.warning(
249
+ f"Updating value of {hi_prior_col} will cause conflicts with the user-defined "
250
+ f"value in {lo_prior_col}. The update will be ignored.\n"
251
+ )
252
+ return
253
+
254
+ super().__setattr__(name, value)
255
+
256
+
257
+ @dataclass
258
+ class ParquetDatasetConfig(DatasetConfig):
259
+ """
260
+ Config for datasets stored in Parquet format.
261
+
262
+ XXX: this config should not hold non-trival default values.
263
+ We want this to make datacards info and hydra config merge easier.
264
+ All None value should be filled up in downstream `build_parquet_iterator_pipeline`.
265
+ """
266
+
267
+ name: Optional[str] = None
268
+ """When name is provided, it will use preregistered cards to populate all attributes.
269
+ name convention is the following
270
+ - {card_name}={split}:{weight}
271
+
272
+ Example:
273
+ - wiki
274
+ - wiki:0.2 # no split
275
+ - wiki=dev # default weight=1
276
+ - wiki=dev:0.2
277
+
278
+ Cards attributes will be overwritten by user defined ParquetDatasetConfig in
279
+ `create_dataset_config_from_cards`.
280
+ """
281
+
282
+ parquet_path: str = str()
283
+ """The path to parquet dataset file.
284
+ if `parquet_path` is remote (like stats with "s3://..."),
285
+ the filesystem will be automatically detected and `filesystem_expr` should remain None
286
+ """
287
+
288
+ weight: float = 1.0
289
+ """
290
+ Indicates relative weight of dataset that can be used for sampling from different datasets.
291
+ """
292
+
293
+ limit: Optional[ParquetDatasetLimitOptions] = None
294
+ """
295
+ Contains different options that allows to load only a part of the provided dataset.
296
+ It will **always** take some number of **first** fragments according to the order in which
297
+ they appear in the dataset and this logic will not be depedent on suffling/seed.
298
+ When several limits are provided, each of them will be applied (resulting in the strongest limit).
299
+ """
300
+
301
+ source_column: Optional[str] = None
302
+ """ Column to load as source embeddings"""
303
+
304
+ target_column: Optional[str] = None
305
+ """ Column to load as target embeddings for paired data"""
306
+
307
+ source_quality_column: Optional[str] = None
308
+ source_quality_range: Optional[Any] = None
309
+
310
+ partition_filters: Optional[str] = None
311
+ """
312
+ Filters that should be applied only on partition columns for fast partition prunning.
313
+ This filters should not be duplicated in `filters` (below) which are used on materialized data.
314
+ To know the partition columns on dataset :
315
+ ```python
316
+ >>> pq.ParquetDataset(parquet_path).partitioning.schema.names
317
+ ```
318
+ Note that for if `parquet_path` references a single file -> the result above will NOT be correct (returns all columns).
319
+ Note that for a single file case, there should no partition_filters since there're no partitions !!
320
+ """
321
+
322
+ filters: Optional[str] = None
323
+ """See https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression
324
+
325
+ Some examples :
326
+
327
+ >>> import pyarrow.compute as pc
328
+ >>> import pyarrow as pa
329
+
330
+ >>> filters = (pc.field("data_split") == pc.scalar("train")) & (pc.field("duration") > 7)
331
+ >>> filters = pa.compute.greater(pa.compute.utf8_length(ds.field("lang1_text")), 4)
332
+ >>> filters = pa.compute.less_equal(pa.compute.list_value_length(pa.dataset.field("audio_wav")), 16_000 * 30)
333
+
334
+ Note that all fields used here should be among existing columns in the dataset schema.
335
+ For hydra compatibility, we need to pass this filters as an str expression that'll be passed to `eval(...)`
336
+ """
337
+
338
+ filesystem_expr: Optional[str] = None
339
+ """
340
+ DEPRECATED : not used any more and will be remove soon
341
+ """
342
+
343
+ filesystem: Optional[Any] = None
344
+ """
345
+ DEPRECATED: not used any more and will be remove soon
346
+ """
347
+
348
+ split_to_row_groups: Optional[bool] = None
349
+ """If ``True``, uses Parquet row groups instead of simple partitions which
350
+ are generally smaller. Highly recommended for non-partitioned parquet files."""
351
+
352
+ nb_parallel_fragments: Optional[int] = None
353
+ """
354
+ This parameter can be dataset specific:
355
+ For dataset with large number of sentences per document (sample),
356
+ it's enough to set `nb_parallel_fragments=2 or 3`.
357
+ For datasets, with smaller number of sentences (~10) and small row_group_size (~200-600),
358
+ `nb_parallel_fragments` could be increase to 10 - 20.
359
+
360
+ The number of Parquet fragments allowed to be read in parallel. Higher
361
+ values will result in higher speeds, better randomization, and higher memory
362
+ footprint. If partition size is rather small compared to the batch size, we
363
+ recommend to increase ``nb_parallel_fragments``.
364
+
365
+ Leaving ``nb_parallel_fragments`` to None will trigger auto-detection based on dataset metadata.
366
+ """
367
+
368
+ sharding_in_memory: bool = False
369
+ """
370
+ This option should be activated for sharding small datasets whose total number of row groups is small
371
+ that makes sharding per row group impossible.
372
+ """
373
+
374
+ def __post_init__(self):
375
+ super().__post_init__()
376
+
377
+ if self.source_sequences is not None:
378
+ if self.source_column is not None:
379
+ logger.warning(
380
+ f"Both `source_sequences` and `source_column` is specified. "
381
+ f"Ignore `source_column` and use default value `{ColumnsNames.source_column.value}`.\n"
382
+ f"(`source_sequences` = {self.source_sequences}, \n"
383
+ f"`source_column` = {self.source_column} )"
384
+ )
385
+ assert self.source_quality_range is None
386
+ self.source_column = ColumnsNames.source_column.value
387
+
388
+ if self.target_sequences is not None:
389
+ if self.target_column is not None:
390
+ logger.warning(
391
+ f"Both `target_sequences` and `target_column` is specified. "
392
+ f"Ignore `target_column` and use default value `{ColumnsNames.target_column.value}`.\n"
393
+ f"(`target_sequences` = {self.target_sequences}, \n"
394
+ f"`target_column` = {self.target_column} )"
395
+ )
396
+ self.target_column = ColumnsNames.target_column.value
397
+
398
+ for col in (self.source_sequences or []) + (self.target_sequences or []):
399
+ if col.sonar_column is None and col.text_value is None:
400
+ assert col.text_column, f"Invalid SonarTextColumn: {col}"
401
+ col.sonar_column = col.text_column + "_sonar_emb"
402
+
403
+ if self.source_quality_range is None:
404
+ self.source_quality_column = None
405
+
406
+
407
+ DatasetConfigT = TypeVar("DatasetConfigT", bound=DatasetConfig, contravariant=True)
408
+
409
+
410
+ @dataclass
411
+ class DataLoadingConfig:
412
+ multiple_dataset_chaining: str = "sample"
413
+ """
414
+ This option allows to chain several datasets together.
415
+ The chaining can be done in two ways:
416
+ - `sample` : each dataset will be sampled with the provided weight
417
+ - `concat` : datasets will be concatenated together (no weights taken into account)
418
+ - `round_robin`: datasets will be sampled in a round robin fashion (no weights taken into account)
419
+ """
420
+ batch_size: Optional[int] = None
421
+ """The output batch size."""
422
+
423
+ order_by_length: bool = True
424
+ """
425
+ Whether to create the batches with homogeneous tokens length
426
+ for more efficient padding.
427
+ """
428
+
429
+ max_tokens: Optional[int] = None
430
+ """Used with the ``order_by_length`` option to control the total number of
431
+ padded tokens in each batch. Typically, this option is preferred over
432
+ ``batch_size`` to reduce the memory footprint.
433
+ """
434
+
435
+ len_to_wrap_long_seq: Optional[int] = None
436
+ """
437
+ Wrapping a source sequences to the length of `len_to_wrap_long_seq`.
438
+ For instance, for a `len_to_wrap_long_seq=2`
439
+ batch = {
440
+ "source": [["v1", "v2", "v3", "v4", "v5"], ["u1", "u2", "u3"], ["w1"]],
441
+ }
442
+ will be transormed to
443
+ 1. if packing is False :
444
+ batch = {
445
+ "source": [['v1', 'v2'], ['v3', 'v4'], ['v5'], ["u1", "u2"], ["u3"], ["w1"]]
446
+ }
447
+ 1. if packing is True :
448
+ batch = {
449
+ "source": [['v1', 'v2'], ['v3', 'v4'], ['v5', 'u1'], ["u2", "u3"], ["w1"]]
450
+ }
451
+
452
+ Note: currently only allowed to be used with no "target" provided (unsupervised style) !
453
+ """
454
+
455
+ packing: bool = False
456
+ """
457
+ If True, all sequential documents (seqs of sentences) will be concated into one big document
458
+ before applying wrapping.
459
+ This will result in all samples (except maybe one) having exactly `len_to_wrap_long_seq` length !
460
+ """
461
+
462
+ wrap_before_affixing: bool = False
463
+ """
464
+ If True, we will wrap the sequences before adding the source prefix/suffix.
465
+ Recommended when pre-training with packed data i.e len_to_wrap_long_seq not None and packing=True
466
+ """
467
+
468
+ max_sentence_len_in_doc: Optional[int] = None
469
+ """
470
+ Remove samples (documents) whose `source_text_column` contains at least one sentence of len > `max_sentence_len_in_doc`.
471
+ This operations is done after long sequences wrapping (if applicable).
472
+ Typically values: 100 - 300
473
+ """
474
+ min_sentence_len_in_doc: Optional[int] = None
475
+ """
476
+ Remove samples (documents) `source_text_column` contains at least one sentence of len < `min_sentence_len_in_doc`.
477
+ This operations is done after long sequences wrapping (if applicable).
478
+ Typically values: 5 - 15
479
+ """
480
+
481
+ max_sentence_len_in_target_doc: Optional[int] = None
482
+ """
483
+ same filtering option as above but for `target_text_column`
484
+ """
485
+ min_sentence_len_in_target_doc: Optional[int] = None
486
+ """
487
+ same filtering option as above but for `target_text_column`
488
+ """
489
+
490
+ min_length_of_sequences: Optional[int] = 1
491
+ """
492
+ Remove samples (documents) whose `source_text_column` are scrictly shorter than `min_length_of_sequences`.
493
+ This operations is done after long sequences wrapping (if applicable).
494
+ One can use here the same value as for sequences wrapping
495
+ in order to produce all sequences with the same length.
496
+ """
497
+ min_length_of_sequences_after_batching: Optional[int] = 1
498
+ """
499
+ Remove source sequences shorter than `min_length_of_sequences_after_batching`
500
+ This filtering is applied after batching and potentially affixing and wrapping.
501
+ """
502
+ min_length_of_target_sequences: Optional[int] = 1
503
+ """
504
+ Same as above applied for `target_text_column`
505
+ """
506
+ min_length_of_target_sequences_after_batching: Optional[int] = 1
507
+ """
508
+ Same as above applied for `target_text_column`
509
+ """
510
+
511
+ output_format: ParquetBatchFormat = ParquetBatchFormat.torch
512
+ """The format to use for output batches."""
513
+
514
+ shuffle: bool = True
515
+ """If ``True``, shuffles the dataset samples during the iteration. If ``False``
516
+ and ``order_by_length`` is ``None``, the batch samples will be produced in
517
+ natural Parquet dataset reading order."""
518
+
519
+ drop_null: bool = True
520
+ """If ``True``, drops rows containing any null value."""
521
+
522
+ seed: int = 123
523
+ """The RNG seed value for deterministic behavior."""
524
+
525
+ nb_epochs: int = 100
526
+ """
527
+ Number of passes over the data before iterations stop
528
+ """
529
+
530
+ min_batch_size: int = 1
531
+ """Drops batches whose length is less than ``min_batch_size``"""
532
+
533
+ nb_prefetch: float = 3.0
534
+ """The number of producer groups (of size `nb_parallel_fragments`) to
535
+ prefetch."""
536
+
537
+ num_parallel_calls: float = 1.5
538
+ """The number of parallel calls in map operations."""
539
+
540
+ use_threads: bool = False
541
+ """Whether pyarrow should use its internal threads to read the Parquet file.
542
+ Since we rely on the external parallelism, this param is tuned off by
543
+ default."""
544
+
545
+ ignore_checkpointed_pipeline: bool = False
546
+ """Whether to ignore the saved datapipeline state or load it when resuming.
547
+ Temporary fix for issues re-loading saved checkpoints"""
548
+
549
+ even_sharding: bool = False
550
+ """
551
+ This option should be activated ONLY for validataion on small datasets
552
+ to guarantee the perfect data sharding accross the workers.
553
+ Note that in current impmentation, activating `even_sharding` requires `sharding_in_memory=True`
554
+ which will lead to big overhead for big dataset.
555
+ Note also that some fraction of the data may be dropped due to even sharding.
556
+ For big validation datasets, prefer using large `nb_epoch` + limiting `max_validation_iterations`
557
+ instead of using `even_sharding` !
558
+
559
+ For training use case, it should left to False and combined with large number of epochs.
560
+ For evaluation use case, it also should be False since we dont care about the batch syncronization across different workers.
561
+ """
562
+ max_iteration_steps: Optional[int] = None
563
+ """
564
+ If not None, it will be used to limit the number of batches produced per each dataset
565
+ """
566
+
567
+
568
+ @dataclass
569
+ class ValidationDataLoadingConfig(DataLoadingConfig):
570
+ """
571
+ This class allows to have some hardcoded parameters for data loading of validation datasets
572
+ """
573
+
574
+ multiple_dataset_chaining: str = "concat"
575
+ nb_epochs: int = 1
576
+ min_batch_size: int = 1 # we want to keep all samples
577
+ shuffle: bool = False # we dont need the randomness here
578
+ batch_size: Optional[int] = None
579
+ max_tokens: Optional[int] = None
580
+ """
581
+ Leaving both `max_tokens` and `batch_size` to None will trigger auto-detection based on dataset metadata and distributed training world size.
582
+ to make more or less even distribution of samples across workers. Typically,
583
+ if worker_batch_size = total_batch_size // world_size <= 40, we will use batch_size=worker_batch_size,
584
+ otherwise we will use max_tokens=min(total_tokens_number // world_size, 3000).
585
+ See dataloading:SingleParquetDatasetDataloader::set_validation_params for more details.
586
+ """
587
+
588
+
589
+ @dataclass
590
+ class EvaluationDataLoadingConfig(DataLoadingConfig):
591
+ """
592
+ This class allows to have some hardcoded parameters for data loading of evaluation datasets.
593
+ In partitcular, even in distributed setup evaluation should not require workers syncronization.
594
+ Therefore, we set `even_sharding` = False to get the all data samples !
595
+ """
596
+
597
+ multiple_dataset_chaining: str = "concat"
598
+ nb_epochs: int = 1 # only ONE full pass over the full data !
599
+ min_batch_size: int = 1 # we want to keep all samples
600
+ shuffle: bool = False # we dont need the randomness here
601
+ batch_size: Optional[int] = 10
602
+ max_tokens: Optional[int] = None # this should be ok for most of models
603
+ even_sharding: bool = False # we dont want to lose any sample !
604
+ sharding_in_memory: bool = True # activate sharding by rank and world size
605
+ rank: int = 0
606
+ world_size: int = 1
607
+ max_samples: Optional[int] = None # fmt: skip
608
+ """evaluate only the first n samples (for debugging)"""
609
+
610
+
611
+ def setup_fairseq2_extensions() -> None:
612
+ # path where all datacards should be located !
613
+ cards_dir = Path(__file__).parent.parent.joinpath("datacards")
614
+ if cards_dir.exists():
615
+ default_asset_store.add_file_metadata_provider(cards_dir)
616
+
617
+
618
+ setup_fairseq2_extensions()
619
+
620
+
621
+ def get_cluster() -> Optional[str]:
622
+ """Returns the cluster name of the current environment.
623
+ User can implement their own logic to load datasets living in different locations/clusters
624
+ """
625
+ return "s3"
626
+
627
+
628
+ def _resolve_parquet_path(options: Dict[str, str]) -> Optional[str]:
629
+ cluster_name = get_cluster() or "s3"
630
+
631
+ parquet_path = options.get(cluster_name)
632
+ if parquet_path is None:
633
+ # best effort - taking first element
634
+ parquet_path = next(iter(options.values()))
635
+
636
+ return parquet_path
637
+
638
+
639
+ def _resolve_filters(
640
+ split: Optional[str],
641
+ card_filter: Optional[str],
642
+ user_filter: Optional[str],
643
+ card_partition_filters: Optional[str],
644
+ user_partition_filters: Optional[str],
645
+ ) -> Tuple[Optional[pc.Expression], Optional[pc.Expression]]:
646
+ custom_filters = user_filter or card_filter
647
+ partition_filters = user_partition_filters or card_partition_filters
648
+
649
+ if custom_filters is not None:
650
+ custom_filters = pq.filters_to_expression(eval(custom_filters))
651
+
652
+ if partition_filters is not None:
653
+ partition_filters = pq.filters_to_expression(eval(partition_filters))
654
+
655
+ if split:
656
+ split_filter = pc.equal(pc.field("split"), split)
657
+ if partition_filters is None:
658
+ partition_filters = split_filter
659
+ else:
660
+ partition_filters = pa.compute.if_else(
661
+ split_filter, partition_filters, False
662
+ )
663
+
664
+ return custom_filters, partition_filters
665
+
666
+
667
+ def _default_resolver(a, b):
668
+ res = a if bool(a) and a is not MISSING else b
669
+ return res
670
+
671
+
672
+ def get_parquet_config_from_name(
673
+ name: str, config: Optional[ParquetDatasetConfig] = None
674
+ ) -> ParquetDatasetConfig:
675
+ """
676
+ name convention is the following
677
+ - {card_name}={split}:{weight}
678
+ """
679
+ # parsing name
680
+ pattern = r"^(?P<card_name>[a-zA-Z0-9_]+)=?(?P<split>[a-zA-Z0-9_]*)?:?(?P<weight>\d+(?:\.\d+)?)?$"
681
+ match_ = re.match(pattern, name)
682
+ assert match_ is not None, f"name parsing failed: {name}"
683
+ card_name = match_.group("card_name")
684
+ split = match_.group("split")
685
+ weight = match_.group("weight")
686
+
687
+ if weight:
688
+ weight = float(weight)
689
+ logger.info(
690
+ f"Parsing {name} : card_name={card_name}, split={split}, weight={weight}"
691
+ )
692
+
693
+ reload_config = default_asset_store.retrieve_card(card_name)
694
+ cards_metadata: Dict[str, Any] = {**reload_config._metadata}
695
+
696
+ if config is None:
697
+ config = ParquetDatasetConfig(name=card_name, parquet_path="")
698
+
699
+ assert config is not None
700
+
701
+ if isinstance(config, ParquetDatasetConfig):
702
+ config_dict = asdict(config)
703
+ else:
704
+ config_dict = config # type: ignore
705
+
706
+ metadata = {}
707
+ # resolve parquet_path according to the cluster
708
+ for field in fields(ParquetDatasetConfig):
709
+ field_name = field.name
710
+ metadata[field_name] = _default_resolver(
711
+ config_dict.get(field_name), cards_metadata.get(field_name)
712
+ )
713
+
714
+ if isinstance(metadata["source_sequences"], list):
715
+ metadata["source_sequences"] = [
716
+ SonarTextColumn(**item) for item in metadata["source_sequences"]
717
+ ]
718
+
719
+ if isinstance(metadata["target_sequences"], list):
720
+ metadata["target_sequences"] = [
721
+ SonarTextColumn(**item) for item in metadata["target_sequences"]
722
+ ]
723
+
724
+ metadata["parquet_path"] = _default_resolver(
725
+ config_dict.get("parquet_path"),
726
+ _resolve_parquet_path(cards_metadata["parquet_path"]),
727
+ )
728
+
729
+ metadata["filters"], metadata["partition_filters"] = _resolve_filters(
730
+ split,
731
+ card_filter=cards_metadata.get("filters"),
732
+ user_filter=config_dict.get("filters"),
733
+ card_partition_filters=cards_metadata.get("partition_filters"),
734
+ user_partition_filters=config_dict.get("partition_filters"),
735
+ )
736
+ if weight: # priority from parsed name
737
+ metadata["weight"] = weight
738
+ metadata["name"] = name
739
+
740
+ # to patch nested hydra case !
741
+ if metadata["limit"] is not None and isinstance(metadata["limit"], dict):
742
+ metadata["limit"] = ParquetDatasetLimitOptions(**metadata["limit"])
743
+
744
+ return ParquetDatasetConfig(**metadata)
745
+
746
+
747
+ def create_dataset_config_from_cards(
748
+ config: DatasetConfig,
749
+ ) -> DatasetConfig:
750
+ if getattr(config, "name", None) is None:
751
+ return config
752
+ output_config = get_parquet_config_from_name(config.name, config) # type: ignore
753
+ return output_config
754
+
755
+
756
+ def get_renaming_mappers(configs: Sequence[DatasetConfig]) -> List[dict]:
757
+ used_columns = [x for x in ColumnsNames.__members__ if x != "dataset_name"]
758
+
759
+ pre_mapping = {
760
+ att: [getattr(cc, att) for cc in configs if hasattr(cc, att)]
761
+ for att in used_columns
762
+ }
763
+
764
+ mappers: List[dict] = [{} for _ in configs]
765
+ for att, val in pre_mapping.items():
766
+ if all(x is None for x in val):
767
+ continue
768
+ for i, name in enumerate(val):
769
+ if name is None:
770
+ raise ValueError(
771
+ f"All datasets should provide {att} param, but got {configs[i]}"
772
+ )
773
+ mappers[i][name] = getattr(ColumnsNames, att).value
774
+ return mappers
lcm/datasets/dataloader.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import gc
7
+ import logging
8
+ from copy import deepcopy
9
+ from functools import partial
10
+ from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple
11
+
12
+ import pyarrow.compute as pc
13
+ import torch
14
+ from fairseq2.data.data_pipeline import DataPipeline, read_sequence
15
+ from fairseq2.data.text import TextTokenizer
16
+ from fairseq2.gang import FakeGang, Gang
17
+ from fairseq2.models.sequence import SequenceBatch
18
+ from fairseq2.nn.padding import pad_seqs
19
+ from fairseq2.typing import DataType
20
+ from fairseq2.utils.state import Stateful
21
+ from sonar.models.sonar_text import load_sonar_tokenizer
22
+
23
+ from lcm.datasets.base import DataLoader
24
+ from lcm.datasets.batch import LCMInput
25
+ from lcm.datasets.configs import (
26
+ ColumnsNames,
27
+ DataLoadingConfig,
28
+ ParquetDatasetConfig,
29
+ ParquetDatasetLimitOptions,
30
+ SonarDecoderConfig,
31
+ )
32
+ from lcm.datasets.utils import move_eos_to_the_end
33
+ from lcm.utils.common import set_mkl_num_threads
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def truncate_sequence(tokens: torch.Tensor, max_len: int = 512) -> torch.Tensor:
39
+ if len(tokens) > max_len:
40
+ return tokens[:max_len]
41
+ return tokens
42
+
43
+
44
+ class LCMDataLoader(DataLoader[LCMInput, ParquetDatasetConfig], Stateful):
45
+ def __init__(
46
+ self,
47
+ data_config: DataLoadingConfig,
48
+ datasets: Sequence[ParquetDatasetConfig],
49
+ dtype: DataType = torch.float16,
50
+ use_decoder_backprop: bool = False,
51
+ max_subword_length: int = 64,
52
+ gang: Gang = None,
53
+ sonar_decoder_config: Optional[SonarDecoderConfig] = None,
54
+ ) -> None:
55
+ gang = gang or FakeGang()
56
+
57
+ super().__init__(
58
+ data_config=data_config,
59
+ datasets=datasets,
60
+ dtype=dtype,
61
+ gang=gang,
62
+ )
63
+ set_mkl_num_threads()
64
+
65
+ self.use_decoder_backprop = use_decoder_backprop
66
+ self.sonar_tokenizer: Optional[TextTokenizer] = None
67
+ self.max_subword_length = max_subword_length
68
+ if sonar_decoder_config is not None:
69
+ self.setup_sonar_decoder_tokenizer(config=sonar_decoder_config)
70
+ self._dummy_example: Optional[LCMInput] = None
71
+
72
+ def setup_sonar_decoder_tokenizer(
73
+ self,
74
+ config: SonarDecoderConfig,
75
+ ):
76
+ if self.use_decoder_backprop:
77
+ # The tokenizer
78
+ self.tokenizer = load_sonar_tokenizer(config.tokenizer, progress=False)
79
+ # Target text encoder
80
+ self.sonar_tokenizer = self.tokenizer.create_encoder(
81
+ task="translation",
82
+ lang=config.lang,
83
+ mode="target",
84
+ device=self.gang.device,
85
+ )
86
+ else:
87
+ self.sonar_tokenizer = None
88
+
89
+ def _prepare_subword_tokens(
90
+ self, batch: Dict[str, Any]
91
+ ) -> Tuple[Optional[SequenceBatch], Optional[SequenceBatch]]:
92
+ """
93
+ Given a batch of paragraphs/documents,
94
+ prepare a batch of sentences (flattened) tokenized at the subword-level
95
+ to feed to the SONAR decoder (a standard token-level decoder)
96
+
97
+ Args:
98
+ batch: attributes of a batch from the dataset.
99
+ A batch is M documents/paragraphs each spanning
100
+ a variable number of sentences {N_1, ..., N_M}.
101
+
102
+ E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
103
+ ...[sent^M_1, ... sent^M_{N_M}],
104
+ 'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
105
+ where D is the sonar embedding dimension.
106
+ Returns:
107
+ Toeknized sentences (subword-level) in (\sum_i=1^M N_i, max_len)
108
+ where max_len is min(self.max_subword_length, max length of the sentences in the batch)
109
+
110
+ """
111
+
112
+ if not self.use_decoder_backprop:
113
+ return None, None
114
+
115
+ # flatten the sentences from different documents/paragraphs
116
+ flattened_source_text = (
117
+ pc.list_flatten(batch[ColumnsNames.source_text_column.value])
118
+ .to_pandas()
119
+ .values
120
+ )
121
+
122
+ pipeline: DataPipeline = (
123
+ read_sequence(flattened_source_text)
124
+ .map(
125
+ [
126
+ self.sonar_tokenizer, # type: ignore
127
+ partial(truncate_sequence, max_len=self.max_subword_length),
128
+ ],
129
+ num_parallel_calls=int(max(8 * self.data_config.num_parallel_calls, 1)),
130
+ )
131
+ .and_return(max_num_warnings=4)
132
+ )
133
+
134
+ tokens_seqs, tokens_padding_mask = pad_seqs(list(pipeline)) # type: ignore
135
+ prefix_batch = SequenceBatch(tokens_seqs, tokens_padding_mask)
136
+ # TODO: instead of moving the EOS around, make the tokenizer append at the tokenization.
137
+ target_batch = move_eos_to_the_end(
138
+ prefix_batch,
139
+ eos_token_id=self.tokenizer.vocab_info.eos_idx,
140
+ pad_token_id=self.tokenizer.vocab_info.pad_idx,
141
+ )
142
+
143
+ return prefix_batch, target_batch
144
+
145
+ def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
146
+ """
147
+ Given a batch of documents,
148
+ prepare a batch of input features for the LCM
149
+ This step is to simply fetch the right column for source/target & source text
150
+ and convert torch NestedTensors to list of tensors
151
+
152
+ Args:
153
+ batch: attributes of a batch from the dataset.
154
+ A batch is M documents each spanning
155
+ a variable number of sentences {N_1, ..., N_M}.
156
+
157
+ E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
158
+ ...[sent^M_1, ... sent^M_{N_M}],
159
+ 'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
160
+ where D is the sonar embedding dimension.
161
+ Returns:
162
+ LCMInput(
163
+ source: SONAR embeddings of the source text
164
+ i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]
165
+ target: If supervised data: SONAR embeddings of the source text
166
+ tokens: Tokenized flattened sentences for the SONAR decoder (see `_prepare_subword_tokens`)
167
+ )
168
+
169
+ """
170
+
171
+ # Prepare sentence-wise subword tokens if needed:
172
+ tokens, target_tokens = self._prepare_subword_tokens(batch)
173
+
174
+ # Load target embeddings if requested and to propagate all other embeddings
175
+
176
+ possible_emb_columns = {
177
+ "source": ColumnsNames.source_column,
178
+ "target": ColumnsNames.target_column,
179
+ }
180
+
181
+ outputs = {
182
+ "tokens": tokens,
183
+ "target_tokens": target_tokens,
184
+ "name": batch[ColumnsNames.dataset_name.value],
185
+ "batch": batch,
186
+ }
187
+ for key, col in possible_emb_columns.items():
188
+ col_name = col.value
189
+ if col_name in batch:
190
+ dtype = self.dtype if "_length" not in key else torch.int64
191
+ embs = [x.to(self.gang.device).to(dtype) for x in batch[col_name]]
192
+ # Special case when some embeddings are not shaped as (T, D) e.g., XLMC's answer columns
193
+ if embs[0].dim() == 1 and "_length" not in key:
194
+ embs = [t.unsqueeze(0) for t in embs]
195
+ else:
196
+ embs = None
197
+ outputs[key] = embs
198
+ assert outputs["source"] is not None, (
199
+ "LCMDataLoader requires `source` sequences to be present in batches"
200
+ )
201
+ return LCMInput(**outputs)
202
+
203
+ def iterate_batches(self) -> Iterator[LCMInput]:
204
+ yield from map(self._tokenize_batch, self.pipeline)
205
+
206
+ def iterate_dummy_batches(self) -> Iterator[LCMInput]:
207
+ """
208
+ it's needed to simulate the data that follows the strucutre of self.pipeline (by always returning the same element).
209
+ It can be used only for fast forward pass (to avoid uneven sharding multi-gpus training).
210
+ """
211
+ if self._dummy_example is None:
212
+ # patching the params to get less data with less cost
213
+ limited_datasets = deepcopy(self.datasets)
214
+ for ds_conf in limited_datasets:
215
+ assert isinstance(ds_conf, ParquetDatasetConfig)
216
+ ds_conf.limit = ParquetDatasetLimitOptions(nb_fragments=1)
217
+
218
+ # Copy the true data config and reduce the batch size.
219
+ # When wrapping data, we want to also wrap the dummy batches
220
+ # to not exceed model max_length
221
+ dummy_dataloading_config = deepcopy(self.data_config)
222
+ dummy_dataloading_config.batch_size = 1
223
+
224
+ self._dummy_example = self._tokenize_batch(
225
+ next(
226
+ iter(
227
+ self.builder_func(
228
+ limited_datasets, dummy_dataloading_config, 0, 1
229
+ )
230
+ )
231
+ )
232
+ )
233
+ gc.collect()
234
+
235
+ while True:
236
+ yield self._dummy_example
237
+
238
+ def state_dict(self) -> Dict[str, Any]:
239
+ logger.info("Getting the data pipeline state ...")
240
+ state = self.pipeline.state_dict(strict=False)
241
+ return state
242
+
243
+ def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
244
+ if state_dict is not None:
245
+ assert self.pipeline is not None
246
+ if self.data_config.ignore_checkpointed_pipeline:
247
+ logger.warning("Ignoring existing dataloader state")
248
+ else:
249
+ try:
250
+ self.pipeline.load_state_dict(state_dict)
251
+ logger.info(f"Reloaded datapipeline state: {str(state_dict)[:400]}")
252
+ except ValueError:
253
+ logger.warning(
254
+ f"Failed to load dataloader state: {str(state_dict)[:400]}"
255
+ )
256
+ else:
257
+ # retro-compatibility
258
+ logger.warning(f"Attempt to restore a dataloader {self} with empty state")
lcm/datasets/dataloading.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import logging
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from functools import lru_cache, partial
10
+ from typing import Any, Generator, List, Optional, Sequence
11
+
12
+ import numpy as np
13
+ import pyarrow as pa
14
+ import pyarrow.compute as pc
15
+ import pyarrow.parquet as pq
16
+ from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder
17
+ from fairseq2.data.parquet.tools import BatchOutputType, apply_filter, concat_table
18
+ from pyarrow.dataset import get_partition_keys
19
+ from stopes.utils.arrow_utils import (
20
+ explode_table_with_fixed_length,
21
+ explode_table_with_max_length,
22
+ is_list_like,
23
+ )
24
+
25
+ from lcm.datasets.configs import (
26
+ DataLoadingConfig,
27
+ ParquetBatchFormat,
28
+ ParquetDatasetConfig,
29
+ ValidationDataLoadingConfig,
30
+ get_renaming_mappers,
31
+ )
32
+ from lcm.datasets.parquet_utils import (
33
+ build_batching_loop_over_one_table,
34
+ define_parquet_dataset,
35
+ filter_document_by_quality,
36
+ filter_long_short_sentence_document,
37
+ filter_table_with_different_lengths,
38
+ get_row_group_level_metadata,
39
+ materialize_sequence,
40
+ prefix_and_suffix_one_list_column,
41
+ prepare_suffix_prefix_embeddings,
42
+ pyarrow_table_to_torch_dict,
43
+ renaming,
44
+ shuffle_table,
45
+ stream_parquet_fragments,
46
+ )
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+ PA_NB_CPU = 4
51
+ pa.set_cpu_count(PA_NB_CPU)
52
+ pa.set_io_thread_count(PA_NB_CPU)
53
+
54
+
55
+ def return_none_on_failure(func):
56
+ def wrapper(*args, **kwargs):
57
+ try:
58
+ return func(*args, **kwargs)
59
+ except Exception as e:
60
+ print(f"An error occurred: {e}")
61
+ return None
62
+
63
+ return wrapper
64
+
65
+
66
+ @dataclass
67
+ class GlobalPQStats:
68
+ min_number_of_fragment: int
69
+ mean_fragment_length: float
70
+ mean_fragment_number_of_tokens: Optional[float] = None
71
+
72
+
73
+ class SingleParquetDatasetDataloader:
74
+ _pq_ds: Optional[pq.ParquetDataset] = None
75
+ proxy_number_of_fragments: int
76
+ basic_stats: GlobalPQStats
77
+
78
+ def __init__(
79
+ self, dataset_config: ParquetDatasetConfig, loading_config: DataLoadingConfig
80
+ ):
81
+ self.dataset_config = deepcopy(dataset_config)
82
+ self.loading_config = deepcopy(loading_config)
83
+ self.config_post_init()
84
+ nb_parallel_fragments = self.dataset_config.nb_parallel_fragments
85
+ assert isinstance(nb_parallel_fragments, int)
86
+ self.nb_parallel_fragments: int = nb_parallel_fragments
87
+
88
+ @property
89
+ def is_validation(self) -> bool:
90
+ return isinstance(self.loading_config, ValidationDataLoadingConfig)
91
+
92
+ def head(self, top=5):
93
+ return self.dataset._dataset.head(top)
94
+
95
+ @property
96
+ def dataset(self) -> pq.ParquetDataset:
97
+ if self._pq_ds is None:
98
+ self._pq_ds = self.get_dataset()
99
+
100
+ return self._pq_ds
101
+
102
+ @property
103
+ def full_schema(self) -> pa.Schema:
104
+ return self.dataset.schema
105
+
106
+ def _warn_filters_usage(self, pq_ds: pq.ParquetDataset) -> None:
107
+ partition_filters = self.dataset_config.partition_filters
108
+
109
+ frags = pq_ds.fragments
110
+ if len(frags) == 0:
111
+ raise ValueError(
112
+ f"Working on empty dataset, probably due to wrong `partition_filters` definition : {partition_filters}"
113
+ )
114
+
115
+ partition_columns = list(
116
+ get_partition_keys(frags[0].partition_expression).keys()
117
+ )
118
+ if not partition_columns and partition_filters is not None:
119
+ raise ValueError(
120
+ f"Partition filters {partition_filters} is set but dataset has NO partition columns"
121
+ )
122
+
123
+ if partition_columns and partition_filters is not None:
124
+ expression_candidates = [
125
+ x for x in partition_columns if x in str(partition_filters)
126
+ ]
127
+ if len(expression_candidates) == 0:
128
+ logger.warning(
129
+ f"Partition filters are NOT compatible with partition columns, got: "
130
+ f"partition_filters={partition_filters} and partition_columns={partition_columns}"
131
+ )
132
+ filters = self.dataset_config.filters
133
+ if partition_columns and filters is not None:
134
+ expression_candidates = [x for x in partition_columns if x in str(filters)]
135
+ if len(expression_candidates) > 0:
136
+ logger.warning(
137
+ f"Partitionning columns {expression_candidates} are used as `filters` {filters}. ",
138
+ "You may want to use them in `partition_filters` instead",
139
+ )
140
+
141
+ def get_dataset(self) -> pq.ParquetDataset:
142
+ if isinstance(self.dataset_config.filters, str):
143
+ self.dataset_config.filters = pq.filters_to_expression(
144
+ eval(self.dataset_config.filters)
145
+ )
146
+
147
+ if isinstance(self.dataset_config.partition_filters, str):
148
+ self.dataset_config.partition_filters = pq.filters_to_expression(
149
+ eval(self.dataset_config.partition_filters)
150
+ )
151
+
152
+ pq_ds = define_parquet_dataset(
153
+ str(self.dataset_config.parquet_path), self.dataset_config.partition_filters
154
+ )
155
+
156
+ try:
157
+ self._warn_filters_usage(pq_ds)
158
+ except Exception as e:
159
+ logger.info(f"getting exception during filters examination : {e}")
160
+
161
+ return pq_ds
162
+
163
+ def set_validation_params(
164
+ self,
165
+ world_size: int,
166
+ default_max_tokens: int = 3000,
167
+ default_batch_size: int = 40,
168
+ ) -> None:
169
+ if not (
170
+ self.loading_config.batch_size is None
171
+ and self.loading_config.max_tokens is None
172
+ ):
173
+ return
174
+
175
+ total_batch_size = int(
176
+ self.basic_stats.min_number_of_fragment
177
+ * self.basic_stats.mean_fragment_length
178
+ )
179
+ batch_size = total_batch_size // world_size + int(
180
+ total_batch_size % world_size != 0
181
+ )
182
+
183
+ # for small datasets we can set `batch_size`
184
+ if (
185
+ batch_size <= default_batch_size
186
+ or self.basic_stats.mean_fragment_number_of_tokens is None
187
+ ):
188
+ self.loading_config.batch_size = min(batch_size, default_batch_size)
189
+ self.loading_config.max_tokens = None
190
+ else:
191
+ # for bigger dataset, let's use `max_tokens`
192
+ self.loading_config.batch_size = None
193
+ total_tokens_number = int(
194
+ self.basic_stats.min_number_of_fragment
195
+ * self.basic_stats.mean_fragment_number_of_tokens
196
+ )
197
+ self.loading_config.max_tokens = min(
198
+ max(total_tokens_number // world_size, 1), default_max_tokens
199
+ )
200
+
201
+ def build_dataload_pipeline(
202
+ self, rank: int = 0, world_size: int = 1
203
+ ) -> DataPipelineBuilder:
204
+ if world_size > 1:
205
+ assert self.loading_config.seed is not None, (
206
+ "for distributed training with `world_size` > 1, `seed` should be set !"
207
+ )
208
+ if self.is_validation:
209
+ self.set_validation_params(world_size)
210
+
211
+ # to propagate sharding_in_memory
212
+ if not self.dataset_config.sharding_in_memory:
213
+ sharding_in_memory = (
214
+ self.loading_config.nb_epochs * self.proxy_number_of_fragments
215
+ < 2 * world_size
216
+ )
217
+ else:
218
+ sharding_in_memory = self.dataset_config.sharding_in_memory
219
+ if self.loading_config.even_sharding:
220
+ sharding_in_memory = True
221
+
222
+ if sharding_in_memory:
223
+ logger.info("Activating sharding_in_memory")
224
+
225
+ self.random_state = np.random.RandomState(
226
+ self._get_inner_seed(rank, sharding_in_memory)
227
+ )
228
+ pipeline = self.get_fragments_pipeline()
229
+
230
+ if not sharding_in_memory:
231
+ pipeline = pipeline.shard(
232
+ shard_idx=rank,
233
+ num_shards=world_size,
234
+ allow_uneven=not self.loading_config.even_sharding,
235
+ )
236
+
237
+ pipeline = self.add_basic_fragment_loading_pipeline(pipeline)
238
+
239
+ pipeline = self.create_on_the_fly_columns(pipeline)
240
+ pipeline = self.filter_by_aligned_length(pipeline)
241
+
242
+ # If we want to wrap before adding affixes
243
+ if self.loading_config.wrap_before_affixing:
244
+ pipeline = self.add_wrapping_to_max_length_pipeline(pipeline)
245
+
246
+ # Filtering
247
+ pipeline = self.add_quality_score_filters(pipeline)
248
+ pipeline = self.add_min_sentence_number_in_doc_filter(
249
+ pipeline,
250
+ min_source_length=self.loading_config.min_length_of_sequences,
251
+ min_target_length=self.loading_config.min_length_of_target_sequences,
252
+ )
253
+ pipeline = self.add_min_max_sentence_len_in_doc_filter(pipeline)
254
+
255
+ # Affix
256
+ pipeline = self._add_source_target_affixes_to_pipeline(pipeline)
257
+
258
+ def cost_fn(table) -> float:
259
+ cost = 0
260
+ for name in [
261
+ self.dataset_config.source_column,
262
+ self.dataset_config.target_column,
263
+ ]:
264
+ if name is not None:
265
+ col = table[name]
266
+ if is_list_like(col):
267
+ cost += pa.compute.list_value_length(col).to_numpy().sum()
268
+ else:
269
+ # we should not be there, but let take batch_size as a proxy
270
+ cost += len(col)
271
+ return cost
272
+
273
+ pipeline = pipeline.dynamic_bucket(
274
+ self._shuffling_tokens_size,
275
+ cost_fn,
276
+ min_num_examples=self.nb_parallel_fragments,
277
+ max_num_examples=100, # max number of small fragements
278
+ drop_remainder=False,
279
+ )
280
+ pipeline = pipeline.map(concat_table, num_parallel_calls=1)
281
+
282
+ # wrap documents after affixing
283
+ if not self.loading_config.wrap_before_affixing:
284
+ # Note that packing with proper attention masks and position codes requires
285
+ # document indices that cover all sentences. Currently this can only come from affixing before wrapping.
286
+ # Adding affixes after wrapping will require annexing these affixes to edge sentences which is not intuitive.
287
+ if self.loading_config.shuffle:
288
+ pipeline = pipeline.map(
289
+ partial(shuffle_table, random_state=self.random_state),
290
+ num_parallel_calls=1,
291
+ )
292
+ pipeline = self.add_wrapping_to_max_length_pipeline(pipeline)
293
+
294
+ # batch with batch_size or max_tokens
295
+ pipeline = self.add_inner_pipeline(pipeline)
296
+
297
+ # Filter once again after wrapping and batching to remove batches with few number sentences
298
+ pipeline = self.add_min_sentence_number_in_doc_filter(
299
+ pipeline,
300
+ min_source_length=self.loading_config.min_length_of_sequences_after_batching,
301
+ min_target_length=self.loading_config.min_length_of_target_sequences_after_batching,
302
+ )
303
+
304
+ # Remove batch sizes with a size smaller than min_batch_size (default=1)
305
+ pipeline = pipeline.filter(
306
+ lambda table: bool(len(table) >= self.loading_config.min_batch_size)
307
+ )
308
+
309
+ if sharding_in_memory:
310
+ pipeline = pipeline.shard(
311
+ shard_idx=rank,
312
+ num_shards=world_size,
313
+ allow_uneven=not self.loading_config.even_sharding,
314
+ )
315
+ if self.loading_config.max_iteration_steps is not None:
316
+ pipeline = pipeline.take(self.loading_config.max_iteration_steps)
317
+ pipeline = self.add_format_conversion(pipeline)
318
+ return pipeline
319
+
320
+ def create_on_the_fly_columns(
321
+ self, pipeline: DataPipelineBuilder
322
+ ) -> DataPipelineBuilder:
323
+ if self.dataset_config.source_sequences is not None:
324
+ assert self.dataset_config.source_column is not None, (
325
+ f"Expected a source_column - found {self.dataset_config.source_column}"
326
+ )
327
+ assert self.dataset_config.source_text_column is not None, (
328
+ f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
329
+ )
330
+
331
+ pipeline = pipeline.map(
332
+ partial(
333
+ materialize_sequence,
334
+ column_sequence=self.dataset_config.source_sequences,
335
+ vector_name=self.dataset_config.source_column,
336
+ text_name=self.dataset_config.source_text_column,
337
+ ),
338
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
339
+ )
340
+ if self.dataset_config.target_sequences is not None:
341
+ assert self.dataset_config.target_column is not None, (
342
+ f"Expected a target_column, found {self.dataset_config.target_column}"
343
+ )
344
+ assert self.dataset_config.target_text_column is not None, (
345
+ f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
346
+ )
347
+
348
+ pipeline = pipeline.map(
349
+ partial(
350
+ materialize_sequence,
351
+ column_sequence=self.dataset_config.target_sequences,
352
+ vector_name=self.dataset_config.target_column,
353
+ text_name=self.dataset_config.target_text_column,
354
+ ),
355
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
356
+ )
357
+
358
+ columns_to_drop = list(
359
+ set(self._get_sequences_columns()) - set(self.extra_required_columns)
360
+ )
361
+ if columns_to_drop:
362
+ pipeline = pipeline.map(lambda table: table.drop(columns_to_drop))
363
+
364
+ return pipeline
365
+
366
+ def _add_source_target_affixes_to_pipeline(self, pipeline) -> DataPipelineBuilder:
367
+ # prefixing/suffixing before wrapping/packing
368
+ ps_vals = self._get_suffix_prefix_vector()
369
+ pipeline = self.add_prefix_suffix_pipeline(
370
+ pipeline,
371
+ self.dataset_config.source_column,
372
+ ps_vals["source_prefix_vector"],
373
+ ps_vals["source_suffix_vector"],
374
+ )
375
+ pipeline = self.add_prefix_suffix_pipeline(
376
+ pipeline,
377
+ self.dataset_config.source_text_column,
378
+ ps_vals["source_prefix_sentences"],
379
+ ps_vals["source_suffix_sentences"],
380
+ )
381
+
382
+ pipeline = self.add_prefix_suffix_pipeline(
383
+ pipeline,
384
+ self.dataset_config.source_quality_column,
385
+ (
386
+ pa.array([None])
387
+ if self.dataset_config.source_prefix_text
388
+ else pa.array([])
389
+ ),
390
+ (
391
+ pa.array([None])
392
+ if self.dataset_config.source_suffix_text
393
+ else pa.array([])
394
+ ),
395
+ )
396
+
397
+ pipeline = self.add_prefix_suffix_pipeline(
398
+ pipeline,
399
+ self.dataset_config.target_column,
400
+ ps_vals["target_prefix_vector"],
401
+ ps_vals["target_suffix_vector"],
402
+ )
403
+ pipeline = self.add_prefix_suffix_pipeline(
404
+ pipeline,
405
+ self.dataset_config.target_text_column,
406
+ ps_vals["target_prefix_sentences"],
407
+ ps_vals["target_suffix_sentences"],
408
+ )
409
+
410
+ return pipeline
411
+
412
+ def _num_parallel_call(self, x: float) -> int:
413
+ return int(max(self.loading_config.num_parallel_calls * x, 1))
414
+
415
+ def _nb_prefetch(self, x: float) -> int:
416
+ return int(max(self.loading_config.nb_prefetch * x, 0))
417
+
418
+ def config_post_init(self) -> None:
419
+ if getattr(self.loading_config, "len_to_wrap_long_seq", None):
420
+ if (
421
+ self.dataset_config.target_column
422
+ or self.dataset_config.target_text_column
423
+ ):
424
+ raise ValueError(
425
+ "Using `len_to_wrap_long_seq` is not supported for suppervised training"
426
+ )
427
+
428
+ if self.loading_config.even_sharding:
429
+ assert self.loading_config.seed is not None, (
430
+ "`even_sharding` sharding requires to seed to be set"
431
+ )
432
+
433
+ if self.loading_config.max_tokens == 0:
434
+ self.loading_config.max_tokens = None
435
+ # setting max_tokens=0 turns off this option (argparser won't accept None directly)
436
+
437
+ if (self.loading_config.batch_size is None) == (
438
+ self.loading_config.max_tokens is None
439
+ ) and (not self.is_validation or self.loading_config.max_tokens is not None):
440
+ raise ValueError(
441
+ f"Need to provide either `batch_size` or `max_tokens` - \
442
+ Received batch_size={self.loading_config.batch_size} \
443
+ and max_tokens={self.loading_config.max_tokens}"
444
+ )
445
+
446
+ if self.loading_config.max_tokens and not self.dataset_config.source_column:
447
+ raise ValueError(
448
+ "Cannot batch based on `max_tokens` when `source_column` is not specified, "
449
+ "please use `batch_size` instead."
450
+ )
451
+
452
+ self.dataset_config.split_to_row_groups = (
453
+ self.dataset_config.split_to_row_groups
454
+ if self.dataset_config.split_to_row_groups is not None
455
+ else True
456
+ )
457
+ self.extra_required_columns = self.dataset_config.columns or []
458
+ self.dataset_config.override_attr("columns", self._get_minimal_columns())
459
+ logger.info(f"Following columns will be loaded: {self.dataset_config.columns}")
460
+
461
+ self.basic_stats = self.compute_stats()
462
+
463
+ self._shuffling_tokens_size = self._get_shuffling_tokens_size(self.basic_stats)
464
+ logger.info(
465
+ f"Bucketing will require at least: {self._shuffling_tokens_size} of tokens (source + target)"
466
+ )
467
+ logger.info(f"Dataset stats: {asdict(self.basic_stats)}")
468
+
469
+ self.proxy_number_of_fragments = self.basic_stats.min_number_of_fragment
470
+ if self.dataset_config.nb_parallel_fragments is None:
471
+ self.dataset_config.nb_parallel_fragments = (
472
+ self._find_nb_parallel_fragments(self.basic_stats)
473
+ )
474
+
475
+ logger.info(f"Dataset Config: {self.dataset_config}")
476
+ logger.info(f"Using Loading Config: {self.loading_config}")
477
+
478
+ def _get_shuffling_tokens_size(self, basic_stats) -> int:
479
+ """
480
+ `_shuffling_tokens_size` is used in dynamic bucketing to determine how many small parquet tables
481
+ (which are loaded raw parquet fragments that were potentially filtered on-the-fly) will be merged together :
482
+ we'll get a such number of consecutive parquet tables so that their total number of tokens (sentences)
483
+ will be greater than `_shuffling_tokens_size`.
484
+ It's called "shuffling" because all merged documents (from different tables) will be permuated together (if `shuffle=True`)
485
+ before being returned as final small batches (of required shape or volume).
486
+
487
+ The formula behind `_shuffling_tokens_size` is the following:
488
+ - If we use `max_tokens` in config, we want to have a least _shuffling_tokens_size = 4 * max_tokens,
489
+ so that at least 4 full batch will be formed next. It's good for shuffling and to avoid having "remainders" too often.
490
+ - For wrapping/packing case, we use a proxy for `max_tokens` as `batch_size` * `len_to_wrap_long_seq`
491
+ - If not, some average fragment characteristic `mean_fragment_number_of_tokens`, multiplied by 1.5 to get on average >=2 tables
492
+ - Finally, if no, other info is available, we use 10_000 as arbitrary proxy (good typical value for many of our datasets).
493
+
494
+ """
495
+ if self.loading_config.max_tokens is not None:
496
+ return 4 * self.loading_config.max_tokens
497
+ if (
498
+ self.loading_config.batch_size is not None
499
+ and self.loading_config.len_to_wrap_long_seq is not None
500
+ ):
501
+ return (
502
+ 4
503
+ * self.loading_config.len_to_wrap_long_seq
504
+ * self.loading_config.batch_size
505
+ )
506
+
507
+ if basic_stats.mean_fragment_number_of_tokens is not None:
508
+ return int(
509
+ 1.5 * basic_stats.mean_fragment_number_of_tokens
510
+ ) # to get few fragments grouped together
511
+
512
+ return 10_000 # default number that should not take a lot of RAM
513
+
514
+ def _find_nb_parallel_fragments(
515
+ self, basic_stats: GlobalPQStats, max_fragments=20, min_fragments=2
516
+ ) -> int:
517
+ """
518
+ Experimental!
519
+ Allows to determine nb of parallel fragments to load base on simple rules and dataset row group stats.
520
+ In particular, if `nb_parallel_fragments` will increase with increasing batch_size of max_tokens.
521
+ """
522
+ if basic_stats.min_number_of_fragment < 3:
523
+ return basic_stats.min_number_of_fragment
524
+
525
+ if basic_stats.mean_fragment_number_of_tokens is None:
526
+ logger.warning(
527
+ f"Cannot get `mean_fragment_number_of_tokens` from dataset {self.dataset_config}, `nb_parallel_fragement` detection can be wrong",
528
+ )
529
+
530
+ mean_fragment_number_of_tokens = (
531
+ basic_stats.mean_fragment_number_of_tokens or 5000
532
+ ) # typical, but arbitrary value
533
+ if (
534
+ self.loading_config.batch_size is None
535
+ and self.loading_config.max_tokens is None
536
+ ):
537
+ # it can happen for evaluation
538
+ nb_frags = 1.0
539
+ elif self.loading_config.batch_size is not None:
540
+ if self.loading_config.len_to_wrap_long_seq is not None:
541
+ max_tokens = (
542
+ self.loading_config.len_to_wrap_long_seq
543
+ * self.loading_config.batch_size
544
+ )
545
+ nb_frags = 3 * max_tokens / mean_fragment_number_of_tokens
546
+ else:
547
+ nb_frags = (
548
+ 5
549
+ * self.loading_config.batch_size
550
+ / basic_stats.mean_fragment_length
551
+ )
552
+ elif self.loading_config.max_tokens is not None:
553
+ nb_frags = (
554
+ 3 * self.loading_config.max_tokens / mean_fragment_number_of_tokens
555
+ )
556
+
557
+ return max(min(max_fragments, round(nb_frags)), min_fragments)
558
+
559
+ @lru_cache
560
+ def _get_sequences_columns(self):
561
+ candidate_columns = []
562
+ for col in (self.dataset_config.source_sequences or []) + (
563
+ self.dataset_config.target_sequences or []
564
+ ):
565
+ candidate_columns.append(col.text_column)
566
+ candidate_columns.append(col.sonar_column)
567
+ return [x for x in candidate_columns if x is not None]
568
+
569
+ def _get_minimal_columns(self):
570
+ # restrict on used collumns
571
+ candidate_columns = [
572
+ self.dataset_config.source_column,
573
+ self.dataset_config.source_text_column,
574
+ self.dataset_config.source_quality_column,
575
+ self.dataset_config.target_column,
576
+ self.dataset_config.target_text_column,
577
+ "split",
578
+ ] + self._get_sequences_columns()
579
+
580
+ minimal_columns: List[str] = [
581
+ x
582
+ for x in candidate_columns
583
+ if x is not None and x in self.full_schema.names
584
+ ]
585
+
586
+ if self.dataset_config.columns is None:
587
+ columns = sorted(set(minimal_columns))
588
+ else:
589
+ columns = sorted(set(minimal_columns + list(self.dataset_config.columns)))
590
+ if not set(columns).issubset(set(self.full_schema.names)):
591
+ raise ValueError(
592
+ f"columns {sorted(set(columns) - set(self.full_schema.names))} are not found in the dataset schema"
593
+ )
594
+
595
+ return columns
596
+
597
+ def _get_suffix_prefix_vector(self):
598
+ nested_result = prepare_suffix_prefix_embeddings(
599
+ self.dataset_config.source_prefix_text,
600
+ self.dataset_config.source_suffix_text,
601
+ self.dataset_config.target_prefix_text,
602
+ self.dataset_config.target_suffix_text,
603
+ )
604
+
605
+ names = (
606
+ ("source_prefix_vector", "source_prefix_sentences"),
607
+ ("source_suffix_vector", "source_suffix_sentences"),
608
+ ("target_prefix_vector", "target_prefix_sentences"),
609
+ ("target_suffix_vector", "target_suffix_sentences"),
610
+ )
611
+
612
+ return {n: v for nn, val in zip(names, nested_result) for n, v in zip(nn, val)}
613
+
614
+ def get_fragments_pipeline(self):
615
+ split_to_row_groups = self.dataset_config.split_to_row_groups
616
+ assert isinstance(split_to_row_groups, bool)
617
+
618
+ # one can use `list_parquet_fragments` for a full fragments scan
619
+ fragments_pipeline_builder = stream_parquet_fragments(
620
+ parquet_ds=self.dataset,
621
+ nb_epochs=self.loading_config.nb_epochs,
622
+ split_to_row_groups=split_to_row_groups,
623
+ shuffle=self.loading_config.shuffle,
624
+ seed=self.loading_config.seed,
625
+ limit_options=self.dataset_config.limit,
626
+ shuffling_window=20 * self.nb_parallel_fragments,
627
+ )
628
+
629
+ return fragments_pipeline_builder
630
+
631
+ def compute_stats(self, max_fragments=100) -> GlobalPQStats:
632
+ if self.dataset_config.source_sequences:
633
+ source_column = None
634
+ else:
635
+ source_column = self.dataset_config.source_column
636
+
637
+ split_to_row_groups = self.dataset_config.split_to_row_groups
638
+
639
+ columns = [source_column] if source_column else None
640
+
641
+ if (
642
+ self.dataset_config.limit is not None
643
+ and self.dataset_config.limit.nb_fragments is not None
644
+ ):
645
+ # TODO: take into account other limit options to get better estimates
646
+ max_fragments = min(self.dataset_config.limit.nb_fragments, max_fragments)
647
+
648
+ self._stats_df = get_row_group_level_metadata(
649
+ self.dataset, columns=columns, max_fragments=max_fragments
650
+ )
651
+ dim = 1
652
+ if source_column:
653
+ self._stats_df["num_tokens"] = self._stats_df[source_column].apply(
654
+ lambda x: x["num_values"]
655
+ )
656
+
657
+ type_source = self.full_schema.field(source_column).type
658
+ try:
659
+ dim = type_source.value_type.list_size
660
+ if not dim or dim < 0:
661
+ dim = 1 # not a fixed vector size
662
+ except AttributeError:
663
+ logger.warning(f"source column {source_column} is not of list type")
664
+ if self.dataset_config.nb_parallel_fragments is None:
665
+ logger.warning("you may need to provide `nb_parallel_fragments`")
666
+ dim = 1
667
+
668
+ if split_to_row_groups:
669
+ global_stats_df = self._stats_df
670
+ elif "num_tokens" in self._stats_df:
671
+ global_stats_df = self._stats_df.groupby("parquet_file_path").agg(
672
+ {"num_rows": "sum", "num_tokens": "sum"}
673
+ )
674
+ else:
675
+ global_stats_df = self._stats_df.groupby("parquet_file_path").agg(
676
+ {"num_rows": "sum"}
677
+ )
678
+
679
+ mean_len_frag = global_stats_df["num_rows"].mean()
680
+
681
+ if "num_tokens" in global_stats_df:
682
+ mean_num_tokens_frag = self._stats_df["num_tokens"].mean() / dim
683
+ else:
684
+ mean_num_tokens_frag = None
685
+
686
+ return GlobalPQStats(
687
+ len(global_stats_df),
688
+ mean_len_frag,
689
+ mean_fragment_number_of_tokens=mean_num_tokens_frag,
690
+ )
691
+
692
+ def add_inner_pipeline(self, pipeline: DataPipelineBuilder) -> DataPipelineBuilder:
693
+ loading_config = self.loading_config
694
+
695
+ columns_to_bucket = [
696
+ self.dataset_config.source_column,
697
+ self.dataset_config.target_column,
698
+ ]
699
+ columns_to_bucket = [x for x in columns_to_bucket if x is not None]
700
+
701
+ def inner_iterator(table: pa.Table) -> DataPipeline:
702
+ return build_batching_loop_over_one_table(
703
+ table=table,
704
+ order_by_length=self.loading_config.order_by_length,
705
+ length_column=columns_to_bucket,
706
+ batch_size=loading_config.batch_size,
707
+ max_tokens=loading_config.max_tokens,
708
+ shuffle=loading_config.shuffle,
709
+ seed=self.random_state.randint(0, 2**32),
710
+ num_parallel_calls=self._num_parallel_call(3),
711
+ )
712
+
713
+ return pipeline.yield_from(inner_iterator)
714
+
715
+ def _get_inner_seed(self, rank: int, sharding_in_memory: bool) -> Optional[int]:
716
+ if self.loading_config.seed is not None:
717
+ if not sharding_in_memory:
718
+ return int(self.loading_config.seed) + rank * 100_000
719
+ else:
720
+ # for `sharding_in_memory`, we want the same shuffling
721
+ # to guarantee the consistent sharding across ranks
722
+ return int(self.loading_config.seed)
723
+ else:
724
+ return None
725
+
726
+ def add_prefix_suffix_pipeline(
727
+ self,
728
+ pipeline: DataPipelineBuilder,
729
+ column: Optional[str],
730
+ prefix,
731
+ suffix,
732
+ ) -> DataPipelineBuilder:
733
+ if (suffix is None and prefix is None) or column is None:
734
+ return pipeline
735
+ pipeline = pipeline.map(
736
+ partial(
737
+ prefix_and_suffix_one_list_column,
738
+ column=column,
739
+ prefix_array=prefix,
740
+ suffix_array=suffix,
741
+ ),
742
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
743
+ )
744
+ return pipeline
745
+
746
+ def add_basic_fragment_loading_pipeline(
747
+ self, pipeline: DataPipelineBuilder
748
+ ) -> DataPipelineBuilder:
749
+ def load_fn(safe_frag):
750
+ try:
751
+ return safe_frag.load(columns=self.dataset_config.columns)
752
+ except Exception as e:
753
+ logger.error(
754
+ f"Error {e} occured while loading fragment {safe_frag} \n, skipping it"
755
+ )
756
+ return None
757
+
758
+ pipeline = pipeline.map(
759
+ load_fn,
760
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
761
+ )
762
+
763
+ pipeline = pipeline.filter(lambda table: bool(table is not None))
764
+
765
+ # we reapply the partition filters just in case of misusage
766
+ # but it should not change the performance
767
+ partition_filters = self.dataset_config.partition_filters
768
+ filters = self.dataset_config.filters
769
+ if partition_filters is not None and filters is not None:
770
+ full_filter = pa.compute.if_else(filters, partition_filters, False)
771
+ else:
772
+ full_filter = partition_filters if filters is None else filters
773
+
774
+ pipeline = pipeline.map(
775
+ partial(
776
+ apply_filter,
777
+ filters=full_filter,
778
+ drop_null=self.loading_config.drop_null,
779
+ )
780
+ )
781
+
782
+ pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
783
+ pipeline = pipeline.prefetch(self._nb_prefetch(self.nb_parallel_fragments))
784
+
785
+ return pipeline
786
+
787
+ def filter_by_aligned_length(
788
+ self, pipeline: DataPipelineBuilder
789
+ ) -> DataPipelineBuilder:
790
+ source_columns: List[str] = [
791
+ x
792
+ for x in (
793
+ self.dataset_config.source_column,
794
+ self.dataset_config.source_text_column,
795
+ self.dataset_config.source_quality_column,
796
+ )
797
+ if x is not None
798
+ ]
799
+
800
+ # filter out sample where number of sentences and number of sonar embeddings are not equal
801
+ # which should never happen normally
802
+
803
+ pipeline = pipeline.map(
804
+ partial(
805
+ filter_table_with_different_lengths,
806
+ columns=source_columns,
807
+ ),
808
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
809
+ )
810
+ pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
811
+
812
+ target_columns: List[str] = [
813
+ x
814
+ for x in (
815
+ self.dataset_config.target_column,
816
+ self.dataset_config.target_text_column,
817
+ )
818
+ if x is not None
819
+ ]
820
+
821
+ pipeline = pipeline.map(
822
+ partial(
823
+ filter_table_with_different_lengths,
824
+ columns=target_columns,
825
+ ),
826
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
827
+ )
828
+ pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
829
+
830
+ return pipeline
831
+
832
+ def add_wrapping_to_max_length_pipeline(
833
+ self, pipeline: DataPipelineBuilder
834
+ ) -> DataPipelineBuilder:
835
+ len_to_wrap_long_seq = getattr(
836
+ self.loading_config, "len_to_wrap_long_seq", None
837
+ )
838
+ if len_to_wrap_long_seq is None:
839
+ return pipeline
840
+
841
+ columns_to_wrap: List[str] = [
842
+ x
843
+ for x in (
844
+ self.dataset_config.source_column,
845
+ self.dataset_config.source_text_column,
846
+ self.dataset_config.source_quality_column,
847
+ )
848
+ if x is not None
849
+ ]
850
+
851
+ if self.loading_config.packing:
852
+ method = return_none_on_failure(explode_table_with_fixed_length)
853
+ logger.info(
854
+ f"Wrapping to len_to_wrap_long_seq={len_to_wrap_long_seq} with fixed length (packing)"
855
+ )
856
+ else:
857
+ method = return_none_on_failure(explode_table_with_max_length)
858
+ logger.info(
859
+ f"Wrapping to len_to_wrap_long_seq={len_to_wrap_long_seq} with max length (without packing)"
860
+ )
861
+
862
+ pipeline = pipeline.map(
863
+ partial(
864
+ method,
865
+ columns=columns_to_wrap,
866
+ max_seq_len=len_to_wrap_long_seq,
867
+ ),
868
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
869
+ )
870
+ return pipeline.filter(lambda table: table is not None)
871
+
872
+ def add_min_max_sentence_len_in_doc_filter(
873
+ self, pipeline: DataPipelineBuilder
874
+ ) -> DataPipelineBuilder:
875
+ if (
876
+ self.loading_config.max_sentence_len_in_doc
877
+ or self.loading_config.min_sentence_len_in_doc
878
+ ):
879
+ assert self.dataset_config.source_text_column is not None, (
880
+ f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
881
+ )
882
+
883
+ pipeline = pipeline.map(
884
+ partial(
885
+ filter_long_short_sentence_document,
886
+ column=self.dataset_config.source_text_column,
887
+ max_sentence_len=self.loading_config.max_sentence_len_in_doc,
888
+ min_sentence_len=self.loading_config.min_sentence_len_in_doc,
889
+ ),
890
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
891
+ ).filter(lambda table: bool(len(table) > 0))
892
+
893
+ if self.dataset_config.target_column is not None and (
894
+ self.loading_config.max_sentence_len_in_target_doc
895
+ or self.loading_config.min_sentence_len_in_target_doc
896
+ ):
897
+ pipeline = pipeline.map(
898
+ partial(
899
+ filter_long_short_sentence_document,
900
+ column=self.dataset_config.target_column,
901
+ max_sentence_len=self.loading_config.max_sentence_len_in_target_doc,
902
+ min_sentence_len=self.loading_config.min_sentence_len_in_target_doc,
903
+ ),
904
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
905
+ ).filter(lambda table: bool(len(table) > 0))
906
+
907
+ return pipeline
908
+
909
+ def add_min_sentence_number_in_doc_filter(
910
+ self,
911
+ pipeline: DataPipelineBuilder,
912
+ min_source_length: Optional[int] = None,
913
+ min_target_length: Optional[int] = None,
914
+ ) -> DataPipelineBuilder:
915
+ """
916
+ If `min_source_length` is not None: filter the source to remove sequences
917
+ with less than `min_source_length` sentences
918
+ If `min_target_length` is not None and data comes with a target column:
919
+ filter the target to remove sequences with less than `min_target_length` sentences
920
+
921
+ """
922
+
923
+ def _min_length_filter(table, column, length):
924
+ filter_ = pc.greater_equal(pc.list_value_length(table[column]), length)
925
+
926
+ if pc.all(filter_).as_py():
927
+ return table
928
+ return table.filter(filter_)
929
+
930
+ if (
931
+ self.dataset_config.source_column is not None
932
+ and min_source_length is not None
933
+ ):
934
+ pipeline = pipeline.map(
935
+ partial(
936
+ _min_length_filter,
937
+ column=self.dataset_config.source_column,
938
+ length=min_source_length,
939
+ ),
940
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
941
+ ).filter(lambda table: bool(len(table) > 0))
942
+
943
+ if (
944
+ self.dataset_config.target_column is not None
945
+ and min_target_length is not None
946
+ ):
947
+ pipeline = pipeline.map(
948
+ partial(
949
+ _min_length_filter,
950
+ column=self.dataset_config.target_column,
951
+ length=min_target_length,
952
+ ),
953
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
954
+ ).filter(lambda table: bool(len(table) > 0))
955
+
956
+ return pipeline
957
+
958
+ def add_quality_score_filters(
959
+ self, pipeline: DataPipelineBuilder
960
+ ) -> DataPipelineBuilder:
961
+ source_quality_range = self.dataset_config.source_quality_range
962
+ if source_quality_range is None:
963
+ return pipeline
964
+
965
+ assert self.dataset_config.source_quality_column is not None, (
966
+ f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
967
+ )
968
+
969
+ pipeline = pipeline.map(
970
+ partial(
971
+ filter_document_by_quality,
972
+ column=self.dataset_config.source_quality_column,
973
+ min_score=source_quality_range[0],
974
+ max_score=source_quality_range[1],
975
+ ),
976
+ num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
977
+ ).filter(lambda table: bool(len(table) > 0))
978
+ return pipeline
979
+
980
+ def add_format_conversion(
981
+ self, pipeline: DataPipelineBuilder
982
+ ) -> DataPipelineBuilder:
983
+ if self.loading_config.output_format == ParquetBatchFormat.pandas:
984
+ pipeline = pipeline.map(lambda table: table.to_pandas())
985
+ elif self.loading_config.output_format == ParquetBatchFormat.torch:
986
+ pipeline = pipeline.map(lambda wt: pyarrow_table_to_torch_dict(wt))
987
+ return pipeline
988
+
989
+ def get_python_iterator(
990
+ self, rank: int = 0, world_size: int = 1
991
+ ) -> Generator[BatchOutputType, None, None]: # type: ignore
992
+ yield from iter(
993
+ self.build_dataload_pipeline(
994
+ rank=rank,
995
+ world_size=world_size,
996
+ )
997
+ .prefetch(self._nb_prefetch(5))
998
+ .and_return(max_num_warnings=4)
999
+ )
1000
+
1001
+
1002
+ def parquet_iterator(
1003
+ dataset_config: ParquetDatasetConfig,
1004
+ loading_config: DataLoadingConfig,
1005
+ rank: int,
1006
+ world_size: int,
1007
+ ) -> Generator[BatchOutputType, None, None]: # type: ignore
1008
+ spdd = SingleParquetDatasetDataloader(dataset_config, loading_config)
1009
+ yield from spdd.get_python_iterator(rank, world_size)
1010
+
1011
+
1012
+ def build_parquet_iterator_pipeline(
1013
+ dataset_config: ParquetDatasetConfig,
1014
+ loading_config: DataLoadingConfig,
1015
+ rank: int = 0,
1016
+ world_size: int = 1,
1017
+ ) -> DataPipelineBuilder:
1018
+ return SingleParquetDatasetDataloader(
1019
+ dataset_config, loading_config
1020
+ ).build_dataload_pipeline(rank=rank, world_size=world_size)
1021
+
1022
+
1023
+ def ds_name(conf: ParquetDatasetConfig) -> str:
1024
+ if conf.name is not None:
1025
+ return conf.name
1026
+ return str(conf.parquet_path)
1027
+
1028
+
1029
+ def circular_shift_left(lst: List[Any], k: int) -> List[Any]:
1030
+ if len(lst) <= 1:
1031
+ return lst
1032
+
1033
+ k = k % len(lst) # To handle shifts larger than the list length
1034
+ return lst[k:] + lst[:k]
1035
+
1036
+
1037
+ def build_weighted_pipeline_with_renaming(
1038
+ dataset_configs: Sequence[ParquetDatasetConfig],
1039
+ loading_config: DataLoadingConfig,
1040
+ rank: int = 0,
1041
+ world_size: int = 1,
1042
+ ) -> DataPipeline:
1043
+ assert loading_config.multiple_dataset_chaining in [
1044
+ "sample",
1045
+ "concat",
1046
+ "round_robin",
1047
+ ]
1048
+
1049
+ # adjusting the number parallel calls and prefetch according to total number of datasets
1050
+ dataset_configs = list(dataset_configs)
1051
+ loading_config.num_parallel_calls = loading_config.num_parallel_calls / len(
1052
+ dataset_configs
1053
+ )
1054
+ loading_config.nb_prefetch = loading_config.nb_prefetch // len(dataset_configs)
1055
+
1056
+ name_mappers = get_renaming_mappers(dataset_configs)
1057
+ pipelines: List[DataPipelineBuilder] = []
1058
+
1059
+ def process_one_pipeline(cc, mapper):
1060
+ return build_parquet_iterator_pipeline(
1061
+ dataset_config=cc,
1062
+ loading_config=loading_config,
1063
+ rank=rank,
1064
+ world_size=world_size,
1065
+ ).map(
1066
+ partial(renaming, mapper=mapper, name=ds_name(cc)),
1067
+ num_parallel_calls=1,
1068
+ )
1069
+
1070
+ # creating all datasets pipeline in parallel
1071
+ pipelines = [
1072
+ process_one_pipeline(cc, mapper)
1073
+ for cc, mapper in zip(dataset_configs, name_mappers)
1074
+ ]
1075
+
1076
+ if len(pipelines) == 1:
1077
+ return (
1078
+ pipelines[0]
1079
+ .prefetch(int(max(loading_config.nb_prefetch, 1)))
1080
+ .and_return(max_num_warnings=4)
1081
+ )
1082
+ if loading_config.seed is not None:
1083
+ seed = loading_config.seed + (0 if loading_config.even_sharding else rank)
1084
+ else:
1085
+ seed = None
1086
+
1087
+ pipelines_with_return = [pp.and_return(max_num_warnings=4) for pp in pipelines]
1088
+
1089
+ if loading_config.multiple_dataset_chaining == "concat":
1090
+ # TODO : check that all weights = 1
1091
+ weighted_pipeline = DataPipeline.concat(
1092
+ circular_shift_left(pipelines_with_return, k=rank),
1093
+ )
1094
+ elif loading_config.multiple_dataset_chaining == "round_robin":
1095
+ weighted_pipeline = DataPipeline.round_robin(
1096
+ circular_shift_left(pipelines_with_return, k=rank), allow_repeats=False
1097
+ )
1098
+ else:
1099
+ weighted_pipeline = DataPipeline.sample(
1100
+ pipelines_with_return,
1101
+ [getattr(cc, "weight", 1.0) for cc in dataset_configs],
1102
+ seed=seed,
1103
+ )
1104
+
1105
+ return weighted_pipeline.prefetch(
1106
+ int(
1107
+ max(loading_config.nb_prefetch * len(dataset_configs) ** 2, 1)
1108
+ ) # try to prefetch at least one element from each dataset
1109
+ ).and_return(max_num_warnings=4)
lcm/datasets/parquet_utils.py ADDED
@@ -0,0 +1,1141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache, reduce, wraps
10
+ from pickle import dumps, loads
11
+ from typing import Any, Iterator, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import polars as pl
16
+ import pyarrow as pa
17
+ import pyarrow.compute as pc
18
+ import pyarrow.parquet as pq
19
+ import torch
20
+ from fairseq2.data.data_pipeline import (
21
+ DataPipeline,
22
+ DataPipelineBuilder,
23
+ read_iterator,
24
+ read_sequence,
25
+ )
26
+ from fairseq2.data.parquet.tools import (
27
+ NestedDict,
28
+ NestedDictValue,
29
+ add_partitioning_values,
30
+ compute_rows_length,
31
+ get_dataset_fragments,
32
+ split_fragment_in_row_groups,
33
+ )
34
+ from joblib import Parallel, delayed
35
+ from numpy.typing import NDArray
36
+ from pyarrow.dataset import get_partition_keys
37
+ from retrying import retry
38
+ from stopes.modules.preprocess.sonar_text_embedding import (
39
+ LangColumnConfig,
40
+ SonarTextBatchEmbedder,
41
+ SonarTextEmbedderConfig,
42
+ )
43
+ from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
44
+ from stopes.utils.arrow_utils import (
45
+ hstack_pyarray_list,
46
+ is_list_like,
47
+ pyarrow_column_to_array,
48
+ simple_array_to_nested,
49
+ )
50
+ from tqdm.auto import tqdm
51
+
52
+ from lcm.datasets.configs import (
53
+ ColumnsNames,
54
+ ParquetDatasetLimitOptions,
55
+ SonarTextColumn,
56
+ )
57
+ from lcm.utils.common import batched
58
+
59
+ try:
60
+ from numba import njit
61
+ except ModuleNotFoundError:
62
+ print("Numba is not installed. Fall-back to the non-recompiled version")
63
+
64
+ def empty_jit(f):
65
+ @wraps(f)
66
+ def _f(*args, **kwargs):
67
+ return f(*args, **kwargs)
68
+
69
+ return _f
70
+
71
+ njit = empty_jit
72
+
73
+
74
+ loading_retry = retry(
75
+ retry_on_exception=lambda exception: isinstance(exception, OSError),
76
+ stop_max_attempt_number=1,
77
+ wait_exponential_multiplier=2,
78
+ wait_exponential_max=20,
79
+ )
80
+
81
+
82
+ logger = logging.getLogger(__name__)
83
+
84
+
85
+ def prefix_and_suffix_one_list_column(
86
+ table: pa.Table, column: str, prefix_array: pa.Array, suffix_array: pa.Array
87
+ ):
88
+ prefix_extended = pa.chunked_array(
89
+ [pa.ListArray.from_arrays([0, len(prefix_array)], prefix_array)] * len(table)
90
+ )
91
+ suffix_extended = pa.chunked_array(
92
+ [pa.ListArray.from_arrays([0, len(suffix_array)], suffix_array)] * len(table)
93
+ )
94
+ target_dtype = table[column].type
95
+ if prefix_extended.type != target_dtype:
96
+ prefix_extended = prefix_extended.cast(target_dtype)
97
+ if suffix_extended.type != target_dtype:
98
+ suffix_extended = suffix_extended.cast(target_dtype)
99
+
100
+ new_array = hstack_pyarray_list(prefix_extended, table[column], suffix_extended)
101
+ return table.drop([column]).append_column(column, new_array)
102
+
103
+
104
+ def define_parquet_dataset(parquet_path: str, partition_filters) -> pq.ParquetDataset:
105
+ return pq.ParquetDataset(
106
+ parquet_path,
107
+ filters=partition_filters,
108
+ )
109
+
110
+
111
+ @lru_cache()
112
+ def default_sonar_pipeline() -> SonarTextBatchEmbedder:
113
+ local_sonar_config = SonarTextEmbedderConfig(
114
+ column_config=[
115
+ LangColumnConfig("input_text", lang_value="eng_Latn"),
116
+ ],
117
+ batch_size=10,
118
+ device="cpu",
119
+ )
120
+ return SonarTextBatchEmbedder(local_sonar_config)
121
+
122
+
123
+ @lru_cache(2000)
124
+ def _get_embed_sentences(text: Optional[str]) -> pa.Array:
125
+ sentences_splitter = get_split_algo("eng_Latn", "default")
126
+ lstbe = default_sonar_pipeline()
127
+ sentences = pa.array(sentences_splitter(text) if text else [""])
128
+ input_table = pa.Table.from_pydict({"input_text": sentences})
129
+ vectors = pyarrow_column_to_array(lstbe(input_table)["input_text_sonar_emb"])
130
+ if not text:
131
+ # empty output of the right type
132
+ vectors = vectors.slice(0, 0)
133
+ sentences = sentences.slice(0, 0)
134
+ return vectors, sentences
135
+
136
+
137
+ def prepare_suffix_prefix_embeddings(*args):
138
+ if all(xx is None for xx in args): # to avoid loading SonarModel
139
+ return [(None, None) for _ in args]
140
+
141
+ return [_get_embed_sentences(xx) for xx in args]
142
+
143
+
144
+ def from_pyarrow_to_torch_tensor(
145
+ arr: Union[pa.Array, pa.ChunkedArray], strict: bool = False
146
+ ) -> NestedDictValue:
147
+ """
148
+ struct_array = pa.Array.from_pandas([{"x": 4, "y": "RR"}] * 10)
149
+ nest_array = pa.Array.from_pandas([[{'a': 1}, {'a': 2}]])
150
+ """
151
+ # for future ideas https://arrow.apache.org/docs/python/generated/pyarrow.Tensor.html
152
+ # for sparse matrix support https://github.com/apache/arrow/blob/main/python/pyarrow/tests/test_sparse_tensor.py
153
+
154
+ if arr.null_count != 0:
155
+ raise ValueError("to torch conversion does not support null values")
156
+
157
+ arr = pyarrow_column_to_array(arr)
158
+
159
+ arr_type = arr.type
160
+ if pa.types.is_primitive(arr_type):
161
+ try:
162
+ return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
163
+ except Exception:
164
+ pass
165
+
166
+ try:
167
+ return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
168
+ except pa.ArrowInvalid:
169
+ pass
170
+
171
+ if pa.types.is_dictionary(arr_type):
172
+ return from_pyarrow_to_torch_tensor(arr.dictionary_decode())
173
+
174
+ if pa.types.is_string(arr_type):
175
+ return arr.to_pandas().tolist()
176
+
177
+ if pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type):
178
+ if pa.types.is_primitive(arr_type.value_type):
179
+ return arr.to_pandas().map(torch.from_numpy).tolist()
180
+
181
+ if pa.types.is_fixed_size_list(arr_type.value_type) and pa.types.is_primitive(
182
+ arr_type.value_type.value_type
183
+ ):
184
+ return (
185
+ arr.to_pandas()
186
+ .map(
187
+ lambda x: torch.from_numpy(
188
+ np.vstack(x) if len(x) > 0 else np.array([], dtype=np.float32)
189
+ )
190
+ )
191
+ .tolist()
192
+ )
193
+
194
+ if pa.types.is_fixed_size_list(arr_type):
195
+ if pa.types.is_primitive(arr_type.value_type):
196
+ return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size)))
197
+
198
+ if pa.types.is_struct(arr_type):
199
+ return {
200
+ arr_type.field(i).name: from_pyarrow_to_torch_tensor(arr.field(i))
201
+ for i in range(arr_type.num_fields)
202
+ }
203
+
204
+ if pa.types.is_nested(arr_type):
205
+ # TODO: deal with arr = [[{'a': 1}, {'a': 2}]]
206
+ pass
207
+
208
+ if strict:
209
+ raise NotImplementedError(f"{arr_type} cannot be converted to torch.Tensor")
210
+ else:
211
+ return arr # keeping as in the orignal pyarrow form
212
+
213
+
214
+ def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = False) -> NestedDict:
215
+ out = {}
216
+ for col in tt.column_names:
217
+ try:
218
+ out[col] = from_pyarrow_to_torch_tensor(tt[col], strict)
219
+ except ValueError as e:
220
+ logger.info(
221
+ f"Column {col} of type {tt[col].type} was not converted to torch as expected",
222
+ str(e),
223
+ )
224
+ out[col] = tt[col]
225
+ return out
226
+
227
+
228
+ def add_fragments_trace(table: pa.Table, fragment: pa.dataset.Fragment) -> pa.Table:
229
+ table = table.append_column(
230
+ "__row_groups_ids",
231
+ len(table)
232
+ * [np.array([int(rg.id) for rg in fragment.row_groups], dtype=np.int32)],
233
+ )
234
+ table = table.append_column(
235
+ "__index_in_fragement", pa.array(np.arange(len(table), dtype=np.int32))
236
+ )
237
+ return table
238
+
239
+
240
+ def shuffle_table(table: pa.Table, random_state: np.random.RandomState) -> pa.Table:
241
+ permutation = pa.array(random_state.permutation(len(table)))
242
+ return table.take(permutation)
243
+
244
+
245
+ class SafeFragment:
246
+ """
247
+ Experimental :
248
+ Simple wrapper around `ParquetFileFragment` that allows to reinit the state of filesystem
249
+ if aws session token has expired.
250
+ """
251
+
252
+ fragment: pa.dataset.ParquetFileFragment
253
+
254
+ def __init__(self, fragment: pa.dataset.ParquetFileFragment):
255
+ self.fragment = fragment
256
+
257
+ def __repr__(self) -> str:
258
+ out = ""
259
+ out += "SafeFragment \n"
260
+ out += "path = " + self.fragment.path + "\n"
261
+ out += f"row_groups = {[int(rg.id) for rg in self.fragment.row_groups]} \n"
262
+ out += f"physical_schema = \n {self.fragment.physical_schema} \n"
263
+ return out
264
+
265
+ @loading_retry
266
+ def load(self, columns: Optional[List[str]] = None) -> pa.Table:
267
+ if columns is not None:
268
+ fragment_columns = [
269
+ col for col in columns if col in self.fragment.physical_schema.names
270
+ ]
271
+ else:
272
+ fragment_columns = self.fragment.physical_schema.names
273
+ # adding technical columns for tracking
274
+ fragment_columns = list(fragment_columns) + [
275
+ "__batch_index",
276
+ "__fragment_index",
277
+ "__filename",
278
+ ]
279
+ try:
280
+ fragment_table = self.fragment.to_table(
281
+ columns=fragment_columns, use_threads=False
282
+ )
283
+
284
+ except OSError as e:
285
+ logger.info(
286
+ "could not load fragment, reinit the fragment state. Error: ", str(e)
287
+ )
288
+ self.fragment = loads(dumps(self.fragment))
289
+ fragment_table = self.fragment.to_table(
290
+ columns=fragment_columns, use_threads=False
291
+ )
292
+
293
+ fragment_table = add_partitioning_values(fragment_table, self.fragment, columns)
294
+ fragment_table = add_fragments_trace(fragment_table, self.fragment)
295
+ return fragment_table
296
+
297
+
298
+ def _parquet_fragments_to_pipeline_builder(
299
+ file_ds_fragments: List[pa.dataset.Fragment],
300
+ nb_epochs: int = 1,
301
+ shuffle: bool = True,
302
+ seed: Optional[int] = None,
303
+ ) -> DataPipelineBuilder:
304
+ if shuffle:
305
+ if seed is None:
306
+ seed = int(torch.randint(0, 2**31, ()).item())
307
+
308
+ rsg = np.random.RandomState(seed)
309
+ ds_fragments_ = np.asarray(file_ds_fragments, dtype="O")
310
+ ds_fragments = np.concatenate(
311
+ [rsg.permutation(ds_fragments_) for _ in range(nb_epochs)]
312
+ ).tolist()
313
+ else:
314
+ ds_fragments = file_ds_fragments * nb_epochs
315
+
316
+ pipeline_builder = read_sequence(ds_fragments)
317
+ pipeline_builder = pipeline_builder.map(SafeFragment)
318
+ return pipeline_builder
319
+
320
+
321
+ def list_parquet_fragments(
322
+ parquet_ds: pq.ParquetDataset,
323
+ nb_epochs: int = 1,
324
+ split_to_row_groups: bool = True,
325
+ shuffle: bool = True,
326
+ seed: Optional[int] = None,
327
+ limit_options: Optional[ParquetDatasetLimitOptions] = None,
328
+ nb_jobs: int = 10,
329
+ ) -> DataPipelineBuilder:
330
+ if limit_options is None:
331
+ limit_options = ParquetDatasetLimitOptions()
332
+
333
+ file_ds_fragments = get_dataset_fragments(parquet_ds, parquet_ds._filter_expression)
334
+ proxy_ds_path = "/".join(parquet_ds.files[0].split("=")[0].split("/")[:-1])
335
+
336
+ logger.info(f"{proxy_ds_path} : full number of files {len(file_ds_fragments)}")
337
+ if limit_options.fraction_of_files is not None:
338
+ file_ds_fragments = file_ds_fragments[
339
+ : max(
340
+ int(round(limit_options.fraction_of_files * len(file_ds_fragments))), 1
341
+ )
342
+ ]
343
+ logger.info(
344
+ f"{proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of fraction_of_files={limit_options.fraction_of_files}"
345
+ )
346
+ if limit_options.nb_files is not None and limit_options.nb_files < len(
347
+ file_ds_fragments
348
+ ):
349
+ file_ds_fragments = file_ds_fragments[: limit_options.nb_files]
350
+ logger.info(
351
+ f"{proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of nb_files={limit_options.nb_files}"
352
+ )
353
+
354
+ output_fragments = []
355
+ total_nb_rows = 0
356
+ if split_to_row_groups:
357
+ logger.info(f"{proxy_ds_path} : starting split in row groups")
358
+
359
+ with Parallel(backend="threading", n_jobs=nb_jobs) as parallel:
360
+ total_nb_fragments = 0
361
+ early_stop = False
362
+
363
+ for batch_of_files in batched(file_ds_fragments, 20 * nb_jobs):
364
+ row_groups = parallel(
365
+ delayed(split_fragment_in_row_groups)(ff) for ff in batch_of_files
366
+ )
367
+ new_file_fragments = [x for y in row_groups for x in y]
368
+ if limit_options.nb_rows is not None:
369
+ new_file_fragments_stats = parallel(
370
+ delayed(lambda frag: frag.row_groups[0].num_rows)(ff)
371
+ for ff in new_file_fragments
372
+ )
373
+ else:
374
+ new_file_fragments_stats = [0] * len(new_file_fragments)
375
+
376
+ for nb_row, frag in zip(new_file_fragments_stats, new_file_fragments):
377
+ output_fragments.append(frag)
378
+ total_nb_rows += nb_row
379
+ total_nb_fragments += 1
380
+ if (
381
+ limit_options.nb_fragments is not None
382
+ and total_nb_fragments >= limit_options.nb_fragments
383
+ ):
384
+ early_stop = True
385
+ if limit_options.nb_rows is not None:
386
+ logger.info(
387
+ f"{proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached with around {total_nb_rows} rows"
388
+ )
389
+ else:
390
+ logger.info(
391
+ f"{proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached"
392
+ )
393
+ break
394
+ if (
395
+ limit_options.nb_rows is not None
396
+ and total_nb_rows >= limit_options.nb_rows
397
+ ):
398
+ early_stop = True
399
+ logger.info(
400
+ f"{proxy_ds_path} : nb_rows limit {limit_options.nb_rows} was reached with around {total_nb_fragments} fragments"
401
+ )
402
+ break
403
+ if early_stop:
404
+ break
405
+ else:
406
+ for frag in file_ds_fragments[: limit_options.nb_fragments]:
407
+ output_fragments.append(frag)
408
+ if limit_options.nb_rows is not None:
409
+ total_nb_rows += frag.count_rows()
410
+ if total_nb_rows >= limit_options.nb_rows:
411
+ break
412
+
413
+ logger.info(f"{proxy_ds_path} : finding fragments {len(output_fragments)}")
414
+
415
+ return _parquet_fragments_to_pipeline_builder(
416
+ output_fragments,
417
+ nb_epochs=nb_epochs,
418
+ shuffle=shuffle,
419
+ seed=seed,
420
+ )
421
+
422
+
423
+ def compute_length_splits(
424
+ length_col: NDArray[np.int32],
425
+ max_tokens: int,
426
+ order_by_length: bool = True,
427
+ drop_long_sample: bool = True,
428
+ ) -> List[NDArray[np.int32]]:
429
+ """split sequence of length_col in the chunks such that total length is ~ max_tokens
430
+ countint the padding to max length of elements in a chunk
431
+
432
+ Args:
433
+ length_col (np.ndarray):
434
+ max_tokens (int):
435
+ order_by_length (bool):
436
+ drop_long_sample (bool):
437
+
438
+ Returns:
439
+ List[np.ndarray]: splits that contain indices over the original length_col
440
+ """
441
+ argsort_ind = (
442
+ np.argsort(length_col)
443
+ if order_by_length
444
+ else np.arange(len(length_col), dtype=np.int32)
445
+ )
446
+
447
+ sorted_length_col = length_col[argsort_ind]
448
+
449
+ small_elements_masks = sorted_length_col <= max_tokens
450
+ big_elements_inds = argsort_ind[~small_elements_masks]
451
+
452
+ argsort_ind = argsort_ind[small_elements_masks]
453
+ sorted_length_col = sorted_length_col[small_elements_masks]
454
+
455
+ size = len(sorted_length_col)
456
+ splits = []
457
+ begin, end = 0, 0
458
+ while end < size:
459
+ current_max_len = sorted_length_col[begin]
460
+ begin = end
461
+ while end < size:
462
+ current_max_len = max(current_max_len, sorted_length_col[end])
463
+ if current_max_len * (end + 1 - begin) > max_tokens:
464
+ splits.append(argsort_ind[begin:end])
465
+ break
466
+ end += 1
467
+ else:
468
+ if begin < size:
469
+ splits.append(argsort_ind[begin:])
470
+
471
+ # adding big sample at the end one by one
472
+ if not drop_long_sample and len(big_elements_inds):
473
+ splits.extend(np.array_split(big_elements_inds, len(big_elements_inds)))
474
+
475
+ return splits
476
+
477
+
478
+ def build_batching_loop_over_one_table(
479
+ table: pa.Table,
480
+ order_by_length: bool = False,
481
+ length_column: List[Optional[str]] = None,
482
+ batch_size: Optional[int] = None,
483
+ max_tokens: Optional[int] = None,
484
+ shuffle: bool = True,
485
+ seed: Optional[int] = None,
486
+ num_parallel_calls: int = 1,
487
+ ) -> DataPipeline:
488
+ if max_tokens is not None:
489
+ assert length_column is not None, (
490
+ "Need to provide a column to compute the number of tokens"
491
+ )
492
+
493
+ random_state = np.random.RandomState(seed)
494
+ if length_column is not None and len(length_column) > 0:
495
+ length_col = reduce(
496
+ np.add, (compute_rows_length(table[lc]) for lc in length_column)
497
+ )
498
+ else:
499
+ if shuffle:
500
+ length_col = random_state.randint(0, 2**23, len(table))
501
+ else:
502
+ length_col = np.zeros(len(table), dtype=np.int32)
503
+
504
+ if batch_size is not None:
505
+ if order_by_length:
506
+ sorting_ind = np.argsort(length_col, kind="stable")
507
+ else:
508
+ sorting_ind = np.arange(len(length_col), dtype=np.int32)
509
+
510
+ order_tt = pa.Table.from_arrays([pa.array(sorting_ind)], ["order"])
511
+ batches = [ind["order"] for ind in order_tt.to_batches(batch_size)]
512
+ elif max_tokens is not None:
513
+ batches = compute_length_splits(
514
+ length_col, max_tokens, order_by_length=order_by_length
515
+ )
516
+ else:
517
+ raise ValueError("unknown batching method")
518
+
519
+ if shuffle:
520
+ batches = [batches[i] for i in random_state.permutation(len(batches))]
521
+
522
+ def _getter(ind):
523
+ try:
524
+ tt = table.take(ind)
525
+ return tt
526
+ except Exception as e:
527
+ logger.warn(f"Unexpected error : \n {str(e)} \n {table} \n {ind}")
528
+ return None
529
+
530
+ return (
531
+ read_sequence(batches)
532
+ .map(_getter, num_parallel_calls=num_parallel_calls)
533
+ .filter(lambda tt: bool(tt is not None))
534
+ .and_return(max_num_warnings=4)
535
+ )
536
+
537
+
538
+ def filter_long_short_sentence_document(
539
+ batch: pa.Table,
540
+ column: str,
541
+ max_sentence_len: Optional[int],
542
+ min_sentence_len: Optional[int],
543
+ ) -> pa.Table:
544
+ assert max_sentence_len is not None or min_sentence_len is not None
545
+ if min_sentence_len is None:
546
+ min_sentence_len = 0
547
+
548
+ if max_sentence_len is None:
549
+ max_sentence_len = 2**32
550
+
551
+ tt = pl.from_arrow(batch.select([column]), rechunk=False)
552
+ assert isinstance(tt, pl.DataFrame)
553
+ filter_ = tt.with_columns(
554
+ (
555
+ pl.col(column).list.eval(pl.col("").str.len_bytes()).list.max()
556
+ <= max_sentence_len
557
+ )
558
+ & (
559
+ pl.col(column).list.eval(pl.col("").str.len_bytes()).list.min()
560
+ <= max_sentence_len
561
+ )
562
+ )[column].to_arrow()
563
+
564
+ if pc.all(filter_).as_py():
565
+ return batch
566
+ return batch.filter(filter_)
567
+
568
+
569
+ def filter_document_by_quality(
570
+ batch: pa.Table,
571
+ column: str,
572
+ min_score=Optional[float],
573
+ max_score=Optional[float],
574
+ ) -> pa.Table:
575
+ if min_score is None and max_score is None:
576
+ return batch
577
+
578
+ if min_score is None:
579
+ min_score = -float(np.inf)
580
+ if max_score is None:
581
+ max_score = float(np.inf)
582
+
583
+ tt = pl.from_arrow(batch.select([column]), rechunk=False)
584
+ assert isinstance(tt, pl.DataFrame)
585
+ filter_ = tt.with_columns(
586
+ (pl.col(column).list.max() <= max_score)
587
+ & (pl.col(column).list.min() >= min_score)
588
+ )[column].to_arrow()
589
+ if pc.all(filter_).as_py():
590
+ return batch
591
+ return batch.filter(filter_)
592
+
593
+
594
+ def renaming(inp: NestedDict, mapper: dict, name: str) -> NestedDict:
595
+ renamed_name = ColumnsNames.dataset_name.value
596
+ if isinstance(inp, dict):
597
+ out_dict = {mapper.get(key, key): value for key, value in inp.items()}
598
+ out_dict[renamed_name] = name
599
+ res = out_dict
600
+ elif isinstance(inp, pd.DataFrame):
601
+ out_pd = inp.rename(mapper=mapper, axis=1)
602
+ out_pd[renamed_name] = name
603
+ res = out_pd
604
+ elif isinstance(inp, pa.Table):
605
+ out_pa: pa.Table = inp.rename_columns(
606
+ [mapper.get(key, key) for key in inp.column_names],
607
+ )
608
+ out_pa = out_pa.append_column(renamed_name, pa.array([name] * len(out_pa)))
609
+ res = out_pa
610
+ return res
611
+
612
+
613
+ def materialize_sequence(
614
+ table: pa.Table,
615
+ column_sequence: List[SonarTextColumn],
616
+ vector_name: str,
617
+ text_name: str,
618
+ ) -> pa.Table:
619
+ """
620
+ Given `table`, it materializes `column_sequence`.
621
+ Different elements from `column_sequence` are concatenated sequentially.
622
+ Constant text elements will be sentencized and sonarized.
623
+ It also accepts columns with single text and embeddings values instead of list.
624
+
625
+ It returns a new table with two new columns with sequences of sentences and corresponding sequences of their embeddings.
626
+ """
627
+
628
+ table_len = len(table)
629
+ sentences_seq = []
630
+ vectors_seq = []
631
+
632
+ target_dtype = None
633
+ for col in column_sequence:
634
+ if col.sonar_column is not None:
635
+ target_dtype = table[col.sonar_column].type
636
+ break
637
+
638
+ for col in column_sequence:
639
+ if col.text_value is not None:
640
+ vectors, sentences = _get_embed_sentences(col.text_value)
641
+ vectors_extended = pa.chunked_array(
642
+ [pa.ListArray.from_arrays([0, len(vectors)], vectors)] * table_len
643
+ )
644
+ sentences_extended = pa.chunked_array(
645
+ [pa.ListArray.from_arrays([0, len(sentences)], sentences)] * table_len
646
+ )
647
+ else:
648
+ assert (col.text_column is not None) and (col.sonar_column is not None)
649
+ vectors_extended = table[col.sonar_column]
650
+ sentences_extended = table[col.text_column]
651
+ if is_list_like(vectors_extended):
652
+ assert is_list_like(sentences_extended)
653
+ else:
654
+ vectors_extended = simple_array_to_nested(vectors_extended)
655
+ sentences_extended = simple_array_to_nested(sentences_extended)
656
+
657
+ if target_dtype and vectors_extended.type != target_dtype:
658
+ vectors_extended = vectors_extended.cast(target_dtype)
659
+
660
+ vectors_seq.append(vectors_extended)
661
+ sentences_seq.append(sentences_extended)
662
+
663
+ new_vectors_array = hstack_pyarray_list(*vectors_seq)
664
+ new_sentences_array = hstack_pyarray_list(*sentences_seq)
665
+ del vectors_seq, sentences_seq
666
+ table = table.append_column(vector_name, new_vectors_array)
667
+ table = table.append_column(text_name, new_sentences_array)
668
+ return table
669
+
670
+
671
+ @njit
672
+ def _get_hierarchical_indices_and_offsets(
673
+ pagaraphs_lengths: List[np.ndarray], max_seq_len: int
674
+ ):
675
+ indices = []
676
+ new_lens = [0]
677
+ hierarchy_new_lens = [0]
678
+
679
+ for i, current_lens in enumerate(pagaraphs_lengths):
680
+ tmp_lens_sum = 0
681
+ nb_blocks = 0
682
+ for ll in current_lens:
683
+ if ll + tmp_lens_sum > max_seq_len:
684
+ indices.append(i)
685
+ new_lens.append(new_lens[-1] + tmp_lens_sum)
686
+ hierarchy_new_lens.append(hierarchy_new_lens[-1] + nb_blocks)
687
+
688
+ tmp_lens_sum = ll
689
+ nb_blocks = 0
690
+ else:
691
+ tmp_lens_sum += ll
692
+
693
+ nb_blocks += 1
694
+
695
+ if nb_blocks > 0:
696
+ indices.append(i)
697
+ new_lens.append(new_lens[-1] + tmp_lens_sum)
698
+ hierarchy_new_lens.append(hierarchy_new_lens[-1] + nb_blocks)
699
+
700
+ return (
701
+ np.array(indices, dtype=np.int32),
702
+ np.array(new_lens, dtype=np.int32),
703
+ np.array(hierarchy_new_lens, dtype=np.int32),
704
+ )
705
+
706
+
707
+ def hierarchical_explode_table_with_max_length(
708
+ table: pa.Table,
709
+ columns: Union[str, List[str]],
710
+ max_seq_len: int,
711
+ page_len_column: str,
712
+ page_embs_columns: Optional[Union[str, List[str]]],
713
+ ) -> pa.Table:
714
+ if isinstance(columns, str):
715
+ columns = [columns]
716
+
717
+ if isinstance(page_embs_columns, str):
718
+ page_embs_columns = [page_embs_columns]
719
+ elif page_embs_columns is None:
720
+ page_embs_columns = []
721
+
722
+ assert len(columns) > 0
723
+
724
+ cols = [pc.fill_null(table[columns[0]], [None])]
725
+ lengths = pc.list_value_length(cols[0]).to_numpy()
726
+
727
+ for name in columns[1:]:
728
+ col = pc.fill_null(table[name], [None])
729
+ # checking that all columns list structures are parallel
730
+ assert (lengths == pc.list_value_length(col).to_numpy()).all()
731
+ cols.append(col)
732
+
733
+ pagaraphs_lengths = table[page_len_column].to_pandas().to_list()
734
+ # assert [x.sum() for x pagaraphs_lengths] == lengths.tolist()
735
+ # next unroll with max_seq_len
736
+ indices, new_offests, hierarchy_offsets = _get_hierarchical_indices_and_offsets(
737
+ pagaraphs_lengths, max_seq_len
738
+ )
739
+
740
+ other_columns = list(table.schema.names)
741
+ for name in set(columns + [page_len_column] + page_embs_columns):
742
+ other_columns.remove(name)
743
+
744
+ remaining_table = table.select(other_columns).take(indices)
745
+
746
+ result_dict = {}
747
+ for name in other_columns:
748
+ result_dict[name] = remaining_table[name]
749
+
750
+ for name, col in zip(columns, cols):
751
+ rolled_array = pa.ListArray.from_arrays(
752
+ offsets=new_offests,
753
+ values=pyarrow_column_to_array(pc.list_flatten(col)),
754
+ )
755
+ result_dict[name] = rolled_array
756
+
757
+ for name in set([page_len_column] + page_embs_columns):
758
+ col = table[name]
759
+ rolled_array = pa.ListArray.from_arrays(
760
+ offsets=hierarchy_offsets,
761
+ values=pyarrow_column_to_array(pc.list_flatten(col)),
762
+ )
763
+ result_dict[name] = rolled_array
764
+
765
+ return pa.Table.from_pydict(result_dict, schema=table.schema)
766
+
767
+
768
+ def filter_table_with_different_lengths(
769
+ table: pa.Table, columns: List[str]
770
+ ) -> pa.Table:
771
+ if len(columns) <= 1 or not all(is_list_like(table[col]) for col in columns):
772
+ return table
773
+
774
+ ref_lengths = pc.list_value_length(table[columns[0]])
775
+ for col in columns[1:]:
776
+ same_lens = pc.equal(pc.list_value_length(table[col]), ref_lengths)
777
+ if pc.all(same_lens).as_py():
778
+ continue
779
+ else:
780
+ logger.warn(
781
+ f"filtering table whose nb sentences and nb sonar vectors are aligned, keeping {pc.sum(same_lens).as_py()} rows out of{len(table)}"
782
+ )
783
+ table = table.filter(same_lens)
784
+ return table
785
+
786
+
787
+ @dataclass
788
+ class PFSState:
789
+ nb_fully_read_files: int = 0
790
+ nb_current_file_read_fragements: int = 0
791
+ total_nb_fragments: int = 0
792
+ total_nb_rows: int = 0
793
+
794
+
795
+ class ParquetFragmentStreamer:
796
+ def __init__(
797
+ self,
798
+ parquet_ds: pq.ParquetDataset,
799
+ split_to_row_groups: bool = True,
800
+ limit_options: Optional[ParquetDatasetLimitOptions] = None,
801
+ read_state: Optional[PFSState] = None,
802
+ ):
803
+ self.split_to_row_groups = split_to_row_groups
804
+ self.limit_options = limit_options or ParquetDatasetLimitOptions()
805
+ self.parquet_ds = parquet_ds
806
+
807
+ if read_state is not None:
808
+ self.state = read_state
809
+ else:
810
+ self.reset_state()
811
+
812
+ def reset_state(self):
813
+ self.state = PFSState()
814
+
815
+ def __reduce__(self):
816
+ return (
817
+ self.__class__,
818
+ (
819
+ self.parquet_ds,
820
+ self.split_to_row_groups,
821
+ self.limit_options,
822
+ self.state,
823
+ ),
824
+ )
825
+
826
+ def truncate_files(
827
+ self,
828
+ parquet_ds: pq.ParquetDataset,
829
+ fraction_of_files: Optional[float],
830
+ nb_files: Optional[int],
831
+ ) -> List[pa.dataset.Fragment]:
832
+ file_ds_fragments = get_dataset_fragments(
833
+ parquet_ds, parquet_ds._filter_expression
834
+ )
835
+ self.proxy_ds_path = "/".join(parquet_ds.files[0].split("=")[0].split("/")[:-1])
836
+ logger.info(
837
+ f"{self.proxy_ds_path} : full number of files {len(file_ds_fragments)}"
838
+ )
839
+
840
+ if fraction_of_files is not None:
841
+ file_ds_fragments = file_ds_fragments[
842
+ : max(
843
+ int(round(fraction_of_files * len(file_ds_fragments))),
844
+ 1,
845
+ )
846
+ ]
847
+ logger.info(
848
+ f"{self.proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of fraction_of_files={fraction_of_files}"
849
+ )
850
+ if nb_files is not None and nb_files < len(file_ds_fragments):
851
+ file_ds_fragments = file_ds_fragments[:nb_files]
852
+ logger.info(
853
+ f"{self.proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of nb_files={nb_files}"
854
+ )
855
+ return file_ds_fragments
856
+
857
+ def __iter__(self):
858
+ limit_options = self.limit_options
859
+
860
+ file_ds_fragments = self.truncate_files(
861
+ self.parquet_ds,
862
+ limit_options.fraction_of_files,
863
+ limit_options.nb_files,
864
+ )
865
+
866
+ if not self.split_to_row_groups:
867
+ for frag in file_ds_fragments[
868
+ self.state.nb_fully_read_files : limit_options.nb_fragments
869
+ ]:
870
+ self.state.nb_fully_read_files += 1
871
+ yield frag
872
+
873
+ if limit_options.nb_rows is not None:
874
+ self.state.total_nb_rows += frag.count_rows()
875
+ if self.state.total_nb_rows >= limit_options.nb_rows:
876
+ break
877
+ else:
878
+ early_stop = False
879
+ logger.info(f"{self.proxy_ds_path} : starting split in row groups")
880
+
881
+ for new_file in file_ds_fragments[self.state.nb_fully_read_files :]:
882
+ new_file_fragments = split_fragment_in_row_groups(new_file)
883
+ new_file_fragments = new_file_fragments[
884
+ self.state.nb_current_file_read_fragements :
885
+ ]
886
+ if limit_options.nb_rows is not None:
887
+ new_file_fragments_stats = [
888
+ frag.row_groups[0].num_rows for frag in new_file_fragments
889
+ ]
890
+ else:
891
+ new_file_fragments_stats = [0] * len(new_file_fragments)
892
+
893
+ for nb_row, frag in zip(new_file_fragments_stats, new_file_fragments):
894
+ self.state.total_nb_rows += nb_row
895
+ self.state.total_nb_fragments += 1
896
+ self.state.nb_current_file_read_fragements += (
897
+ 1 # increate before yield
898
+ )
899
+ yield frag
900
+
901
+ if (
902
+ limit_options.nb_fragments is not None
903
+ and self.state.total_nb_fragments >= limit_options.nb_fragments
904
+ ):
905
+ early_stop = True
906
+ if limit_options.nb_rows is not None:
907
+ logger.info(
908
+ f"{self.proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached with around {self.state.total_nb_rows} rows"
909
+ )
910
+ else:
911
+ logger.info(
912
+ f"{self.proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached"
913
+ )
914
+ break
915
+ if (
916
+ limit_options.nb_rows is not None
917
+ and self.state.total_nb_rows >= limit_options.nb_rows
918
+ ):
919
+ early_stop = True
920
+ logger.info(
921
+ f"{self.proxy_ds_path} : nb_rows limit {limit_options.nb_rows} was reached with around {self.state.total_nb_fragments} fragments"
922
+ )
923
+ break
924
+ if early_stop:
925
+ break
926
+ # only when full file is read we increament this
927
+ self.state.nb_fully_read_files += 1
928
+ self.state.nb_current_file_read_fragements = 0
929
+
930
+
931
+ @dataclass
932
+ class ShuffledIteratorState:
933
+ epoch_count: int
934
+ current_window: List[Any]
935
+ index: int
936
+ random_state: np.random.RandomState
937
+
938
+
939
+ class ShuffledIterator(Iterator[Any]):
940
+ def __init__(
941
+ self,
942
+ iterator,
943
+ window_size: int,
944
+ nb_epoch: int,
945
+ seed: Optional[int],
946
+ state: Optional[ShuffledIteratorState] = None,
947
+ ):
948
+ self.base_iterator = iterator
949
+ self.window_size = window_size
950
+ self.seed = seed
951
+ self.nb_epoch = nb_epoch
952
+
953
+ if state is None:
954
+ state = ShuffledIteratorState(
955
+ random_state=np.random.RandomState(self.seed),
956
+ epoch_count=0,
957
+ current_window=[],
958
+ index=0,
959
+ )
960
+ self.state = state
961
+ self.window_iterator = None
962
+
963
+ def reset_state(self):
964
+ self.state.random_state = np.random.RandomState(self.seed)
965
+ self.state.epoch_count = 0
966
+ self._reset_inner()
967
+
968
+ def __reduce__(self):
969
+ return (
970
+ self.__class__,
971
+ (
972
+ self.base_iterator,
973
+ self.window_size,
974
+ self.nb_epoch,
975
+ self.seed,
976
+ self.state,
977
+ ),
978
+ )
979
+
980
+ def _reset_inner(self):
981
+ self.base_iterator.reset_state()
982
+ self.state.index = 0
983
+ self.state.current_window = []
984
+ self.window_iterator = None
985
+
986
+ def __iter__(self):
987
+ return self
988
+
989
+ def __next__(self) -> Any:
990
+ if self.state.epoch_count >= self.nb_epoch:
991
+ raise StopIteration
992
+
993
+ # If current window is exhausted, fetch the next window
994
+ if self.window_iterator is None:
995
+ self.window_iterator = batched(self.base_iterator, self.window_size) # type: ignore
996
+ assert self.window_iterator is not None
997
+
998
+ if self.state.index >= len(self.state.current_window):
999
+ try:
1000
+ # Get the next window batch
1001
+ window = next(self.window_iterator)
1002
+ window = np.array(window, dtype="O")
1003
+ self.state.random_state.shuffle(window)
1004
+ self.state.current_window = window
1005
+ self.state.index = 0
1006
+ except StopIteration:
1007
+ # If no more batches, increment epoch count and reset iterator
1008
+ self.state.epoch_count += 1
1009
+ self._reset_inner()
1010
+ return self.__next__()
1011
+
1012
+ # Return the next element from the current window
1013
+ result = self.state.current_window[self.state.index]
1014
+ self.state.index += 1
1015
+ return result
1016
+
1017
+
1018
+ def stream_parquet_fragments(
1019
+ parquet_ds: pq.ParquetDataset,
1020
+ nb_epochs: int,
1021
+ split_to_row_groups: bool = True,
1022
+ shuffle: bool = True,
1023
+ seed: Optional[int] = None,
1024
+ limit_options: Optional[ParquetDatasetLimitOptions] = None,
1025
+ shuffling_window: int = 200,
1026
+ ) -> DataPipelineBuilder:
1027
+ fragments_iterator = ParquetFragmentStreamer(
1028
+ parquet_ds=parquet_ds,
1029
+ split_to_row_groups=split_to_row_groups,
1030
+ limit_options=limit_options,
1031
+ )
1032
+
1033
+ def reset_fn(iterator):
1034
+ iterator.reset_state()
1035
+ return iterator
1036
+
1037
+ pipeline = read_iterator(
1038
+ ShuffledIterator(
1039
+ fragments_iterator,
1040
+ window_size=shuffling_window if shuffle else 1,
1041
+ nb_epoch=nb_epochs,
1042
+ seed=seed,
1043
+ ),
1044
+ reset_fn,
1045
+ infinite=False,
1046
+ )
1047
+
1048
+ return pipeline.map(SafeFragment)
1049
+
1050
+
1051
+ def get_row_group_level_metadata(
1052
+ dataset: pq.ParquetDataset,
1053
+ columns: Optional[List[str]] = None,
1054
+ nb_jobs: int = 40,
1055
+ max_fragments: int = -1,
1056
+ seed: int = 123,
1057
+ ) -> pd.DataFrame:
1058
+ """
1059
+ Parses row group level metadata from a Parquet dataset and returns it as a pandas DataFrame.
1060
+ It's similar to `get_parquet_dataset_metadata`
1061
+ but present a unnested view on row groups statistics for only a subset of columns.
1062
+ This function can be used for any kind of downstream analysis.
1063
+
1064
+ It uses joblib for parallel processing
1065
+ and tqdm for progress tracking, which are good practices for handling large datasets.
1066
+
1067
+ Parameters:
1068
+ - dataset (pq.ParquetDataset): The Parquet dataset to parse.
1069
+ - columns (list of str, optional): The columns to include in the output DataFrame. If not specified, all columns are included.
1070
+ For `columns=[]` no column-vise information will be profided (which is generally much faster).
1071
+ - nb_jobs (int, default=40): The number of parallel jobs to run.
1072
+ - max_fragments (int, default=-1): The maximum number of fragments to include. If -1, all fragments are included.
1073
+ - seed (int, default=123): The seed for the random number generator, used when selecting fragments.
1074
+
1075
+ Returns:
1076
+ - pd.DataFrame: A DataFrame containing the row group level metadata.
1077
+ Example:
1078
+ >>> import pyarrow as pa
1079
+ >>> import pyarrow.fs
1080
+ >>> import pyarrow.compute as pc
1081
+ >>> fs, parquet_uri = pa.fs.FileSystem.from_uri("s3://<bucket_name>/<dataset_name>/")
1082
+ >>> dataset = pq.ParquetDataset(parquet_uri, filesystem=fs, filters=pc.equal(pc.field("split"), "validation"))
1083
+ >>> df_stats = get_row_group_level_metadata(dataset, columns=["col1", "col2", ...])
1084
+ """
1085
+ assert max_fragments >= -1
1086
+ fragments = list(dataset._dataset.get_fragments(filter=dataset._filter_expression))
1087
+
1088
+ if max_fragments != -1 and max_fragments < len(fragments):
1089
+ fragments = (
1090
+ np.random.RandomState(seed)
1091
+ .choice(np.array(fragments, dtype="O"), max_fragments, replace=False)
1092
+ .tolist()
1093
+ )
1094
+
1095
+ physical_schema = fragments[0].physical_schema
1096
+
1097
+ columns = columns if columns is not None else physical_schema.names
1098
+ # taking only existing columns
1099
+ non_existing_columns = tuple(set(columns) - set(physical_schema.names))
1100
+ if non_existing_columns:
1101
+ print(
1102
+ "Following colums are not present in physical schema and will be ignored",
1103
+ non_existing_columns,
1104
+ )
1105
+ columns = [col for col in columns if col in physical_schema.names]
1106
+
1107
+ columns_index = [physical_schema.get_field_index(col) for col in columns]
1108
+
1109
+ columns_to_exclude = set(["row_group_id", "num_rows", "total_byte_size"]) & set(
1110
+ columns
1111
+ )
1112
+ assert len(columns_to_exclude) == 0, (
1113
+ f"names conflict, rename/remove : {columns_to_exclude}"
1114
+ )
1115
+
1116
+ def get_one_row_group_stats(row_group):
1117
+ metadata = row_group.metadata
1118
+ info = {
1119
+ "row_group_id": row_group.id,
1120
+ "num_rows": metadata.num_rows,
1121
+ "total_byte_size": metadata.total_byte_size,
1122
+ }
1123
+ for col, ind in zip(columns, columns_index):
1124
+ info[col] = metadata.column(ind).to_dict()
1125
+ return info
1126
+
1127
+ def get_fragment_stats(frag):
1128
+ return {
1129
+ "rg_stats": list(map(get_one_row_group_stats, frag.row_groups)),
1130
+ "parquet_file_path": frag.path,
1131
+ **get_partition_keys(frag.partition_expression),
1132
+ }
1133
+
1134
+ stats = Parallel(nb_jobs, backend="threading")(
1135
+ delayed(get_fragment_stats)(frag) for frag in tqdm(fragments)
1136
+ )
1137
+
1138
+ stats = pd.DataFrame(stats).explode("rg_stats")
1139
+ flatten_row_df = pd.DataFrame(stats.pop("rg_stats").tolist(), index=stats.index)
1140
+ result_df = pd.concat([stats, flatten_row_df], axis=1)
1141
+ return result_df
lcm/datasets/sentence_splitter_pipeline.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import gc
9
+ import typing as tp
10
+ from builtins import enumerate
11
+ from dataclasses import dataclass, field
12
+
13
+ import numba
14
+ import numpy as np
15
+ import polars as pl
16
+ import pyarrow as pa
17
+ import pyarrow.compute as pc
18
+ import torch
19
+ from stopes.modules.partitioned_data_mapper import BatchMapper
20
+ from stopes.modules.preprocess.sonar_text_embedding import (
21
+ SonarTextBatchEmbedder,
22
+ SonarTextEmbedderConfig,
23
+ )
24
+ from stopes.utils.arrow_utils import (
25
+ apply_on_nested_array,
26
+ )
27
+ from wtpsplit import SaT, indices_to_sentences
28
+
29
+ from lcm.datasets.sentence_splitting import remove_emojis, resplit
30
+
31
+
32
+ @numba.jit(nopython=True)
33
+ def insert_elements(arr, max_diff):
34
+ """
35
+ Insert elements into an array to ensure no two consecutive elements have a difference greater than max_diff.
36
+
37
+ Parameters:
38
+ arr (numpy array): The original array of integers.
39
+ max_diff (int): The maximum allowed difference between consecutive elements after insertion.
40
+
41
+ Returns:
42
+ numpy array: The modified array with additional elements inserted to satisfy the max_diff condition.
43
+ """
44
+
45
+ result = []
46
+ for i in range(len(arr) - 1):
47
+ result.append(arr[i])
48
+ diff = arr[i + 1] - arr[i]
49
+ if diff > max_diff:
50
+ num_insert = int(diff // max_diff)
51
+ step_size = diff / (num_insert + 1)
52
+ last_val = arr[i]
53
+ for j in range(1, num_insert + 1):
54
+ val = round(last_val + step_size)
55
+ if val < arr[i + 1]:
56
+ result.append(val)
57
+ last_val = val
58
+ result.append(arr[-1])
59
+ return np.array(result, dtype=np.int32)
60
+
61
+
62
+ @numba.jit(nopython=True)
63
+ def merge_small_intervals(
64
+ lenghts: np.ndarray, min_merging_length: int = 2, max_merge_length: int = 15
65
+ ):
66
+ """
67
+ Merge small intervals in a list of lengths.
68
+ This function takes a list of lengths and merges any intervals that are smaller than or equal to `min_merging_length`
69
+ into larger intervals. The merged intervals are limited to a maximum length of `max_merge_length`.
70
+ Parameters:
71
+ lengths (np.ndarray): A list of lengths to be merged.
72
+ min_merging_length (int): The minimum length of an interval to be merged. Defaults to 2.
73
+ max_merge_length (int): The maximum length of a merged interval. Defaults to 15.
74
+ Returns:
75
+ list: A list of merged lengths.
76
+
77
+ Examples:
78
+ >>> merge_small_intervals(np.array([1, 2, 3, 4, 5]))
79
+ array([3, 3, 4, 5], dtype=int32)
80
+ >>> merge_small_intervals(np.array([1, 1, 1, 1, 1]))
81
+ array([5], dtype=int32)
82
+ >>> merge_small_intervals(np.array([1, 2, 3, 2, 2, 2, 4, 1, 1, 5]))
83
+ array([3, 3, 6, 4, 2, 5], dtype=int32)
84
+ """
85
+ merge_arr = []
86
+ merge_len = 0
87
+
88
+ for curr_len in lenghts:
89
+ if curr_len <= min_merging_length and merge_len + curr_len <= max_merge_length:
90
+ merge_len += curr_len
91
+ else:
92
+ if merge_len > 0:
93
+ merge_arr.append(merge_len)
94
+ merge_len = 0
95
+ merge_arr.append(curr_len)
96
+ if merge_len > 0:
97
+ merge_arr.append(merge_len)
98
+
99
+ return np.array(merge_arr, dtype=np.int32)
100
+
101
+
102
+ @numba.jit(nopython=True)
103
+ def find_closest_indices(arr1, arr2):
104
+ """
105
+ Find indices of the closest elements in arr2 for each element in arr1.
106
+
107
+ Parameters:
108
+ arr1 (numpy array): The array containing the elements for which we want to find the closest elements in arr2.
109
+ arr2 (numpy array): The array in which we want to find the closest elements.
110
+
111
+ Returns:
112
+ indices (numpy array): The indices of the closest elements in arr2 for each element in arr1.
113
+ """
114
+ # Use searchsorted to find the indices where elements from arr1 should be inserted in arr2
115
+ indices = np.searchsorted(arr2, arr1, side="left")
116
+
117
+ indices_bis = np.clip(indices - 1, a_min=0, a_max=len(arr2) - 1)
118
+ dist_one = np.abs(arr2[indices] - arr1)
119
+ dist_bis = np.abs(arr2[indices_bis] - arr1)
120
+
121
+ return np.where(dist_one < dist_bis, indices, indices_bis)
122
+
123
+
124
+ @dataclass
125
+ class SentenceSplitterConfig:
126
+ columns: tp.List[str]
127
+ model_name: str = "sat-6l"
128
+ sentence_suffix: str = "_sentences"
129
+ sentence_threshold: float = 0.01
130
+ max_sentence_len: int = 256
131
+ min_text_length: int = 10
132
+ min_unique_chars: int = 0
133
+ fallback_separators: tp.List[str] = field(
134
+ default_factory=lambda: [
135
+ "...",
136
+ "\n",
137
+ "!",
138
+ "?",
139
+ ";",
140
+ ":",
141
+ ".",
142
+ ",",
143
+ "\t",
144
+ " ",
145
+ ]
146
+ )
147
+ device: str = "cuda"
148
+ remove_whitespace_before_inference: bool = False
149
+ batch_size: int = 256
150
+ block_size: int = 256
151
+ stride: int = 256
152
+ outer_batch_size: int = 1024
153
+ verbose: bool = False
154
+ pad_last_batch: bool = False
155
+
156
+
157
+ class SentenceSplitter(BatchMapper):
158
+ def __init__(self, config: SentenceSplitterConfig):
159
+ super().__init__(config)
160
+ self.columns = config.columns
161
+ device = torch.device(config.device if torch.cuda.is_available() else "cpu")
162
+
163
+ try:
164
+ self.model = SaT(
165
+ self.config.model_name,
166
+ from_pretrained_kwargs={"local_files_only": True},
167
+ )
168
+ except Exception:
169
+ self.model = SaT(self.config.model_name)
170
+
171
+ if "cuda" in config.device:
172
+ self.model.half()
173
+
174
+ self.model.eval().to(device)
175
+
176
+ @torch.inference_mode()
177
+ def _resplit_long_sentences(self, col: pa.Array) -> pa.Array:
178
+ mask = pc.greater_equal(pc.utf8_length(col), self.config.max_sentence_len)
179
+ texts_to_resplit = col.filter(mask).to_pandas().to_list()
180
+
181
+ resplit_sentences = []
182
+ for text, probs in zip(
183
+ texts_to_resplit,
184
+ self.model.predict_proba(
185
+ texts_to_resplit,
186
+ stride=self.config.stride,
187
+ block_size=self.config.block_size,
188
+ batch_size=self.config.batch_size,
189
+ pad_last_batch=self.config.pad_last_batch,
190
+ remove_whitespace_before_inference=self.config.remove_whitespace_before_inference,
191
+ outer_batch_size=self.config.outer_batch_size,
192
+ verbose=self.config.verbose,
193
+ ),
194
+ ):
195
+ nb_split = round(len(probs) / self.config.max_sentence_len) + 1
196
+ sentence_threshold = np.partition(probs, -nb_split)[-nb_split]
197
+ sentences = indices_to_sentences(
198
+ text,
199
+ np.where(probs >= sentence_threshold)[0],
200
+ strip_whitespace=False,
201
+ )
202
+ resplit_sentences.append(sentences)
203
+
204
+ # if not, hard resplit with some separators
205
+ def _resplit(raw_sentences):
206
+ for separator in self.config.fallback_separators:
207
+ raw_sentences = [
208
+ subchunk.strip()
209
+ for sent in raw_sentences
210
+ for subchunk in resplit(
211
+ sent, max_length=self.config.max_sentence_len, sep=separator
212
+ )
213
+ ]
214
+ return raw_sentences
215
+
216
+ np_mask = mask.to_pandas().to_numpy()
217
+ full_text = col.to_pandas().to_list()
218
+
219
+ output_sentences = []
220
+ j = 0
221
+ for i, text in enumerate(full_text):
222
+ if np_mask[i]:
223
+ output_sentences.append(_resplit(resplit_sentences[j]))
224
+ j += 1
225
+ else:
226
+ output_sentences.append([text])
227
+
228
+ return pa.array(output_sentences, type=pa.list_(pa.string()))
229
+
230
+ def resplit_long_sentences(self, col: pa.Array) -> pa.Array:
231
+ list_col = apply_on_nested_array(self._resplit_long_sentences, col)
232
+ reflatten_col = pl.from_arrow(list_col).list.eval(pl.element().explode()) # type: ignore
233
+ # remove single char repeated
234
+ if self.config.min_unique_chars > 0:
235
+ reflatten_col = reflatten_col.list.eval(
236
+ pl.when(
237
+ pl.element().str.split("").list.n_unique()
238
+ > self.config.min_unique_chars
239
+ )
240
+ .then(pl.element())
241
+ .drop_nulls()
242
+ )
243
+ return reflatten_col.to_arrow().cast(pa.list_(pa.string()))
244
+
245
+ @torch.inference_mode()
246
+ def basic_split_on_single_column(
247
+ self,
248
+ col: tp.Union[pa.Array, pa.ChunkedArray],
249
+ ) -> tp.Union[pa.Array, pa.ChunkedArray]:
250
+ if not (pa.types.is_large_string(col.type) or pa.types.is_string(col.type)):
251
+ raise ValueError("Column must be of type string")
252
+
253
+ texts = col.to_pandas().to_list()
254
+ texts = list(map(remove_emojis, texts))
255
+
256
+ long_texts = [t for t in texts if len(t) > self.config.min_text_length]
257
+ keep_texts = [
258
+ (idx, t)
259
+ for idx, t in enumerate(texts)
260
+ if len(t) <= self.config.min_text_length
261
+ ]
262
+
263
+ outputs = self.model.split(
264
+ long_texts,
265
+ threshold=self.config.sentence_threshold,
266
+ stride=self.config.stride,
267
+ block_size=self.config.block_size,
268
+ batch_size=self.config.batch_size,
269
+ pad_last_batch=self.config.pad_last_batch,
270
+ remove_whitespace_before_inference=self.config.remove_whitespace_before_inference,
271
+ outer_batch_size=self.config.outer_batch_size,
272
+ verbose=self.config.verbose,
273
+ )
274
+ sentences = []
275
+ for row in outputs:
276
+ sentences.append([s.strip() for s in row if s.strip()])
277
+
278
+ for idx, text in keep_texts:
279
+ sentences.insert(idx, text)
280
+
281
+ return pa.array(sentences, type=pa.list_(pa.string()))
282
+
283
+ def __call__(self, table: pa.Table) -> pa.Table:
284
+ for column in self.columns:
285
+ sentence_array = self.basic_split_on_single_column(table[column])
286
+
287
+ sentence_array = self.resplit_long_sentences(sentence_array)
288
+
289
+ table = table.append_column(
290
+ f"{column}{self.config.sentence_suffix}", sentence_array
291
+ )
292
+
293
+ return table
294
+
295
+
296
+ @dataclass
297
+ class FullPipelineConfig:
298
+ splitter_config: SentenceSplitterConfig
299
+ sonar_encoder_config: SonarTextEmbedderConfig
300
+ min_text_length: int = 10
301
+
302
+
303
+ class FullPipeline(BatchMapper):
304
+ """
305
+ Creating sonar vectors from scratch.
306
+ Making sentences splits.
307
+ Computing sonar embeddings.
308
+
309
+ Config example requires only one input column:
310
+ - `text`
311
+
312
+ Note also that text should not be empty!
313
+
314
+ Example of config:
315
+
316
+ splitter_config = SentenceSplitterConfig(
317
+ columns=["text"],
318
+ model_name="sat-3l",
319
+ verbose=True,
320
+ sentence_threshold=0.02,
321
+ max_sentence_len=256,
322
+ )
323
+ sonar_encoder_config = SonarTextEmbedderConfig(
324
+ column_config=[LangColumnConfig("text_sentences", lang_value="eng_Latn")],
325
+ device="cuda",
326
+ )
327
+
328
+ full_config = FullPipelineConfig(
329
+ splitter_config=splitter_config,
330
+ sonar_encoder_config=sonar_encoder_config,
331
+ )
332
+
333
+ """
334
+
335
+ def __init__(self, config: FullPipelineConfig):
336
+ self.config = config
337
+ self.splitter = SentenceSplitter(self.config.splitter_config)
338
+ self.sonar_encoder = SonarTextBatchEmbedder(self.config.sonar_encoder_config)
339
+
340
+ def __call__(self, batch: pa.Table) -> pa.Table:
341
+ for col in self.config.splitter_config.columns:
342
+ batch = batch.filter(
343
+ pc.greater_equal(
344
+ pc.utf8_length(batch[col]), self.config.min_text_length
345
+ )
346
+ )
347
+
348
+ batch = self.splitter(batch)
349
+ batch = self.sonar_encoder(batch)
350
+ gc.collect()
351
+ return batch
lcm/datasets/sentence_splitting.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ import codecs
8
+ import re
9
+ import typing as tp
10
+ from functools import lru_cache
11
+
12
+ import spacy
13
+ import torch
14
+ from sacremoses import MosesDetokenizer, MosesPunctNormalizer
15
+ from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
16
+ from stopes.utils.language_codes import language_code_to_short_code
17
+
18
+
19
+ def remove_emojis(text: str) -> str:
20
+ emoji_pattern = re.compile(
21
+ "["
22
+ "\U0001f600-\U0001f64f" # emoticons
23
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
24
+ "\U0001f680-\U0001f6ff" # transport & map symbols
25
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
26
+ "\U00002702-\U000027b0"
27
+ "\U000024c2-\U0001f251"
28
+ "\U0001f900-\U0001f9ff" # Supplemental Symbols and Pictographs
29
+ "\U0001f700-\U0001f77f" # Alchemical Symbols
30
+ "\U0001f780-\U0001f7ff" # Geometric Shapes Extended
31
+ "\U0001f800-\U0001f8ff" # Supplemental Arrows-C
32
+ "\U0001fa00-\U0001fa6f" # Chess Symbols
33
+ "\U0001fa70-\U0001faff" # Symbols and Pictographs Extended-A
34
+ "\U0001f6c0-\U0001f6cf" # Miscellaneous Symbols and Pictographs (part)
35
+ "\U0001f6d0-\U0001f6d5" # Miscellaneous Symbols and Pictographs (part)
36
+ "\U0001f6f0-\U0001f6fa" # Miscellaneous Symbols and Pictographs (part)
37
+ "]+",
38
+ flags=re.UNICODE,
39
+ )
40
+ return emoji_pattern.sub(r"", text)
41
+
42
+
43
+ def batched(inputs: tp.Iterable, batch_size=10000) -> tp.Iterable:
44
+ batch = []
45
+ for line in inputs:
46
+ batch.append(line)
47
+ if len(batch) == batch_size:
48
+ yield batch
49
+ batch = []
50
+ yield batch
51
+
52
+
53
+ def filter_empty_string(text):
54
+ return not any(char.isalnum() for char in text)
55
+
56
+
57
+ def remove_non_printable_chars(string):
58
+ return re.sub(r"[^\x20-\x7E]", "", string)
59
+
60
+
61
+ def deescape_special_chars(string):
62
+ return codecs.decode(string, "unicode_escape")
63
+
64
+
65
+ def resplit(text: str, max_length: int, sep: str) -> tp.List[str]:
66
+ words = text.split(sep)
67
+ result = []
68
+ current_piece = ""
69
+
70
+ for i, word in enumerate(words[:-1]):
71
+ # Append separator back to each word except the last
72
+ word += sep
73
+ if len(current_piece) + len(word) <= max_length:
74
+ current_piece += word
75
+ else:
76
+ if current_piece:
77
+ result.append(current_piece)
78
+ current_piece = word
79
+
80
+ # Handle the last word separately to avoid adding an extra separator
81
+ last_word = words[-1]
82
+ if len(current_piece) + len(last_word) <= max_length:
83
+ current_piece += last_word
84
+ else:
85
+ if current_piece:
86
+ result.append(current_piece)
87
+ current_piece = last_word
88
+
89
+ if current_piece:
90
+ result.append(current_piece)
91
+
92
+ return result
93
+
94
+
95
+ @lru_cache
96
+ def get_moses_normalizers(lang):
97
+ moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
98
+ mpn = MosesPunctNormalizer(lang=moses_lang)
99
+ mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
100
+ md = MosesDetokenizer(lang=moses_lang)
101
+ return mpn, md
102
+
103
+
104
+ @lru_cache
105
+ def get_splitter(lang: str, model_name: str = None):
106
+ moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
107
+ if model_name is None:
108
+ model_name = (
109
+ f"{moses_lang}_core_web_sm"
110
+ if moses_lang == "en"
111
+ else f"{moses_lang}_core_news_sm"
112
+ )
113
+ try:
114
+ if torch.cuda.is_available():
115
+ spacy.require_gpu()
116
+ spacy_nlp = spacy.load(model_name, enable=["sentencizer"])
117
+ spacy_nlp.add_pipe("sentencizer")
118
+
119
+ def spacy_splitter(text):
120
+ for batch in batched(text, batch_size=999_000):
121
+ for sent in spacy_nlp("".join(batch)).sents:
122
+ yield str(sent)
123
+
124
+ return spacy_splitter
125
+ except ModuleNotFoundError:
126
+ print(
127
+ f"Spacy splitter not found for {lang}, switching to stopes implementation"
128
+ )
129
+ return get_split_algo(lang[:3], "default")
130
+
131
+
132
+ class ResplitSentenceSplitter:
133
+ def __init__(
134
+ self,
135
+ fallback_separators=(".", "!", "?", "...", "\n", ";", ",", ":", ">", " "),
136
+ ):
137
+ self.fallback_separators = fallback_separators
138
+
139
+ def __call__(
140
+ self, document: str, lang: str = "eng_Latn", max_length: int = 200
141
+ ) -> tp.List[str]:
142
+ mpn, md = get_moses_normalizers(lang)
143
+ # XXX: two below are not various language friendly
144
+ # document = deescape_special_chars(document)
145
+ # document = remove_non_printable_chars(document)
146
+ document = remove_emojis(document)
147
+
148
+ raw_sentences = get_splitter(lang)(document)
149
+ for separator in self.fallback_separators or []:
150
+ raw_sentences = [
151
+ subchunk.strip()
152
+ for sent in raw_sentences
153
+ for subchunk in resplit(sent, max_length=max_length, sep=separator)
154
+ ]
155
+
156
+ return [
157
+ mpn.normalize(md.detokenize(sent.strip().split()))
158
+ for sent in raw_sentences
159
+ if len(sent) > 1 and not filter_empty_string(sent)
160
+ ]
lcm/datasets/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ import torch
8
+ from fairseq2.models.sequence import SequenceBatch
9
+
10
+
11
+ def move_eos_to_the_end(
12
+ batch: SequenceBatch, pad_token_id: int = 0, eos_token_id: int = 3
13
+ ) -> SequenceBatch:
14
+ """
15
+ Convert a decoder-input batch (with the eos token in the beginning) to a decoder-output batch
16
+ (with eos in the end) of the same shape.
17
+ Note that this processing is missing two potentially critical issues:
18
+ 1) If the sequence end has been truncated away, EOS token will be appended erroneously.
19
+ 2) The language code token is still included in the loss computation (we may want to avoid it).
20
+ """
21
+ # strip the EOS token prepended to the input and add an empty token in the end
22
+ seqs = torch.cat(
23
+ [
24
+ batch.seqs[:, 1:],
25
+ torch.zeros_like(batch.seqs[:, :1]) + pad_token_id,
26
+ ],
27
+ dim=-1,
28
+ )
29
+ # fill the last real token in the batch with the eos value
30
+ if batch.padding_mask:
31
+ seqs[
32
+ torch.arange(seqs.shape[0], dtype=torch.int32),
33
+ batch.padding_mask.seq_lens - 1,
34
+ ] = eos_token_id
35
+ else:
36
+ seqs[:, -1] = eos_token_id
37
+
38
+ result = SequenceBatch(
39
+ seqs=seqs,
40
+ padding_mask=batch.padding_mask,
41
+ )
42
+ return result
lcm/models/two_tower_diffusion_lcm/loader.py CHANGED
@@ -6,6 +6,7 @@
6
 
7
  from fairseq2.models.config_loader import StandardModelConfigLoader
8
  from fairseq2.models.loader import StandardModelLoader, load_model
 
9
 
10
  from lcm.models.base_lcm.loader import convert_lcm_checkpoint
11
  from lcm.models.two_tower_diffusion_lcm.builder import (
@@ -23,11 +24,12 @@ load_two_tower_diffusion_lcm_config = StandardModelConfigLoader(
23
  )
24
 
25
 
26
- load_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
27
  config_loader=load_two_tower_diffusion_lcm_config,
28
  factory=create_two_tower_diffusion_lcm_model,
29
  checkpoint_converter=convert_lcm_checkpoint,
30
  restrict_checkpoints=False,
 
31
  )
32
 
33
  load_model.register(
 
6
 
7
  from fairseq2.models.config_loader import StandardModelConfigLoader
8
  from fairseq2.models.loader import StandardModelLoader, load_model
9
+ from Patches import Patch_TorchLoader
10
 
11
  from lcm.models.base_lcm.loader import convert_lcm_checkpoint
12
  from lcm.models.two_tower_diffusion_lcm.builder import (
 
24
  )
25
 
26
 
27
+ load_two_tower_diffusion_lcm_model = StandardModelLoader(
28
  config_loader=load_two_tower_diffusion_lcm_config,
29
  factory=create_two_tower_diffusion_lcm_model,
30
  checkpoint_converter=convert_lcm_checkpoint,
31
  restrict_checkpoints=False,
32
+ tensor_loader=Patch_TorchLoader.load_tensors, # 🔥 the key patch
33
  )
34
 
35
  load_model.register(
lcm/train/__main__.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+
7
+ import asyncio
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Optional
11
+
12
+ import hydra
13
+ import submitit
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from omegaconf.omegaconf import open_dict, read_write
16
+ from stopes.core import Requirements, StopesModule
17
+
18
+ from lcm.train.common import get_trainer
19
+ from lcm.utils.common import setup_conf
20
+
21
+ setup_conf()
22
+
23
+
24
+ class TrainModule(StopesModule):
25
+ def requirements(self) -> Requirements:
26
+ return self.config.requirements
27
+
28
+ def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
29
+ # Add module.name to the config's log_folder
30
+ with read_write(self.config):
31
+ self.config.log_folder = Path(self.config.log_folder) / self.name()
32
+
33
+ trainer = get_trainer(self.config)
34
+
35
+ # trainer should have a run() method
36
+ trainer.run()
37
+
38
+ def should_retry(
39
+ self,
40
+ ex: Exception,
41
+ attempt: int,
42
+ iteration_value: Optional[Any] = None,
43
+ iteration_index: int = 0,
44
+ ) -> bool:
45
+ # Before retrying the failed train run, clean the environment to make sure
46
+ # fs2 ProcessGroupGang can set up properly without raising error if the
47
+ # gang is not set up reliably
48
+ with submitit.helpers.clean_env():
49
+ return "ValueError" not in str(ex)
50
+
51
+ def name(self):
52
+ """
53
+ implement this if you want to give a fancy name to your job
54
+ """
55
+ name = self.config.get(
56
+ "experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}"
57
+ )
58
+ return name
59
+
60
+
61
+ @dataclass
62
+ class TrainingConfig:
63
+ trainer: DictConfig
64
+ launcher: DictConfig
65
+ dry_run: bool = False
66
+
67
+
68
+ async def run(config: TrainingConfig):
69
+ # dump the all config to the outputs config log
70
+ dump_dir = Path(config.launcher.config_dump_dir)
71
+ dump_dir.mkdir(parents=True, exist_ok=True)
72
+ OmegaConf.resolve(config) # type: ignore
73
+ # XXX: do we want to promote datasets configs from thier names to the final params
74
+ OmegaConf.save(
75
+ config=config,
76
+ f=str(dump_dir / "all_config.yaml"),
77
+ )
78
+
79
+ train_config = config.trainer
80
+
81
+ # If launcher.cluster = debug set debug in the trainer to True
82
+ with open_dict(train_config):
83
+ if config.launcher.cluster == "debug":
84
+ train_config.debug = True
85
+ train_config.log_folder = config.launcher.log_folder
86
+
87
+ if getattr(config, "dry_run", False):
88
+ trainer = get_trainer(train_config)
89
+ print(f"Trainer: {trainer}")
90
+ print(f"Train config: {getattr(trainer, 'config')}")
91
+
92
+ return
93
+
94
+ launcher = hydra.utils.instantiate(config.launcher)
95
+
96
+ train_module = TrainModule(train_config)
97
+ wait_on = launcher.schedule(train_module)
98
+
99
+ await wait_on
100
+
101
+
102
+ @hydra.main(
103
+ version_base="1.2",
104
+ config_path="../../recipes/train",
105
+ config_name="defaults.yaml",
106
+ )
107
+ def main(config: TrainingConfig) -> None:
108
+ """
109
+ Launch a train module from CLI.
110
+
111
+ Example:
112
+
113
+ ```sh
114
+ python -m lcm.train +pretrain=mse
115
+ ```
116
+
117
+ in this example, `pretrain` is a folder under the `recipes` directory and `mse`
118
+ is a yaml file with the trainer configuration.
119
+ This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer`
120
+ hydra directive).
121
+ It must contain a `__trainer__` entry defining the constructor for the trainer.
122
+
123
+ You can use `-c job` to see the configuration without running anything. You can use
124
+ `dry_run=true` to initialize the trainer from the configuration and make sure it's correct
125
+ without running the actual training. To debug the jobs, you can use `launcher.cluster=debug`
126
+ """
127
+ asyncio.run(run(config))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
lcm/train/common.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from inspect import signature
7
+ from typing import Any, Dict, Protocol, Union, runtime_checkable
8
+
9
+ import hydra
10
+ from omegaconf import DictConfig, OmegaConf, read_write
11
+
12
+ from lcm.utils.common import promote_config
13
+
14
+ TRAINER_KEY = "_trainer_"
15
+
16
+
17
+ @runtime_checkable
18
+ class Trainer(Protocol):
19
+ """Abstract trainer in LCM"""
20
+
21
+ def run(self) -> Any: ...
22
+
23
+
24
+ def _parse_training_config(train_config: DictConfig):
25
+ """Return the TrainingConfig object from the omegaconf inputs"""
26
+ # The train_config should have 2 keys "_target_" and "_trainer_"
27
+ # the config is set to read-only within stopes module __init__
28
+ assert TRAINER_KEY in train_config, (
29
+ f"The trainer configuration is missing a {TRAINER_KEY} configuration, "
30
+ "you need to specify a Callable to initialize your config."
31
+ )
32
+ trainer_cls_or_func = train_config.get(TRAINER_KEY)
33
+ try:
34
+ trainer_obj = hydra.utils.get_object(trainer_cls_or_func)
35
+ sign = signature(trainer_obj)
36
+ assert len(sign.parameters) == 1 and "config" in sign.parameters, (
37
+ f'{trainer_cls_or_func} should take a single argument called "config"'
38
+ )
39
+ param_type = sign.parameters["config"].annotation
40
+
41
+ OmegaConf.resolve(train_config)
42
+ with read_write(train_config):
43
+ del train_config._trainer_
44
+
45
+ typed_config = promote_config(train_config, param_type)
46
+ return trainer_obj, typed_config
47
+ except Exception as ex:
48
+ raise ValueError(
49
+ f"couldnt parse the train config: {train_config}.", str(ex)
50
+ ) from ex
51
+
52
+
53
+ def get_trainer(train_config: DictConfig) -> Trainer:
54
+ trainer_obj, typed_config = _parse_training_config(train_config)
55
+ return trainer_obj(typed_config)
56
+
57
+
58
+ def _is_missing(config: Union[DictConfig, Dict], attr: str) -> bool:
59
+ if isinstance(config, Dict):
60
+ return attr in config and config[attr]
61
+ if OmegaConf.is_missing(config, attr):
62
+ return True
63
+ if not hasattr(config, attr) or not getattr(config, attr):
64
+ return True
65
+ return False
lcm/train/criterion.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from abc import abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Literal
9
+
10
+ from fairseq2.logging import get_log_writer
11
+ from omegaconf import MISSING
12
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
13
+ FullyShardedDataParallel as FSDP,
14
+ )
15
+ from torch.nn import Module
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+
18
+ from lcm.train.metrics import LossTerm
19
+
20
+ logger = get_log_writer(__name__)
21
+
22
+
23
+ @dataclass
24
+ class CriterionConfig:
25
+ """A dataclass for criterion parameters"""
26
+
27
+ name: str = MISSING
28
+ """Name of the criterion, a unique identifier used in the CriterionsFactory"""
29
+
30
+ reduction: Literal["sum", "mean"] = "sum"
31
+ """How to reduce the loss across samples"""
32
+
33
+
34
+ class Criterion:
35
+ """And abstract class for training criterions"""
36
+
37
+ def __init__(
38
+ self,
39
+ config: CriterionConfig,
40
+ model: Module,
41
+ ):
42
+ self.config = config
43
+
44
+ self.model = model
45
+
46
+ self.summands: List[str] = []
47
+ """ A list of loss term names to track during training.
48
+ This will create metric bags for each
49
+ """
50
+
51
+ self.reduction = config.reduction
52
+
53
+ @property
54
+ def throughput_metric_name(self) -> str:
55
+ return "num_target_elements"
56
+
57
+ @property
58
+ def base_model(self):
59
+ """A pointer to the unwrapped model if training with FSDP/DDP"""
60
+ if isinstance(self.model, (DDP, FSDP)):
61
+ _model = self.model.module
62
+ else:
63
+ _model = self.model
64
+ return _model
65
+
66
+ @abstractmethod
67
+ def __call__(self, batch) -> LossTerm:
68
+ """
69
+ Computes the loss given an input batch.
70
+ The model's forward pass is performed here
71
+ """
72
+
73
+
74
+ class CriterionsFactory:
75
+ """Factory for LCM criterions"""
76
+
77
+ registry: Dict[str, Any] = {}
78
+
79
+ @classmethod
80
+ def build_criterion(cls, name: str, **kwargs) -> Any:
81
+ """build the criterion of choice from within the trainer"""
82
+
83
+ criterion_class = cls.registry[name]
84
+
85
+ criterion = criterion_class(**kwargs)
86
+
87
+ return criterion
88
+
89
+ @classmethod
90
+ def register(cls, name: str) -> Callable:
91
+ """decorator for adding criterions to the registry"""
92
+
93
+ def inner_wrapper(wrapped_class: Criterion) -> Callable:
94
+ assert name not in cls.registry, (
95
+ f"{name} is already register as a criterion"
96
+ )
97
+ cls.registry[name] = wrapped_class
98
+ return wrapped_class
99
+
100
+ return inner_wrapper
lcm/train/lcm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
lcm/train/lcm/criterion.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from abc import abstractmethod
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ from fairseq2.logging import get_log_writer
12
+ from torch import Tensor
13
+
14
+ from lcm.datasets.batch import LCMInput, LCMStyle
15
+ from lcm.models.abstract_lcm import AbstractLCModel
16
+ from lcm.models.sonar_normalizer import SonarNormalizer
17
+ from lcm.train.criterion import Criterion, CriterionConfig
18
+ from lcm.train.metrics import LossTerm
19
+
20
+ logger = get_log_writer(__name__)
21
+
22
+
23
+ def compute_standard_mse(
24
+ flattened_predictions: Tensor,
25
+ flattened_target: Tensor,
26
+ scales: Optional[Tensor] = None,
27
+ normalizer: Optional[SonarNormalizer] = None,
28
+ ) -> Tuple[Tensor, Tensor]:
29
+ """
30
+ Computes MSE loss between predictions and targets.
31
+ Note that, unlike regular MSE with mean/sum reduction, we first sum across channels
32
+ before later reducing in the criterion.
33
+
34
+ Parameters:
35
+ flattened_predictions (Tensor): The predictions in (N, C)
36
+ flattened_target (Tensor): The targets in (N, C)
37
+ scales (Optional[Tensor]): If not None, each channel will be weighted by the corresponding scale.
38
+ epsilon: A small epsilon to be added before taking the square root of the l2 distance
39
+ normalizer (Optional[SonarNormalizer]): If a normalizer is provided,
40
+ the predictions and targets will first be denormalized before computing the RMSE loss
41
+
42
+ Returns:
43
+ mse (Tensor): the MSE loss with optional scaling
44
+ plain_mse (Tensor): The MSE loss without any scaling (for logging)
45
+ """
46
+
47
+ assert flattened_predictions.dim() == 2, (
48
+ "Expecting two-dimensional predictions and targets. ",
49
+ f"Found targets in {flattened_target.size()} and ",
50
+ f"predictions in {flattened_predictions.size()}",
51
+ )
52
+
53
+ assert flattened_predictions.shape == flattened_target.shape, (
54
+ "Expecting predictions and targets of the same shape ",
55
+ f"Received predictions {flattened_predictions.shape} and targets {flattened_target.shape}",
56
+ )
57
+
58
+ if scales is not None:
59
+ assert scales.dim() == 1, (
60
+ "Expecting a uni-dimensional tensor of scales ",
61
+ f"Found a tensor with dimension {scales.dim()}",
62
+ )
63
+ assert len(scales) == flattened_target.shape[-1], (
64
+ "The provided scales should have the same size as the target channels. ",
65
+ f"Found {len(scales)} expected {flattened_target.shape[-1]}",
66
+ )
67
+
68
+ if normalizer is not None:
69
+ assert hasattr(normalizer, "denormalize"), (
70
+ "The provided normalizer has not method `denormalize`"
71
+ )
72
+ flattened_predictions = normalizer.denormalize(flattened_predictions)
73
+ flattened_target = normalizer.denormalize(flattened_target)
74
+
75
+ full_mse = torch.nn.functional.mse_loss(
76
+ flattened_predictions, flattened_target, reduction="none"
77
+ )
78
+ plain_mse = full_mse.sum(dim=-1)
79
+
80
+ if scales is not None:
81
+ full_mse = full_mse * scales.unsqueeze(0)
82
+ mse = full_mse.sum(dim=-1)
83
+ else:
84
+ mse = plain_mse
85
+ return mse, plain_mse
86
+
87
+
88
+ @dataclass
89
+ class LCMCriterionConfig(CriterionConfig):
90
+ compute_rmse: bool = True
91
+ """If `True` take the square-root of MSE.
92
+ This is for now `True` by default for backward compatibility"""
93
+
94
+
95
+ class LCMCriterion(Criterion):
96
+ """And abstract class for the LCM's criterions"""
97
+
98
+ config: LCMCriterionConfig
99
+
100
+ def __init__(
101
+ self,
102
+ config: LCMCriterionConfig,
103
+ model: AbstractLCModel,
104
+ style: LCMStyle = LCMStyle.UNSUPERVISED,
105
+ ):
106
+ super().__init__(config, model)
107
+
108
+ self.style = style
109
+
110
+ # Summands for log/tb recorders
111
+ self.summands = ["mse_loss", "reconstruction_loss"]
112
+
113
+ self.normalize_in_criterion = (
114
+ self.base_model.config.sonar_normalizer_name is not None
115
+ )
116
+
117
+ @property
118
+ def sonar_normalizer(self) -> Optional[SonarNormalizer]:
119
+ if hasattr(self.base_model, "sonar_normalizer"):
120
+ return self.base_model.sonar_normalizer
121
+
122
+ elif hasattr(self.base_model, "frontend") and hasattr(
123
+ self.base_model.frontend, "sonar_normalizer"
124
+ ):
125
+ return self.base_model.frontend.sonar_normalizer
126
+
127
+ else:
128
+ logger.warning(
129
+ "Couldn't find the model's `sonar_normalizer`, defaulting to None"
130
+ )
131
+ return None
132
+
133
+ @property
134
+ def throughput_metric_name(self) -> str:
135
+ return "num_target_elements"
136
+
137
+ @abstractmethod
138
+ def __call__(self, batch: LCMInput) -> LossTerm:
139
+ """
140
+ Computes the loss given an input batch.
141
+ The model's forward pass is performed here
142
+ Input batch is LCMInput (see `lcm.datasets.batch`):
143
+ """
lcm/train/lcm/trainer.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Dict, List, Mapping, Optional, Union
8
+
9
+ from fairseq2.assets import AssetCard
10
+ from fairseq2.checkpoint import FileCheckpointManager
11
+ from fairseq2.gang import Gang
12
+ from fairseq2.logging import get_log_writer
13
+ from fairseq2.metrics import MetricRecorder
14
+ from fairseq2.optim import DynamicLossScaler
15
+ from fairseq2.optim.lr_scheduler import AbstractLRScheduler
16
+ from fairseq2.utils.profiler import Profiler, Stopwatch
17
+ from fairseq2.utils.rng import RngBag
18
+ from omegaconf import MISSING
19
+ from stopes.core import Requirements
20
+ from torch.nn import Module
21
+ from torch.optim import Optimizer
22
+ import torch
23
+
24
+ from lcm.datasets.configs import ParquetDatasetConfig
25
+ from lcm.datasets.dataloader import LCMDataLoader
26
+ from lcm.datasets.dataloading import ds_name
27
+ from lcm.models.abstract_lcm import AbstractLCModelConfig
28
+ from lcm.models.base_lcm.loader import load_base_lcm_model
29
+ from lcm.train.criterion import CriterionsFactory
30
+ from lcm.train.metrics import LCMMetricBag
31
+ from lcm.train.mse_lcm.criterion import ReconstructionCriterionConfig
32
+ from lcm.train.trainer import Trainer, TrainerBuilder, TrainingConfig
33
+ from lcm.utils.card_utils import create_model_card
34
+
35
+ logger = get_log_writer(__name__)
36
+
37
+
38
+ @dataclass
39
+ class LCMTrainingConfig(TrainingConfig):
40
+ """Holds the configuration of an LCM training job."""
41
+
42
+ training_data: List[ParquetDatasetConfig] = field(default_factory=list)
43
+ """The datasets to train with.""" # TODO use dataset cards
44
+
45
+ validation_data: List[ParquetDatasetConfig] = field(default_factory=list)
46
+ """The datasets to validate on.""" # TODO use dataset cards
47
+
48
+ model_config_or_name: Union[AbstractLCModelConfig, str, None] = None
49
+ """The model configuration or name to train."""
50
+
51
+ requirements: Requirements = field(
52
+ default_factory=lambda: Requirements(
53
+ nodes=1,
54
+ tasks_per_node=8,
55
+ gpus_per_node=8,
56
+ cpus_per_task=8,
57
+ mem_gb=256,
58
+ timeout_min=3 * 24 * 60,
59
+ constraint="volta32gb",
60
+ )
61
+ )
62
+ """The scheduling requirements for this trainer"""
63
+
64
+ criterion: ReconstructionCriterionConfig = MISSING
65
+ """The MSE loss is the default base criterion used in either the `lcm` or `mse_lcm` trainers"""
66
+
67
+ max_subword_length: int = 512
68
+ """ Max subword length used to truncate seqs during sonar decoder backprop"""
69
+
70
+
71
+ class LCMTrainer(Trainer):
72
+ config: LCMTrainingConfig
73
+ model: Module
74
+ training_data_loader: LCMDataLoader
75
+ validation_data_loader: Optional[LCMDataLoader]
76
+ gang: Gang
77
+ optimizer: Optimizer
78
+ loss_scaler: DynamicLossScaler
79
+ lr_scheduler: AbstractLRScheduler
80
+ rng_bag: RngBag
81
+ step_nr: int
82
+ train_metric_bag: LCMMetricBag
83
+ valid_metric_bag: Mapping[str, LCMMetricBag]
84
+ metric_recorders: List[MetricRecorder]
85
+ profiler: Profiler
86
+ stopwatch: Stopwatch
87
+
88
+ def __init__(
89
+ self,
90
+ config: LCMTrainingConfig,
91
+ model: Module,
92
+ training_data_loader: LCMDataLoader,
93
+ validation_data_loader: Optional[LCMDataLoader],
94
+ gang: Gang,
95
+ checkpoint_manager: FileCheckpointManager,
96
+ rng_bag: RngBag,
97
+ stopwatch: Stopwatch,
98
+ card_metadata: Dict,
99
+ ) -> None:
100
+ super().__init__(
101
+ config,
102
+ model,
103
+ training_data_loader,
104
+ validation_data_loader,
105
+ gang,
106
+ checkpoint_manager,
107
+ rng_bag,
108
+ stopwatch,
109
+ card_metadata=card_metadata,
110
+ )
111
+
112
+ def setup_criterion(self):
113
+ return CriterionsFactory.build_criterion(
114
+ name=self.config.criterion.name,
115
+ config=self.config.criterion,
116
+ model=self.model,
117
+ )
118
+
119
+ def setup_metric_bags(self):
120
+ self.train_metric_bag = LCMMetricBag(
121
+ self.gang,
122
+ loss_summands=self.criterion.summands,
123
+ reduction=self.criterion.reduction,
124
+ )
125
+
126
+ self.register_non_stateful(
127
+ "valid_metric_bag",
128
+ {
129
+ ds_name(dataset): LCMMetricBag(
130
+ self.gang,
131
+ loss_summands=self.criterion.summands,
132
+ reduction=self.criterion.reduction,
133
+ )
134
+ for dataset in self.config.validation_data
135
+ },
136
+ )
137
+
138
+ def create_model_card_for_last_checkpoint(
139
+ self, is_final: bool = True, **card_kwargs
140
+ ) -> Optional[AssetCard]:
141
+ """Create a model card based on the last saved
142
+ checkpoint and the model config."""
143
+
144
+ current_step_number: Optional[int] = None
145
+ if is_final:
146
+ steps = self.checkpoint_manager.get_step_numbers()
147
+ current_step_number = steps[-1] if len(steps) else None
148
+ else:
149
+ current_step_number = self.checkpoint_manager._get_checkpoint_step_nr()
150
+
151
+ if current_step_number is None:
152
+ logger.warning(
153
+ "No checkpoint was saved, the final model card wil not be created"
154
+ )
155
+ return None
156
+
157
+ cp_fn = (
158
+ self.checkpoint_manager._checkpoint_dir
159
+ / f"step_{current_step_number}"
160
+ / "model.pt" # type: ignore
161
+ )
162
+
163
+ card = create_model_card(
164
+ checkpoint_path=cp_fn.absolute(),
165
+ model_arch=self.card_metadata["model_arch"],
166
+ model_config=self.card_metadata["model_config"],
167
+ model_type=self.card_metadata["model_type"],
168
+ **card_kwargs,
169
+ )
170
+ return card
171
+
172
+
173
+ class LCMTrainerBuilder(TrainerBuilder):
174
+ config: LCMTrainingConfig
175
+
176
+ def __init__(self, config: LCMTrainingConfig):
177
+ super().__init__(config)
178
+
179
+ def load_data(self):
180
+ """Load training and validation data"""
181
+
182
+ training_data_loader = LCMDataLoader(
183
+ data_config=self.config.data_loading_config,
184
+ datasets=self.config.training_data,
185
+ max_subword_length=self.config.max_subword_length,
186
+ dtype=self.dtype,
187
+ gang=self.gang,
188
+ )
189
+
190
+ validation_data_loader = LCMDataLoader(
191
+ data_config=self.config.validation_data_loading_config,
192
+ datasets=self.config.validation_data,
193
+ max_subword_length=self.config.max_subword_length,
194
+ dtype=self.dtype,
195
+ gang=self.gang,
196
+ )
197
+
198
+ return training_data_loader, validation_data_loader
199
+
200
+ @property
201
+ def model_loader(self):
202
+ """A fairseq2 ModelLoader"""
203
+ return load_base_lcm_model
204
+
205
+ def build_trainer(self):
206
+ """Build the trainer by loading data and
207
+ setting up the model for training"""
208
+
209
+ training_data_loader, validation_data_loader = self.load_data()
210
+
211
+ checkpoint_manager = FileCheckpointManager(
212
+ self.config.output_dir.joinpath("checkpoints"),
213
+ self.gang,
214
+ )
215
+
216
+ self.has_checkpoint = checkpoint_manager.has_checkpoint()
217
+
218
+ model = self.create_model()
219
+
220
+ # Force all model parameters to bfloat16 regardless of submodule defaults
221
+ model = model.to(dtype=torch.bfloat16)
222
+
223
+ model = self.maybe_load_model(model)
224
+
225
+ model = self.maybe_freeze_parameters(model)
226
+
227
+ # If using the META device, we need to move the model to gang.device
228
+ wrapped_model = None
229
+
230
+ if self.use_fsdp:
231
+ wrapped_model = self.wrap_model_with_fsdp(model)
232
+ elif self.use_ddp:
233
+ wrapped_model = self.wrap_model_with_ddp(model) # type: ignore
234
+
235
+ trainer = LCMTrainer(
236
+ self.config, # type: ignore
237
+ wrapped_model or model,
238
+ training_data_loader,
239
+ validation_data_loader,
240
+ self.gang,
241
+ checkpoint_manager,
242
+ self.rng_bag,
243
+ self.stopwatch,
244
+ card_metadata=self.card_metadata,
245
+ )
246
+
247
+ trainer.setup()
248
+
249
+ if self.has_checkpoint:
250
+ trainer.restore()
251
+
252
+ return trainer
253
+
254
+
255
+ def prepare_lcm_trainer(config: LCMTrainingConfig) -> LCMTrainer:
256
+ """Create an LCM Trainer.
257
+ :param config: The training configuration.
258
+ """
259
+ return LCMTrainerBuilder(config).build_trainer()
lcm/train/metrics.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from collections.abc import MutableMapping
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ from pathlib import Path
10
+ from typing import (
11
+ Any,
12
+ Callable,
13
+ Dict,
14
+ List,
15
+ Mapping,
16
+ Optional,
17
+ Sequence,
18
+ Set,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import torch
24
+ from fairseq2.gang import Gang
25
+ from fairseq2.logging import get_log_writer
26
+ from fairseq2.metrics import (
27
+ MetricBag,
28
+ format_as_float,
29
+ format_as_int,
30
+ format_as_seconds,
31
+ )
32
+ from fairseq2.metrics.recorder import (
33
+ MetricRecorder,
34
+ _metric_formatters,
35
+ register_metric_formatter,
36
+ )
37
+ from fairseq2.typing import override
38
+ from torch import Tensor
39
+ from torch.cuda import _get_device_index
40
+ from torcheval.metrics import Max, Mean, Sum, Throughput
41
+
42
+ logger = get_log_writer(__name__)
43
+
44
+ format_as_percent = partial(format_as_int, postfix="%")
45
+
46
+
47
+ def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> Dict:
48
+ """
49
+ A helper function to flatten nested dictionaries
50
+ Example. With a training config like
51
+ config = {
52
+ 'data': {
53
+ 'training': {'batch_size': 10},
54
+ 'validation': {'batch_size': 2}
55
+ },
56
+ 'model': {'model_dim': 1024},
57
+ 'use_fsdp': True
58
+ }
59
+ The flat config will be:
60
+ {
61
+ 'data.training.batch_size': 10,
62
+ 'data.validation.batch_size': 2,
63
+ 'model.model_dim': 1024,
64
+ 'use_fsdp': True
65
+ }
66
+ This helper is used to convert our nested training config into a flat
67
+ dictionary for Tensoarboard's HParams conusmption
68
+
69
+ """
70
+ items: List = []
71
+ for k, v in d.items():
72
+ new_key = parent_key + sep + k if parent_key else k
73
+ if isinstance(v, MutableMapping):
74
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
75
+ else:
76
+ items.append((new_key, v))
77
+ return dict(items)
78
+
79
+
80
+ def get_allocated_gpu_memory(device):
81
+ """
82
+ Get allocated memory in GiB for GPU devices
83
+ """
84
+ if device.type == "cpu":
85
+ return 0, 0
86
+ device = _get_device_index(device, optional=True)
87
+ memory_stats = torch.cuda.memory_stats(device=device)
88
+ current_usage = memory_stats["allocated_bytes.all.current"] / (1024**3)
89
+ peak_usage = memory_stats["allocated_bytes.all.peak"] / (1024**3)
90
+ return current_usage, peak_usage
91
+
92
+
93
+ @dataclass
94
+ class LossTerm:
95
+ """Dataclass for a batch loss term"""
96
+
97
+ value: Tensor
98
+ """The final loss to be optimized"""
99
+
100
+ batch_size: int
101
+
102
+ num_target_elements: Union[int, float]
103
+
104
+ summands: Dict[str, Tuple[Any, Any]] = field(default_factory=lambda: {})
105
+ """A dictionary of loss terms to record. Each term is a tuple of (loss, number of elements)
106
+ The second term is optional; if None, we will use `num_target_elements` when aggregating"""
107
+
108
+
109
+ class LCMMetricBag(MetricBag):
110
+ """Holds the common metrics of an LCM."""
111
+
112
+ loss: Mean
113
+ batch_size: Sum
114
+ elements_per_batch: Mean
115
+ elements_per_second: Throughput
116
+ num_target_elements: Sum
117
+ total_num_target_elements: Sum
118
+
119
+ grad_norm: Mean
120
+
121
+ def __init__(
122
+ self, gang: Gang, loss_summands: Sequence[str] = [], reduction: str = "sum"
123
+ ) -> None:
124
+ """
125
+ :param gang:
126
+ The gang to sync metrics across all processes.
127
+ """
128
+ super().__init__(gang)
129
+
130
+ # temporary fix:
131
+
132
+ self.reduction = reduction
133
+
134
+ d = gang.device
135
+
136
+ # A temporary solution to track as many loss terms as we explore
137
+ self.loss_summands = loss_summands
138
+
139
+ self.register_metric("loss", Mean(device=d), persistent=False)
140
+
141
+ # this is the effective batch size
142
+ self.register_metric("batch_size", Sum(device=d), persistent=False)
143
+
144
+ self.register_metric("elements_per_batch", Mean(device=d), persistent=False)
145
+
146
+ self.register_metric(
147
+ "elements_per_second", Throughput(device=d), persistent=False
148
+ )
149
+
150
+ self.register_metric("gpu_memory_usage", Max(device=d), persistent=False)
151
+
152
+ self.register_metric("gpu_peak_memory_usage", Max(device=d), persistent=False)
153
+
154
+ # self.register_metric("ram_percentage", Max(device=d), persistent=False)
155
+
156
+ # self.register_metric("cpu_percentage", Max(device=d), persistent=False)
157
+
158
+ for summand in self.loss_summands:
159
+ self.register_metric(summand, Mean(device=d), persistent=False)
160
+
161
+ # The number of target tokens in a parallel batch. Used for computing throughput
162
+ self.register_metric("num_target_elements", Sum(device=d), persistent=False)
163
+
164
+ # The total_num_target_elements is persistent and is supposed to track the
165
+ # total number of tokens consumed since training started
166
+ self.total_num_target_elements = Sum(device=d)
167
+
168
+ def register_adaln_metric(self, module_name: str):
169
+ for block in ["mha", "ffn"]:
170
+ for tensor in [
171
+ "shift",
172
+ "scale",
173
+ "gate",
174
+ ]:
175
+ self.register_metric(
176
+ f"{module_name}_{block}_{tensor}_mean",
177
+ Mean(device=self._gang.device),
178
+ persistent=False,
179
+ )
180
+ self.register_metric(
181
+ f"{module_name}_{block}_{tensor}_std",
182
+ Mean(device=self._gang.device),
183
+ persistent=False,
184
+ )
185
+ # formatters
186
+ register_metric_formatter(
187
+ f"{module_name}_{block}_{tensor}_mean",
188
+ f"{module_name}_{block}_{tensor}_mean",
189
+ 1000,
190
+ format_as_float,
191
+ )
192
+ register_metric_formatter(
193
+ f"{module_name}_{block}_{tensor}_std",
194
+ f"{module_name}_{block}_{tensor}_std",
195
+ 1000,
196
+ format_as_float,
197
+ )
198
+
199
+ def register_module_metric(self, module_name: str):
200
+ for tensor in [
201
+ "input_gradient",
202
+ "output_gradient",
203
+ "input_activations",
204
+ "output_activations",
205
+ ]:
206
+ self.register_metric(
207
+ f"{module_name}_{tensor}_mean",
208
+ Mean(device=self._gang.device),
209
+ persistent=False,
210
+ )
211
+ self.register_metric(
212
+ f"{module_name}_{tensor}_std",
213
+ Mean(device=self._gang.device),
214
+ persistent=False,
215
+ )
216
+ # formatters
217
+ register_metric_formatter(
218
+ f"{module_name}_{tensor}_mean",
219
+ f"{module_name}_{tensor}_mean",
220
+ 1000,
221
+ format_as_float,
222
+ )
223
+ register_metric_formatter(
224
+ f"{module_name}_{tensor}_std",
225
+ f"{module_name}_{tensor}_std",
226
+ 1000,
227
+ format_as_float,
228
+ )
229
+
230
+ @torch.inference_mode()
231
+ def update(
232
+ self,
233
+ losses: Sequence[LossTerm],
234
+ ) -> None:
235
+ """Update the metrics.
236
+
237
+ :param output:
238
+ The losses generated by the model for each batch
239
+ :param elapsed_time:
240
+ The total elapsed time to read and process batches
241
+ """
242
+
243
+ loss = torch.zeros((), dtype=torch.float64)
244
+
245
+ loss_summands = {
246
+ s: torch.zeros((), dtype=torch.float64) for s in self.loss_summands
247
+ }
248
+ # Denominator to normalize the loss summands, if -1,
249
+ # we will default to normalizing with `num_target_elements`
250
+ loss_summands_numel = {
251
+ s: -torch.ones((), dtype=torch.long) for s in self.loss_summands
252
+ }
253
+
254
+ batch_size = torch.zeros((), dtype=torch.int64)
255
+
256
+ num_target_elements = torch.zeros((), dtype=torch.int64)
257
+
258
+ # Only in the case of using gradient accumulation that `losses` will be a non-singleton
259
+ for batch_loss in losses:
260
+ loss += float(batch_loss.value)
261
+
262
+ for s in self.loss_summands:
263
+ loss_term = batch_loss.summands.get(s, (0.0, None))
264
+ loss_summands[s] += float(loss_term[0])
265
+ if loss_term[1] is not None and not loss_term[1] == -1:
266
+ if loss_summands_numel[s] == -1:
267
+ loss_summands_numel[s] = torch.zeros((), dtype=torch.int64)
268
+ loss_summands_numel[s] += loss_term[1]
269
+
270
+ batch_size += batch_loss.batch_size
271
+ num_target_elements += batch_loss.num_target_elements
272
+
273
+ # Misleading normalization in the metric bag with reduction == "mean"
274
+ # Kept here for backward compatibility
275
+ # Any normalization here is only for reporting and doesn't impact optimization
276
+ if self.reduction == "sum":
277
+ loss /= num_target_elements
278
+ keys = list(loss_summands)
279
+ for k in keys:
280
+ denom = loss_summands_numel[k]
281
+ if denom == -1:
282
+ denom = num_target_elements
283
+ loss_summands[k] /= denom + 1e-6
284
+
285
+ self.loss.update(loss, weight=num_target_elements)
286
+
287
+ for s in loss_summands:
288
+ weight = loss_summands_numel[s]
289
+ if weight == -1:
290
+ weight = num_target_elements
291
+ getattr(self, s).update(loss_summands[s], weight=weight)
292
+
293
+ self.batch_size.update(batch_size)
294
+
295
+ self.elements_per_batch.update(num_target_elements)
296
+
297
+ self.num_target_elements.update(num_target_elements)
298
+
299
+ # update the cumulative metric
300
+ self.total_num_target_elements.update(num_target_elements)
301
+
302
+ # Get GPU memory usage
303
+ gpu_memory_usage, gpu_peak_memory_usage = get_allocated_gpu_memory(
304
+ self._gang.device
305
+ )
306
+ self.gpu_memory_usage.update(torch.tensor(gpu_memory_usage))
307
+ self.gpu_peak_memory_usage.update(torch.tensor(gpu_peak_memory_usage))
308
+
309
+ def reset_batch_metrics(self) -> None:
310
+ """Reset the batch metrics to their initial state."""
311
+ self.loss.reset()
312
+ for s in self.loss_summands:
313
+ getattr(self, s).reset()
314
+
315
+ self.batch_size.reset()
316
+ self.elements_per_batch.reset()
317
+ self.elements_per_second.reset()
318
+ self.grad_norm.reset()
319
+ self.gpu_memory_usage.reset()
320
+ self.gpu_peak_memory_usage.reset()
321
+ # self.ram_percentage.reset()
322
+ # self.cpu_percentage.reset()
323
+
324
+
325
+ ## Weight and Biases recorder
326
+
327
+ try:
328
+ import wandb # type: ignore[import-not-found]
329
+ except ImportError:
330
+ has_wandb = False
331
+ else:
332
+ has_wandb = True
333
+
334
+
335
+ class LCMWandBRecorder(MetricRecorder):
336
+ """Records metric values to Weights & Biases."""
337
+
338
+ defined_runs: Set[str] = set()
339
+
340
+ def __init__(
341
+ self,
342
+ project: Optional[str] = None,
343
+ name: Optional[str] = None,
344
+ output_dir: Optional[Path] = None,
345
+ config: Dict[str, Any] = {},
346
+ **kwargs,
347
+ ) -> None:
348
+ """
349
+ :param project: A project to organise this run with other experiments, if none, the run will go under `uncategorized`.
350
+ :param name: A unique name for your run, if none is given, a random name will be generated
351
+ :param output_dir: The base directory under which to store the W&B files. You don't have to provide this.
352
+ :param config: A dictionary of key-value pairs to be stored as the experiment's config. (akin to hparams in tb)
353
+ :param kwargs: Additional arguments to pass to wandb.init()
354
+
355
+ In order to use W&B, run `wandb login` from the command line and enter
356
+ the API key when prompted.
357
+ """
358
+ if not has_wandb:
359
+ log = get_log_writer(__name__)
360
+ log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip
361
+
362
+ self._run = None
363
+ else:
364
+ if output_dir:
365
+ output_dir.mkdir(parents=True, exist_ok=True)
366
+ self._run = wandb.init( # type: ignore
367
+ project=project,
368
+ name=name,
369
+ dir=output_dir,
370
+ resume="allow",
371
+ config=config,
372
+ **kwargs,
373
+ )
374
+
375
+ def _define_run(self, run: str):
376
+ if run in self.defined_runs:
377
+ return
378
+ # https://docs.wandb.ai/guides/track/log/customize-logging-axes/
379
+ wandb.define_metric(f"{run}/step")
380
+ wandb.define_metric(f"{run}/*", step_metric=f"{run}/step")
381
+
382
+ @override
383
+ def record_metrics(
384
+ self,
385
+ run: str,
386
+ values: Mapping[str, Any],
387
+ step_nr: Optional[int] = None,
388
+ *,
389
+ flush: bool = True,
390
+ ) -> None:
391
+ if self._run is None:
392
+ return
393
+
394
+ self._define_run(run)
395
+
396
+ for name, value in values.items():
397
+ formatter = _metric_formatters.get(name)
398
+ if formatter is None:
399
+ display_name = name
400
+ else:
401
+ display_name = formatter.display_name
402
+
403
+ self._run.log({f"{run}/{display_name}": value, f"{run}/step": step_nr})
404
+
405
+ @override
406
+ def close(self) -> None:
407
+ if self._run is not None:
408
+ self._run.finish()
409
+
410
+
411
+ lcm_metric_formatters: Dict[str, Tuple[str, int, Callable[[Any], str]]] = {
412
+ # fmt: off
413
+ "loss": ("Loss", 100, format_as_float),
414
+ "nll_loss": ("NLL Loss", 100, format_as_float),
415
+ "mse_loss": ("MSE Loss", 100, format_as_float),
416
+ "contrastive_loss": ("Contrastive Loss", 110, format_as_float),
417
+ "reconstruction_loss": ("Reconstruction loss", 110, format_as_float),
418
+ "unnormalized_reconstruction_loss": (
419
+ "Unnormalized Reconstruction Loss",
420
+ 110,
421
+ format_as_float,
422
+ ),
423
+ "kld": ("KLD loss", 110, format_as_float),
424
+ "encoder_mse_loss": ("Encoder MSE loss", 110, format_as_float),
425
+ "decoder_ce_loss": ("Decoder CE loss", 110, format_as_float),
426
+ "elapsed_time": ("Elapsed Time", 500, format_as_seconds),
427
+ "wall_time": ("Wall Time", 510, format_as_seconds),
428
+ "lr": ("Learning Rate", 800, format_as_float),
429
+ "loss_scale": ("Loss Scale", 810, format_as_float),
430
+ "grad_norm": ("Grad norm", 810, format_as_float),
431
+ "raw_grad_norm": ("Raw Grad norm", 815, format_as_float),
432
+ "encoder_mse_scale": ("Encoder MSE loss scale", 850, format_as_float),
433
+ "batch_size": ("Batch Size", 900, format_as_int),
434
+ "elements_per_batch": ("Elements per Batch", 900, format_as_int),
435
+ "elements_per_second": ("Elements per Second", 900, format_as_int),
436
+ "num_examples": ("Number of Examples", 900, format_as_int),
437
+ "num_source_elements": ("Number of Source Elements", 900, format_as_int),
438
+ "num_target_elements": ("Number of Target Elements", 900, format_as_int),
439
+ "total_num_target_elements": ("Accumulated Target Elements", 920, format_as_int),
440
+ "gpu_memory_usage": ("GPU memory usage (GiB)", 910, format_as_float),
441
+ "gpu_peak_memory_usage": ("GPU peak memory usage (GiB)", 910, format_as_float),
442
+ "ram_percentage": ("RAM usage", 920, format_as_percent),
443
+ "cpu_percentage": ("CPU usage", 920, format_as_percent),
444
+ "mean_predicted_embeddings": ("mean_predicted_embeddings", 920, format_as_float),
445
+ "std_predicted_embeddings": ("std_predicted_embeddings", 920, format_as_float),
446
+ # fmt: on
447
+ }
448
+ for key in lcm_metric_formatters:
449
+ register_metric_formatter(key, *lcm_metric_formatters[key], overwrite=True)
lcm/train/mse_lcm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
lcm/train/mse_lcm/criterion.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ from fairseq2.logging import get_log_writer
11
+ from torch import Tensor
12
+
13
+ from lcm.datasets.batch import EmbeddingsBatch, LCMInput, LCMStyle
14
+ from lcm.models.abstract_lcm import AbstractLCModel
15
+ from lcm.train.criterion import CriterionsFactory
16
+ from lcm.train.lcm.criterion import (
17
+ LCMCriterion,
18
+ LCMCriterionConfig,
19
+ compute_standard_mse,
20
+ )
21
+ from lcm.train.metrics import LossTerm
22
+
23
+ logger = get_log_writer(__name__)
24
+
25
+
26
+ @dataclass
27
+ class ReconstructionCriterionConfig(LCMCriterionConfig):
28
+ min_context_size: int = 1
29
+ """minimum context size for next sentence prediction"""
30
+
31
+
32
+ @CriterionsFactory.register("next_sentence_mse")
33
+ class ReconstructionCriterion(LCMCriterion):
34
+ """Computes the MSE reconstruction loss for next-sentence prediction"""
35
+
36
+ config: ReconstructionCriterionConfig
37
+
38
+ def __init__(
39
+ self,
40
+ config: ReconstructionCriterionConfig,
41
+ model: AbstractLCModel,
42
+ style: LCMStyle = LCMStyle.UNSUPERVISED,
43
+ ):
44
+ super().__init__(config, model, style)
45
+
46
+ if style is not LCMStyle.SUPERVISED:
47
+ assert (
48
+ config.min_context_size is not None and config.min_context_size > 0
49
+ ), (
50
+ "For unsupervised pre-training, expecting a min_context_size of at least 1. "
51
+ f"Received min_context_size={config.min_context_size}. "
52
+ "Note that we need some context to predict the first position and "
53
+ "this context can come from a dummy `beginning of document (BOD)` vector. "
54
+ "With a minimum context size of 1 we ensure that we never ask the model to predict BOD"
55
+ )
56
+
57
+ self.min_context_size = config.min_context_size
58
+
59
+ def prepare_input_and_mask(
60
+ self,
61
+ batch: LCMInput,
62
+ ) -> Tuple[EmbeddingsBatch, torch.Tensor]:
63
+ """
64
+ A method for preparing model inputs and mask for a batch.
65
+ It will be typically reused by the `__call__`
66
+ implementations of the subclasses.
67
+ """
68
+ input_embeddings = batch.prepare_input(style=self.style)
69
+
70
+ target_mask = batch.prepare_target_mask(
71
+ input_embeddings,
72
+ style=self.style,
73
+ min_context_size=self.config.min_context_size,
74
+ )
75
+
76
+ return input_embeddings, target_mask
77
+
78
+ def __call__(self, batch: LCMInput) -> LossTerm:
79
+ """
80
+ Args:
81
+ batch is an LCMInput (see lcm.datasets.batch):
82
+
83
+ Returns a LossTerm
84
+ """
85
+
86
+ # prepare_input_and mask returns embeddings with seqs in B,T,C
87
+ # and a target mask in B,T,C. Note that the first position is never used as target
88
+ # (i.e. BOS vector or first sentence in the document) and will always be set to False
89
+ # in the target mask
90
+ input_embeddings, target_mask = self.prepare_input_and_mask(batch)
91
+
92
+ if self.normalize_in_criterion:
93
+ # the input to the model will be normalize and
94
+ # so is the target used for loss computation
95
+ input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
96
+
97
+ # Predict model outputs
98
+ output_embeddings = self.model(input_embeddings)
99
+
100
+ # Prepare predictions and targets:
101
+ # Shift the input to remove the first position.
102
+ # Shifted seqs from input_embeddings are used as ground truth target embeddings
103
+ target_seqs = input_embeddings.seqs[:, 1:].contiguous()
104
+ batch_size, _, sonar_dim = target_seqs.size()
105
+
106
+ # shift and flatten
107
+ target_mask = target_mask[:, 1:].reshape(-1)
108
+ # i.e. s2, s3, s4, s5
109
+
110
+ # Trim the last position.
111
+ # output_seqs represent contextualized embeddings / predictions for the next sentence
112
+ # This shifting/trimming allows us to predict `s_t` conditioned on `s_{<t}`
113
+ predicted_seqs = output_embeddings.seqs[:, :-1].contiguous()
114
+ # i.e. s<=1, s<=2, s<=3, s<=4
115
+
116
+ # only measure distance over `target_mask = True` positions
117
+ flattened_predictions = predicted_seqs.view(-1, sonar_dim)[target_mask]
118
+ flattened_target = target_seqs.view(-1, sonar_dim)[target_mask]
119
+
120
+ # Cast features to float32 before computing the loss:
121
+ reconstruction_loss, mse_loss = self.compute_loss(
122
+ flattened_predictions.float(), flattened_target.float()
123
+ )
124
+
125
+ num_target_elements = target_mask.sum()
126
+
127
+ if self.reduction == "sum" or num_target_elements == 0:
128
+ reduced_reconstruction_loss = reconstruction_loss.sum()
129
+ mse_loss = mse_loss.sum()
130
+
131
+ elif self.reduction == "mean":
132
+ reduced_reconstruction_loss = reconstruction_loss.mean()
133
+ mse_loss = mse_loss.mean()
134
+
135
+ final_loss = reduced_reconstruction_loss
136
+
137
+ # Loss summands for records
138
+ summands = {
139
+ "mse_loss": (mse_loss.item(), None),
140
+ "reconstruction_loss": (reduced_reconstruction_loss.item(), None),
141
+ }
142
+
143
+ return LossTerm(
144
+ value=final_loss,
145
+ batch_size=batch_size,
146
+ num_target_elements=num_target_elements.item(),
147
+ summands=summands,
148
+ )
149
+
150
+ def compute_loss(
151
+ self, flattened_predictions, flattened_target
152
+ ) -> Tuple[Tensor, Tensor]:
153
+ """
154
+ Computes the following loss terms:
155
+ 1. The Reconstruction loss we want to optimize as well as:
156
+ 2. RMSE loss (for tracking) (in this parent class, RMSE=Reconstruction loss)
157
+ Returns reconstruction_loss, mse_loss
158
+ """
159
+ reconstruction_loss, _ = compute_standard_mse(
160
+ flattened_predictions, flattened_target
161
+ )
162
+ if self.config.compute_rmse:
163
+ epsilon = 1e-5
164
+ reconstruction_loss = torch.sqrt(reconstruction_loss + epsilon)
165
+
166
+ return reconstruction_loss, reconstruction_loss
167
+
168
+
169
+ @CriterionsFactory.register("target_mse")
170
+ class TargetMSECriterion(ReconstructionCriterion):
171
+ """Computes the LCM training objective given source/target pairs"""
172
+
173
+ def __init__(
174
+ self,
175
+ config: ReconstructionCriterionConfig,
176
+ model: AbstractLCModel,
177
+ style: LCMStyle = LCMStyle.SUPERVISED,
178
+ ):
179
+ super().__init__(config, model, style)
lcm/train/optim.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from typing import Tuple
7
+
8
+ from fairseq2.logging import get_log_writer
9
+ from fairseq2.optim.lr_scheduler import (
10
+ AbstractLRScheduler,
11
+ CosineAnnealingLR,
12
+ MyleLR,
13
+ NoopLR,
14
+ PolynomialDecayLR,
15
+ TriStageLR,
16
+ )
17
+ from torch.optim import Optimizer
18
+
19
+ logger = get_log_writer(__name__)
20
+
21
+
22
+ def build_lr_scheduler(
23
+ optimizer: Optimizer,
24
+ lr: float,
25
+ warmup_steps: int,
26
+ start_lr: float = 1e-7,
27
+ final_lr: float = 1e-5,
28
+ max_steps: int = 10_000,
29
+ stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5),
30
+ schedule: str = "myle",
31
+ ) -> AbstractLRScheduler:
32
+ assert schedule in [
33
+ "noop",
34
+ "myle",
35
+ "cosine",
36
+ "wsd",
37
+ "polynomial",
38
+ ], (
39
+ f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported"
40
+ )
41
+
42
+ assert lr > 0, "The learning reate should be strictly positive"
43
+
44
+ lr_scheduler: AbstractLRScheduler
45
+
46
+ if schedule == "noop":
47
+ lr_scheduler = NoopLR(optimizer)
48
+
49
+ elif schedule == "myle":
50
+ lr_scheduler = MyleLR(
51
+ optimizer,
52
+ num_warmup_steps=warmup_steps,
53
+ start_lr=[start_lr],
54
+ )
55
+
56
+ elif schedule == "cosine":
57
+ lr_scheduler = CosineAnnealingLR(
58
+ optimizer,
59
+ cycle_len=max_steps - warmup_steps + 1,
60
+ num_warmup_steps=warmup_steps,
61
+ start_lr=[start_lr],
62
+ final_lr=[final_lr],
63
+ cycle_mul=1.0,
64
+ lr_mul=1.0,
65
+ )
66
+
67
+ elif schedule == "wsd":
68
+ assert lr > start_lr, (
69
+ f"the starting learning rate {start_lr} should be lesser than the main lr {lr}"
70
+ )
71
+ start_lr_scale = start_lr / lr
72
+
73
+ assert lr > final_lr, (
74
+ f"the final learning rate {final_lr} should be lesser than the main lr {lr}"
75
+ )
76
+ final_lr_scale = final_lr / lr
77
+
78
+ lr_scheduler = TriStageLR(
79
+ optimizer,
80
+ max_steps,
81
+ stage_ratio=stage_ratio, # type: ignore
82
+ start_lr_scale=start_lr_scale,
83
+ final_lr_scale=final_lr_scale,
84
+ )
85
+
86
+ elif schedule == "polynomial":
87
+ lr_scheduler = PolynomialDecayLR(
88
+ optimizer,
89
+ max_steps,
90
+ warmup_steps,
91
+ power=200,
92
+ start_lr=start_lr,
93
+ final_lr=final_lr,
94
+ )
95
+
96
+ return lr_scheduler
lcm/train/step_sampler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Literal, Optional
8
+
9
+ import torch
10
+ import torch.distributions as D
11
+ from fairseq2.logging import get_log_writer
12
+ from torch import Tensor
13
+
14
+ from lcm.nn.schedulers import DDIMScheduler
15
+
16
+ SUPPORTED_SAMPLERS = Literal["uniform", "beta"]
17
+ SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"]
18
+
19
+ logger = get_log_writer(__name__)
20
+
21
+
22
+ def beta_function(a, b):
23
+ result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b))
24
+ return result
25
+
26
+
27
+ @dataclass
28
+ class StepsSamplerConfig:
29
+ sampling: SUPPORTED_SAMPLERS = "uniform"
30
+ weighting: SUPPORTED_WEIGHTINGS = "none"
31
+ beta_a: float = 0.8
32
+ beta_b: float = 1
33
+ max_gamma: float = 5.0
34
+ min_gamma: float = 0
35
+
36
+
37
+ class StepsSampler(object):
38
+ def __init__(
39
+ self,
40
+ config: StepsSamplerConfig,
41
+ noise_scheduler: DDIMScheduler,
42
+ ):
43
+ num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps
44
+ weights: Optional[Tensor] = None
45
+
46
+ if config.sampling == "uniform":
47
+ weights = torch.ones(
48
+ num_diffusion_train_steps,
49
+ )
50
+
51
+ elif config.sampling == "beta":
52
+ # As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf
53
+ a = torch.tensor([config.beta_a])
54
+ b = torch.tensor([config.beta_b])
55
+ # a=1, b=1 -> uniform
56
+ # The paper empirically chooses b=1, a=0.8 < 1
57
+
58
+ steps = (
59
+ torch.arange(1, num_diffusion_train_steps + 1)
60
+ / num_diffusion_train_steps
61
+ )
62
+ weights = (
63
+ 1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1))
64
+ )
65
+
66
+ assert weights is not None, "The sampling weights were not properly set!"
67
+ logger.info(f"Training with sampling weights={weights}")
68
+
69
+ self.distrib = D.Categorical(
70
+ probs=weights / weights.sum(),
71
+ )
72
+
73
+ # setup weights for scaling:
74
+ if config.weighting == "none":
75
+ self.gamma_per_step = None
76
+
77
+ elif config.weighting == "clamp_snr":
78
+ # Min-SNR scheme from
79
+ # https://arxiv.org/abs/2303.09556
80
+ snrs = noise_scheduler.get_snrs()
81
+ # gamma(t) = min(max_gamma, snr(t))
82
+ self.gamma_per_step = torch.clamp(
83
+ snrs, max=config.max_gamma, min=config.min_gamma
84
+ )
85
+
86
+ logger.info(f"Training with Gamma={self.gamma_per_step}")
87
+
88
+ @property
89
+ def _training_weights(self) -> Tensor:
90
+ return self.distrib.probs
91
+
92
+ def sample(self, size: torch.Size, device: torch.device):
93
+ samples = self.distrib.sample(size).to(device)
94
+ # print('Samples', samples)
95
+ # print('Counts:', torch.bincount(samples.flatten()))
96
+ return samples
97
+
98
+ def get_loss_scales(self, steps):
99
+ if self.gamma_per_step is None:
100
+ return None
101
+
102
+ # If we're using constant Gamma=1 (returning None), then the sum of
103
+ # the loss scales is steps.numel(), to match the total mass,
104
+ # we normalize the scales to sum to steps.numel()
105
+ gamma = self.gamma_per_step.to(steps.device)[steps]
106
+ gamma = gamma / gamma.sum() * steps.numel()
107
+ return gamma
lcm/train/trainer.py ADDED
@@ -0,0 +1,1422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ import gc
7
+ import logging
8
+ import os
9
+ import sys
10
+ from abc import abstractmethod
11
+ from contextlib import nullcontext
12
+ from dataclasses import asdict, dataclass, field
13
+ from functools import cached_property
14
+ from itertools import count
15
+ from pathlib import Path
16
+ from pprint import pformat
17
+ from typing import (
18
+ Any,
19
+ ContextManager,
20
+ Dict,
21
+ Iterator,
22
+ List,
23
+ Literal,
24
+ Mapping,
25
+ Optional,
26
+ Tuple,
27
+ )
28
+
29
+ import torch
30
+ import yaml
31
+ from fairseq2.assets import AssetCard, AssetCardFieldNotFoundError
32
+ from fairseq2.checkpoint import FileCheckpointManager
33
+ from fairseq2.gang import FakeGang, Gang, ReduceOperation, all_sum
34
+ from fairseq2.logging import get_log_writer
35
+ from fairseq2.metrics import (
36
+ LogMetricRecorder,
37
+ MetricBag,
38
+ MetricRecorder,
39
+ TensorBoardRecorder,
40
+ record_metrics,
41
+ )
42
+ from fairseq2.nn.ddp import to_ddp
43
+ from fairseq2.nn.fsdp import to_fsdp
44
+ from fairseq2.nn.utils.gradient import (
45
+ check_gradient_norms,
46
+ clip_gradient_norm,
47
+ scale_gradients,
48
+ )
49
+ from fairseq2.nn.utils.module import (
50
+ _get_named_modules,
51
+ freeze_parameters,
52
+ to_device,
53
+ )
54
+ from fairseq2.optim import AdamW, DynamicLossScaler
55
+ from fairseq2.optim.lr_scheduler import AbstractLRScheduler, get_effective_lr
56
+ from fairseq2.recipes.utils.log import log_model
57
+ from fairseq2.utils.profiler import Profiler, Stopwatch
58
+ from fairseq2.utils.rng import RngBag
59
+ from fairseq2.utils.state import StatefulObjectBag
60
+ from omegaconf import MISSING
61
+ from stopes.core import Requirements
62
+ from torch.distributed.fsdp.fully_sharded_data_parallel import (
63
+ FullyShardedDataParallel as FSDP,
64
+ )
65
+ from torch.nn import Module
66
+ from torch.nn.parallel import DistributedDataParallel as DDP
67
+ from torch.optim import Optimizer
68
+ from torch.profiler import record_function
69
+ from torcheval.metrics import Mean
70
+
71
+ from lcm.datasets.configs import DataLoadingConfig, ValidationDataLoadingConfig
72
+ from lcm.datasets.dataloading import ds_name
73
+ from lcm.train.metrics import (
74
+ LCMWandBRecorder,
75
+ flatten_dict,
76
+ )
77
+ from lcm.train.optim import build_lr_scheduler
78
+ from lcm.utils.data_utils import update_dataclass
79
+ from lcm.utils.distributed import (
80
+ SUPPORTED_FSDP_MEMORY_POLICIES,
81
+ SUPPORTED_FSDP_WRAP_POLICIES,
82
+ get_fsdp_memory_policy,
83
+ get_fsdp_wrap_policy,
84
+ init_process_group,
85
+ )
86
+ from lcm.utils.logging import (
87
+ log_env_variables,
88
+ setup_additional_logging,
89
+ )
90
+
91
+ logger = get_log_writer(__name__)
92
+
93
+
94
+ @dataclass
95
+ class TrainingConfig:
96
+ """Holds the configuration of a training job."""
97
+
98
+ training_data: Any = MISSING
99
+ """The datasets to train with."""
100
+
101
+ validation_data: Any = MISSING
102
+ """The datasets to validate on."""
103
+
104
+ model_arch: Optional[str] = None
105
+ """Starting architecture for the model to train"""
106
+
107
+ model_arch_overrides: Optional[Dict] = None
108
+ """Dict of parameters to overwrite in `model_arch`"""
109
+
110
+ model_config_or_name: Optional[Any] = None
111
+ """The model configuration or name to train.
112
+ This option cannot be paired with model_arch + model_arch_overrides
113
+ If provided, this option supersedes model_arch + model_arch_overrides
114
+ """
115
+ output_dir: Path = MISSING
116
+ """The output directory to store checkpoints and logs."""
117
+
118
+ log_folder: Optional[Path] = None
119
+ """The executor's log directory where stdout/stderr will be redirected.
120
+ We will use this directory to optionally enable ATEN and NCCL
121
+ logging (if debug is True) """
122
+
123
+ tb_dir: Optional[Path] = None
124
+ """The output directory to store tensorbaord logs"""
125
+
126
+ # defaults to "uncategorized"
127
+ wandb_project: Optional[str] = None
128
+ wandb_run_name: Optional[str] = None
129
+ wandb_entity: Optional[str] = None
130
+
131
+ requirements: Requirements = field(
132
+ default_factory=lambda: Requirements(
133
+ nodes=1,
134
+ tasks_per_node=8,
135
+ gpus_per_node=8,
136
+ cpus_per_task=8,
137
+ mem_gb=256,
138
+ timeout_min=3 * 24 * 60,
139
+ constraint="volta32gb",
140
+ )
141
+ )
142
+ """The scheduling requirements for this trainer"""
143
+
144
+ data_loading_config: DataLoadingConfig = MISSING
145
+
146
+ validation_data_loading_config: ValidationDataLoadingConfig = field(
147
+ default_factory=lambda: ValidationDataLoadingConfig()
148
+ )
149
+
150
+ criterion: Any = MISSING
151
+
152
+ dtype: str = "torch.float32"
153
+ """The data type of the model."""
154
+
155
+ lr_schedule: str = "myle"
156
+ """The learning rate schedule out of
157
+ `noop`: no learning rate schedule, just use the initial learning rate,
158
+ `myle`: inv-sqrt as implemented in Fairseq,
159
+ `cosine` cosine annealing schedule,
160
+ `wsd` for Warmup-Stable-Decay (WSD) or tri-stage """
161
+
162
+ lr: float = 0.004
163
+ """The initial (post-warm-up) learning rate for AdamW."""
164
+
165
+ start_lr: float = 1e-7
166
+ """The initial warmup learning rate."""
167
+
168
+ final_lr: float = 1e-5
169
+ """The final learning rate."""
170
+
171
+ lr_stage_ratios: List[float] = field(default_factory=lambda: [0.1, 0.4, 0.5])
172
+ """The ratios of the wsd (tri-stage) learning rate scheduler."""
173
+
174
+ num_lr_warmup_steps: int = 800
175
+ """The number of warm-up steps for the learning rate."""
176
+
177
+ weight_decay: float = 0.1
178
+ """The weight decay coefficient of AdamW (PyTorch default: 1e-2, Fs2 default: 0.0)."""
179
+
180
+ adam_betas: List[float] = field(default_factory=lambda: [0.9, 0.98])
181
+ """The beta coefficients of AdamW used for computing running averages of gradient and its square."""
182
+
183
+ adam_eps: float = 1e-6
184
+ """The term added to the denominator in AdamW to improve numerical stability.
185
+ Default in FS2 and PyTorch is 1e-8. Previous hard coded value in our trainer is 1e-6"""
186
+
187
+ use_optimizer_in_fp32: bool = True
188
+ """if True, the optimizer (AdamW) will be initialized with `use_fp32 = True`
189
+ i.e. we will store the optimizer state in single precision and convert
190
+ gradients on-the-fly to single precision for numerical stability"""
191
+
192
+ max_steps: int = 10_000
193
+ """The maximum number of training steps."""
194
+
195
+ max_grad_norm: float = 1000
196
+ """Maximal gradient norm, for gradient clipping.
197
+ gradients are multiplied by `torch.clamp(max_norm / (total_norm + 1e-6), max=1.0)`
198
+ if max_norm is arbitrarily large, then we'll only report gradients norm
199
+ """
200
+ turn_off_grad_normalization: bool = False
201
+ """If ``True``, Turn off gradient normalization"""
202
+
203
+ gradient_accumulation: int = 1
204
+ """The number of steps to accumulate gradients before an optimizer update."""
205
+
206
+ validate_every_n_steps: int = 5000
207
+ """The number of steps after which to validate the model."""
208
+
209
+ checkpoint_every_n_steps: int = 5000
210
+ """The number of steps after which to checkpoint."""
211
+
212
+ keep_last_n_checkpoints: int = -1
213
+ """The number of checkpoints to keep on disk."""
214
+
215
+ save_model_every_n_steps: int = 5000
216
+ """The number of steps after which to save a consolidated version of the model."""
217
+
218
+ preserve_consolidated_models: bool = False
219
+ """If `True`, only pt files excluding ones starting with `mdoel` will be deleted from the step checkpoint directory."""
220
+
221
+ publish_metrics_every_n_steps: int = 1
222
+ """The number of steps after which to publish training metrics."""
223
+
224
+ gc_every_n_steps: int = 1000
225
+ """The frequency of steps at which we collect garbage with `gc.collect()`."""
226
+
227
+ seed: int = 2
228
+ """The RNG seed to use while starting the job."""
229
+
230
+ debug: bool = False
231
+ """If ``True``, runs the trainer in debug mode"""
232
+
233
+ profile: bool = False
234
+ """If ``True``, runs the PyTorch profiler at the beginning of the training."""
235
+
236
+ profiler_skip_first: int = 200
237
+
238
+ profiler_active: int = 3
239
+ """If profiling (``profile = True``), The profiler will skip the first ``skip_first`` steps, then do the active recording for the next ``active`` steps
240
+ If planning to visualize the trace with tensorbaord, then ``active`` should be small (less than 10 steps), otherwise tb won't load!
241
+ """
242
+ loss_scaler_init_scale: float = 2.0**15
243
+ """The initial scale for the gradient scaler, fairseq2's default is 2.0**15"""
244
+
245
+ loss_scaler_scale_window: Optional[int] = None
246
+ """The number of consecutive optimizer steps without inf/NaN gradients that must occur for the scale to be updated"""
247
+
248
+ use_fsdp: bool = True
249
+ """If ``True``, uses FSDP instead of DDP."""
250
+
251
+ use_autocast: bool = False
252
+ """If ``True``, wrap the forward pass in AMP autocast context.
253
+ autocast is only needed if training with mixed precision.
254
+ If training fails without it, check if some module with its weights is not properly cast
255
+ """
256
+
257
+ fsdp_wrap_granularity: SUPPORTED_FSDP_WRAP_POLICIES = "model"
258
+ """The granularity at which to wrap the model."""
259
+
260
+ fsdp_memory_policy: SUPPORTED_FSDP_MEMORY_POLICIES = "standard"
261
+ """The FSDP memory policy."""
262
+
263
+ fsdp_fp32_reduce: bool = False
264
+ """ If ``True``, the gradients will be reduced in full precision even when dtype is `torch.float16`"""
265
+
266
+ use_submitit: bool = True
267
+ """If ``True``, setup the environment ti use submitit."""
268
+
269
+ fake_gang_device: Optional[str] = None
270
+ """If non-empty, the trainer will be set locally on a device, instead of distributed training."""
271
+
272
+ experiment_name: Optional[str] = None
273
+ """experiment name for job trackin, if None default to StopesModule naming"""
274
+
275
+ raise_oom: bool = False
276
+ """If ``True``, raise OOM errors when they occur, if ``False`` give it another try."""
277
+
278
+ raise_nan_or_inf: bool = False
279
+ """If ``True``, raise FloatingPointError with Nan/Inf losses, if ``False`` give it another try."""
280
+
281
+ max_ooms: int = 10
282
+ """If ```raise_oom`` is False, how many OOMs we can tolerate per rank before raising an error."""
283
+
284
+ max_nans_or_infs: int = 10
285
+ """If ```raise_nan_or_inf`` is False, how many Nan/Infs we can tolerate per rank before raising an error."""
286
+
287
+ freeze_modules: Optional[List[str]] = None
288
+ """Name of modules in the model to be frozen when training/finetuning"""
289
+
290
+ freezing_strategy: Literal["none", "modules", "ffn", "ffn-adaln", "adaln"] = "none"
291
+ """
292
+ Freezing strategy to follow. Options are:
293
+ 1. none: Nothing will be frozen (default)
294
+ 2. modules: A list of modules to freeze will be read from `freeze_modules`
295
+ 3. ffn: All ffn sub-modules will be frozen
296
+ 4. ffn-adaln: all FFN and Adaln sub-modules will be frozen.
297
+ """
298
+
299
+
300
+ class Trainer(StatefulObjectBag):
301
+ config: TrainingConfig
302
+ model: Module
303
+ training_data_loader: Any
304
+ validation_data_loader: Optional[Any]
305
+ gang: Gang
306
+ optimizer: Optimizer
307
+ loss_scaler: DynamicLossScaler
308
+ lr_scheduler: AbstractLRScheduler
309
+ rng_bag: RngBag
310
+ step_nr: int
311
+ train_metric_bag: MetricBag
312
+ valid_metric_bag: Mapping[str, MetricBag]
313
+ metric_recorders: List[MetricRecorder]
314
+ profiler: Profiler
315
+ stopwatch: Stopwatch
316
+ criterion: Any
317
+ card_metdata: Dict
318
+ _train_step_time: float
319
+ _valid_step_time: float
320
+
321
+ def __init__(
322
+ self,
323
+ config: TrainingConfig,
324
+ model: Module,
325
+ training_data_loader: Any,
326
+ validation_data_loader: Optional[Any],
327
+ gang: Gang,
328
+ checkpoint_manager: FileCheckpointManager,
329
+ rng_bag: RngBag,
330
+ stopwatch: Stopwatch,
331
+ card_metadata: Dict,
332
+ ) -> None:
333
+ super().__init__()
334
+
335
+ self.config = config
336
+
337
+ if self.config.debug:
338
+ logger._logger.setLevel(logging.DEBUG)
339
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
340
+
341
+ self.card_metadata = card_metadata
342
+
343
+ self.dtype = eval(config.dtype)
344
+
345
+ self.model = model
346
+
347
+ self.training_data_loader = training_data_loader
348
+
349
+ # Skip saving and loading the state of validation dataloader
350
+ self.register_non_stateful("validation_data_loader", validation_data_loader)
351
+
352
+ self.gang = gang
353
+
354
+ self.rng_bag = rng_bag
355
+
356
+ self.step_nr = 1
357
+
358
+ self.current_run_steps = 0
359
+
360
+ self.checkpoint_manager = checkpoint_manager
361
+
362
+ tb_dir = config.tb_dir or config.output_dir.joinpath("tb")
363
+
364
+ self.metric_recorders = [LogMetricRecorder(logger)]
365
+
366
+ if gang.rank == 0:
367
+ self.metric_recorders.append(TensorBoardRecorder(tb_dir))
368
+ self.metric_recorders.append(
369
+ LCMWandBRecorder(
370
+ name=config.wandb_run_name,
371
+ project=config.wandb_project or "uncategorized",
372
+ output_dir=config.output_dir / "wandb",
373
+ config=self._tb_flat_config,
374
+ )
375
+ )
376
+
377
+ self.profiler = Profiler(
378
+ skip_first=config.profiler_skip_first,
379
+ active=config.profiler_active,
380
+ log_dir=tb_dir,
381
+ gang=gang,
382
+ enabled=config.profile,
383
+ )
384
+
385
+ self.stopwatch = stopwatch
386
+ self._train_step_time = 0.0
387
+ self._valid_step_time = 0.0
388
+
389
+ self.criterion = None # type: ignore
390
+
391
+ self.loss_scaler = None # type: ignore
392
+
393
+ @property
394
+ def is_fsdp(self) -> bool:
395
+ return isinstance(self.model, FSDP)
396
+
397
+ @property
398
+ def is_ddp(self) -> bool:
399
+ return isinstance(self.model, DDP)
400
+
401
+ def setup(self) -> None:
402
+ self.criterion = self.setup_criterion()
403
+
404
+ self.setup_metric_bags()
405
+
406
+ # Add the grad_norm metric to the training metric bag
407
+ self.train_metric_bag.register_metric(
408
+ "grad_norm", Mean(device=self.gang.device), persistent=False
409
+ )
410
+ self.train_metric_bag.register_metric(
411
+ "raw_grad_norm", Mean(device=self.gang.device), persistent=False
412
+ )
413
+
414
+ self.setup_optimizer_and_lr_schedule()
415
+
416
+ def setup_optimizer_and_lr_schedule(self):
417
+ optimizer = AdamW(
418
+ self.model.parameters(),
419
+ lr=self.config.lr,
420
+ betas=tuple(self.config.adam_betas), # type: ignore
421
+ eps=self.config.adam_eps,
422
+ use_fp32=self.config.use_optimizer_in_fp32,
423
+ weight_decay=self.config.weight_decay,
424
+ )
425
+ logger.info(
426
+ (
427
+ f"Setting up AdamW optimizer with betas={self.config.adam_betas}, "
428
+ f"base lr={self.config.lr} and weight decay={self.config.weight_decay} "
429
+ f"and use_fp32={self.config.use_optimizer_in_fp32}"
430
+ )
431
+ )
432
+
433
+ self.register_stateful("optimizer", optimizer)
434
+
435
+ self.loss_scaler = DynamicLossScaler(
436
+ optimizer,
437
+ gang=self.gang,
438
+ init_scale=self.config.loss_scaler_init_scale,
439
+ min_scale=0.0001,
440
+ scale_window=self.config.loss_scaler_scale_window,
441
+ enabled=self.dtype == torch.float16,
442
+ )
443
+
444
+ if self.loss_scaler.is_enabled:
445
+ logger.info(
446
+ f"Initializing DynamicLossScaler with init_scale={self.config.loss_scaler_init_scale}"
447
+ )
448
+
449
+ lr_scheduler = build_lr_scheduler(
450
+ optimizer=self.optimizer,
451
+ schedule=self.config.lr_schedule,
452
+ lr=self.config.lr,
453
+ warmup_steps=self.config.num_lr_warmup_steps,
454
+ start_lr=self.config.start_lr,
455
+ final_lr=self.config.final_lr,
456
+ max_steps=self.config.max_steps,
457
+ stage_ratio=tuple(self.config.lr_stage_ratios),
458
+ )
459
+
460
+ # Saving the lr_scheduler as well to properly resume training
461
+ self.register_stateful("lr_scheduler", lr_scheduler)
462
+
463
+ @abstractmethod
464
+ def setup_criterion(self):
465
+ """Define a criterion (loss / objective function to optimize)"""
466
+
467
+ def setup_metric_bags(self):
468
+ """Setup metric bags for tracking"""
469
+
470
+ self.train_metric_bag = MetricBag(self.gang)
471
+
472
+ self.register_non_stateful(
473
+ "valid_metric_bag",
474
+ {
475
+ ds_name(dataset): MetricBag(self.gang)
476
+ for dataset in self.config.validation_data
477
+ },
478
+ )
479
+
480
+ def checkpoint_and_raise(self, exc) -> None:
481
+ # Checkpoint before exiting
482
+ if torch.cuda.is_available():
483
+ torch.cuda.synchronize()
484
+ logger.warning(f"R{self.gang.rank} checkpoint_and_raise - error={exc}")
485
+ if self.current_run_steps > 100:
486
+ # avoid checkpoining for early failures
487
+ self._checkpoint(crash=exc)
488
+ raise exc
489
+
490
+ @cached_property
491
+ def _tb_flat_config(self):
492
+ """
493
+ Prepare the flat config that will be used as HParams
494
+ to record training metadata, namely config and environment hashes.
495
+ """
496
+
497
+ dict_config = flatten_dict(asdict(self.config))
498
+
499
+ # Merge the data lists:
500
+ def get_data_signature(dataset):
501
+ return ":".join(
502
+ map(str, (dataset["name"], dataset["weight"], dataset["filters"]))
503
+ )
504
+
505
+ dict_config["training_data"] = "+".join(
506
+ get_data_signature(dataset) for dataset in dict_config["training_data"]
507
+ )
508
+ dict_config["validation_data"] = "+".join(
509
+ get_data_signature(dataset) for dataset in dict_config["validation_data"]
510
+ )
511
+
512
+ # value should be one of int, float, str, bool, or torch.Tensor
513
+ allowed_types = (int, float, str, bool, torch.Tensor)
514
+ config_keys = list(dict_config)
515
+ for k in config_keys:
516
+ if not isinstance(dict_config[k], allowed_types):
517
+ del dict_config[k]
518
+
519
+ return dict_config
520
+
521
+ def run(self) -> None:
522
+ """Run the trainer for up to `max_steps`"""
523
+
524
+ logger.info(f"Running training on {self.gang.size} device(s).")
525
+
526
+ data_iter = self.training_data_loader.iterate_batches()
527
+
528
+ logger.info(
529
+ f"R{self.gang.rank} - waiting for all ranks to prepare a data iterator!"
530
+ )
531
+ self.gang.barrier()
532
+
533
+ # These counters are rank-specific
534
+ ooms, nans_or_infs = 0, 0
535
+
536
+ # TODO: validate before training
537
+ # logger.info(f"Starting with validation at step={self.step_nr}")
538
+ # self._validate()
539
+
540
+ with self.profiler:
541
+ while self.step_nr <= self.config.max_steps:
542
+ with record_function(f"step_{self.step_nr}"):
543
+ try:
544
+ # Main training step: forward -> backward -> optimizer.step -> log
545
+ stepped = self._train_step(data_iter)
546
+
547
+ except RuntimeError as e:
548
+ if "out of memory" in str(e):
549
+ self._log_oom(e)
550
+ ooms += 1
551
+ if self.config.raise_oom or ooms > self.config.max_ooms:
552
+ # Previous behaviour, no retries but still checkpointing
553
+ self.checkpoint_and_raise(e)
554
+
555
+ logger.warning(
556
+ f"Attempting to recover from OOM on R{self.gang.rank} (OOMS={ooms})"
557
+ )
558
+ stepped = True
559
+ # reset optimizer
560
+ self.optimizer.zero_grad(set_to_none=True)
561
+
562
+ # rollback updates
563
+ self.train_metric_bag.rollback_updates()
564
+
565
+ # Empty CUDA cache before trying again
566
+ if torch.cuda.is_available():
567
+ torch.cuda.empty_cache()
568
+
569
+ else:
570
+ # Other RuntimeErrors
571
+ self.checkpoint_and_raise(e)
572
+
573
+ except FloatingPointError as e:
574
+ if "Losses are Nan/Inf" in str(e):
575
+ self._log_nan_loss(e)
576
+ nans_or_infs += 1
577
+ if (
578
+ self.config.raise_nan_or_inf
579
+ or nans_or_infs > self.config.max_nans_or_infs
580
+ ):
581
+ self.checkpoint_and_raise(e)
582
+
583
+ logger.warning(
584
+ f"Attempting to recover from NaN/Inf loss on R{self.gang.rank} (NaNs/Infs={nans_or_infs})"
585
+ )
586
+ stepped = True
587
+ # reset optimizer
588
+ self.optimizer.zero_grad(set_to_none=True)
589
+
590
+ # rollback updates
591
+ self.train_metric_bag.rollback_updates()
592
+
593
+ else:
594
+ # Other FloatingPointErrors
595
+ self.checkpoint_and_raise(e)
596
+
597
+ except Exception as e:
598
+ self.checkpoint_and_raise(e)
599
+
600
+ if stepped:
601
+ if self._should_publish_train_metrics():
602
+ self._publish_train_metrics()
603
+
604
+ if self._should_checkpoint():
605
+ self._checkpoint()
606
+
607
+ if self._should_validate():
608
+ self._validate()
609
+
610
+ if self._should_collect_garbage():
611
+ self._collect_garbage()
612
+
613
+ self.profiler.step()
614
+
615
+ self.step_nr += 1
616
+ self.current_run_steps += 1
617
+
618
+ else:
619
+ logger.info(f"R{self.gang.rank} - Resetting the datapipeline")
620
+ self.training_data_loader.pipeline.reset()
621
+
622
+ logger.info(f"R{self.gang.rank} - Done resetting the datapipeline")
623
+ data_iter = self.training_data_loader.iterate_batches()
624
+
625
+ self._save_model_card_for_last_checkpoint(to_checkpoint_dir=False)
626
+ logger.info(f"Finished training after {self.step_nr - 1} step(s).")
627
+
628
+ self.gang.close()
629
+
630
+ def restore(self) -> None:
631
+ logger.info("Attempting to load last checkpoint.")
632
+
633
+ step_nr, checkpoint = self.checkpoint_manager.load_last_checkpoint()
634
+
635
+ logger.info(f"Checkpoint loaded, restoring training from step {step_nr}.")
636
+
637
+ self.load_state_dict(checkpoint)
638
+
639
+ self.gang.barrier()
640
+
641
+ logger.info("Training restored, resuming.")
642
+
643
+ self.step_nr = step_nr + 1
644
+
645
+ def _maybe_with_autocast(self) -> ContextManager[None]:
646
+ # autocast is only needed if training with mixed precision.
647
+ # If training fails without it, check if some module with its weights
648
+ # is not properly cast
649
+ if self.config.use_autocast:
650
+ return torch.autocast(device_type="cuda", dtype=self.dtype)
651
+ else:
652
+ return nullcontext()
653
+
654
+ def _train_step(self, data_iter: Iterator) -> bool:
655
+ step_nr = self.step_nr
656
+
657
+ step_stopwatch = Stopwatch(start=True, device=self.gang.device)
658
+
659
+ stepped = False
660
+
661
+ # We have to retry the step in case of a gradient overflow.
662
+ while not stepped:
663
+ batches = []
664
+
665
+ # Collect batches.
666
+ with record_function(f"step_{step_nr}_data_load"):
667
+ for _ in range(self.config.gradient_accumulation):
668
+ try:
669
+ batches.append(next(data_iter))
670
+ except StopIteration:
671
+ break
672
+
673
+ if len(batches) != self.config.gradient_accumulation:
674
+ logger.info(
675
+ f"R{self.gang.rank} -End of data reached at training step {step_nr}."
676
+ )
677
+
678
+ return False
679
+
680
+ # create a copy of the current metrics
681
+ # any update to the metrics from this point will either be committed with `commit_updates`
682
+ # or ignored with `rollback_updates`
683
+ self.train_metric_bag.begin_updates()
684
+
685
+ num_targets = 0
686
+
687
+ # Accumulate gradients.
688
+ for batch_nr, batch in enumerate(batches):
689
+ with self._maybe_no_sync(batch_nr, len(batches)):
690
+ with record_function(f"step_{step_nr}_{batch_nr}_forward"):
691
+ # autocast should wrap only the forward pass(es)
692
+ # of your network, including the loss computation(s).
693
+ # Backward passes under autocast are not recommended.
694
+ with self._maybe_with_autocast():
695
+ loss = self.criterion(batch)
696
+
697
+ if not (
698
+ torch.isfinite(loss.value).all() or self.loss_scaler.is_enabled
699
+ ):
700
+ raise FloatingPointError("Losses are Nan/Inf.")
701
+
702
+ # update metrics
703
+ self.train_metric_bag.update([loss])
704
+
705
+ with record_function(f"step_{step_nr}_{batch_nr}_backward"):
706
+ self.loss_scaler.backward(loss.value)
707
+
708
+ num_targets += loss.num_target_elements
709
+
710
+ # Record and clip gradient norm
711
+ grad_norm, raw_grad_norm = self.process_gradients(step_nr, num_targets)
712
+
713
+ # Update parameters.
714
+ with record_function(f"step_{step_nr}_optimizer"):
715
+ # scale_result: LossScaleResult(old_scale: float, new_scale: float, overflow: bool, min_reached: bool)
716
+ _, scale_result = self.loss_scaler.run_optimizer_step(step_nr)
717
+
718
+ if scale_result.overflow:
719
+ # Walk back the metrics update:
720
+ self.train_metric_bag.rollback_updates()
721
+ logger.debug(
722
+ f"R{self.gang.rank} rolled back update {self.train_metric_bag._original_metrics is None}"
723
+ )
724
+
725
+ if scale_result.min_reached:
726
+ logger.error(f"Loss has started exploding at step {step_nr}. Stopping training.") # fmt: skip
727
+
728
+ raise FloatingPointError("The training loss has exploded.")
729
+
730
+ logger.debug(f"Repeating training step {step_nr}.")
731
+
732
+ else:
733
+ self.lr_scheduler.step()
734
+
735
+ stepped = True
736
+
737
+ # Reset.
738
+ self.optimizer.zero_grad(set_to_none=True)
739
+
740
+ # Stepped = True:
741
+ with record_function(f"step_{step_nr}_metrics"):
742
+ # do something with losses and grad_norm
743
+
744
+ self.train_metric_bag.commit_updates()
745
+
746
+ # gradient norm is common to workers
747
+ self.train_metric_bag.grad_norm.update(grad_norm)
748
+ self.train_metric_bag.raw_grad_norm.update(raw_grad_norm)
749
+
750
+ if self.gang.rank == 0:
751
+ # update elapsed time once
752
+ self._train_step_time += step_stopwatch.get_elapsed_time()
753
+
754
+ del batches
755
+ return stepped
756
+
757
+ def _maybe_no_sync(self, batch_nr: int, num_batches: int) -> ContextManager[None]:
758
+ if batch_nr < num_batches - 1 and self.gang.size > 1:
759
+ return self.model.no_sync()
760
+ return nullcontext()
761
+
762
+ def normalize_gradients(self, num_targets: int) -> None:
763
+ """
764
+ :param num_target:
765
+ The number of targets used in loss computation in this process.
766
+
767
+ If reduction = sum:
768
+ similar to fairseq2's `normalize_gradients`, will normalize the gradients of the model by ``world_size/num_targets``.
769
+ If reduction = mean:
770
+ will simply multiply by world size i.e undo DDP/FSDP's default normalization
771
+ """
772
+ reduction = self.criterion.reduction
773
+ if reduction == "sum":
774
+ total_num_targets = torch.tensor(
775
+ num_targets, device=self.gang.device, dtype=torch.int64
776
+ )
777
+
778
+ self.gang.all_reduce(total_num_targets, ReduceOperation.SUM)
779
+
780
+ # Both DDP and FSDP divide gradients by the world size which we also undo.
781
+ if total_num_targets > 0:
782
+ grad_scale = self.gang.size / total_num_targets
783
+ else:
784
+ # If total_num_targets == 0, gradients will be zeroes anyway
785
+ grad_scale = self.gang.size
786
+
787
+ else:
788
+ grad_scale = self.gang.size
789
+
790
+ scale_gradients(self.model, grad_scale)
791
+
792
+ def process_gradients(
793
+ self, step_nr: int, num_targets: int
794
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
795
+ with record_function(f"step_{self.step_nr}_process_grads"):
796
+ # Normalize gradients
797
+ """
798
+ Normalize and clip the gradients
799
+ """
800
+ # this raw grad norm is only used for debugging
801
+ raw_grad_norm = clip_gradient_norm(
802
+ self.model,
803
+ max_norm=None,
804
+ )
805
+
806
+ if not self.config.turn_off_grad_normalization:
807
+ self.normalize_gradients(num_targets=num_targets)
808
+
809
+ # undo the GradScaler's scaling before clipping
810
+ self.loss_scaler.unscale_gradients_()
811
+
812
+ # Clip gradients
813
+ # If DDP, we use torch.nn.utils.clip_grad_norm_, if FSDP,
814
+ # we use torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
815
+ # this method handles the fact that gradients might be sharded across ranks.
816
+ grad_norm = clip_gradient_norm(
817
+ self.model,
818
+ max_norm=self.config.max_grad_norm,
819
+ )
820
+
821
+ # Check for gradient consistency across workers:
822
+ if not check_gradient_norms(grad_norm, self.gang, step_nr):
823
+ raise FloatingPointError(
824
+ f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue."
825
+ )
826
+
827
+ return grad_norm, raw_grad_norm
828
+
829
+ def _should_validate(self) -> bool:
830
+ return self._should_do(self.config.validate_every_n_steps)
831
+
832
+ def _should_collect_garbage(self) -> bool:
833
+ return self._should_do(self.config.gc_every_n_steps)
834
+
835
+ def _collect_garbage(self):
836
+ logger.info("Collecting garbage...")
837
+ gc.collect()
838
+
839
+ @torch.inference_mode()
840
+ def _validate(self) -> None:
841
+ gc.collect()
842
+ torch.cuda.empty_cache()
843
+
844
+ if self.validation_data_loader is None:
845
+ logger.info("Skip validation as the data loader is empty")
846
+ return
847
+
848
+ self.model.eval()
849
+
850
+ logger.info(f"Starting validation after step {self.step_nr}.")
851
+
852
+ self.validation_data_loader.pipeline.reset()
853
+
854
+ data_iter = self.validation_data_loader.iterate_batches()
855
+ data_dummy_iter = self.validation_data_loader.iterate_dummy_batches()
856
+
857
+ logger.info(f"R{self.gang.rank} done creating the validation data iterator")
858
+
859
+ for step_nr in count(start=1):
860
+ step_stopwatch = Stopwatch(start=True, device=self.gang.device)
861
+
862
+ try:
863
+ batch = next(data_iter)
864
+ true_batch = 1
865
+ except StopIteration:
866
+ batch = next(data_dummy_iter)
867
+ true_batch = 0
868
+
869
+ total_nb_batches = all_sum(self.gang, true_batch)
870
+
871
+ if bool(total_nb_batches == 0):
872
+ break
873
+ # we apply model for all workers to avoid process groups sync issues
874
+ loss = self.criterion(batch)
875
+
876
+ if true_batch:
877
+ self._valid_step_time += step_stopwatch.get_elapsed_time()
878
+ self.valid_metric_bag[batch.name].update([loss])
879
+
880
+ self._publish_validation_metrics()
881
+
882
+ logger.info(
883
+ f"R{self.gang.rank} Validation complete in {step_nr} steps, resuming training."
884
+ )
885
+
886
+ self.model.train()
887
+
888
+ def _should_publish_train_metrics(self) -> bool:
889
+ return self._should_do(self.config.publish_metrics_every_n_steps)
890
+
891
+ def _set_elements_per_second(
892
+ self, metric_values: Dict[str, Any], elapsed_time: float
893
+ ) -> None:
894
+ try:
895
+ num_elements = metric_values[self.criterion.throughput_metric_name]
896
+ except KeyError:
897
+ return
898
+
899
+ if not isinstance(num_elements, (int, float, torch.Tensor)):
900
+ return
901
+
902
+ if elapsed_time == 0.0:
903
+ metric_values["elements_per_second"] = 0.0
904
+ else:
905
+ metric_values["elements_per_second"] = num_elements / elapsed_time
906
+
907
+ def _publish_train_metrics(self) -> None:
908
+ values = self.train_metric_bag.sync_and_compute_metrics()
909
+
910
+ self.train_metric_bag.reset_non_persistent_metrics()
911
+
912
+ # Only rank-0 to record and publish
913
+ # since sync_and_compute_metrics's recipient rank is 0
914
+ if self.gang.rank != 0:
915
+ return
916
+
917
+ assert values is not None
918
+
919
+ values["lr"] = get_effective_lr(self.lr_scheduler)
920
+
921
+ self._set_elements_per_second(values, self._train_step_time)
922
+
923
+ if self.loss_scaler.is_enabled:
924
+ values["grad_scale"] = self.loss_scaler.get_scale()
925
+
926
+ values["wall_time"] = self.stopwatch.get_elapsed_time()
927
+ values["elapsed_time"] = self._train_step_time
928
+
929
+ record_metrics(self.metric_recorders, "Train", values, self.step_nr)
930
+
931
+ self._train_step_time = 0.0
932
+
933
+ def _publish_validation_metrics(self) -> None:
934
+ values = {}
935
+ for name, metric_bag in self.valid_metric_bag.items():
936
+ values[name] = metric_bag.sync_and_compute_metrics()
937
+ metric_bag.reset_non_persistent_metrics()
938
+
939
+ # Only rank-0 to record and publish
940
+ if self.gang.rank != 0:
941
+ return
942
+
943
+ for name, val in values.items():
944
+ assert val is not None
945
+ self._set_elements_per_second(val, self._valid_step_time)
946
+ val["elapsed_time"] = self._valid_step_time
947
+ val["wall_time"] = self.stopwatch.get_elapsed_time()
948
+ valid_name = f"Valid | {name}"
949
+ record_metrics(self.metric_recorders, valid_name, val, self.step_nr)
950
+
951
+ # reset timers
952
+ self._valid_step_time = 0.0
953
+
954
+ def _should_checkpoint(self) -> bool:
955
+ return self._should_do(self.config.checkpoint_every_n_steps)
956
+
957
+ def _should_save_consolidated_model(self) -> bool:
958
+ return self.is_fsdp and self._should_do(self.config.save_model_every_n_steps)
959
+
960
+ def _checkpoint(self, crash=None) -> None:
961
+ logger.info(f"Saving checkpoint at step {self.step_nr}")
962
+ checkpoint = self.state_dict()
963
+
964
+ metadata = {
965
+ "config": self.config,
966
+ "crash": crash,
967
+ }
968
+
969
+ self.checkpoint_manager.begin_checkpoint(self.step_nr)
970
+
971
+ if self.is_fsdp:
972
+ replicated_keys = None
973
+ elif self.is_ddp:
974
+ # If we do not shard, save the model and the optimizer only on rank 0.
975
+ replicated_keys = {"model", "optimizer"}
976
+ else:
977
+ replicated_keys = {"*"}
978
+
979
+ self.checkpoint_manager.save_state(checkpoint, replicated_keys=replicated_keys)
980
+
981
+ self.checkpoint_manager.save_metadata(metadata)
982
+
983
+ if self._should_save_consolidated_model():
984
+ self._save_consolidated_model()
985
+
986
+ # Create a model card only after creating model.pt
987
+ # i.e., regular checkpointing with DDP or after consolidation with FSDP
988
+ if not self.is_fsdp:
989
+ self._save_model_card_for_last_checkpoint(to_checkpoint_dir=True)
990
+
991
+ self.checkpoint_manager.commit_checkpoint()
992
+
993
+ # Note that this logic looks at saved directories regardless of
994
+ # the nature of the checkpointing, consolidated or not
995
+ if self.config.keep_last_n_checkpoints != -1:
996
+ self.checkpoint_manager.keep_last_n_checkpoints(
997
+ self.config.keep_last_n_checkpoints,
998
+ preserve_model=self.config.preserve_consolidated_models,
999
+ )
1000
+
1001
+ logger.info(f"Checkpoint saved by worker @rank={self.gang.rank}")
1002
+
1003
+ def _save_consolidated_model(self) -> None:
1004
+ logger.info(f"Saving consolidated model at step {self.step_nr}.")
1005
+ self.checkpoint_manager.save_consolidated_fsdp_model(self.model)
1006
+ self._save_model_card_for_last_checkpoint(to_checkpoint_dir=True)
1007
+ logger.info("Consolidated model saved.")
1008
+
1009
+ def _should_do(self, n_step: int) -> bool:
1010
+ return self.step_nr % n_step == 0
1011
+
1012
+ def create_model_card_for_last_checkpoint(
1013
+ self, is_final: bool = False, **card_kwargs
1014
+ ) -> Optional[AssetCard]:
1015
+ """Create a model card based on the last saved checkpoint and the model config."""
1016
+ logger.warning(
1017
+ "Could not create a model card with a generic trainer. Please use a model-specific one."
1018
+ )
1019
+ return None
1020
+
1021
+ def _save_model_card_for_last_checkpoint(
1022
+ self, to_checkpoint_dir: bool = False
1023
+ ) -> None:
1024
+ """Save the model card for the last checkpoint to the checkpoint directory or the core output directory."""
1025
+ if self.gang.rank != 0:
1026
+ return
1027
+
1028
+ if to_checkpoint_dir:
1029
+ current_step_nr = self.checkpoint_manager._checkpoint_step_nr
1030
+ output_dir = self.checkpoint_manager._checkpoint_dir.joinpath(
1031
+ f"step_{current_step_nr}.tmp"
1032
+ )
1033
+ else:
1034
+ output_dir = self.config.output_dir
1035
+
1036
+ card = self.create_model_card_for_last_checkpoint(
1037
+ is_final=not to_checkpoint_dir
1038
+ )
1039
+
1040
+ if card is not None:
1041
+ card_data = card._metadata # TODO: use the exposed attribute when available
1042
+ with open(output_dir / "model_card.yaml", "w", encoding="utf-8") as outfile:
1043
+ yaml.dump(card_data, outfile, default_flow_style=False)
1044
+ logger.info(f"Model card saved in {output_dir}")
1045
+
1046
+ def _log_oom(self, exc):
1047
+ logger.warning(
1048
+ f"OOM: Ran out of memory on R{self.gang.rank} with exception: {exc}"
1049
+ )
1050
+
1051
+ if torch.cuda.is_available():
1052
+ for device_idx in range(torch.cuda.device_count()):
1053
+ logger.warning(torch.cuda.memory_summary(device=device_idx))
1054
+
1055
+ sys.stderr.flush()
1056
+
1057
+ def _log_nan_loss(self, exc):
1058
+ logger.warning(f"We hit a Nan/Inf Loss: raised with exception: {exc}")
1059
+
1060
+
1061
+ class TrainerBuilder:
1062
+ def __init__(self, config: TrainingConfig):
1063
+ assert config.save_model_every_n_steps % config.checkpoint_every_n_steps == 0, (
1064
+ f"save_model_every_n_steps={config.save_model_every_n_steps} for saving consolidated models should be a multiplier of checkpoint_every_n_steps={config.checkpoint_every_n_steps}"
1065
+ )
1066
+
1067
+ self.config = config
1068
+
1069
+ self.stopwatch = Stopwatch(start=True)
1070
+
1071
+ # In case we train on Ampere or later, use TF32.
1072
+ torch.set_float32_matmul_precision("high")
1073
+
1074
+ if self.config.fake_gang_device is None:
1075
+ # By default, we work with a process group
1076
+ self.gang = init_process_group(config, logger=logger._logger)
1077
+ else:
1078
+ # For testing purposes, we use a fake gang on the chosen device
1079
+ self.gang = FakeGang(device=torch.device(self.config.fake_gang_device))
1080
+
1081
+ self.gang_rank = self.gang.rank if self.gang else 0
1082
+
1083
+ if self.gang.device.type == "cuda":
1084
+ # Setup ATEN and NCCL logging if in debug mode
1085
+ self._setup_additional_logging()
1086
+
1087
+ # Dump environment variables:
1088
+ log_env_variables(self.gang.device)
1089
+
1090
+ # A variable to carry fields necessary to build concise model cards
1091
+ self.card_metdata: Dict = {}
1092
+
1093
+ if self.gang_rank == 0:
1094
+ logger.info(f"Job Config\n{pformat(config)}")
1095
+
1096
+ self.device = self.gang.device
1097
+
1098
+ rng_bag = RngBag.from_device_defaults(self.device)
1099
+
1100
+ # Ensure that each run has deterministic behavior.
1101
+ rng_bag.manual_seed(config.seed)
1102
+
1103
+ self.rng_bag = rng_bag
1104
+
1105
+ self.dtype = eval(config.dtype)
1106
+
1107
+ self.finetune: bool = False
1108
+
1109
+ self.has_checkpoint: bool = False
1110
+
1111
+ @property
1112
+ @abstractmethod
1113
+ def model_loader(self):
1114
+ """A fairseq2 ModelLoader"""
1115
+
1116
+ @property
1117
+ def model_config_loader(self):
1118
+ """A fairseq2 ConfigLoader"""
1119
+ return self.model_loader._config_loader
1120
+
1121
+ @abstractmethod
1122
+ def load_data(self):
1123
+ """Load training and validation data
1124
+ Returns one loader for training data and one for validation data
1125
+ """
1126
+
1127
+ def create_model_config(self, set_finetune_flag: bool = False):
1128
+ """
1129
+ Given `model_config_or_name`, `model_arch` and `model_arch_overrides`
1130
+ create the model config dict
1131
+ if `set_finetune_flag` is `True` then the trainer's finetune flag will be set
1132
+ here inferred from the use of `model_config_or_name`
1133
+ """
1134
+ if self.config.model_config_or_name is not None:
1135
+ assert self.config.model_arch is None, (
1136
+ "We cannot set both `model_config_or_name` and `model_arch`"
1137
+ )
1138
+
1139
+ if isinstance(self.config.model_config_or_name, str):
1140
+ # The config of a registered model i.e. we're finetuning
1141
+ logger.info(
1142
+ f"Loading pretrained model from {self.config.model_config_or_name}"
1143
+ )
1144
+
1145
+ model_config = self.model_config_loader(
1146
+ self.config.model_config_or_name
1147
+ )
1148
+ finetune = True
1149
+
1150
+ # Metadata for card creation
1151
+ source_card = self.model_config_loader._asset_store.retrieve_card(
1152
+ self.config.model_config_or_name
1153
+ )
1154
+ try:
1155
+ arch = source_card.field("model_arch").as_(str)
1156
+ except AssetCardFieldNotFoundError:
1157
+ arch = None
1158
+
1159
+ self.card_metadata = {
1160
+ "model_config": model_config if arch is None else None,
1161
+ "model_type": model_config.model_type,
1162
+ "model_arch": arch,
1163
+ }
1164
+
1165
+ else:
1166
+ # model_config_or_name is a dataclass
1167
+ logger.info(
1168
+ "Creating a model from the provided config in model_config_or_name"
1169
+ )
1170
+ model_config = self.config.model_config_or_name
1171
+
1172
+ self.card_metadata = {
1173
+ "model_config": model_config,
1174
+ "model_type": model_config.model_type,
1175
+ "model_arch": None,
1176
+ }
1177
+
1178
+ finetune = False
1179
+
1180
+ elif self.config.model_arch is not None:
1181
+ assert (
1182
+ self.config.model_arch in self.model_config_loader._arch_configs.names()
1183
+ ), (
1184
+ f"Could not recognise {self.config.model_arch} as a registered architecture "
1185
+ )
1186
+
1187
+ logger.info(
1188
+ f"Creating a model from registered arch {self.config.model_arch}"
1189
+ )
1190
+
1191
+ finetune = False
1192
+ model_config = self.model_config_loader._arch_configs.get(
1193
+ self.config.model_arch
1194
+ )
1195
+ self.card_metadata = {
1196
+ "model_config": None,
1197
+ "model_type": model_config.model_type,
1198
+ "model_arch": self.config.model_arch,
1199
+ }
1200
+
1201
+ # In all setups we can override some config parameters
1202
+ if self.config.model_arch_overrides is not None:
1203
+ try:
1204
+ update_dataclass(model_config, self.config.model_arch_overrides)
1205
+
1206
+ except (TypeError, ValueError) as ex:
1207
+ raise ValueError(
1208
+ "The model_arch_overrides contain one or more invalid keys"
1209
+ ) from ex
1210
+
1211
+ self.card_metadata["model_arch"] = None
1212
+ self.card_metadata["model_config"] = model_config
1213
+
1214
+ logger.info(
1215
+ f"Overwriting model config parameters with {self.config.model_arch_overrides}"
1216
+ )
1217
+
1218
+ if set_finetune_flag:
1219
+ self.finetune = finetune
1220
+
1221
+ return model_config
1222
+
1223
+ def create_model(self):
1224
+ """
1225
+ Load the model to be trained.
1226
+ In case other models are developed following a different paradigm, we can create
1227
+ corresponding trainers by overriding `create_model`
1228
+ """
1229
+ logger.info("Initializing model.")
1230
+
1231
+ model_config = self.create_model_config(set_finetune_flag=True)
1232
+
1233
+ if self.gang_rank == 0:
1234
+ logger.info(f"Final model config:\n{pformat(model_config)}")
1235
+
1236
+ model = self.model_loader._factory(
1237
+ model_config,
1238
+ device=self.device,
1239
+ dtype=self.dtype,
1240
+ )
1241
+ # log model before any wrapping:
1242
+ log_model(model, logger)
1243
+
1244
+ return model
1245
+
1246
+ def wrap_model_with_ddp(self, model) -> DDP:
1247
+ """Wrap the model with DDP"""
1248
+
1249
+ try:
1250
+ ddp_model = to_ddp(
1251
+ model,
1252
+ self.gang,
1253
+ )
1254
+
1255
+ except ValueError:
1256
+ logger.warning(
1257
+ "Using pytorch DDP instead of fairseq's `to_ddp`\
1258
+ - please check fairseq2 after a3de79dcc6a4ea34cde644e15b4056f1a808a6a8"
1259
+ )
1260
+
1261
+ ddp_model = DDP(model)
1262
+
1263
+ if self.gang_rank == 0:
1264
+ log_model(ddp_model, logger)
1265
+
1266
+ return ddp_model
1267
+
1268
+ def wrap_model_with_fsdp(self, model) -> FSDP:
1269
+ """Wrap the model with FSDP."""
1270
+
1271
+ wrap_policy, ignored_modules = get_fsdp_wrap_policy(
1272
+ model, wrap_granularity=self.config.fsdp_wrap_granularity
1273
+ )
1274
+ memory_policy = get_fsdp_memory_policy(policy=self.config.fsdp_memory_policy)
1275
+
1276
+ if self.dtype == torch.float32:
1277
+ mixed_precision_dtype = None
1278
+ else:
1279
+ mixed_precision_dtype = self.dtype
1280
+
1281
+ skip_init = False
1282
+ broadcast_state = self.finetune and not self.has_checkpoint
1283
+ fp32_reduce = self.config.fsdp_fp32_reduce
1284
+
1285
+ if self.gang.rank == 0:
1286
+ logger.info(
1287
+ (
1288
+ f"FSDP init with: \n--- ignored_modules={ignored_modules}"
1289
+ f"\n--- wrap_policy={wrap_policy}"
1290
+ f"\n--- mixed_precision_dtype={mixed_precision_dtype}"
1291
+ f"\n--- skip_init={skip_init}"
1292
+ f"\n--- broadcast_state (FSDP's sync_module_states)={broadcast_state}"
1293
+ f"\n--- fp32_reduce={fp32_reduce}"
1294
+ f"\n--- memory_policy={memory_policy}"
1295
+ )
1296
+ )
1297
+
1298
+ fsdp_model = to_fsdp(
1299
+ model,
1300
+ self.gang,
1301
+ wrap_policy,
1302
+ mixed_precision_dtype=mixed_precision_dtype,
1303
+ ignored_modules=ignored_modules,
1304
+ fp32_reduce=fp32_reduce,
1305
+ skip_init=skip_init,
1306
+ broadcast_state=broadcast_state,
1307
+ memory_policy=memory_policy,
1308
+ )
1309
+
1310
+ if self.gang_rank == 0:
1311
+ log_model(fsdp_model, logger)
1312
+
1313
+ return fsdp_model
1314
+
1315
+ def maybe_load_model(self, model):
1316
+ """
1317
+ If we are finetuning and we don't have a checkpoint,
1318
+ load the pre-trained model and broadcast it to
1319
+ all gang processes from rank 0.
1320
+ """
1321
+ if not self.has_checkpoint and self.finetune:
1322
+ logger.info(f"Loading for finetuning: {self.config.model_config_or_name}")
1323
+
1324
+ if self.gang_rank == 0:
1325
+ pretrained_model = self.model_loader(
1326
+ model_name_or_card=self.config.model_config_or_name,
1327
+ device=self.gang.device,
1328
+ dtype=self.dtype,
1329
+ ) # type: ignore[arg-type]
1330
+
1331
+ try:
1332
+ model.load_state_dict(
1333
+ pretrained_model.state_dict(),
1334
+ strict=True,
1335
+ assign=False,
1336
+ )
1337
+ except (KeyError, ValueError) as ex:
1338
+ raise ValueError(
1339
+ f"The model state form {self.config.model_config_or_name} "
1340
+ "cannot be loaded. See nested exception for details."
1341
+ ) from ex
1342
+
1343
+ self.gang.barrier()
1344
+
1345
+ to_device(model, self.gang.device)
1346
+
1347
+ logger.info(
1348
+ f"Done loading model for finetuning: {self.config.model_config_or_name}"
1349
+ )
1350
+
1351
+ return model
1352
+
1353
+ def maybe_freeze_parameters(self, model):
1354
+ assert (self.config.freezing_strategy == "modules") == (
1355
+ self.config.freeze_modules is not None
1356
+ ), (
1357
+ "For the `modules` freezing_strategy, we need a list of `freeze_modules`. "
1358
+ "If `freeze_modules` is provided, make sure to use freezing_strategy=modules"
1359
+ )
1360
+
1361
+ if self.config.freezing_strategy == "none":
1362
+ return model
1363
+
1364
+ if self.config.freezing_strategy == "modules":
1365
+ # Optionally freeze the parameters of sub-modules:
1366
+ if self.config.freeze_modules is not None:
1367
+ for module in self.config.freeze_modules:
1368
+ logger.info(f"... Freezing module={module}")
1369
+ freeze_parameters(getattr(model, module))
1370
+ return model
1371
+
1372
+ if self.config.freezing_strategy == "ffn":
1373
+ for name, m in _get_named_modules(model):
1374
+ if "ffn" in name:
1375
+ logger.info(f"... Freezing module={name}")
1376
+ freeze_parameters(m)
1377
+ return model
1378
+
1379
+ if self.config.freezing_strategy == "adaln":
1380
+ for name, m in _get_named_modules(model):
1381
+ if "modulator" in name:
1382
+ logger.info(f"... Freezing module={name}")
1383
+ freeze_parameters(m)
1384
+ return model
1385
+
1386
+ if self.config.freezing_strategy == "ffn-adaln":
1387
+ for name, m in _get_named_modules(model):
1388
+ if "modulator" in name or "ffn" in name:
1389
+ logger.info(f"... Freezing module={name}")
1390
+ freeze_parameters(m)
1391
+ return model
1392
+
1393
+ raise ValueError(f"Unknown freezing stratgey {self.config.freezing_strategy}")
1394
+
1395
+ def _setup_additional_logging(self):
1396
+ if self.config.debug:
1397
+ assert self.config.log_folder is not None, (
1398
+ "Missing log_folder, \
1399
+ make sure the log_folder is properly set in the training config"
1400
+ )
1401
+ setup_additional_logging(log_folder=self.config.log_folder)
1402
+
1403
+ @property
1404
+ def use_fsdp(self) -> bool:
1405
+ return self.config.use_fsdp
1406
+
1407
+ @property
1408
+ def use_ddp(self) -> bool:
1409
+ """
1410
+ Whether DDP should be used.
1411
+ if selg.gang.size == 1: single worker, no parallelism
1412
+ if use_fsdp: use FSDP instead
1413
+ """
1414
+ return not (self.gang.size == 1 or self.use_fsdp)
1415
+
1416
+ @abstractmethod
1417
+ def build_trainer(self):
1418
+ """Build the trainer by loading data and
1419
+ setting up the model for training
1420
+
1421
+ Returns trainer
1422
+ """
lcm/train/two_tower_diffusion_lcm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
lcm/train/two_tower_diffusion_lcm/criterion.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq2.logging import get_log_writer
12
+ from fairseq2.nn.padding import pad_seqs
13
+ from torch import Tensor
14
+
15
+ from lcm.datasets.batch import EmbeddingsBatch, LCMInput, LCMStyle
16
+ from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel
17
+ from lcm.train.criterion import CriterionsFactory
18
+ from lcm.train.lcm.criterion import (
19
+ LCMCriterion,
20
+ LCMCriterionConfig,
21
+ compute_standard_mse,
22
+ )
23
+ from lcm.train.metrics import LossTerm, format_as_float, register_metric_formatter
24
+ from lcm.train.step_sampler import StepsSampler, StepsSamplerConfig
25
+
26
+ logger = get_log_writer(__name__)
27
+
28
+
29
+ @dataclass
30
+ class TowerDiffusionLCMCriterionConfig(LCMCriterionConfig):
31
+ cf_guidance_probability: float = 0.0
32
+ """Probability to use classifier-free guidance by dropping conditioning.
33
+ Note that this requires the model to be set with
34
+ `trained_with_cf_guidance = True`!
35
+ """
36
+ step_sampling: StepsSamplerConfig = field(
37
+ default_factory=lambda: StepsSamplerConfig()
38
+ )
39
+
40
+ log_losses_per_timestep_bucket: bool = False
41
+
42
+
43
+ @CriterionsFactory.register("two_tower_diffusion_next_sent")
44
+ class TwoTowerDiffusionCriterion(LCMCriterion):
45
+ """Computes the LCM training objective for next-sentence prediction with diffusion"""
46
+
47
+ config: TowerDiffusionLCMCriterionConfig
48
+ model: TwoTowerDiffusionLCModel
49
+
50
+ def __init__(
51
+ self,
52
+ config: TowerDiffusionLCMCriterionConfig,
53
+ model: TwoTowerDiffusionLCModel,
54
+ style: LCMStyle = LCMStyle.UNSUPERVISED,
55
+ ):
56
+ super().__init__(config, model, style)
57
+ assert hasattr(self.base_model, "noise_scheduler"), (
58
+ "Expecting the diffusion model to have a `noise_scheduler`"
59
+ )
60
+ self.noise_scheduler = self.base_model.noise_scheduler
61
+
62
+ self.prediction_type = self.noise_scheduler.prediction_type
63
+
64
+ self.trained_with_cf_guidance = self.base_model.config.trained_with_cf_guidance
65
+
66
+ self.cf_guidance_probability = config.cf_guidance_probability
67
+
68
+ assert (
69
+ bool(self.cf_guidance_probability > 0) == self.trained_with_cf_guidance
70
+ ), (
71
+ "Expecting the config's cf_guidance_probabilitya to align with the model's `trained_with_cf_guidance` ",
72
+ f"Found cf_guidance_probability={config.cf_guidance_probability} and "
73
+ f"trained_with_cf_guidance={self.trained_with_cf_guidance}",
74
+ )
75
+
76
+ assert self.normalize_in_criterion, (
77
+ "We only support `normalize_in_criterion = True` in the diffusion criterions"
78
+ )
79
+
80
+ self.summands.append("unnormalized_reconstruction_loss")
81
+
82
+ if self.config.log_losses_per_timestep_bucket:
83
+ # customize if needed
84
+ self.step_bucketing_boundaries = torch.linspace(
85
+ 0, self.noise_scheduler.num_diffusion_train_steps, 11
86
+ )
87
+ self.step_bucketing_labels: List[str] = []
88
+ for e in range(len(self.step_bucketing_boundaries) - 1):
89
+ bucket_left = self.step_bucketing_boundaries[e]
90
+ bucket_right = self.step_bucketing_boundaries[e + 1]
91
+ self.step_bucketing_labels.append(
92
+ f"reconstruction_loss_t{bucket_left:.0f}-{bucket_right:.0f}"
93
+ )
94
+
95
+ self.summands.extend(self.step_bucketing_labels)
96
+ for label in self.step_bucketing_labels:
97
+ register_metric_formatter(
98
+ label, label, 1000, format_as_float, overwrite=True
99
+ )
100
+
101
+ # Step sampler + loss weighter
102
+ self.step_sampler = StepsSampler(
103
+ config.step_sampling,
104
+ noise_scheduler=self.noise_scheduler,
105
+ )
106
+
107
+ def prepare_input_and_mask(
108
+ self,
109
+ batch: LCMInput,
110
+ ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
111
+ """
112
+ A method for preparing model inputs and mask for a batch.
113
+ It will be typically reused by the `__call__`
114
+ implementations of the subclasses.
115
+ Returns:
116
+ - input_batch: context
117
+ - target_batch: denoiser input
118
+ - target_mask mask of positions to compute the loss over
119
+
120
+ """
121
+ # Prepare the input as in MSE LCM: each sequence is (src, tgt)
122
+ input_embeddings = batch.prepare_input(style=self.style)
123
+
124
+ # Normalize the embeddings
125
+ if self.normalize_in_criterion:
126
+ input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
127
+
128
+ target_mask = torch.ones(
129
+ size=input_embeddings.seqs.shape[:-1],
130
+ dtype=torch.bool,
131
+ device=input_embeddings.seqs.device,
132
+ )
133
+
134
+ # Factor in padded positions:
135
+ if input_embeddings.padding_mask is not None:
136
+ target_mask &= input_embeddings.padding_mask.materialize()
137
+
138
+ return input_embeddings, input_embeddings.clone(), target_mask
139
+
140
+ def sample_noisy_input_and_targets(self, input_batch, target_mask):
141
+ """
142
+ (1)
143
+ Prepares the noised inputs (latents) by sampling diffusion timesteps and calling
144
+ on the model's noise_scheduler to add noise accordingly
145
+ (2) Given the scheduler prediction type, prepares the target that the model will be
146
+ trained to predict.
147
+
148
+ :param input_bach: EmbeddingsBatch of the ground truth embeddings with seqs in (B, T, C)
149
+ :param target_mask: Bool tensor in (B, T) where `True` signals that the
150
+ model will be asked to predict the position
151
+ """
152
+ input_seqs, padding_mask = input_batch.seqs, input_batch.padding_mask
153
+
154
+ timesteps = self.step_sampler.sample(
155
+ size=input_seqs[..., 0].size(), device=input_seqs.device
156
+ )
157
+
158
+ # Sample noise
159
+ noise_seqs = torch.randn_like(input_seqs)
160
+
161
+ # Define target in (B*T, C)
162
+ sonar_dim = input_seqs.size(-1)
163
+ if self.prediction_type == "sample":
164
+ """Predict the clean ground truth embeddings. Default mode"""
165
+ target = input_seqs.view(-1, sonar_dim)
166
+
167
+ elif self.prediction_type == "epsilon":
168
+ """Predict the added noise"""
169
+ target = noise_seqs.view(-1, sonar_dim)
170
+
171
+ elif self.prediction_type == "v_prediction":
172
+ """Predict an interpolation of the ground truth clean
173
+ embeddings and the added noise.
174
+ As introduced in https://arxiv.org/pdf/2305.08891
175
+ """
176
+ target = self.noise_scheduler.get_velocity(
177
+ input_seqs.view(-1, sonar_dim),
178
+ noise_seqs.view(-1, sonar_dim),
179
+ timesteps.view(-1),
180
+ ).clone()
181
+ else:
182
+ raise ValueError(
183
+ "Prediction type should be either: sample, epsilon, v_prediction"
184
+ )
185
+
186
+ # Add noise
187
+ # Reshape inputs and noise into in (B*T , C) -> add noise -> reshape back as (B, T, C)
188
+ noisy_input_seqs = self.noise_scheduler.add_noise(
189
+ input_seqs.view(-1, sonar_dim),
190
+ noise_seqs.view(-1, sonar_dim),
191
+ timesteps.view(-1),
192
+ ).view(input_seqs.size())
193
+
194
+ # Create sequence batch with diffusion timesteps
195
+ noisy_input_batch = EmbeddingsBatch(
196
+ noisy_input_seqs,
197
+ padding_mask,
198
+ diffusion_timesteps=timesteps,
199
+ )
200
+ return noisy_input_batch, target, target_mask
201
+
202
+ def compute_loss(
203
+ self, flattened_predictions, flattened_target
204
+ ) -> Tuple[Tensor, Tensor, Tensor]:
205
+ """
206
+ Parameters:
207
+ flattened_predictions (Tensor): The predictions in (N, C)
208
+ flattened_target (Tensor): The targets in (N, C)
209
+
210
+ Returns:
211
+ reconstruction_loss (Tensor): The Reconstruction loss we want to optimize (RMSE, SmoothL1, Huber etc.).
212
+ plain_reconstruction_loss (Tensor): plain RMSE loss.
213
+ unnormalized_reconstruction_loss (Tensor): plain RMSE loss between unnormalized features.
214
+ """
215
+ reconstruction_loss, plain_reconstruction_loss = compute_standard_mse(
216
+ flattened_predictions,
217
+ flattened_target,
218
+ )
219
+
220
+ unnormalized_reconstruction_loss, _ = compute_standard_mse(
221
+ flattened_predictions,
222
+ flattened_target,
223
+ normalizer=self.sonar_normalizer,
224
+ )
225
+ # For backward compatibility with ongoing runs, take the sqrt
226
+ if self.config.compute_rmse:
227
+ epsilon = 1e-5
228
+ reconstruction_loss = torch.sqrt(reconstruction_loss + epsilon)
229
+ plain_reconstruction_loss = torch.sqrt(plain_reconstruction_loss + epsilon)
230
+ unnormalized_reconstruction_loss = torch.sqrt(
231
+ unnormalized_reconstruction_loss + epsilon
232
+ )
233
+
234
+ return (
235
+ reconstruction_loss,
236
+ plain_reconstruction_loss,
237
+ unnormalized_reconstruction_loss,
238
+ )
239
+
240
+ @torch.no_grad()
241
+ def _log_losses_per_step(self, batch_steps, reconstruction_loss):
242
+ # Aggregate loss terms based on their bucket of diffusion steps for tracking
243
+ summands = {}
244
+ if self.config.log_losses_per_timestep_bucket:
245
+ # Reconstruction_loss in BT,
246
+ # batch_steps in BT,
247
+ bucket_index = torch.bucketize(
248
+ batch_steps, self.step_bucketing_boundaries.to(batch_steps.device)
249
+ )
250
+ onehot = F.one_hot(
251
+ bucket_index,
252
+ num_classes=self.step_bucketing_boundaries.numel(),
253
+ )
254
+ loss_per_step = torch.matmul(onehot.t().float(), reconstruction_loss)
255
+ count_steps = onehot.sum(dim=0) + 1e-6
256
+ if self.reduction == "mean":
257
+ loss_per_step /= count_steps
258
+
259
+ for e, label in enumerate(self.step_bucketing_labels):
260
+ summands[label] = (
261
+ loss_per_step[e].item(),
262
+ count_steps[e].long().item(),
263
+ )
264
+
265
+ return summands
266
+
267
+ def __call__(self, batch: LCMInput) -> LossTerm:
268
+ """
269
+ Input batch is LCMInput with:
270
+ source: List[Tensor]
271
+ target: Union[None, List[Tensor]]
272
+ """
273
+
274
+ # Prepare the clean inputs and target mask:
275
+ input_batch, target_batch, target_mask = self.prepare_input_and_mask(batch)
276
+
277
+ noisy_target_batch, target, target_mask = self.sample_noisy_input_and_targets(
278
+ target_batch, target_mask
279
+ )
280
+ # Encode the context and diffuse:
281
+ output_batch = self.model(
282
+ input_batch,
283
+ noisy_target_batch,
284
+ cf_guidance_prob=self.cf_guidance_probability,
285
+ )
286
+
287
+ # Shape B, T, C
288
+ output_seqs = output_batch.seqs
289
+
290
+ sonar_dim = output_seqs.size(-1)
291
+
292
+ # only measure distance over `target_mask = True` positions
293
+ target_mask = target_mask.reshape(-1)
294
+
295
+ # The target is basically the doubled ground truth sequence before noising
296
+ # (with some modification to adjust for the denoiser's prediction type)
297
+
298
+ # contextualized latents (noised inputs preceding the target) e_1, e_2, ...
299
+ flattened_predictions = output_seqs.view(-1, sonar_dim)[target_mask]
300
+
301
+ # x1, x2, ..., xT
302
+ # Target is already in B*T, C
303
+ flattened_target = target[target_mask]
304
+
305
+ # Cast features to float32 before computing the loss:
306
+ (
307
+ reconstruction_loss,
308
+ mse_loss,
309
+ unnormalized_reconstruction_loss,
310
+ ) = self.compute_loss(flattened_predictions.float(), flattened_target.float())
311
+
312
+ num_target_elements = target_mask.sum()
313
+
314
+ batch_steps = noisy_target_batch.diffusion_timesteps.view(-1)[target_mask]
315
+
316
+ summands = self._log_losses_per_step(batch_steps, reconstruction_loss)
317
+
318
+ # Get loss scales per timestep (gamma)
319
+ gammas = self.step_sampler.get_loss_scales(batch_steps)
320
+ # Weight the loss terms
321
+ if gammas is not None:
322
+ reconstruction_loss = torch.mul(reconstruction_loss, gammas)
323
+
324
+ if self.reduction == "sum" or num_target_elements == 0:
325
+ reduced_reconstruction_loss = reconstruction_loss.sum()
326
+ mse_loss = mse_loss.sum()
327
+ unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.sum()
328
+
329
+ elif self.reduction == "mean":
330
+ reduced_reconstruction_loss = reconstruction_loss.mean()
331
+ mse_loss = mse_loss.mean()
332
+ unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.mean()
333
+
334
+ final_loss = reduced_reconstruction_loss
335
+
336
+ # Loss summands for records
337
+ summands.update(
338
+ {
339
+ "mse_loss": (mse_loss.item(), -1),
340
+ "reconstruction_loss": (reduced_reconstruction_loss.item(), -1),
341
+ "unnormalized_reconstruction_loss": (
342
+ unnormalized_reconstruction_loss.item(),
343
+ -1,
344
+ ),
345
+ }
346
+ )
347
+
348
+ return LossTerm(
349
+ value=final_loss,
350
+ batch_size=output_seqs.size(0),
351
+ num_target_elements=num_target_elements.item(),
352
+ summands=summands,
353
+ )
354
+
355
+
356
+ @CriterionsFactory.register("two_tower_diffusion_next_sent_finetuning")
357
+ class DiffusionNextSentFinetuningCriterion(TwoTowerDiffusionCriterion):
358
+ def __init__(
359
+ self,
360
+ config: TowerDiffusionLCMCriterionConfig,
361
+ model: TwoTowerDiffusionLCModel,
362
+ ):
363
+ super().__init__(config, model, LCMStyle.SUPERVISED)
364
+
365
+ def prepare_input_and_mask(
366
+ self,
367
+ batch: LCMInput,
368
+ ) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
369
+ """
370
+ A method for preparing model inputs and mask for a batch.
371
+ It will be typically reused by the `__call__`
372
+ implementations of the subclasses.
373
+
374
+ Returns:
375
+ - input_batch: context
376
+ - target_batch: denoiser input
377
+ - target_mask mask of positions to compute the loss over
378
+ """
379
+
380
+ # Prepare the input as in MSE LCM
381
+ input_embeddings = batch.prepare_input(style=self.style)
382
+
383
+ assert input_embeddings.source_lengths is not None, (
384
+ "Missing source lengths needed for the two-tower supervised fintuning"
385
+ )
386
+
387
+ target_embeddings = EmbeddingsBatch(*pad_seqs(batch.target)) # type: ignore
388
+
389
+ # Normalize the embeddings
390
+ if self.normalize_in_criterion:
391
+ input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
392
+ target_embeddings = target_embeddings.normalize_seqs(self.sonar_normalizer)
393
+
394
+ target_mask = torch.ones(
395
+ size=target_embeddings.shape[:-1],
396
+ dtype=torch.bool,
397
+ device=input_embeddings.seqs.device,
398
+ )
399
+
400
+ # Factor in padded positions:
401
+ if target_embeddings.padding_mask is not None:
402
+ target_mask &= target_embeddings.padding_mask.materialize()
403
+
404
+ return input_embeddings, target_embeddings, target_mask
lcm/train/two_tower_diffusion_lcm/trainer.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ #
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Union
8
+
9
+ from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig
10
+ from lcm.models.two_tower_diffusion_lcm.loader import (
11
+ load_two_tower_diffusion_lcm_model,
12
+ )
13
+ from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig
14
+ from lcm.train.two_tower_diffusion_lcm.criterion import (
15
+ TowerDiffusionLCMCriterionConfig,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig):
21
+ model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None
22
+ """The model configuration or name to train."""
23
+
24
+ criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore
25
+ default_factory=lambda: TowerDiffusionLCMCriterionConfig()
26
+ )
27
+
28
+
29
+ class DiffusionLCMTrainerBuilder(LCMTrainerBuilder):
30
+ config: TwoTowerDiffusionLCMTrainingConfig
31
+
32
+ def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig):
33
+ super().__init__(config)
34
+
35
+ @property
36
+ def model_loader(self):
37
+ """A fairseq2 ModelLoader"""
38
+ return load_two_tower_diffusion_lcm_model
39
+
40
+
41
+ def prepare_two_tower_diffusion_lcm_trainer(
42
+ config: TwoTowerDiffusionLCMTrainingConfig,
43
+ ) -> LCMTrainer:
44
+ """Create an LCM Trainer.
45
+ :param config: The training configuration.
46
+ """
47
+ return DiffusionLCMTrainerBuilder(config).build_trainer()
pyproject.toml CHANGED
@@ -13,6 +13,7 @@ dependencies = [
13
  "polars>=1.16.0",
14
  "pyarrow>=16.1.0",
15
  "retrying>=1.3.4",
 
16
  "sentence-splitter>=1.4",
17
  "sonar-space>=0.3.2",
18
  "stopes[mono]>=2.2.0",
 
13
  "polars>=1.16.0",
14
  "pyarrow>=16.1.0",
15
  "retrying>=1.3.4",
16
+ "safetensors>=0.5.3",
17
  "sentence-splitter>=1.4",
18
  "sonar-space>=0.3.2",
19
  "stopes[mono]>=2.2.0",
scripts/CovertToST.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import save_file
3
+ import os
4
+
5
+ # Define the location and files to process
6
+ location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000"
7
+ files = ["model", "rank_0", "metadata"]
8
+
9
+ for file in files:
10
+ pt_path = os.path.join(location, f"{file}.pt")
11
+ st_path = os.path.join(location, f"{file}.safetensors")
12
+
13
+ try:
14
+ # Attempt to load the checkpoint with weights_only=True
15
+ checkpoint = torch.load(pt_path, weights_only=True)
16
+ except Exception as e:
17
+ print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}")
18
+ print("Attempting to load with weights_only=False (ensure the source is trusted).")
19
+ try:
20
+ checkpoint = torch.load(pt_path, weights_only=False)
21
+ except Exception as e:
22
+ print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}")
23
+ continue # Skip to the next file
24
+
25
+ # Determine the state_dict
26
+ state_dict = checkpoint.get('model', checkpoint)
27
+
28
+ # Filter out non-tensor entries
29
+ tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
30
+
31
+ # Save the filtered state_dict to a .safetensors file
32
+ save_file(tensor_state_dict, st_path)
33
+ print(f"Successfully converted {pt_path} to {st_path}")