Lexa
commited on
Commit
·
b5a0bec
1
Parent(s):
bb10ea5
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
Browse files- .gitattributes +1 -0
- .gitignore +4 -1
- Patches/Patch_TorchLoader.py +87 -0
- _LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{metadata.pt → metadata.safetensors} +2 -2
- _LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{model.pt → model.safetensors} +2 -2
- _LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml +1 -1
- _LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{rank_0.pt → rank_0.safetensors} +2 -2
- lcm/datasets/base.py +114 -0
- lcm/datasets/configs.py +774 -0
- lcm/datasets/dataloader.py +258 -0
- lcm/datasets/dataloading.py +1109 -0
- lcm/datasets/parquet_utils.py +1141 -0
- lcm/datasets/sentence_splitter_pipeline.py +351 -0
- lcm/datasets/sentence_splitting.py +160 -0
- lcm/datasets/utils.py +42 -0
- lcm/models/two_tower_diffusion_lcm/loader.py +3 -1
- lcm/train/__main__.py +131 -0
- lcm/train/common.py +65 -0
- lcm/train/criterion.py +100 -0
- lcm/train/lcm/__init__.py +4 -0
- lcm/train/lcm/criterion.py +143 -0
- lcm/train/lcm/trainer.py +259 -0
- lcm/train/metrics.py +449 -0
- lcm/train/mse_lcm/__init__.py +4 -0
- lcm/train/mse_lcm/criterion.py +179 -0
- lcm/train/optim.py +96 -0
- lcm/train/step_sampler.py +107 -0
- lcm/train/trainer.py +1422 -0
- lcm/train/two_tower_diffusion_lcm/__init__.py +4 -0
- lcm/train/two_tower_diffusion_lcm/criterion.py +404 -0
- lcm/train/two_tower_diffusion_lcm/trainer.py +47 -0
- pyproject.toml +1 -0
- scripts/CovertToST.py +33 -0
.gitattributes
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -115,4 +115,7 @@ mortimer_env.txt
|
|
| 115 |
_LexaLCM_Block0/Datasets/
|
| 116 |
|
| 117 |
# UV
|
| 118 |
-
uv.lock
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
_LexaLCM_Block0/Datasets/
|
| 116 |
|
| 117 |
# UV
|
| 118 |
+
uv.lock
|
| 119 |
+
|
| 120 |
+
# Unsafe files
|
| 121 |
+
*.pt
|
Patches/Patch_TorchLoader.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Patch for fairseq2.utils.file.load_tensors
|
| 2 |
+
#
|
| 3 |
+
# This patch allows for loading safetensors files
|
| 4 |
+
#
|
| 5 |
+
# It is used in the two_tower_diffusion_lcm model loader:
|
| 6 |
+
# ./lcm/models/two_tower_diffusion_lcm/loader.py
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import warnings
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Union
|
| 13 |
+
from warnings import catch_warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
from typing_extensions import TypeAlias
|
| 18 |
+
|
| 19 |
+
from fairseq2.typing import Device
|
| 20 |
+
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
|
| 23 |
+
MapLocation: TypeAlias = Optional[
|
| 24 |
+
Union[Callable[[Tensor, str], Tensor], Device, str, Dict[str, str]]
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TensorLoader(Protocol):
|
| 29 |
+
"""Loads tensors from files."""
|
| 30 |
+
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
path: Path,
|
| 34 |
+
*,
|
| 35 |
+
map_location: MapLocation = None,
|
| 36 |
+
restrict: bool = False,
|
| 37 |
+
) -> Dict[str, Any]:
|
| 38 |
+
"""
|
| 39 |
+
:param path:
|
| 40 |
+
The path to the file.
|
| 41 |
+
:param map_location:
|
| 42 |
+
Same as the ``map_location`` parametload_two_tower_diffusion_lcm_model = StandardModelLoader( # type: ignore # FIXME
|
| 43 |
+
config_loader=load_two_tower_diffusion_lcm_config,
|
| 44 |
+
factory=create_two_tower_diffusion_lcm_model,
|
| 45 |
+
checkpoint_converter=convert_lcm_checkpoint,
|
| 46 |
+
restrict_checkpoints=False,
|
| 47 |
+
)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TensorDumper(Protocol):
|
| 52 |
+
"""Dumps tensors to files."""
|
| 53 |
+
|
| 54 |
+
def __call__(self, data: Mapping[str, Any], path: Path) -> None:
|
| 55 |
+
"""
|
| 56 |
+
:param data:
|
| 57 |
+
The dictionary containing tensors and other auxiliary data.
|
| 58 |
+
:param path:
|
| 59 |
+
The path to the file.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_tensors(
|
| 64 |
+
path: Path,
|
| 65 |
+
*,
|
| 66 |
+
map_location=None,
|
| 67 |
+
restrict: bool = False,
|
| 68 |
+
) -> Dict[str, Any]:
|
| 69 |
+
"""Load a checkpoint in .pt or .safetensors format."""
|
| 70 |
+
if str(path).endswith(".safetensors"):
|
| 71 |
+
tensors = load_file(str(path), device=str(map_location) if map_location else "cpu")
|
| 72 |
+
return {"model": tensors} # ✅ Wrap it like a .pt file
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
with warnings.catch_warnings():
|
| 76 |
+
warnings.simplefilter("ignore")
|
| 77 |
+
return torch.load(
|
| 78 |
+
str(path), map_location, weights_only=restrict # type: ignore[arg-type]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def dump_tensors(data: Mapping[str, Any], path: Path) -> None:
|
| 83 |
+
"""Dump ``data`` to a PyTorch tensor file under ``path``."""
|
| 84 |
+
with catch_warnings():
|
| 85 |
+
warnings.simplefilter("ignore") # Suppress noisy FSDP warnings.
|
| 86 |
+
|
| 87 |
+
torch.save(data, path)
|
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{metadata.pt → metadata.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
|
| 3 |
+
size 16
|
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{model.pt → model.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f6160840e8a76276b126f4da6ded5568c2dcc777fd40007ccfa5bcfb08d9bce
|
| 3 |
+
size 575804960
|
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model_card.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
__source__: inproc
|
| 2 |
-
checkpoint: file:///home/lexa/DevProjects/_Unsorted/LexaLCM_Pre0_288M/_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model.
|
| 3 |
model_arch: arch_lexa_lcm_pre0
|
| 4 |
model_family: two_tower_diffusion_lcm
|
| 5 |
name: on_the_fly_lcm
|
|
|
|
| 1 |
__source__: inproc
|
| 2 |
+
checkpoint: file:///home/lexa/DevProjects/_Unsorted/LexaLCM_Pre0_288M/_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/model.safetensors
|
| 3 |
model_arch: arch_lexa_lcm_pre0
|
| 4 |
model_family: two_tower_diffusion_lcm
|
| 5 |
name: on_the_fly_lcm
|
_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000/{rank_0.pt → rank_0.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
|
| 3 |
+
size 16
|
lcm/datasets/base.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Callable, Dict, Generic, Iterator, Optional, Sequence, TypeVar, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq2.data.data_pipeline import DataPipeline
|
| 12 |
+
from fairseq2.gang import FakeGang, Gang
|
| 13 |
+
from fairseq2.typing import DataType
|
| 14 |
+
|
| 15 |
+
from lcm.datasets.configs import (
|
| 16 |
+
DataLoadingConfig,
|
| 17 |
+
DatasetConfigT,
|
| 18 |
+
create_dataset_config_from_cards,
|
| 19 |
+
)
|
| 20 |
+
from lcm.datasets.dataloading import (
|
| 21 |
+
build_weighted_pipeline_with_renaming as default_build_fn,
|
| 22 |
+
)
|
| 23 |
+
from lcm.utils.common import Batched, set_mkl_num_threads
|
| 24 |
+
|
| 25 |
+
BatchT_co = TypeVar("BatchT_co", bound=Union[Dict, Batched], covariant=True)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DataLoader(ABC, Generic[BatchT_co, DatasetConfigT]):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
data_config: DataLoadingConfig,
|
| 33 |
+
datasets: Sequence[DatasetConfigT],
|
| 34 |
+
gang: Gang,
|
| 35 |
+
builder_func: Callable[..., DataPipeline] = default_build_fn,
|
| 36 |
+
dtype: DataType = torch.float16,
|
| 37 |
+
):
|
| 38 |
+
self.data_config = data_config
|
| 39 |
+
self.datasets = list(map(create_dataset_config_from_cards, datasets))
|
| 40 |
+
self.dtype = dtype
|
| 41 |
+
self.gang = gang
|
| 42 |
+
self.builder_func = builder_func
|
| 43 |
+
|
| 44 |
+
self._pipeline: Optional[DataPipeline] = None
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def pipeline(self) -> DataPipeline:
|
| 48 |
+
if self._pipeline is None:
|
| 49 |
+
logger.info(f"R{self.gang.rank} self._pipeline is None, building...")
|
| 50 |
+
gang_rank = self.gang.rank if self.gang else 0
|
| 51 |
+
world_size = self.gang.size if self.gang else 1
|
| 52 |
+
|
| 53 |
+
self._pipeline = self.builder_func(
|
| 54 |
+
self.datasets, self.data_config, gang_rank, world_size
|
| 55 |
+
)
|
| 56 |
+
assert self._pipeline, (
|
| 57 |
+
f"Cannot build data pipeline from config {self.data_config}"
|
| 58 |
+
)
|
| 59 |
+
return self._pipeline
|
| 60 |
+
|
| 61 |
+
def destroy(self) -> None:
|
| 62 |
+
"""Destroy the pipeline to rebuild it with different shuffling"""
|
| 63 |
+
self._pipeline = None
|
| 64 |
+
# Build again and reset it
|
| 65 |
+
logger.info(f"R{self.gang.rank} resetting the pipeline in DataLoader.destroy")
|
| 66 |
+
self.reset()
|
| 67 |
+
|
| 68 |
+
def reset(self) -> None:
|
| 69 |
+
"""
|
| 70 |
+
Applying reset will result in different shuffling for next iterations,
|
| 71 |
+
since pipeline will use modified generator state from previous one.
|
| 72 |
+
This's suitable side effect for `sharding_in_memory=False` (training) scenario.
|
| 73 |
+
|
| 74 |
+
Illustrative example :
|
| 75 |
+
>>> import torch
|
| 76 |
+
>>> from fairseq2.data import read_sequence
|
| 77 |
+
|
| 78 |
+
>>> def get_one_epoch_pipeline():
|
| 79 |
+
... torch.manual_seed(13)
|
| 80 |
+
... return read_sequence(list(range(10))).shuffle(5)
|
| 81 |
+
|
| 82 |
+
>>> bb = get_one_epoch_pipeline().and_return()
|
| 83 |
+
>>> list(bb)
|
| 84 |
+
[3, 1, 2, 4, 0, 8, 5, 6, 9, 7]
|
| 85 |
+
>>> bb.reset()
|
| 86 |
+
>>> list(bb)
|
| 87 |
+
[4, 0, 3, 2, 1, 9, 7, 6, 8, 5]
|
| 88 |
+
"""
|
| 89 |
+
self.pipeline.reset()
|
| 90 |
+
|
| 91 |
+
@abstractmethod
|
| 92 |
+
def iterate_batches(self) -> Iterator[BatchT_co]: ...
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class BaseDataLoader(DataLoader[dict, DatasetConfigT]):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
data_config: DataLoadingConfig,
|
| 99 |
+
datasets: Sequence[DatasetConfigT],
|
| 100 |
+
dtype: DataType = torch.float16,
|
| 101 |
+
gang: Gang = None,
|
| 102 |
+
) -> None:
|
| 103 |
+
gang = gang or FakeGang()
|
| 104 |
+
super().__init__(
|
| 105 |
+
data_config=data_config,
|
| 106 |
+
datasets=datasets,
|
| 107 |
+
builder_func=default_build_fn,
|
| 108 |
+
dtype=dtype,
|
| 109 |
+
gang=gang,
|
| 110 |
+
)
|
| 111 |
+
set_mkl_num_threads()
|
| 112 |
+
|
| 113 |
+
def iterate_batches(self) -> Iterator[dict]:
|
| 114 |
+
yield from iter(self.pipeline)
|
lcm/datasets/configs.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import asdict, dataclass, fields
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar
|
| 12 |
+
|
| 13 |
+
# XXX: these should be kept for eval of filters expressions
|
| 14 |
+
import pyarrow as pa
|
| 15 |
+
import pyarrow.compute as pc
|
| 16 |
+
import pyarrow.parquet as pq
|
| 17 |
+
from fairseq2.assets import default_asset_store
|
| 18 |
+
from omegaconf import MISSING
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ParquetBatchFormat(Enum):
|
| 24 |
+
pyarrow = 0
|
| 25 |
+
pandas = 1
|
| 26 |
+
torch = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ColumnsNames(Enum):
|
| 30 |
+
source_column = "_source_column"
|
| 31 |
+
source_text_column = "_source_text_column"
|
| 32 |
+
target_column = "_target_column"
|
| 33 |
+
target_text_column = "_target_text_column"
|
| 34 |
+
|
| 35 |
+
dataset_name = "_dataset_name"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class SonarTextColumn:
|
| 40 |
+
text_value: Optional[str] = None
|
| 41 |
+
"""
|
| 42 |
+
Raw text expression that will be used as constant colum after being sententized and sonarized.
|
| 43 |
+
"""
|
| 44 |
+
text_column: Optional[str] = None
|
| 45 |
+
sonar_column: Optional[str] = None
|
| 46 |
+
"""
|
| 47 |
+
Note `text_column` and `sonar_column` should be aligned (so `sonar_column` should be sonar encoded `text_column`).
|
| 48 |
+
If `sonar_column` is None and `text_column` is provided, we set `sonar_column = f"{text_column}_sonar_emb"` as default processing value!
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class ParquetDatasetLimitOptions:
|
| 54 |
+
fraction_of_files: Optional[float] = None
|
| 55 |
+
nb_files: Optional[int] = None
|
| 56 |
+
nb_fragments: Optional[int] = None
|
| 57 |
+
nb_rows: Optional[int] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass(frozen=True)
|
| 61 |
+
class SonarDecoderConfig:
|
| 62 |
+
tokenizer: str = "text_sonar_basic_decoder"
|
| 63 |
+
""" SONAR tokenizer """
|
| 64 |
+
|
| 65 |
+
decoder: str = "text_sonar_basic_decoder"
|
| 66 |
+
""" SONAR decoder"""
|
| 67 |
+
|
| 68 |
+
lang: str = "eng_Latn"
|
| 69 |
+
""" Target language """
|
| 70 |
+
|
| 71 |
+
max_tokens_in_sentence: int = 256
|
| 72 |
+
"""Maximum number of tokens generated in the text"""
|
| 73 |
+
|
| 74 |
+
temperature: float = 1.0
|
| 75 |
+
"""The decoding logit temperature, where values greater than 1.0 produce more
|
| 76 |
+
uniform logits; values less than 1.0 produce sharper logits."""
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass(frozen=True)
|
| 80 |
+
class SonarEncoderConfig:
|
| 81 |
+
tokenizer: str = "text_sonar_basic_encoder"
|
| 82 |
+
""" SONAR tokenizer """
|
| 83 |
+
|
| 84 |
+
encoder: str = "text_sonar_basic_encoder"
|
| 85 |
+
""" SONAR decoder"""
|
| 86 |
+
|
| 87 |
+
lang: str = "eng_Latn"
|
| 88 |
+
""" Target language """
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class DatasetConfig:
|
| 93 |
+
"""
|
| 94 |
+
Generic dataset config
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
columns: Optional[List[str]] = None
|
| 98 |
+
"""The list of columns to load.
|
| 99 |
+
Columns such as `source_column`, ..., will be added automatically.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
source_text_column: Optional[str] = None
|
| 103 |
+
""" Column to load as source raw text"""
|
| 104 |
+
|
| 105 |
+
target_text_column: Optional[str] = None
|
| 106 |
+
""" Column to load as target raw text for paired data"""
|
| 107 |
+
|
| 108 |
+
source_prefix_text: Optional[str] = None
|
| 109 |
+
""" Text to prepend to the content of the source_column"""
|
| 110 |
+
|
| 111 |
+
source_suffix_text: Optional[str] = None
|
| 112 |
+
""" Text to append to the content of the target_column"""
|
| 113 |
+
|
| 114 |
+
target_prefix_text: Optional[str] = None
|
| 115 |
+
""" Text to prepend to the content of the source_column"""
|
| 116 |
+
|
| 117 |
+
target_suffix_text: Optional[str] = None
|
| 118 |
+
""" Text to append to the content of the target_column"""
|
| 119 |
+
|
| 120 |
+
source_sequences: Optional[List[SonarTextColumn]] = None
|
| 121 |
+
"""
|
| 122 |
+
Designed to make on-the-fly prompts from existing columns that are more complex than prefix and suffix.
|
| 123 |
+
Each element of source_sequences is a SonarTextColumn, which can be either:
|
| 124 |
+
- constant raw text (with the text_value argument)
|
| 125 |
+
- text column (with the text_column argument)
|
| 126 |
+
- sonar column (with the sonar_column argument)
|
| 127 |
+
|
| 128 |
+
Note that text_value cannot co-exist with text_column or sonar_column, and sonar column cannot be specified
|
| 129 |
+
without a text column. Further behaviour for parquet datasets:
|
| 130 |
+
- If text_value is specified, this will be split to sentences and sonarized
|
| 131 |
+
- If only text_column is specified, a new column named "<text_column>_sonar_emb" will be added as sonar_column.
|
| 132 |
+
- If both (text_column, sonar_column) is specified,
|
| 133 |
+
|
| 134 |
+
All SonarTextColumn elements from source_sequences will be concatenated together to produce new source_column
|
| 135 |
+
and source_text_column (same for target), which will have names as defined in ColumnsNames.
|
| 136 |
+
Using source_sequences is NOT compatible with using source_column or source_text_column, as well as quality filtering.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
target_sequences: Optional[List[SonarTextColumn]] = None
|
| 140 |
+
"""Designed to make on-the-fly prompts / instructions for target column, see `source_sequences` for more details"""
|
| 141 |
+
|
| 142 |
+
silent_freeze: bool = False
|
| 143 |
+
"""If set to true, the config value can only be set once, i.e. it will not be able to update after the being set is instantiated.
|
| 144 |
+
This is helpful to avoid side-effect in setting some configs after being specified by the user application (Hydra, CLI)"""
|
| 145 |
+
|
| 146 |
+
def __post_init__(self):
|
| 147 |
+
if self.source_sequences is not None:
|
| 148 |
+
if self.source_text_column is not None:
|
| 149 |
+
logger.warning(
|
| 150 |
+
f"Both `source_sequence` and `source_text_column` is specified. "
|
| 151 |
+
f"Ignore `source_text_column` and use default value `{ColumnsNames.source_text_column.value}`.\n"
|
| 152 |
+
f"(`source_sequences` = {self.source_sequences}, \n"
|
| 153 |
+
f"`source_text_column` = {self.source_text_column} )"
|
| 154 |
+
)
|
| 155 |
+
self.source_text_column = ColumnsNames.source_text_column.value
|
| 156 |
+
|
| 157 |
+
if self.target_sequences is not None:
|
| 158 |
+
if self.target_text_column is not None:
|
| 159 |
+
logger.warning(
|
| 160 |
+
f"Both `target_sequences` and `target_text_column` is specified. "
|
| 161 |
+
f"Ignore `target_text_column` and use default value `{ColumnsNames.target_text_column.value}`.\n"
|
| 162 |
+
f"(`target_sequences` = {self.target_sequences}, \n"
|
| 163 |
+
f"`target_text_column` = {self.target_text_column} )"
|
| 164 |
+
)
|
| 165 |
+
self.target_text_column = ColumnsNames.target_text_column.value
|
| 166 |
+
|
| 167 |
+
for col in (self.source_sequences or []) + (self.target_sequences or []):
|
| 168 |
+
if col.text_value is not None:
|
| 169 |
+
assert col.text_column is None and col.sonar_column is None
|
| 170 |
+
else:
|
| 171 |
+
assert col.text_column is not None
|
| 172 |
+
|
| 173 |
+
self._has_initialized_: bool = True
|
| 174 |
+
|
| 175 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 176 |
+
if not getattr(self, "_has_initialized_", False):
|
| 177 |
+
return super().__setattr__(name, value)
|
| 178 |
+
if name == "silent_freeze":
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"Direct change of silent_freeze outside __init__ is forbidden"
|
| 181 |
+
)
|
| 182 |
+
if self.silent_freeze and getattr(self, name) not in ("", None, MISSING):
|
| 183 |
+
logger.debug(
|
| 184 |
+
f"Ignore change of {name} since silent_freeze is set and value is not empty ({getattr(self, name)})"
|
| 185 |
+
)
|
| 186 |
+
return
|
| 187 |
+
super().__setattr__(name, value)
|
| 188 |
+
|
| 189 |
+
def override_attr(self, name: str, value: Any) -> None:
|
| 190 |
+
try:
|
| 191 |
+
self._has_initialized_ = False
|
| 192 |
+
super().__setattr__(name, value)
|
| 193 |
+
finally:
|
| 194 |
+
self._has_initialized_ = True
|
| 195 |
+
|
| 196 |
+
def freeze(self) -> None:
|
| 197 |
+
"""Turn the `silent_freeze` flag on"""
|
| 198 |
+
try:
|
| 199 |
+
self._has_initialized_ = False
|
| 200 |
+
self.silent_freeze = True
|
| 201 |
+
finally:
|
| 202 |
+
self._has_initialized_ = True
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@dataclass
|
| 206 |
+
class JSONDatasetConfig(DatasetConfig):
|
| 207 |
+
"""Config for datasets stored in JsonL format."""
|
| 208 |
+
|
| 209 |
+
file_path: str = str()
|
| 210 |
+
"""
|
| 211 |
+
Path to the directory containing the Jsonl dataset.
|
| 212 |
+
Each task will replace this wil a real Json files
|
| 213 |
+
TODO: Add support for remote JsonL file (e.g. with "s3://...")
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
prompt_template: Optional[str] = None
|
| 217 |
+
"""
|
| 218 |
+
A jinja-format string to apply for each item in the dataset to transform into a string.
|
| 219 |
+
Useful for example when compiling a dynamic instruction / prompt for training or evaluation.
|
| 220 |
+
Note that when this is specified, it will take precedence over the "affix" option, i.e. the
|
| 221 |
+
columns `source_prefix_text`, `source_suffix_text`,... will be ignored.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 225 |
+
if not getattr(self, "_has_initialized_", False):
|
| 226 |
+
return super().__setattr__(name, value)
|
| 227 |
+
|
| 228 |
+
if name == "silent_freeze":
|
| 229 |
+
raise ValueError("Direct change of silent_freeze is forbidden")
|
| 230 |
+
|
| 231 |
+
if self.silent_freeze:
|
| 232 |
+
if getattr(self, name) not in ("", None, MISSING):
|
| 233 |
+
logger.debug(
|
| 234 |
+
f"Ignore change of {name} in silent frozen mode when value is not empty ({getattr(self, name)})"
|
| 235 |
+
)
|
| 236 |
+
return
|
| 237 |
+
|
| 238 |
+
# Ensure we cannot set the default `prompt_template` value when the user specifies
|
| 239 |
+
# source_sequences or source_text_column explicitly
|
| 240 |
+
for hi_prior_col, lo_prior_col, lo_prior_value in [
|
| 241 |
+
("source_sequences", "source_text_column", self.source_text_column),
|
| 242 |
+
("target_sequences", "target_text_column", self.target_text_column),
|
| 243 |
+
("prompt_template", "source_sequences", self.source_sequences),
|
| 244 |
+
("prompt_template", "source_prefix_text", self.source_prefix_text),
|
| 245 |
+
("prompt_template", "source_suffix_text", self.source_suffix_text),
|
| 246 |
+
]:
|
| 247 |
+
if name == hi_prior_col and lo_prior_value not in ("", None, MISSING):
|
| 248 |
+
logger.warning(
|
| 249 |
+
f"Updating value of {hi_prior_col} will cause conflicts with the user-defined "
|
| 250 |
+
f"value in {lo_prior_col}. The update will be ignored.\n"
|
| 251 |
+
)
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
super().__setattr__(name, value)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@dataclass
|
| 258 |
+
class ParquetDatasetConfig(DatasetConfig):
|
| 259 |
+
"""
|
| 260 |
+
Config for datasets stored in Parquet format.
|
| 261 |
+
|
| 262 |
+
XXX: this config should not hold non-trival default values.
|
| 263 |
+
We want this to make datacards info and hydra config merge easier.
|
| 264 |
+
All None value should be filled up in downstream `build_parquet_iterator_pipeline`.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
name: Optional[str] = None
|
| 268 |
+
"""When name is provided, it will use preregistered cards to populate all attributes.
|
| 269 |
+
name convention is the following
|
| 270 |
+
- {card_name}={split}:{weight}
|
| 271 |
+
|
| 272 |
+
Example:
|
| 273 |
+
- wiki
|
| 274 |
+
- wiki:0.2 # no split
|
| 275 |
+
- wiki=dev # default weight=1
|
| 276 |
+
- wiki=dev:0.2
|
| 277 |
+
|
| 278 |
+
Cards attributes will be overwritten by user defined ParquetDatasetConfig in
|
| 279 |
+
`create_dataset_config_from_cards`.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
parquet_path: str = str()
|
| 283 |
+
"""The path to parquet dataset file.
|
| 284 |
+
if `parquet_path` is remote (like stats with "s3://..."),
|
| 285 |
+
the filesystem will be automatically detected and `filesystem_expr` should remain None
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
weight: float = 1.0
|
| 289 |
+
"""
|
| 290 |
+
Indicates relative weight of dataset that can be used for sampling from different datasets.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
limit: Optional[ParquetDatasetLimitOptions] = None
|
| 294 |
+
"""
|
| 295 |
+
Contains different options that allows to load only a part of the provided dataset.
|
| 296 |
+
It will **always** take some number of **first** fragments according to the order in which
|
| 297 |
+
they appear in the dataset and this logic will not be depedent on suffling/seed.
|
| 298 |
+
When several limits are provided, each of them will be applied (resulting in the strongest limit).
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
source_column: Optional[str] = None
|
| 302 |
+
""" Column to load as source embeddings"""
|
| 303 |
+
|
| 304 |
+
target_column: Optional[str] = None
|
| 305 |
+
""" Column to load as target embeddings for paired data"""
|
| 306 |
+
|
| 307 |
+
source_quality_column: Optional[str] = None
|
| 308 |
+
source_quality_range: Optional[Any] = None
|
| 309 |
+
|
| 310 |
+
partition_filters: Optional[str] = None
|
| 311 |
+
"""
|
| 312 |
+
Filters that should be applied only on partition columns for fast partition prunning.
|
| 313 |
+
This filters should not be duplicated in `filters` (below) which are used on materialized data.
|
| 314 |
+
To know the partition columns on dataset :
|
| 315 |
+
```python
|
| 316 |
+
>>> pq.ParquetDataset(parquet_path).partitioning.schema.names
|
| 317 |
+
```
|
| 318 |
+
Note that for if `parquet_path` references a single file -> the result above will NOT be correct (returns all columns).
|
| 319 |
+
Note that for a single file case, there should no partition_filters since there're no partitions !!
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
filters: Optional[str] = None
|
| 323 |
+
"""See https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression
|
| 324 |
+
|
| 325 |
+
Some examples :
|
| 326 |
+
|
| 327 |
+
>>> import pyarrow.compute as pc
|
| 328 |
+
>>> import pyarrow as pa
|
| 329 |
+
|
| 330 |
+
>>> filters = (pc.field("data_split") == pc.scalar("train")) & (pc.field("duration") > 7)
|
| 331 |
+
>>> filters = pa.compute.greater(pa.compute.utf8_length(ds.field("lang1_text")), 4)
|
| 332 |
+
>>> filters = pa.compute.less_equal(pa.compute.list_value_length(pa.dataset.field("audio_wav")), 16_000 * 30)
|
| 333 |
+
|
| 334 |
+
Note that all fields used here should be among existing columns in the dataset schema.
|
| 335 |
+
For hydra compatibility, we need to pass this filters as an str expression that'll be passed to `eval(...)`
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
filesystem_expr: Optional[str] = None
|
| 339 |
+
"""
|
| 340 |
+
DEPRECATED : not used any more and will be remove soon
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
filesystem: Optional[Any] = None
|
| 344 |
+
"""
|
| 345 |
+
DEPRECATED: not used any more and will be remove soon
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
split_to_row_groups: Optional[bool] = None
|
| 349 |
+
"""If ``True``, uses Parquet row groups instead of simple partitions which
|
| 350 |
+
are generally smaller. Highly recommended for non-partitioned parquet files."""
|
| 351 |
+
|
| 352 |
+
nb_parallel_fragments: Optional[int] = None
|
| 353 |
+
"""
|
| 354 |
+
This parameter can be dataset specific:
|
| 355 |
+
For dataset with large number of sentences per document (sample),
|
| 356 |
+
it's enough to set `nb_parallel_fragments=2 or 3`.
|
| 357 |
+
For datasets, with smaller number of sentences (~10) and small row_group_size (~200-600),
|
| 358 |
+
`nb_parallel_fragments` could be increase to 10 - 20.
|
| 359 |
+
|
| 360 |
+
The number of Parquet fragments allowed to be read in parallel. Higher
|
| 361 |
+
values will result in higher speeds, better randomization, and higher memory
|
| 362 |
+
footprint. If partition size is rather small compared to the batch size, we
|
| 363 |
+
recommend to increase ``nb_parallel_fragments``.
|
| 364 |
+
|
| 365 |
+
Leaving ``nb_parallel_fragments`` to None will trigger auto-detection based on dataset metadata.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
sharding_in_memory: bool = False
|
| 369 |
+
"""
|
| 370 |
+
This option should be activated for sharding small datasets whose total number of row groups is small
|
| 371 |
+
that makes sharding per row group impossible.
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
def __post_init__(self):
|
| 375 |
+
super().__post_init__()
|
| 376 |
+
|
| 377 |
+
if self.source_sequences is not None:
|
| 378 |
+
if self.source_column is not None:
|
| 379 |
+
logger.warning(
|
| 380 |
+
f"Both `source_sequences` and `source_column` is specified. "
|
| 381 |
+
f"Ignore `source_column` and use default value `{ColumnsNames.source_column.value}`.\n"
|
| 382 |
+
f"(`source_sequences` = {self.source_sequences}, \n"
|
| 383 |
+
f"`source_column` = {self.source_column} )"
|
| 384 |
+
)
|
| 385 |
+
assert self.source_quality_range is None
|
| 386 |
+
self.source_column = ColumnsNames.source_column.value
|
| 387 |
+
|
| 388 |
+
if self.target_sequences is not None:
|
| 389 |
+
if self.target_column is not None:
|
| 390 |
+
logger.warning(
|
| 391 |
+
f"Both `target_sequences` and `target_column` is specified. "
|
| 392 |
+
f"Ignore `target_column` and use default value `{ColumnsNames.target_column.value}`.\n"
|
| 393 |
+
f"(`target_sequences` = {self.target_sequences}, \n"
|
| 394 |
+
f"`target_column` = {self.target_column} )"
|
| 395 |
+
)
|
| 396 |
+
self.target_column = ColumnsNames.target_column.value
|
| 397 |
+
|
| 398 |
+
for col in (self.source_sequences or []) + (self.target_sequences or []):
|
| 399 |
+
if col.sonar_column is None and col.text_value is None:
|
| 400 |
+
assert col.text_column, f"Invalid SonarTextColumn: {col}"
|
| 401 |
+
col.sonar_column = col.text_column + "_sonar_emb"
|
| 402 |
+
|
| 403 |
+
if self.source_quality_range is None:
|
| 404 |
+
self.source_quality_column = None
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
DatasetConfigT = TypeVar("DatasetConfigT", bound=DatasetConfig, contravariant=True)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
@dataclass
|
| 411 |
+
class DataLoadingConfig:
|
| 412 |
+
multiple_dataset_chaining: str = "sample"
|
| 413 |
+
"""
|
| 414 |
+
This option allows to chain several datasets together.
|
| 415 |
+
The chaining can be done in two ways:
|
| 416 |
+
- `sample` : each dataset will be sampled with the provided weight
|
| 417 |
+
- `concat` : datasets will be concatenated together (no weights taken into account)
|
| 418 |
+
- `round_robin`: datasets will be sampled in a round robin fashion (no weights taken into account)
|
| 419 |
+
"""
|
| 420 |
+
batch_size: Optional[int] = None
|
| 421 |
+
"""The output batch size."""
|
| 422 |
+
|
| 423 |
+
order_by_length: bool = True
|
| 424 |
+
"""
|
| 425 |
+
Whether to create the batches with homogeneous tokens length
|
| 426 |
+
for more efficient padding.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
max_tokens: Optional[int] = None
|
| 430 |
+
"""Used with the ``order_by_length`` option to control the total number of
|
| 431 |
+
padded tokens in each batch. Typically, this option is preferred over
|
| 432 |
+
``batch_size`` to reduce the memory footprint.
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
len_to_wrap_long_seq: Optional[int] = None
|
| 436 |
+
"""
|
| 437 |
+
Wrapping a source sequences to the length of `len_to_wrap_long_seq`.
|
| 438 |
+
For instance, for a `len_to_wrap_long_seq=2`
|
| 439 |
+
batch = {
|
| 440 |
+
"source": [["v1", "v2", "v3", "v4", "v5"], ["u1", "u2", "u3"], ["w1"]],
|
| 441 |
+
}
|
| 442 |
+
will be transormed to
|
| 443 |
+
1. if packing is False :
|
| 444 |
+
batch = {
|
| 445 |
+
"source": [['v1', 'v2'], ['v3', 'v4'], ['v5'], ["u1", "u2"], ["u3"], ["w1"]]
|
| 446 |
+
}
|
| 447 |
+
1. if packing is True :
|
| 448 |
+
batch = {
|
| 449 |
+
"source": [['v1', 'v2'], ['v3', 'v4'], ['v5', 'u1'], ["u2", "u3"], ["w1"]]
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
Note: currently only allowed to be used with no "target" provided (unsupervised style) !
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
packing: bool = False
|
| 456 |
+
"""
|
| 457 |
+
If True, all sequential documents (seqs of sentences) will be concated into one big document
|
| 458 |
+
before applying wrapping.
|
| 459 |
+
This will result in all samples (except maybe one) having exactly `len_to_wrap_long_seq` length !
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
wrap_before_affixing: bool = False
|
| 463 |
+
"""
|
| 464 |
+
If True, we will wrap the sequences before adding the source prefix/suffix.
|
| 465 |
+
Recommended when pre-training with packed data i.e len_to_wrap_long_seq not None and packing=True
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
max_sentence_len_in_doc: Optional[int] = None
|
| 469 |
+
"""
|
| 470 |
+
Remove samples (documents) whose `source_text_column` contains at least one sentence of len > `max_sentence_len_in_doc`.
|
| 471 |
+
This operations is done after long sequences wrapping (if applicable).
|
| 472 |
+
Typically values: 100 - 300
|
| 473 |
+
"""
|
| 474 |
+
min_sentence_len_in_doc: Optional[int] = None
|
| 475 |
+
"""
|
| 476 |
+
Remove samples (documents) `source_text_column` contains at least one sentence of len < `min_sentence_len_in_doc`.
|
| 477 |
+
This operations is done after long sequences wrapping (if applicable).
|
| 478 |
+
Typically values: 5 - 15
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
max_sentence_len_in_target_doc: Optional[int] = None
|
| 482 |
+
"""
|
| 483 |
+
same filtering option as above but for `target_text_column`
|
| 484 |
+
"""
|
| 485 |
+
min_sentence_len_in_target_doc: Optional[int] = None
|
| 486 |
+
"""
|
| 487 |
+
same filtering option as above but for `target_text_column`
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
min_length_of_sequences: Optional[int] = 1
|
| 491 |
+
"""
|
| 492 |
+
Remove samples (documents) whose `source_text_column` are scrictly shorter than `min_length_of_sequences`.
|
| 493 |
+
This operations is done after long sequences wrapping (if applicable).
|
| 494 |
+
One can use here the same value as for sequences wrapping
|
| 495 |
+
in order to produce all sequences with the same length.
|
| 496 |
+
"""
|
| 497 |
+
min_length_of_sequences_after_batching: Optional[int] = 1
|
| 498 |
+
"""
|
| 499 |
+
Remove source sequences shorter than `min_length_of_sequences_after_batching`
|
| 500 |
+
This filtering is applied after batching and potentially affixing and wrapping.
|
| 501 |
+
"""
|
| 502 |
+
min_length_of_target_sequences: Optional[int] = 1
|
| 503 |
+
"""
|
| 504 |
+
Same as above applied for `target_text_column`
|
| 505 |
+
"""
|
| 506 |
+
min_length_of_target_sequences_after_batching: Optional[int] = 1
|
| 507 |
+
"""
|
| 508 |
+
Same as above applied for `target_text_column`
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
output_format: ParquetBatchFormat = ParquetBatchFormat.torch
|
| 512 |
+
"""The format to use for output batches."""
|
| 513 |
+
|
| 514 |
+
shuffle: bool = True
|
| 515 |
+
"""If ``True``, shuffles the dataset samples during the iteration. If ``False``
|
| 516 |
+
and ``order_by_length`` is ``None``, the batch samples will be produced in
|
| 517 |
+
natural Parquet dataset reading order."""
|
| 518 |
+
|
| 519 |
+
drop_null: bool = True
|
| 520 |
+
"""If ``True``, drops rows containing any null value."""
|
| 521 |
+
|
| 522 |
+
seed: int = 123
|
| 523 |
+
"""The RNG seed value for deterministic behavior."""
|
| 524 |
+
|
| 525 |
+
nb_epochs: int = 100
|
| 526 |
+
"""
|
| 527 |
+
Number of passes over the data before iterations stop
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
min_batch_size: int = 1
|
| 531 |
+
"""Drops batches whose length is less than ``min_batch_size``"""
|
| 532 |
+
|
| 533 |
+
nb_prefetch: float = 3.0
|
| 534 |
+
"""The number of producer groups (of size `nb_parallel_fragments`) to
|
| 535 |
+
prefetch."""
|
| 536 |
+
|
| 537 |
+
num_parallel_calls: float = 1.5
|
| 538 |
+
"""The number of parallel calls in map operations."""
|
| 539 |
+
|
| 540 |
+
use_threads: bool = False
|
| 541 |
+
"""Whether pyarrow should use its internal threads to read the Parquet file.
|
| 542 |
+
Since we rely on the external parallelism, this param is tuned off by
|
| 543 |
+
default."""
|
| 544 |
+
|
| 545 |
+
ignore_checkpointed_pipeline: bool = False
|
| 546 |
+
"""Whether to ignore the saved datapipeline state or load it when resuming.
|
| 547 |
+
Temporary fix for issues re-loading saved checkpoints"""
|
| 548 |
+
|
| 549 |
+
even_sharding: bool = False
|
| 550 |
+
"""
|
| 551 |
+
This option should be activated ONLY for validataion on small datasets
|
| 552 |
+
to guarantee the perfect data sharding accross the workers.
|
| 553 |
+
Note that in current impmentation, activating `even_sharding` requires `sharding_in_memory=True`
|
| 554 |
+
which will lead to big overhead for big dataset.
|
| 555 |
+
Note also that some fraction of the data may be dropped due to even sharding.
|
| 556 |
+
For big validation datasets, prefer using large `nb_epoch` + limiting `max_validation_iterations`
|
| 557 |
+
instead of using `even_sharding` !
|
| 558 |
+
|
| 559 |
+
For training use case, it should left to False and combined with large number of epochs.
|
| 560 |
+
For evaluation use case, it also should be False since we dont care about the batch syncronization across different workers.
|
| 561 |
+
"""
|
| 562 |
+
max_iteration_steps: Optional[int] = None
|
| 563 |
+
"""
|
| 564 |
+
If not None, it will be used to limit the number of batches produced per each dataset
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
@dataclass
|
| 569 |
+
class ValidationDataLoadingConfig(DataLoadingConfig):
|
| 570 |
+
"""
|
| 571 |
+
This class allows to have some hardcoded parameters for data loading of validation datasets
|
| 572 |
+
"""
|
| 573 |
+
|
| 574 |
+
multiple_dataset_chaining: str = "concat"
|
| 575 |
+
nb_epochs: int = 1
|
| 576 |
+
min_batch_size: int = 1 # we want to keep all samples
|
| 577 |
+
shuffle: bool = False # we dont need the randomness here
|
| 578 |
+
batch_size: Optional[int] = None
|
| 579 |
+
max_tokens: Optional[int] = None
|
| 580 |
+
"""
|
| 581 |
+
Leaving both `max_tokens` and `batch_size` to None will trigger auto-detection based on dataset metadata and distributed training world size.
|
| 582 |
+
to make more or less even distribution of samples across workers. Typically,
|
| 583 |
+
if worker_batch_size = total_batch_size // world_size <= 40, we will use batch_size=worker_batch_size,
|
| 584 |
+
otherwise we will use max_tokens=min(total_tokens_number // world_size, 3000).
|
| 585 |
+
See dataloading:SingleParquetDatasetDataloader::set_validation_params for more details.
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@dataclass
|
| 590 |
+
class EvaluationDataLoadingConfig(DataLoadingConfig):
|
| 591 |
+
"""
|
| 592 |
+
This class allows to have some hardcoded parameters for data loading of evaluation datasets.
|
| 593 |
+
In partitcular, even in distributed setup evaluation should not require workers syncronization.
|
| 594 |
+
Therefore, we set `even_sharding` = False to get the all data samples !
|
| 595 |
+
"""
|
| 596 |
+
|
| 597 |
+
multiple_dataset_chaining: str = "concat"
|
| 598 |
+
nb_epochs: int = 1 # only ONE full pass over the full data !
|
| 599 |
+
min_batch_size: int = 1 # we want to keep all samples
|
| 600 |
+
shuffle: bool = False # we dont need the randomness here
|
| 601 |
+
batch_size: Optional[int] = 10
|
| 602 |
+
max_tokens: Optional[int] = None # this should be ok for most of models
|
| 603 |
+
even_sharding: bool = False # we dont want to lose any sample !
|
| 604 |
+
sharding_in_memory: bool = True # activate sharding by rank and world size
|
| 605 |
+
rank: int = 0
|
| 606 |
+
world_size: int = 1
|
| 607 |
+
max_samples: Optional[int] = None # fmt: skip
|
| 608 |
+
"""evaluate only the first n samples (for debugging)"""
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def setup_fairseq2_extensions() -> None:
|
| 612 |
+
# path where all datacards should be located !
|
| 613 |
+
cards_dir = Path(__file__).parent.parent.joinpath("datacards")
|
| 614 |
+
if cards_dir.exists():
|
| 615 |
+
default_asset_store.add_file_metadata_provider(cards_dir)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
setup_fairseq2_extensions()
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def get_cluster() -> Optional[str]:
|
| 622 |
+
"""Returns the cluster name of the current environment.
|
| 623 |
+
User can implement their own logic to load datasets living in different locations/clusters
|
| 624 |
+
"""
|
| 625 |
+
return "s3"
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _resolve_parquet_path(options: Dict[str, str]) -> Optional[str]:
|
| 629 |
+
cluster_name = get_cluster() or "s3"
|
| 630 |
+
|
| 631 |
+
parquet_path = options.get(cluster_name)
|
| 632 |
+
if parquet_path is None:
|
| 633 |
+
# best effort - taking first element
|
| 634 |
+
parquet_path = next(iter(options.values()))
|
| 635 |
+
|
| 636 |
+
return parquet_path
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def _resolve_filters(
|
| 640 |
+
split: Optional[str],
|
| 641 |
+
card_filter: Optional[str],
|
| 642 |
+
user_filter: Optional[str],
|
| 643 |
+
card_partition_filters: Optional[str],
|
| 644 |
+
user_partition_filters: Optional[str],
|
| 645 |
+
) -> Tuple[Optional[pc.Expression], Optional[pc.Expression]]:
|
| 646 |
+
custom_filters = user_filter or card_filter
|
| 647 |
+
partition_filters = user_partition_filters or card_partition_filters
|
| 648 |
+
|
| 649 |
+
if custom_filters is not None:
|
| 650 |
+
custom_filters = pq.filters_to_expression(eval(custom_filters))
|
| 651 |
+
|
| 652 |
+
if partition_filters is not None:
|
| 653 |
+
partition_filters = pq.filters_to_expression(eval(partition_filters))
|
| 654 |
+
|
| 655 |
+
if split:
|
| 656 |
+
split_filter = pc.equal(pc.field("split"), split)
|
| 657 |
+
if partition_filters is None:
|
| 658 |
+
partition_filters = split_filter
|
| 659 |
+
else:
|
| 660 |
+
partition_filters = pa.compute.if_else(
|
| 661 |
+
split_filter, partition_filters, False
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
return custom_filters, partition_filters
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def _default_resolver(a, b):
|
| 668 |
+
res = a if bool(a) and a is not MISSING else b
|
| 669 |
+
return res
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def get_parquet_config_from_name(
|
| 673 |
+
name: str, config: Optional[ParquetDatasetConfig] = None
|
| 674 |
+
) -> ParquetDatasetConfig:
|
| 675 |
+
"""
|
| 676 |
+
name convention is the following
|
| 677 |
+
- {card_name}={split}:{weight}
|
| 678 |
+
"""
|
| 679 |
+
# parsing name
|
| 680 |
+
pattern = r"^(?P<card_name>[a-zA-Z0-9_]+)=?(?P<split>[a-zA-Z0-9_]*)?:?(?P<weight>\d+(?:\.\d+)?)?$"
|
| 681 |
+
match_ = re.match(pattern, name)
|
| 682 |
+
assert match_ is not None, f"name parsing failed: {name}"
|
| 683 |
+
card_name = match_.group("card_name")
|
| 684 |
+
split = match_.group("split")
|
| 685 |
+
weight = match_.group("weight")
|
| 686 |
+
|
| 687 |
+
if weight:
|
| 688 |
+
weight = float(weight)
|
| 689 |
+
logger.info(
|
| 690 |
+
f"Parsing {name} : card_name={card_name}, split={split}, weight={weight}"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
reload_config = default_asset_store.retrieve_card(card_name)
|
| 694 |
+
cards_metadata: Dict[str, Any] = {**reload_config._metadata}
|
| 695 |
+
|
| 696 |
+
if config is None:
|
| 697 |
+
config = ParquetDatasetConfig(name=card_name, parquet_path="")
|
| 698 |
+
|
| 699 |
+
assert config is not None
|
| 700 |
+
|
| 701 |
+
if isinstance(config, ParquetDatasetConfig):
|
| 702 |
+
config_dict = asdict(config)
|
| 703 |
+
else:
|
| 704 |
+
config_dict = config # type: ignore
|
| 705 |
+
|
| 706 |
+
metadata = {}
|
| 707 |
+
# resolve parquet_path according to the cluster
|
| 708 |
+
for field in fields(ParquetDatasetConfig):
|
| 709 |
+
field_name = field.name
|
| 710 |
+
metadata[field_name] = _default_resolver(
|
| 711 |
+
config_dict.get(field_name), cards_metadata.get(field_name)
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
if isinstance(metadata["source_sequences"], list):
|
| 715 |
+
metadata["source_sequences"] = [
|
| 716 |
+
SonarTextColumn(**item) for item in metadata["source_sequences"]
|
| 717 |
+
]
|
| 718 |
+
|
| 719 |
+
if isinstance(metadata["target_sequences"], list):
|
| 720 |
+
metadata["target_sequences"] = [
|
| 721 |
+
SonarTextColumn(**item) for item in metadata["target_sequences"]
|
| 722 |
+
]
|
| 723 |
+
|
| 724 |
+
metadata["parquet_path"] = _default_resolver(
|
| 725 |
+
config_dict.get("parquet_path"),
|
| 726 |
+
_resolve_parquet_path(cards_metadata["parquet_path"]),
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
metadata["filters"], metadata["partition_filters"] = _resolve_filters(
|
| 730 |
+
split,
|
| 731 |
+
card_filter=cards_metadata.get("filters"),
|
| 732 |
+
user_filter=config_dict.get("filters"),
|
| 733 |
+
card_partition_filters=cards_metadata.get("partition_filters"),
|
| 734 |
+
user_partition_filters=config_dict.get("partition_filters"),
|
| 735 |
+
)
|
| 736 |
+
if weight: # priority from parsed name
|
| 737 |
+
metadata["weight"] = weight
|
| 738 |
+
metadata["name"] = name
|
| 739 |
+
|
| 740 |
+
# to patch nested hydra case !
|
| 741 |
+
if metadata["limit"] is not None and isinstance(metadata["limit"], dict):
|
| 742 |
+
metadata["limit"] = ParquetDatasetLimitOptions(**metadata["limit"])
|
| 743 |
+
|
| 744 |
+
return ParquetDatasetConfig(**metadata)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def create_dataset_config_from_cards(
|
| 748 |
+
config: DatasetConfig,
|
| 749 |
+
) -> DatasetConfig:
|
| 750 |
+
if getattr(config, "name", None) is None:
|
| 751 |
+
return config
|
| 752 |
+
output_config = get_parquet_config_from_name(config.name, config) # type: ignore
|
| 753 |
+
return output_config
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def get_renaming_mappers(configs: Sequence[DatasetConfig]) -> List[dict]:
|
| 757 |
+
used_columns = [x for x in ColumnsNames.__members__ if x != "dataset_name"]
|
| 758 |
+
|
| 759 |
+
pre_mapping = {
|
| 760 |
+
att: [getattr(cc, att) for cc in configs if hasattr(cc, att)]
|
| 761 |
+
for att in used_columns
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
mappers: List[dict] = [{} for _ in configs]
|
| 765 |
+
for att, val in pre_mapping.items():
|
| 766 |
+
if all(x is None for x in val):
|
| 767 |
+
continue
|
| 768 |
+
for i, name in enumerate(val):
|
| 769 |
+
if name is None:
|
| 770 |
+
raise ValueError(
|
| 771 |
+
f"All datasets should provide {att} param, but got {configs[i]}"
|
| 772 |
+
)
|
| 773 |
+
mappers[i][name] = getattr(ColumnsNames, att).value
|
| 774 |
+
return mappers
|
lcm/datasets/dataloader.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import logging
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import pyarrow.compute as pc
|
| 13 |
+
import torch
|
| 14 |
+
from fairseq2.data.data_pipeline import DataPipeline, read_sequence
|
| 15 |
+
from fairseq2.data.text import TextTokenizer
|
| 16 |
+
from fairseq2.gang import FakeGang, Gang
|
| 17 |
+
from fairseq2.models.sequence import SequenceBatch
|
| 18 |
+
from fairseq2.nn.padding import pad_seqs
|
| 19 |
+
from fairseq2.typing import DataType
|
| 20 |
+
from fairseq2.utils.state import Stateful
|
| 21 |
+
from sonar.models.sonar_text import load_sonar_tokenizer
|
| 22 |
+
|
| 23 |
+
from lcm.datasets.base import DataLoader
|
| 24 |
+
from lcm.datasets.batch import LCMInput
|
| 25 |
+
from lcm.datasets.configs import (
|
| 26 |
+
ColumnsNames,
|
| 27 |
+
DataLoadingConfig,
|
| 28 |
+
ParquetDatasetConfig,
|
| 29 |
+
ParquetDatasetLimitOptions,
|
| 30 |
+
SonarDecoderConfig,
|
| 31 |
+
)
|
| 32 |
+
from lcm.datasets.utils import move_eos_to_the_end
|
| 33 |
+
from lcm.utils.common import set_mkl_num_threads
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def truncate_sequence(tokens: torch.Tensor, max_len: int = 512) -> torch.Tensor:
|
| 39 |
+
if len(tokens) > max_len:
|
| 40 |
+
return tokens[:max_len]
|
| 41 |
+
return tokens
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LCMDataLoader(DataLoader[LCMInput, ParquetDatasetConfig], Stateful):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
data_config: DataLoadingConfig,
|
| 48 |
+
datasets: Sequence[ParquetDatasetConfig],
|
| 49 |
+
dtype: DataType = torch.float16,
|
| 50 |
+
use_decoder_backprop: bool = False,
|
| 51 |
+
max_subword_length: int = 64,
|
| 52 |
+
gang: Gang = None,
|
| 53 |
+
sonar_decoder_config: Optional[SonarDecoderConfig] = None,
|
| 54 |
+
) -> None:
|
| 55 |
+
gang = gang or FakeGang()
|
| 56 |
+
|
| 57 |
+
super().__init__(
|
| 58 |
+
data_config=data_config,
|
| 59 |
+
datasets=datasets,
|
| 60 |
+
dtype=dtype,
|
| 61 |
+
gang=gang,
|
| 62 |
+
)
|
| 63 |
+
set_mkl_num_threads()
|
| 64 |
+
|
| 65 |
+
self.use_decoder_backprop = use_decoder_backprop
|
| 66 |
+
self.sonar_tokenizer: Optional[TextTokenizer] = None
|
| 67 |
+
self.max_subword_length = max_subword_length
|
| 68 |
+
if sonar_decoder_config is not None:
|
| 69 |
+
self.setup_sonar_decoder_tokenizer(config=sonar_decoder_config)
|
| 70 |
+
self._dummy_example: Optional[LCMInput] = None
|
| 71 |
+
|
| 72 |
+
def setup_sonar_decoder_tokenizer(
|
| 73 |
+
self,
|
| 74 |
+
config: SonarDecoderConfig,
|
| 75 |
+
):
|
| 76 |
+
if self.use_decoder_backprop:
|
| 77 |
+
# The tokenizer
|
| 78 |
+
self.tokenizer = load_sonar_tokenizer(config.tokenizer, progress=False)
|
| 79 |
+
# Target text encoder
|
| 80 |
+
self.sonar_tokenizer = self.tokenizer.create_encoder(
|
| 81 |
+
task="translation",
|
| 82 |
+
lang=config.lang,
|
| 83 |
+
mode="target",
|
| 84 |
+
device=self.gang.device,
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
self.sonar_tokenizer = None
|
| 88 |
+
|
| 89 |
+
def _prepare_subword_tokens(
|
| 90 |
+
self, batch: Dict[str, Any]
|
| 91 |
+
) -> Tuple[Optional[SequenceBatch], Optional[SequenceBatch]]:
|
| 92 |
+
"""
|
| 93 |
+
Given a batch of paragraphs/documents,
|
| 94 |
+
prepare a batch of sentences (flattened) tokenized at the subword-level
|
| 95 |
+
to feed to the SONAR decoder (a standard token-level decoder)
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
batch: attributes of a batch from the dataset.
|
| 99 |
+
A batch is M documents/paragraphs each spanning
|
| 100 |
+
a variable number of sentences {N_1, ..., N_M}.
|
| 101 |
+
|
| 102 |
+
E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
|
| 103 |
+
...[sent^M_1, ... sent^M_{N_M}],
|
| 104 |
+
'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
|
| 105 |
+
where D is the sonar embedding dimension.
|
| 106 |
+
Returns:
|
| 107 |
+
Toeknized sentences (subword-level) in (\sum_i=1^M N_i, max_len)
|
| 108 |
+
where max_len is min(self.max_subword_length, max length of the sentences in the batch)
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
if not self.use_decoder_backprop:
|
| 113 |
+
return None, None
|
| 114 |
+
|
| 115 |
+
# flatten the sentences from different documents/paragraphs
|
| 116 |
+
flattened_source_text = (
|
| 117 |
+
pc.list_flatten(batch[ColumnsNames.source_text_column.value])
|
| 118 |
+
.to_pandas()
|
| 119 |
+
.values
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
pipeline: DataPipeline = (
|
| 123 |
+
read_sequence(flattened_source_text)
|
| 124 |
+
.map(
|
| 125 |
+
[
|
| 126 |
+
self.sonar_tokenizer, # type: ignore
|
| 127 |
+
partial(truncate_sequence, max_len=self.max_subword_length),
|
| 128 |
+
],
|
| 129 |
+
num_parallel_calls=int(max(8 * self.data_config.num_parallel_calls, 1)),
|
| 130 |
+
)
|
| 131 |
+
.and_return(max_num_warnings=4)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
tokens_seqs, tokens_padding_mask = pad_seqs(list(pipeline)) # type: ignore
|
| 135 |
+
prefix_batch = SequenceBatch(tokens_seqs, tokens_padding_mask)
|
| 136 |
+
# TODO: instead of moving the EOS around, make the tokenizer append at the tokenization.
|
| 137 |
+
target_batch = move_eos_to_the_end(
|
| 138 |
+
prefix_batch,
|
| 139 |
+
eos_token_id=self.tokenizer.vocab_info.eos_idx,
|
| 140 |
+
pad_token_id=self.tokenizer.vocab_info.pad_idx,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return prefix_batch, target_batch
|
| 144 |
+
|
| 145 |
+
def _tokenize_batch(self, batch: Dict[str, Any]) -> LCMInput:
|
| 146 |
+
"""
|
| 147 |
+
Given a batch of documents,
|
| 148 |
+
prepare a batch of input features for the LCM
|
| 149 |
+
This step is to simply fetch the right column for source/target & source text
|
| 150 |
+
and convert torch NestedTensors to list of tensors
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
batch: attributes of a batch from the dataset.
|
| 154 |
+
A batch is M documents each spanning
|
| 155 |
+
a variable number of sentences {N_1, ..., N_M}.
|
| 156 |
+
|
| 157 |
+
E.g., {'text_sentences': [[sent^1_1, ...sent^1_{N_1}],
|
| 158 |
+
...[sent^M_1, ... sent^M_{N_M}],
|
| 159 |
+
'text_sentences_sonar_emb': [X^1 in (N_1, D), ... X^M in (N_M, D)]}
|
| 160 |
+
where D is the sonar embedding dimension.
|
| 161 |
+
Returns:
|
| 162 |
+
LCMInput(
|
| 163 |
+
source: SONAR embeddings of the source text
|
| 164 |
+
i.e [X^1 in (N_1, D), ... X^M in (N_M, D)]
|
| 165 |
+
target: If supervised data: SONAR embeddings of the source text
|
| 166 |
+
tokens: Tokenized flattened sentences for the SONAR decoder (see `_prepare_subword_tokens`)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# Prepare sentence-wise subword tokens if needed:
|
| 172 |
+
tokens, target_tokens = self._prepare_subword_tokens(batch)
|
| 173 |
+
|
| 174 |
+
# Load target embeddings if requested and to propagate all other embeddings
|
| 175 |
+
|
| 176 |
+
possible_emb_columns = {
|
| 177 |
+
"source": ColumnsNames.source_column,
|
| 178 |
+
"target": ColumnsNames.target_column,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
outputs = {
|
| 182 |
+
"tokens": tokens,
|
| 183 |
+
"target_tokens": target_tokens,
|
| 184 |
+
"name": batch[ColumnsNames.dataset_name.value],
|
| 185 |
+
"batch": batch,
|
| 186 |
+
}
|
| 187 |
+
for key, col in possible_emb_columns.items():
|
| 188 |
+
col_name = col.value
|
| 189 |
+
if col_name in batch:
|
| 190 |
+
dtype = self.dtype if "_length" not in key else torch.int64
|
| 191 |
+
embs = [x.to(self.gang.device).to(dtype) for x in batch[col_name]]
|
| 192 |
+
# Special case when some embeddings are not shaped as (T, D) e.g., XLMC's answer columns
|
| 193 |
+
if embs[0].dim() == 1 and "_length" not in key:
|
| 194 |
+
embs = [t.unsqueeze(0) for t in embs]
|
| 195 |
+
else:
|
| 196 |
+
embs = None
|
| 197 |
+
outputs[key] = embs
|
| 198 |
+
assert outputs["source"] is not None, (
|
| 199 |
+
"LCMDataLoader requires `source` sequences to be present in batches"
|
| 200 |
+
)
|
| 201 |
+
return LCMInput(**outputs)
|
| 202 |
+
|
| 203 |
+
def iterate_batches(self) -> Iterator[LCMInput]:
|
| 204 |
+
yield from map(self._tokenize_batch, self.pipeline)
|
| 205 |
+
|
| 206 |
+
def iterate_dummy_batches(self) -> Iterator[LCMInput]:
|
| 207 |
+
"""
|
| 208 |
+
it's needed to simulate the data that follows the strucutre of self.pipeline (by always returning the same element).
|
| 209 |
+
It can be used only for fast forward pass (to avoid uneven sharding multi-gpus training).
|
| 210 |
+
"""
|
| 211 |
+
if self._dummy_example is None:
|
| 212 |
+
# patching the params to get less data with less cost
|
| 213 |
+
limited_datasets = deepcopy(self.datasets)
|
| 214 |
+
for ds_conf in limited_datasets:
|
| 215 |
+
assert isinstance(ds_conf, ParquetDatasetConfig)
|
| 216 |
+
ds_conf.limit = ParquetDatasetLimitOptions(nb_fragments=1)
|
| 217 |
+
|
| 218 |
+
# Copy the true data config and reduce the batch size.
|
| 219 |
+
# When wrapping data, we want to also wrap the dummy batches
|
| 220 |
+
# to not exceed model max_length
|
| 221 |
+
dummy_dataloading_config = deepcopy(self.data_config)
|
| 222 |
+
dummy_dataloading_config.batch_size = 1
|
| 223 |
+
|
| 224 |
+
self._dummy_example = self._tokenize_batch(
|
| 225 |
+
next(
|
| 226 |
+
iter(
|
| 227 |
+
self.builder_func(
|
| 228 |
+
limited_datasets, dummy_dataloading_config, 0, 1
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
gc.collect()
|
| 234 |
+
|
| 235 |
+
while True:
|
| 236 |
+
yield self._dummy_example
|
| 237 |
+
|
| 238 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 239 |
+
logger.info("Getting the data pipeline state ...")
|
| 240 |
+
state = self.pipeline.state_dict(strict=False)
|
| 241 |
+
return state
|
| 242 |
+
|
| 243 |
+
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
|
| 244 |
+
if state_dict is not None:
|
| 245 |
+
assert self.pipeline is not None
|
| 246 |
+
if self.data_config.ignore_checkpointed_pipeline:
|
| 247 |
+
logger.warning("Ignoring existing dataloader state")
|
| 248 |
+
else:
|
| 249 |
+
try:
|
| 250 |
+
self.pipeline.load_state_dict(state_dict)
|
| 251 |
+
logger.info(f"Reloaded datapipeline state: {str(state_dict)[:400]}")
|
| 252 |
+
except ValueError:
|
| 253 |
+
logger.warning(
|
| 254 |
+
f"Failed to load dataloader state: {str(state_dict)[:400]}"
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
# retro-compatibility
|
| 258 |
+
logger.warning(f"Attempt to restore a dataloader {self} with empty state")
|
lcm/datasets/dataloading.py
ADDED
|
@@ -0,0 +1,1109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from dataclasses import asdict, dataclass
|
| 9 |
+
from functools import lru_cache, partial
|
| 10 |
+
from typing import Any, Generator, List, Optional, Sequence
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pyarrow as pa
|
| 14 |
+
import pyarrow.compute as pc
|
| 15 |
+
import pyarrow.parquet as pq
|
| 16 |
+
from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder
|
| 17 |
+
from fairseq2.data.parquet.tools import BatchOutputType, apply_filter, concat_table
|
| 18 |
+
from pyarrow.dataset import get_partition_keys
|
| 19 |
+
from stopes.utils.arrow_utils import (
|
| 20 |
+
explode_table_with_fixed_length,
|
| 21 |
+
explode_table_with_max_length,
|
| 22 |
+
is_list_like,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from lcm.datasets.configs import (
|
| 26 |
+
DataLoadingConfig,
|
| 27 |
+
ParquetBatchFormat,
|
| 28 |
+
ParquetDatasetConfig,
|
| 29 |
+
ValidationDataLoadingConfig,
|
| 30 |
+
get_renaming_mappers,
|
| 31 |
+
)
|
| 32 |
+
from lcm.datasets.parquet_utils import (
|
| 33 |
+
build_batching_loop_over_one_table,
|
| 34 |
+
define_parquet_dataset,
|
| 35 |
+
filter_document_by_quality,
|
| 36 |
+
filter_long_short_sentence_document,
|
| 37 |
+
filter_table_with_different_lengths,
|
| 38 |
+
get_row_group_level_metadata,
|
| 39 |
+
materialize_sequence,
|
| 40 |
+
prefix_and_suffix_one_list_column,
|
| 41 |
+
prepare_suffix_prefix_embeddings,
|
| 42 |
+
pyarrow_table_to_torch_dict,
|
| 43 |
+
renaming,
|
| 44 |
+
shuffle_table,
|
| 45 |
+
stream_parquet_fragments,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
PA_NB_CPU = 4
|
| 51 |
+
pa.set_cpu_count(PA_NB_CPU)
|
| 52 |
+
pa.set_io_thread_count(PA_NB_CPU)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def return_none_on_failure(func):
|
| 56 |
+
def wrapper(*args, **kwargs):
|
| 57 |
+
try:
|
| 58 |
+
return func(*args, **kwargs)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"An error occurred: {e}")
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
return wrapper
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class GlobalPQStats:
|
| 68 |
+
min_number_of_fragment: int
|
| 69 |
+
mean_fragment_length: float
|
| 70 |
+
mean_fragment_number_of_tokens: Optional[float] = None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class SingleParquetDatasetDataloader:
|
| 74 |
+
_pq_ds: Optional[pq.ParquetDataset] = None
|
| 75 |
+
proxy_number_of_fragments: int
|
| 76 |
+
basic_stats: GlobalPQStats
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self, dataset_config: ParquetDatasetConfig, loading_config: DataLoadingConfig
|
| 80 |
+
):
|
| 81 |
+
self.dataset_config = deepcopy(dataset_config)
|
| 82 |
+
self.loading_config = deepcopy(loading_config)
|
| 83 |
+
self.config_post_init()
|
| 84 |
+
nb_parallel_fragments = self.dataset_config.nb_parallel_fragments
|
| 85 |
+
assert isinstance(nb_parallel_fragments, int)
|
| 86 |
+
self.nb_parallel_fragments: int = nb_parallel_fragments
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def is_validation(self) -> bool:
|
| 90 |
+
return isinstance(self.loading_config, ValidationDataLoadingConfig)
|
| 91 |
+
|
| 92 |
+
def head(self, top=5):
|
| 93 |
+
return self.dataset._dataset.head(top)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def dataset(self) -> pq.ParquetDataset:
|
| 97 |
+
if self._pq_ds is None:
|
| 98 |
+
self._pq_ds = self.get_dataset()
|
| 99 |
+
|
| 100 |
+
return self._pq_ds
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def full_schema(self) -> pa.Schema:
|
| 104 |
+
return self.dataset.schema
|
| 105 |
+
|
| 106 |
+
def _warn_filters_usage(self, pq_ds: pq.ParquetDataset) -> None:
|
| 107 |
+
partition_filters = self.dataset_config.partition_filters
|
| 108 |
+
|
| 109 |
+
frags = pq_ds.fragments
|
| 110 |
+
if len(frags) == 0:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
f"Working on empty dataset, probably due to wrong `partition_filters` definition : {partition_filters}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
partition_columns = list(
|
| 116 |
+
get_partition_keys(frags[0].partition_expression).keys()
|
| 117 |
+
)
|
| 118 |
+
if not partition_columns and partition_filters is not None:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"Partition filters {partition_filters} is set but dataset has NO partition columns"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if partition_columns and partition_filters is not None:
|
| 124 |
+
expression_candidates = [
|
| 125 |
+
x for x in partition_columns if x in str(partition_filters)
|
| 126 |
+
]
|
| 127 |
+
if len(expression_candidates) == 0:
|
| 128 |
+
logger.warning(
|
| 129 |
+
f"Partition filters are NOT compatible with partition columns, got: "
|
| 130 |
+
f"partition_filters={partition_filters} and partition_columns={partition_columns}"
|
| 131 |
+
)
|
| 132 |
+
filters = self.dataset_config.filters
|
| 133 |
+
if partition_columns and filters is not None:
|
| 134 |
+
expression_candidates = [x for x in partition_columns if x in str(filters)]
|
| 135 |
+
if len(expression_candidates) > 0:
|
| 136 |
+
logger.warning(
|
| 137 |
+
f"Partitionning columns {expression_candidates} are used as `filters` {filters}. ",
|
| 138 |
+
"You may want to use them in `partition_filters` instead",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def get_dataset(self) -> pq.ParquetDataset:
|
| 142 |
+
if isinstance(self.dataset_config.filters, str):
|
| 143 |
+
self.dataset_config.filters = pq.filters_to_expression(
|
| 144 |
+
eval(self.dataset_config.filters)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if isinstance(self.dataset_config.partition_filters, str):
|
| 148 |
+
self.dataset_config.partition_filters = pq.filters_to_expression(
|
| 149 |
+
eval(self.dataset_config.partition_filters)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
pq_ds = define_parquet_dataset(
|
| 153 |
+
str(self.dataset_config.parquet_path), self.dataset_config.partition_filters
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
self._warn_filters_usage(pq_ds)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.info(f"getting exception during filters examination : {e}")
|
| 160 |
+
|
| 161 |
+
return pq_ds
|
| 162 |
+
|
| 163 |
+
def set_validation_params(
|
| 164 |
+
self,
|
| 165 |
+
world_size: int,
|
| 166 |
+
default_max_tokens: int = 3000,
|
| 167 |
+
default_batch_size: int = 40,
|
| 168 |
+
) -> None:
|
| 169 |
+
if not (
|
| 170 |
+
self.loading_config.batch_size is None
|
| 171 |
+
and self.loading_config.max_tokens is None
|
| 172 |
+
):
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
total_batch_size = int(
|
| 176 |
+
self.basic_stats.min_number_of_fragment
|
| 177 |
+
* self.basic_stats.mean_fragment_length
|
| 178 |
+
)
|
| 179 |
+
batch_size = total_batch_size // world_size + int(
|
| 180 |
+
total_batch_size % world_size != 0
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# for small datasets we can set `batch_size`
|
| 184 |
+
if (
|
| 185 |
+
batch_size <= default_batch_size
|
| 186 |
+
or self.basic_stats.mean_fragment_number_of_tokens is None
|
| 187 |
+
):
|
| 188 |
+
self.loading_config.batch_size = min(batch_size, default_batch_size)
|
| 189 |
+
self.loading_config.max_tokens = None
|
| 190 |
+
else:
|
| 191 |
+
# for bigger dataset, let's use `max_tokens`
|
| 192 |
+
self.loading_config.batch_size = None
|
| 193 |
+
total_tokens_number = int(
|
| 194 |
+
self.basic_stats.min_number_of_fragment
|
| 195 |
+
* self.basic_stats.mean_fragment_number_of_tokens
|
| 196 |
+
)
|
| 197 |
+
self.loading_config.max_tokens = min(
|
| 198 |
+
max(total_tokens_number // world_size, 1), default_max_tokens
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def build_dataload_pipeline(
|
| 202 |
+
self, rank: int = 0, world_size: int = 1
|
| 203 |
+
) -> DataPipelineBuilder:
|
| 204 |
+
if world_size > 1:
|
| 205 |
+
assert self.loading_config.seed is not None, (
|
| 206 |
+
"for distributed training with `world_size` > 1, `seed` should be set !"
|
| 207 |
+
)
|
| 208 |
+
if self.is_validation:
|
| 209 |
+
self.set_validation_params(world_size)
|
| 210 |
+
|
| 211 |
+
# to propagate sharding_in_memory
|
| 212 |
+
if not self.dataset_config.sharding_in_memory:
|
| 213 |
+
sharding_in_memory = (
|
| 214 |
+
self.loading_config.nb_epochs * self.proxy_number_of_fragments
|
| 215 |
+
< 2 * world_size
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
sharding_in_memory = self.dataset_config.sharding_in_memory
|
| 219 |
+
if self.loading_config.even_sharding:
|
| 220 |
+
sharding_in_memory = True
|
| 221 |
+
|
| 222 |
+
if sharding_in_memory:
|
| 223 |
+
logger.info("Activating sharding_in_memory")
|
| 224 |
+
|
| 225 |
+
self.random_state = np.random.RandomState(
|
| 226 |
+
self._get_inner_seed(rank, sharding_in_memory)
|
| 227 |
+
)
|
| 228 |
+
pipeline = self.get_fragments_pipeline()
|
| 229 |
+
|
| 230 |
+
if not sharding_in_memory:
|
| 231 |
+
pipeline = pipeline.shard(
|
| 232 |
+
shard_idx=rank,
|
| 233 |
+
num_shards=world_size,
|
| 234 |
+
allow_uneven=not self.loading_config.even_sharding,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
pipeline = self.add_basic_fragment_loading_pipeline(pipeline)
|
| 238 |
+
|
| 239 |
+
pipeline = self.create_on_the_fly_columns(pipeline)
|
| 240 |
+
pipeline = self.filter_by_aligned_length(pipeline)
|
| 241 |
+
|
| 242 |
+
# If we want to wrap before adding affixes
|
| 243 |
+
if self.loading_config.wrap_before_affixing:
|
| 244 |
+
pipeline = self.add_wrapping_to_max_length_pipeline(pipeline)
|
| 245 |
+
|
| 246 |
+
# Filtering
|
| 247 |
+
pipeline = self.add_quality_score_filters(pipeline)
|
| 248 |
+
pipeline = self.add_min_sentence_number_in_doc_filter(
|
| 249 |
+
pipeline,
|
| 250 |
+
min_source_length=self.loading_config.min_length_of_sequences,
|
| 251 |
+
min_target_length=self.loading_config.min_length_of_target_sequences,
|
| 252 |
+
)
|
| 253 |
+
pipeline = self.add_min_max_sentence_len_in_doc_filter(pipeline)
|
| 254 |
+
|
| 255 |
+
# Affix
|
| 256 |
+
pipeline = self._add_source_target_affixes_to_pipeline(pipeline)
|
| 257 |
+
|
| 258 |
+
def cost_fn(table) -> float:
|
| 259 |
+
cost = 0
|
| 260 |
+
for name in [
|
| 261 |
+
self.dataset_config.source_column,
|
| 262 |
+
self.dataset_config.target_column,
|
| 263 |
+
]:
|
| 264 |
+
if name is not None:
|
| 265 |
+
col = table[name]
|
| 266 |
+
if is_list_like(col):
|
| 267 |
+
cost += pa.compute.list_value_length(col).to_numpy().sum()
|
| 268 |
+
else:
|
| 269 |
+
# we should not be there, but let take batch_size as a proxy
|
| 270 |
+
cost += len(col)
|
| 271 |
+
return cost
|
| 272 |
+
|
| 273 |
+
pipeline = pipeline.dynamic_bucket(
|
| 274 |
+
self._shuffling_tokens_size,
|
| 275 |
+
cost_fn,
|
| 276 |
+
min_num_examples=self.nb_parallel_fragments,
|
| 277 |
+
max_num_examples=100, # max number of small fragements
|
| 278 |
+
drop_remainder=False,
|
| 279 |
+
)
|
| 280 |
+
pipeline = pipeline.map(concat_table, num_parallel_calls=1)
|
| 281 |
+
|
| 282 |
+
# wrap documents after affixing
|
| 283 |
+
if not self.loading_config.wrap_before_affixing:
|
| 284 |
+
# Note that packing with proper attention masks and position codes requires
|
| 285 |
+
# document indices that cover all sentences. Currently this can only come from affixing before wrapping.
|
| 286 |
+
# Adding affixes after wrapping will require annexing these affixes to edge sentences which is not intuitive.
|
| 287 |
+
if self.loading_config.shuffle:
|
| 288 |
+
pipeline = pipeline.map(
|
| 289 |
+
partial(shuffle_table, random_state=self.random_state),
|
| 290 |
+
num_parallel_calls=1,
|
| 291 |
+
)
|
| 292 |
+
pipeline = self.add_wrapping_to_max_length_pipeline(pipeline)
|
| 293 |
+
|
| 294 |
+
# batch with batch_size or max_tokens
|
| 295 |
+
pipeline = self.add_inner_pipeline(pipeline)
|
| 296 |
+
|
| 297 |
+
# Filter once again after wrapping and batching to remove batches with few number sentences
|
| 298 |
+
pipeline = self.add_min_sentence_number_in_doc_filter(
|
| 299 |
+
pipeline,
|
| 300 |
+
min_source_length=self.loading_config.min_length_of_sequences_after_batching,
|
| 301 |
+
min_target_length=self.loading_config.min_length_of_target_sequences_after_batching,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Remove batch sizes with a size smaller than min_batch_size (default=1)
|
| 305 |
+
pipeline = pipeline.filter(
|
| 306 |
+
lambda table: bool(len(table) >= self.loading_config.min_batch_size)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if sharding_in_memory:
|
| 310 |
+
pipeline = pipeline.shard(
|
| 311 |
+
shard_idx=rank,
|
| 312 |
+
num_shards=world_size,
|
| 313 |
+
allow_uneven=not self.loading_config.even_sharding,
|
| 314 |
+
)
|
| 315 |
+
if self.loading_config.max_iteration_steps is not None:
|
| 316 |
+
pipeline = pipeline.take(self.loading_config.max_iteration_steps)
|
| 317 |
+
pipeline = self.add_format_conversion(pipeline)
|
| 318 |
+
return pipeline
|
| 319 |
+
|
| 320 |
+
def create_on_the_fly_columns(
|
| 321 |
+
self, pipeline: DataPipelineBuilder
|
| 322 |
+
) -> DataPipelineBuilder:
|
| 323 |
+
if self.dataset_config.source_sequences is not None:
|
| 324 |
+
assert self.dataset_config.source_column is not None, (
|
| 325 |
+
f"Expected a source_column - found {self.dataset_config.source_column}"
|
| 326 |
+
)
|
| 327 |
+
assert self.dataset_config.source_text_column is not None, (
|
| 328 |
+
f"Expected a source_text_column - found {self.dataset_config.source_text_column}"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
pipeline = pipeline.map(
|
| 332 |
+
partial(
|
| 333 |
+
materialize_sequence,
|
| 334 |
+
column_sequence=self.dataset_config.source_sequences,
|
| 335 |
+
vector_name=self.dataset_config.source_column,
|
| 336 |
+
text_name=self.dataset_config.source_text_column,
|
| 337 |
+
),
|
| 338 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 339 |
+
)
|
| 340 |
+
if self.dataset_config.target_sequences is not None:
|
| 341 |
+
assert self.dataset_config.target_column is not None, (
|
| 342 |
+
f"Expected a target_column, found {self.dataset_config.target_column}"
|
| 343 |
+
)
|
| 344 |
+
assert self.dataset_config.target_text_column is not None, (
|
| 345 |
+
f"Expected a target_text_columns, found {self.dataset_config.target_text_column}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
pipeline = pipeline.map(
|
| 349 |
+
partial(
|
| 350 |
+
materialize_sequence,
|
| 351 |
+
column_sequence=self.dataset_config.target_sequences,
|
| 352 |
+
vector_name=self.dataset_config.target_column,
|
| 353 |
+
text_name=self.dataset_config.target_text_column,
|
| 354 |
+
),
|
| 355 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
columns_to_drop = list(
|
| 359 |
+
set(self._get_sequences_columns()) - set(self.extra_required_columns)
|
| 360 |
+
)
|
| 361 |
+
if columns_to_drop:
|
| 362 |
+
pipeline = pipeline.map(lambda table: table.drop(columns_to_drop))
|
| 363 |
+
|
| 364 |
+
return pipeline
|
| 365 |
+
|
| 366 |
+
def _add_source_target_affixes_to_pipeline(self, pipeline) -> DataPipelineBuilder:
|
| 367 |
+
# prefixing/suffixing before wrapping/packing
|
| 368 |
+
ps_vals = self._get_suffix_prefix_vector()
|
| 369 |
+
pipeline = self.add_prefix_suffix_pipeline(
|
| 370 |
+
pipeline,
|
| 371 |
+
self.dataset_config.source_column,
|
| 372 |
+
ps_vals["source_prefix_vector"],
|
| 373 |
+
ps_vals["source_suffix_vector"],
|
| 374 |
+
)
|
| 375 |
+
pipeline = self.add_prefix_suffix_pipeline(
|
| 376 |
+
pipeline,
|
| 377 |
+
self.dataset_config.source_text_column,
|
| 378 |
+
ps_vals["source_prefix_sentences"],
|
| 379 |
+
ps_vals["source_suffix_sentences"],
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
pipeline = self.add_prefix_suffix_pipeline(
|
| 383 |
+
pipeline,
|
| 384 |
+
self.dataset_config.source_quality_column,
|
| 385 |
+
(
|
| 386 |
+
pa.array([None])
|
| 387 |
+
if self.dataset_config.source_prefix_text
|
| 388 |
+
else pa.array([])
|
| 389 |
+
),
|
| 390 |
+
(
|
| 391 |
+
pa.array([None])
|
| 392 |
+
if self.dataset_config.source_suffix_text
|
| 393 |
+
else pa.array([])
|
| 394 |
+
),
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
pipeline = self.add_prefix_suffix_pipeline(
|
| 398 |
+
pipeline,
|
| 399 |
+
self.dataset_config.target_column,
|
| 400 |
+
ps_vals["target_prefix_vector"],
|
| 401 |
+
ps_vals["target_suffix_vector"],
|
| 402 |
+
)
|
| 403 |
+
pipeline = self.add_prefix_suffix_pipeline(
|
| 404 |
+
pipeline,
|
| 405 |
+
self.dataset_config.target_text_column,
|
| 406 |
+
ps_vals["target_prefix_sentences"],
|
| 407 |
+
ps_vals["target_suffix_sentences"],
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
return pipeline
|
| 411 |
+
|
| 412 |
+
def _num_parallel_call(self, x: float) -> int:
|
| 413 |
+
return int(max(self.loading_config.num_parallel_calls * x, 1))
|
| 414 |
+
|
| 415 |
+
def _nb_prefetch(self, x: float) -> int:
|
| 416 |
+
return int(max(self.loading_config.nb_prefetch * x, 0))
|
| 417 |
+
|
| 418 |
+
def config_post_init(self) -> None:
|
| 419 |
+
if getattr(self.loading_config, "len_to_wrap_long_seq", None):
|
| 420 |
+
if (
|
| 421 |
+
self.dataset_config.target_column
|
| 422 |
+
or self.dataset_config.target_text_column
|
| 423 |
+
):
|
| 424 |
+
raise ValueError(
|
| 425 |
+
"Using `len_to_wrap_long_seq` is not supported for suppervised training"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if self.loading_config.even_sharding:
|
| 429 |
+
assert self.loading_config.seed is not None, (
|
| 430 |
+
"`even_sharding` sharding requires to seed to be set"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if self.loading_config.max_tokens == 0:
|
| 434 |
+
self.loading_config.max_tokens = None
|
| 435 |
+
# setting max_tokens=0 turns off this option (argparser won't accept None directly)
|
| 436 |
+
|
| 437 |
+
if (self.loading_config.batch_size is None) == (
|
| 438 |
+
self.loading_config.max_tokens is None
|
| 439 |
+
) and (not self.is_validation or self.loading_config.max_tokens is not None):
|
| 440 |
+
raise ValueError(
|
| 441 |
+
f"Need to provide either `batch_size` or `max_tokens` - \
|
| 442 |
+
Received batch_size={self.loading_config.batch_size} \
|
| 443 |
+
and max_tokens={self.loading_config.max_tokens}"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if self.loading_config.max_tokens and not self.dataset_config.source_column:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"Cannot batch based on `max_tokens` when `source_column` is not specified, "
|
| 449 |
+
"please use `batch_size` instead."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.dataset_config.split_to_row_groups = (
|
| 453 |
+
self.dataset_config.split_to_row_groups
|
| 454 |
+
if self.dataset_config.split_to_row_groups is not None
|
| 455 |
+
else True
|
| 456 |
+
)
|
| 457 |
+
self.extra_required_columns = self.dataset_config.columns or []
|
| 458 |
+
self.dataset_config.override_attr("columns", self._get_minimal_columns())
|
| 459 |
+
logger.info(f"Following columns will be loaded: {self.dataset_config.columns}")
|
| 460 |
+
|
| 461 |
+
self.basic_stats = self.compute_stats()
|
| 462 |
+
|
| 463 |
+
self._shuffling_tokens_size = self._get_shuffling_tokens_size(self.basic_stats)
|
| 464 |
+
logger.info(
|
| 465 |
+
f"Bucketing will require at least: {self._shuffling_tokens_size} of tokens (source + target)"
|
| 466 |
+
)
|
| 467 |
+
logger.info(f"Dataset stats: {asdict(self.basic_stats)}")
|
| 468 |
+
|
| 469 |
+
self.proxy_number_of_fragments = self.basic_stats.min_number_of_fragment
|
| 470 |
+
if self.dataset_config.nb_parallel_fragments is None:
|
| 471 |
+
self.dataset_config.nb_parallel_fragments = (
|
| 472 |
+
self._find_nb_parallel_fragments(self.basic_stats)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
logger.info(f"Dataset Config: {self.dataset_config}")
|
| 476 |
+
logger.info(f"Using Loading Config: {self.loading_config}")
|
| 477 |
+
|
| 478 |
+
def _get_shuffling_tokens_size(self, basic_stats) -> int:
|
| 479 |
+
"""
|
| 480 |
+
`_shuffling_tokens_size` is used in dynamic bucketing to determine how many small parquet tables
|
| 481 |
+
(which are loaded raw parquet fragments that were potentially filtered on-the-fly) will be merged together :
|
| 482 |
+
we'll get a such number of consecutive parquet tables so that their total number of tokens (sentences)
|
| 483 |
+
will be greater than `_shuffling_tokens_size`.
|
| 484 |
+
It's called "shuffling" because all merged documents (from different tables) will be permuated together (if `shuffle=True`)
|
| 485 |
+
before being returned as final small batches (of required shape or volume).
|
| 486 |
+
|
| 487 |
+
The formula behind `_shuffling_tokens_size` is the following:
|
| 488 |
+
- If we use `max_tokens` in config, we want to have a least _shuffling_tokens_size = 4 * max_tokens,
|
| 489 |
+
so that at least 4 full batch will be formed next. It's good for shuffling and to avoid having "remainders" too often.
|
| 490 |
+
- For wrapping/packing case, we use a proxy for `max_tokens` as `batch_size` * `len_to_wrap_long_seq`
|
| 491 |
+
- If not, some average fragment characteristic `mean_fragment_number_of_tokens`, multiplied by 1.5 to get on average >=2 tables
|
| 492 |
+
- Finally, if no, other info is available, we use 10_000 as arbitrary proxy (good typical value for many of our datasets).
|
| 493 |
+
|
| 494 |
+
"""
|
| 495 |
+
if self.loading_config.max_tokens is not None:
|
| 496 |
+
return 4 * self.loading_config.max_tokens
|
| 497 |
+
if (
|
| 498 |
+
self.loading_config.batch_size is not None
|
| 499 |
+
and self.loading_config.len_to_wrap_long_seq is not None
|
| 500 |
+
):
|
| 501 |
+
return (
|
| 502 |
+
4
|
| 503 |
+
* self.loading_config.len_to_wrap_long_seq
|
| 504 |
+
* self.loading_config.batch_size
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
if basic_stats.mean_fragment_number_of_tokens is not None:
|
| 508 |
+
return int(
|
| 509 |
+
1.5 * basic_stats.mean_fragment_number_of_tokens
|
| 510 |
+
) # to get few fragments grouped together
|
| 511 |
+
|
| 512 |
+
return 10_000 # default number that should not take a lot of RAM
|
| 513 |
+
|
| 514 |
+
def _find_nb_parallel_fragments(
|
| 515 |
+
self, basic_stats: GlobalPQStats, max_fragments=20, min_fragments=2
|
| 516 |
+
) -> int:
|
| 517 |
+
"""
|
| 518 |
+
Experimental!
|
| 519 |
+
Allows to determine nb of parallel fragments to load base on simple rules and dataset row group stats.
|
| 520 |
+
In particular, if `nb_parallel_fragments` will increase with increasing batch_size of max_tokens.
|
| 521 |
+
"""
|
| 522 |
+
if basic_stats.min_number_of_fragment < 3:
|
| 523 |
+
return basic_stats.min_number_of_fragment
|
| 524 |
+
|
| 525 |
+
if basic_stats.mean_fragment_number_of_tokens is None:
|
| 526 |
+
logger.warning(
|
| 527 |
+
f"Cannot get `mean_fragment_number_of_tokens` from dataset {self.dataset_config}, `nb_parallel_fragement` detection can be wrong",
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
mean_fragment_number_of_tokens = (
|
| 531 |
+
basic_stats.mean_fragment_number_of_tokens or 5000
|
| 532 |
+
) # typical, but arbitrary value
|
| 533 |
+
if (
|
| 534 |
+
self.loading_config.batch_size is None
|
| 535 |
+
and self.loading_config.max_tokens is None
|
| 536 |
+
):
|
| 537 |
+
# it can happen for evaluation
|
| 538 |
+
nb_frags = 1.0
|
| 539 |
+
elif self.loading_config.batch_size is not None:
|
| 540 |
+
if self.loading_config.len_to_wrap_long_seq is not None:
|
| 541 |
+
max_tokens = (
|
| 542 |
+
self.loading_config.len_to_wrap_long_seq
|
| 543 |
+
* self.loading_config.batch_size
|
| 544 |
+
)
|
| 545 |
+
nb_frags = 3 * max_tokens / mean_fragment_number_of_tokens
|
| 546 |
+
else:
|
| 547 |
+
nb_frags = (
|
| 548 |
+
5
|
| 549 |
+
* self.loading_config.batch_size
|
| 550 |
+
/ basic_stats.mean_fragment_length
|
| 551 |
+
)
|
| 552 |
+
elif self.loading_config.max_tokens is not None:
|
| 553 |
+
nb_frags = (
|
| 554 |
+
3 * self.loading_config.max_tokens / mean_fragment_number_of_tokens
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return max(min(max_fragments, round(nb_frags)), min_fragments)
|
| 558 |
+
|
| 559 |
+
@lru_cache
|
| 560 |
+
def _get_sequences_columns(self):
|
| 561 |
+
candidate_columns = []
|
| 562 |
+
for col in (self.dataset_config.source_sequences or []) + (
|
| 563 |
+
self.dataset_config.target_sequences or []
|
| 564 |
+
):
|
| 565 |
+
candidate_columns.append(col.text_column)
|
| 566 |
+
candidate_columns.append(col.sonar_column)
|
| 567 |
+
return [x for x in candidate_columns if x is not None]
|
| 568 |
+
|
| 569 |
+
def _get_minimal_columns(self):
|
| 570 |
+
# restrict on used collumns
|
| 571 |
+
candidate_columns = [
|
| 572 |
+
self.dataset_config.source_column,
|
| 573 |
+
self.dataset_config.source_text_column,
|
| 574 |
+
self.dataset_config.source_quality_column,
|
| 575 |
+
self.dataset_config.target_column,
|
| 576 |
+
self.dataset_config.target_text_column,
|
| 577 |
+
"split",
|
| 578 |
+
] + self._get_sequences_columns()
|
| 579 |
+
|
| 580 |
+
minimal_columns: List[str] = [
|
| 581 |
+
x
|
| 582 |
+
for x in candidate_columns
|
| 583 |
+
if x is not None and x in self.full_schema.names
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
if self.dataset_config.columns is None:
|
| 587 |
+
columns = sorted(set(minimal_columns))
|
| 588 |
+
else:
|
| 589 |
+
columns = sorted(set(minimal_columns + list(self.dataset_config.columns)))
|
| 590 |
+
if not set(columns).issubset(set(self.full_schema.names)):
|
| 591 |
+
raise ValueError(
|
| 592 |
+
f"columns {sorted(set(columns) - set(self.full_schema.names))} are not found in the dataset schema"
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
return columns
|
| 596 |
+
|
| 597 |
+
def _get_suffix_prefix_vector(self):
|
| 598 |
+
nested_result = prepare_suffix_prefix_embeddings(
|
| 599 |
+
self.dataset_config.source_prefix_text,
|
| 600 |
+
self.dataset_config.source_suffix_text,
|
| 601 |
+
self.dataset_config.target_prefix_text,
|
| 602 |
+
self.dataset_config.target_suffix_text,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
names = (
|
| 606 |
+
("source_prefix_vector", "source_prefix_sentences"),
|
| 607 |
+
("source_suffix_vector", "source_suffix_sentences"),
|
| 608 |
+
("target_prefix_vector", "target_prefix_sentences"),
|
| 609 |
+
("target_suffix_vector", "target_suffix_sentences"),
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
return {n: v for nn, val in zip(names, nested_result) for n, v in zip(nn, val)}
|
| 613 |
+
|
| 614 |
+
def get_fragments_pipeline(self):
|
| 615 |
+
split_to_row_groups = self.dataset_config.split_to_row_groups
|
| 616 |
+
assert isinstance(split_to_row_groups, bool)
|
| 617 |
+
|
| 618 |
+
# one can use `list_parquet_fragments` for a full fragments scan
|
| 619 |
+
fragments_pipeline_builder = stream_parquet_fragments(
|
| 620 |
+
parquet_ds=self.dataset,
|
| 621 |
+
nb_epochs=self.loading_config.nb_epochs,
|
| 622 |
+
split_to_row_groups=split_to_row_groups,
|
| 623 |
+
shuffle=self.loading_config.shuffle,
|
| 624 |
+
seed=self.loading_config.seed,
|
| 625 |
+
limit_options=self.dataset_config.limit,
|
| 626 |
+
shuffling_window=20 * self.nb_parallel_fragments,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
return fragments_pipeline_builder
|
| 630 |
+
|
| 631 |
+
def compute_stats(self, max_fragments=100) -> GlobalPQStats:
|
| 632 |
+
if self.dataset_config.source_sequences:
|
| 633 |
+
source_column = None
|
| 634 |
+
else:
|
| 635 |
+
source_column = self.dataset_config.source_column
|
| 636 |
+
|
| 637 |
+
split_to_row_groups = self.dataset_config.split_to_row_groups
|
| 638 |
+
|
| 639 |
+
columns = [source_column] if source_column else None
|
| 640 |
+
|
| 641 |
+
if (
|
| 642 |
+
self.dataset_config.limit is not None
|
| 643 |
+
and self.dataset_config.limit.nb_fragments is not None
|
| 644 |
+
):
|
| 645 |
+
# TODO: take into account other limit options to get better estimates
|
| 646 |
+
max_fragments = min(self.dataset_config.limit.nb_fragments, max_fragments)
|
| 647 |
+
|
| 648 |
+
self._stats_df = get_row_group_level_metadata(
|
| 649 |
+
self.dataset, columns=columns, max_fragments=max_fragments
|
| 650 |
+
)
|
| 651 |
+
dim = 1
|
| 652 |
+
if source_column:
|
| 653 |
+
self._stats_df["num_tokens"] = self._stats_df[source_column].apply(
|
| 654 |
+
lambda x: x["num_values"]
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
type_source = self.full_schema.field(source_column).type
|
| 658 |
+
try:
|
| 659 |
+
dim = type_source.value_type.list_size
|
| 660 |
+
if not dim or dim < 0:
|
| 661 |
+
dim = 1 # not a fixed vector size
|
| 662 |
+
except AttributeError:
|
| 663 |
+
logger.warning(f"source column {source_column} is not of list type")
|
| 664 |
+
if self.dataset_config.nb_parallel_fragments is None:
|
| 665 |
+
logger.warning("you may need to provide `nb_parallel_fragments`")
|
| 666 |
+
dim = 1
|
| 667 |
+
|
| 668 |
+
if split_to_row_groups:
|
| 669 |
+
global_stats_df = self._stats_df
|
| 670 |
+
elif "num_tokens" in self._stats_df:
|
| 671 |
+
global_stats_df = self._stats_df.groupby("parquet_file_path").agg(
|
| 672 |
+
{"num_rows": "sum", "num_tokens": "sum"}
|
| 673 |
+
)
|
| 674 |
+
else:
|
| 675 |
+
global_stats_df = self._stats_df.groupby("parquet_file_path").agg(
|
| 676 |
+
{"num_rows": "sum"}
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
mean_len_frag = global_stats_df["num_rows"].mean()
|
| 680 |
+
|
| 681 |
+
if "num_tokens" in global_stats_df:
|
| 682 |
+
mean_num_tokens_frag = self._stats_df["num_tokens"].mean() / dim
|
| 683 |
+
else:
|
| 684 |
+
mean_num_tokens_frag = None
|
| 685 |
+
|
| 686 |
+
return GlobalPQStats(
|
| 687 |
+
len(global_stats_df),
|
| 688 |
+
mean_len_frag,
|
| 689 |
+
mean_fragment_number_of_tokens=mean_num_tokens_frag,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
def add_inner_pipeline(self, pipeline: DataPipelineBuilder) -> DataPipelineBuilder:
|
| 693 |
+
loading_config = self.loading_config
|
| 694 |
+
|
| 695 |
+
columns_to_bucket = [
|
| 696 |
+
self.dataset_config.source_column,
|
| 697 |
+
self.dataset_config.target_column,
|
| 698 |
+
]
|
| 699 |
+
columns_to_bucket = [x for x in columns_to_bucket if x is not None]
|
| 700 |
+
|
| 701 |
+
def inner_iterator(table: pa.Table) -> DataPipeline:
|
| 702 |
+
return build_batching_loop_over_one_table(
|
| 703 |
+
table=table,
|
| 704 |
+
order_by_length=self.loading_config.order_by_length,
|
| 705 |
+
length_column=columns_to_bucket,
|
| 706 |
+
batch_size=loading_config.batch_size,
|
| 707 |
+
max_tokens=loading_config.max_tokens,
|
| 708 |
+
shuffle=loading_config.shuffle,
|
| 709 |
+
seed=self.random_state.randint(0, 2**32),
|
| 710 |
+
num_parallel_calls=self._num_parallel_call(3),
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
return pipeline.yield_from(inner_iterator)
|
| 714 |
+
|
| 715 |
+
def _get_inner_seed(self, rank: int, sharding_in_memory: bool) -> Optional[int]:
|
| 716 |
+
if self.loading_config.seed is not None:
|
| 717 |
+
if not sharding_in_memory:
|
| 718 |
+
return int(self.loading_config.seed) + rank * 100_000
|
| 719 |
+
else:
|
| 720 |
+
# for `sharding_in_memory`, we want the same shuffling
|
| 721 |
+
# to guarantee the consistent sharding across ranks
|
| 722 |
+
return int(self.loading_config.seed)
|
| 723 |
+
else:
|
| 724 |
+
return None
|
| 725 |
+
|
| 726 |
+
def add_prefix_suffix_pipeline(
|
| 727 |
+
self,
|
| 728 |
+
pipeline: DataPipelineBuilder,
|
| 729 |
+
column: Optional[str],
|
| 730 |
+
prefix,
|
| 731 |
+
suffix,
|
| 732 |
+
) -> DataPipelineBuilder:
|
| 733 |
+
if (suffix is None and prefix is None) or column is None:
|
| 734 |
+
return pipeline
|
| 735 |
+
pipeline = pipeline.map(
|
| 736 |
+
partial(
|
| 737 |
+
prefix_and_suffix_one_list_column,
|
| 738 |
+
column=column,
|
| 739 |
+
prefix_array=prefix,
|
| 740 |
+
suffix_array=suffix,
|
| 741 |
+
),
|
| 742 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 743 |
+
)
|
| 744 |
+
return pipeline
|
| 745 |
+
|
| 746 |
+
def add_basic_fragment_loading_pipeline(
|
| 747 |
+
self, pipeline: DataPipelineBuilder
|
| 748 |
+
) -> DataPipelineBuilder:
|
| 749 |
+
def load_fn(safe_frag):
|
| 750 |
+
try:
|
| 751 |
+
return safe_frag.load(columns=self.dataset_config.columns)
|
| 752 |
+
except Exception as e:
|
| 753 |
+
logger.error(
|
| 754 |
+
f"Error {e} occured while loading fragment {safe_frag} \n, skipping it"
|
| 755 |
+
)
|
| 756 |
+
return None
|
| 757 |
+
|
| 758 |
+
pipeline = pipeline.map(
|
| 759 |
+
load_fn,
|
| 760 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
pipeline = pipeline.filter(lambda table: bool(table is not None))
|
| 764 |
+
|
| 765 |
+
# we reapply the partition filters just in case of misusage
|
| 766 |
+
# but it should not change the performance
|
| 767 |
+
partition_filters = self.dataset_config.partition_filters
|
| 768 |
+
filters = self.dataset_config.filters
|
| 769 |
+
if partition_filters is not None and filters is not None:
|
| 770 |
+
full_filter = pa.compute.if_else(filters, partition_filters, False)
|
| 771 |
+
else:
|
| 772 |
+
full_filter = partition_filters if filters is None else filters
|
| 773 |
+
|
| 774 |
+
pipeline = pipeline.map(
|
| 775 |
+
partial(
|
| 776 |
+
apply_filter,
|
| 777 |
+
filters=full_filter,
|
| 778 |
+
drop_null=self.loading_config.drop_null,
|
| 779 |
+
)
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
|
| 783 |
+
pipeline = pipeline.prefetch(self._nb_prefetch(self.nb_parallel_fragments))
|
| 784 |
+
|
| 785 |
+
return pipeline
|
| 786 |
+
|
| 787 |
+
def filter_by_aligned_length(
|
| 788 |
+
self, pipeline: DataPipelineBuilder
|
| 789 |
+
) -> DataPipelineBuilder:
|
| 790 |
+
source_columns: List[str] = [
|
| 791 |
+
x
|
| 792 |
+
for x in (
|
| 793 |
+
self.dataset_config.source_column,
|
| 794 |
+
self.dataset_config.source_text_column,
|
| 795 |
+
self.dataset_config.source_quality_column,
|
| 796 |
+
)
|
| 797 |
+
if x is not None
|
| 798 |
+
]
|
| 799 |
+
|
| 800 |
+
# filter out sample where number of sentences and number of sonar embeddings are not equal
|
| 801 |
+
# which should never happen normally
|
| 802 |
+
|
| 803 |
+
pipeline = pipeline.map(
|
| 804 |
+
partial(
|
| 805 |
+
filter_table_with_different_lengths,
|
| 806 |
+
columns=source_columns,
|
| 807 |
+
),
|
| 808 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 809 |
+
)
|
| 810 |
+
pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
|
| 811 |
+
|
| 812 |
+
target_columns: List[str] = [
|
| 813 |
+
x
|
| 814 |
+
for x in (
|
| 815 |
+
self.dataset_config.target_column,
|
| 816 |
+
self.dataset_config.target_text_column,
|
| 817 |
+
)
|
| 818 |
+
if x is not None
|
| 819 |
+
]
|
| 820 |
+
|
| 821 |
+
pipeline = pipeline.map(
|
| 822 |
+
partial(
|
| 823 |
+
filter_table_with_different_lengths,
|
| 824 |
+
columns=target_columns,
|
| 825 |
+
),
|
| 826 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 827 |
+
)
|
| 828 |
+
pipeline = pipeline.filter(lambda table: bool(len(table) > 0))
|
| 829 |
+
|
| 830 |
+
return pipeline
|
| 831 |
+
|
| 832 |
+
def add_wrapping_to_max_length_pipeline(
|
| 833 |
+
self, pipeline: DataPipelineBuilder
|
| 834 |
+
) -> DataPipelineBuilder:
|
| 835 |
+
len_to_wrap_long_seq = getattr(
|
| 836 |
+
self.loading_config, "len_to_wrap_long_seq", None
|
| 837 |
+
)
|
| 838 |
+
if len_to_wrap_long_seq is None:
|
| 839 |
+
return pipeline
|
| 840 |
+
|
| 841 |
+
columns_to_wrap: List[str] = [
|
| 842 |
+
x
|
| 843 |
+
for x in (
|
| 844 |
+
self.dataset_config.source_column,
|
| 845 |
+
self.dataset_config.source_text_column,
|
| 846 |
+
self.dataset_config.source_quality_column,
|
| 847 |
+
)
|
| 848 |
+
if x is not None
|
| 849 |
+
]
|
| 850 |
+
|
| 851 |
+
if self.loading_config.packing:
|
| 852 |
+
method = return_none_on_failure(explode_table_with_fixed_length)
|
| 853 |
+
logger.info(
|
| 854 |
+
f"Wrapping to len_to_wrap_long_seq={len_to_wrap_long_seq} with fixed length (packing)"
|
| 855 |
+
)
|
| 856 |
+
else:
|
| 857 |
+
method = return_none_on_failure(explode_table_with_max_length)
|
| 858 |
+
logger.info(
|
| 859 |
+
f"Wrapping to len_to_wrap_long_seq={len_to_wrap_long_seq} with max length (without packing)"
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
pipeline = pipeline.map(
|
| 863 |
+
partial(
|
| 864 |
+
method,
|
| 865 |
+
columns=columns_to_wrap,
|
| 866 |
+
max_seq_len=len_to_wrap_long_seq,
|
| 867 |
+
),
|
| 868 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 869 |
+
)
|
| 870 |
+
return pipeline.filter(lambda table: table is not None)
|
| 871 |
+
|
| 872 |
+
def add_min_max_sentence_len_in_doc_filter(
|
| 873 |
+
self, pipeline: DataPipelineBuilder
|
| 874 |
+
) -> DataPipelineBuilder:
|
| 875 |
+
if (
|
| 876 |
+
self.loading_config.max_sentence_len_in_doc
|
| 877 |
+
or self.loading_config.min_sentence_len_in_doc
|
| 878 |
+
):
|
| 879 |
+
assert self.dataset_config.source_text_column is not None, (
|
| 880 |
+
f"Expexted a source_text_columns, found {self.dataset_config.source_text_column}"
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
pipeline = pipeline.map(
|
| 884 |
+
partial(
|
| 885 |
+
filter_long_short_sentence_document,
|
| 886 |
+
column=self.dataset_config.source_text_column,
|
| 887 |
+
max_sentence_len=self.loading_config.max_sentence_len_in_doc,
|
| 888 |
+
min_sentence_len=self.loading_config.min_sentence_len_in_doc,
|
| 889 |
+
),
|
| 890 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 891 |
+
).filter(lambda table: bool(len(table) > 0))
|
| 892 |
+
|
| 893 |
+
if self.dataset_config.target_column is not None and (
|
| 894 |
+
self.loading_config.max_sentence_len_in_target_doc
|
| 895 |
+
or self.loading_config.min_sentence_len_in_target_doc
|
| 896 |
+
):
|
| 897 |
+
pipeline = pipeline.map(
|
| 898 |
+
partial(
|
| 899 |
+
filter_long_short_sentence_document,
|
| 900 |
+
column=self.dataset_config.target_column,
|
| 901 |
+
max_sentence_len=self.loading_config.max_sentence_len_in_target_doc,
|
| 902 |
+
min_sentence_len=self.loading_config.min_sentence_len_in_target_doc,
|
| 903 |
+
),
|
| 904 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 905 |
+
).filter(lambda table: bool(len(table) > 0))
|
| 906 |
+
|
| 907 |
+
return pipeline
|
| 908 |
+
|
| 909 |
+
def add_min_sentence_number_in_doc_filter(
|
| 910 |
+
self,
|
| 911 |
+
pipeline: DataPipelineBuilder,
|
| 912 |
+
min_source_length: Optional[int] = None,
|
| 913 |
+
min_target_length: Optional[int] = None,
|
| 914 |
+
) -> DataPipelineBuilder:
|
| 915 |
+
"""
|
| 916 |
+
If `min_source_length` is not None: filter the source to remove sequences
|
| 917 |
+
with less than `min_source_length` sentences
|
| 918 |
+
If `min_target_length` is not None and data comes with a target column:
|
| 919 |
+
filter the target to remove sequences with less than `min_target_length` sentences
|
| 920 |
+
|
| 921 |
+
"""
|
| 922 |
+
|
| 923 |
+
def _min_length_filter(table, column, length):
|
| 924 |
+
filter_ = pc.greater_equal(pc.list_value_length(table[column]), length)
|
| 925 |
+
|
| 926 |
+
if pc.all(filter_).as_py():
|
| 927 |
+
return table
|
| 928 |
+
return table.filter(filter_)
|
| 929 |
+
|
| 930 |
+
if (
|
| 931 |
+
self.dataset_config.source_column is not None
|
| 932 |
+
and min_source_length is not None
|
| 933 |
+
):
|
| 934 |
+
pipeline = pipeline.map(
|
| 935 |
+
partial(
|
| 936 |
+
_min_length_filter,
|
| 937 |
+
column=self.dataset_config.source_column,
|
| 938 |
+
length=min_source_length,
|
| 939 |
+
),
|
| 940 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 941 |
+
).filter(lambda table: bool(len(table) > 0))
|
| 942 |
+
|
| 943 |
+
if (
|
| 944 |
+
self.dataset_config.target_column is not None
|
| 945 |
+
and min_target_length is not None
|
| 946 |
+
):
|
| 947 |
+
pipeline = pipeline.map(
|
| 948 |
+
partial(
|
| 949 |
+
_min_length_filter,
|
| 950 |
+
column=self.dataset_config.target_column,
|
| 951 |
+
length=min_target_length,
|
| 952 |
+
),
|
| 953 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 954 |
+
).filter(lambda table: bool(len(table) > 0))
|
| 955 |
+
|
| 956 |
+
return pipeline
|
| 957 |
+
|
| 958 |
+
def add_quality_score_filters(
|
| 959 |
+
self, pipeline: DataPipelineBuilder
|
| 960 |
+
) -> DataPipelineBuilder:
|
| 961 |
+
source_quality_range = self.dataset_config.source_quality_range
|
| 962 |
+
if source_quality_range is None:
|
| 963 |
+
return pipeline
|
| 964 |
+
|
| 965 |
+
assert self.dataset_config.source_quality_column is not None, (
|
| 966 |
+
f"Expected a source_quality_columns, found {self.dataset_config.source_quality_column}"
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
pipeline = pipeline.map(
|
| 970 |
+
partial(
|
| 971 |
+
filter_document_by_quality,
|
| 972 |
+
column=self.dataset_config.source_quality_column,
|
| 973 |
+
min_score=source_quality_range[0],
|
| 974 |
+
max_score=source_quality_range[1],
|
| 975 |
+
),
|
| 976 |
+
num_parallel_calls=self._num_parallel_call(self.nb_parallel_fragments),
|
| 977 |
+
).filter(lambda table: bool(len(table) > 0))
|
| 978 |
+
return pipeline
|
| 979 |
+
|
| 980 |
+
def add_format_conversion(
|
| 981 |
+
self, pipeline: DataPipelineBuilder
|
| 982 |
+
) -> DataPipelineBuilder:
|
| 983 |
+
if self.loading_config.output_format == ParquetBatchFormat.pandas:
|
| 984 |
+
pipeline = pipeline.map(lambda table: table.to_pandas())
|
| 985 |
+
elif self.loading_config.output_format == ParquetBatchFormat.torch:
|
| 986 |
+
pipeline = pipeline.map(lambda wt: pyarrow_table_to_torch_dict(wt))
|
| 987 |
+
return pipeline
|
| 988 |
+
|
| 989 |
+
def get_python_iterator(
|
| 990 |
+
self, rank: int = 0, world_size: int = 1
|
| 991 |
+
) -> Generator[BatchOutputType, None, None]: # type: ignore
|
| 992 |
+
yield from iter(
|
| 993 |
+
self.build_dataload_pipeline(
|
| 994 |
+
rank=rank,
|
| 995 |
+
world_size=world_size,
|
| 996 |
+
)
|
| 997 |
+
.prefetch(self._nb_prefetch(5))
|
| 998 |
+
.and_return(max_num_warnings=4)
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
def parquet_iterator(
|
| 1003 |
+
dataset_config: ParquetDatasetConfig,
|
| 1004 |
+
loading_config: DataLoadingConfig,
|
| 1005 |
+
rank: int,
|
| 1006 |
+
world_size: int,
|
| 1007 |
+
) -> Generator[BatchOutputType, None, None]: # type: ignore
|
| 1008 |
+
spdd = SingleParquetDatasetDataloader(dataset_config, loading_config)
|
| 1009 |
+
yield from spdd.get_python_iterator(rank, world_size)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def build_parquet_iterator_pipeline(
|
| 1013 |
+
dataset_config: ParquetDatasetConfig,
|
| 1014 |
+
loading_config: DataLoadingConfig,
|
| 1015 |
+
rank: int = 0,
|
| 1016 |
+
world_size: int = 1,
|
| 1017 |
+
) -> DataPipelineBuilder:
|
| 1018 |
+
return SingleParquetDatasetDataloader(
|
| 1019 |
+
dataset_config, loading_config
|
| 1020 |
+
).build_dataload_pipeline(rank=rank, world_size=world_size)
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
def ds_name(conf: ParquetDatasetConfig) -> str:
|
| 1024 |
+
if conf.name is not None:
|
| 1025 |
+
return conf.name
|
| 1026 |
+
return str(conf.parquet_path)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
def circular_shift_left(lst: List[Any], k: int) -> List[Any]:
|
| 1030 |
+
if len(lst) <= 1:
|
| 1031 |
+
return lst
|
| 1032 |
+
|
| 1033 |
+
k = k % len(lst) # To handle shifts larger than the list length
|
| 1034 |
+
return lst[k:] + lst[:k]
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def build_weighted_pipeline_with_renaming(
|
| 1038 |
+
dataset_configs: Sequence[ParquetDatasetConfig],
|
| 1039 |
+
loading_config: DataLoadingConfig,
|
| 1040 |
+
rank: int = 0,
|
| 1041 |
+
world_size: int = 1,
|
| 1042 |
+
) -> DataPipeline:
|
| 1043 |
+
assert loading_config.multiple_dataset_chaining in [
|
| 1044 |
+
"sample",
|
| 1045 |
+
"concat",
|
| 1046 |
+
"round_robin",
|
| 1047 |
+
]
|
| 1048 |
+
|
| 1049 |
+
# adjusting the number parallel calls and prefetch according to total number of datasets
|
| 1050 |
+
dataset_configs = list(dataset_configs)
|
| 1051 |
+
loading_config.num_parallel_calls = loading_config.num_parallel_calls / len(
|
| 1052 |
+
dataset_configs
|
| 1053 |
+
)
|
| 1054 |
+
loading_config.nb_prefetch = loading_config.nb_prefetch // len(dataset_configs)
|
| 1055 |
+
|
| 1056 |
+
name_mappers = get_renaming_mappers(dataset_configs)
|
| 1057 |
+
pipelines: List[DataPipelineBuilder] = []
|
| 1058 |
+
|
| 1059 |
+
def process_one_pipeline(cc, mapper):
|
| 1060 |
+
return build_parquet_iterator_pipeline(
|
| 1061 |
+
dataset_config=cc,
|
| 1062 |
+
loading_config=loading_config,
|
| 1063 |
+
rank=rank,
|
| 1064 |
+
world_size=world_size,
|
| 1065 |
+
).map(
|
| 1066 |
+
partial(renaming, mapper=mapper, name=ds_name(cc)),
|
| 1067 |
+
num_parallel_calls=1,
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# creating all datasets pipeline in parallel
|
| 1071 |
+
pipelines = [
|
| 1072 |
+
process_one_pipeline(cc, mapper)
|
| 1073 |
+
for cc, mapper in zip(dataset_configs, name_mappers)
|
| 1074 |
+
]
|
| 1075 |
+
|
| 1076 |
+
if len(pipelines) == 1:
|
| 1077 |
+
return (
|
| 1078 |
+
pipelines[0]
|
| 1079 |
+
.prefetch(int(max(loading_config.nb_prefetch, 1)))
|
| 1080 |
+
.and_return(max_num_warnings=4)
|
| 1081 |
+
)
|
| 1082 |
+
if loading_config.seed is not None:
|
| 1083 |
+
seed = loading_config.seed + (0 if loading_config.even_sharding else rank)
|
| 1084 |
+
else:
|
| 1085 |
+
seed = None
|
| 1086 |
+
|
| 1087 |
+
pipelines_with_return = [pp.and_return(max_num_warnings=4) for pp in pipelines]
|
| 1088 |
+
|
| 1089 |
+
if loading_config.multiple_dataset_chaining == "concat":
|
| 1090 |
+
# TODO : check that all weights = 1
|
| 1091 |
+
weighted_pipeline = DataPipeline.concat(
|
| 1092 |
+
circular_shift_left(pipelines_with_return, k=rank),
|
| 1093 |
+
)
|
| 1094 |
+
elif loading_config.multiple_dataset_chaining == "round_robin":
|
| 1095 |
+
weighted_pipeline = DataPipeline.round_robin(
|
| 1096 |
+
circular_shift_left(pipelines_with_return, k=rank), allow_repeats=False
|
| 1097 |
+
)
|
| 1098 |
+
else:
|
| 1099 |
+
weighted_pipeline = DataPipeline.sample(
|
| 1100 |
+
pipelines_with_return,
|
| 1101 |
+
[getattr(cc, "weight", 1.0) for cc in dataset_configs],
|
| 1102 |
+
seed=seed,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
return weighted_pipeline.prefetch(
|
| 1106 |
+
int(
|
| 1107 |
+
max(loading_config.nb_prefetch * len(dataset_configs) ** 2, 1)
|
| 1108 |
+
) # try to prefetch at least one element from each dataset
|
| 1109 |
+
).and_return(max_num_warnings=4)
|
lcm/datasets/parquet_utils.py
ADDED
|
@@ -0,0 +1,1141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from functools import lru_cache, reduce, wraps
|
| 10 |
+
from pickle import dumps, loads
|
| 11 |
+
from typing import Any, Iterator, List, Optional, Union
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import polars as pl
|
| 16 |
+
import pyarrow as pa
|
| 17 |
+
import pyarrow.compute as pc
|
| 18 |
+
import pyarrow.parquet as pq
|
| 19 |
+
import torch
|
| 20 |
+
from fairseq2.data.data_pipeline import (
|
| 21 |
+
DataPipeline,
|
| 22 |
+
DataPipelineBuilder,
|
| 23 |
+
read_iterator,
|
| 24 |
+
read_sequence,
|
| 25 |
+
)
|
| 26 |
+
from fairseq2.data.parquet.tools import (
|
| 27 |
+
NestedDict,
|
| 28 |
+
NestedDictValue,
|
| 29 |
+
add_partitioning_values,
|
| 30 |
+
compute_rows_length,
|
| 31 |
+
get_dataset_fragments,
|
| 32 |
+
split_fragment_in_row_groups,
|
| 33 |
+
)
|
| 34 |
+
from joblib import Parallel, delayed
|
| 35 |
+
from numpy.typing import NDArray
|
| 36 |
+
from pyarrow.dataset import get_partition_keys
|
| 37 |
+
from retrying import retry
|
| 38 |
+
from stopes.modules.preprocess.sonar_text_embedding import (
|
| 39 |
+
LangColumnConfig,
|
| 40 |
+
SonarTextBatchEmbedder,
|
| 41 |
+
SonarTextEmbedderConfig,
|
| 42 |
+
)
|
| 43 |
+
from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
|
| 44 |
+
from stopes.utils.arrow_utils import (
|
| 45 |
+
hstack_pyarray_list,
|
| 46 |
+
is_list_like,
|
| 47 |
+
pyarrow_column_to_array,
|
| 48 |
+
simple_array_to_nested,
|
| 49 |
+
)
|
| 50 |
+
from tqdm.auto import tqdm
|
| 51 |
+
|
| 52 |
+
from lcm.datasets.configs import (
|
| 53 |
+
ColumnsNames,
|
| 54 |
+
ParquetDatasetLimitOptions,
|
| 55 |
+
SonarTextColumn,
|
| 56 |
+
)
|
| 57 |
+
from lcm.utils.common import batched
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
from numba import njit
|
| 61 |
+
except ModuleNotFoundError:
|
| 62 |
+
print("Numba is not installed. Fall-back to the non-recompiled version")
|
| 63 |
+
|
| 64 |
+
def empty_jit(f):
|
| 65 |
+
@wraps(f)
|
| 66 |
+
def _f(*args, **kwargs):
|
| 67 |
+
return f(*args, **kwargs)
|
| 68 |
+
|
| 69 |
+
return _f
|
| 70 |
+
|
| 71 |
+
njit = empty_jit
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
loading_retry = retry(
|
| 75 |
+
retry_on_exception=lambda exception: isinstance(exception, OSError),
|
| 76 |
+
stop_max_attempt_number=1,
|
| 77 |
+
wait_exponential_multiplier=2,
|
| 78 |
+
wait_exponential_max=20,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
logger = logging.getLogger(__name__)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def prefix_and_suffix_one_list_column(
|
| 86 |
+
table: pa.Table, column: str, prefix_array: pa.Array, suffix_array: pa.Array
|
| 87 |
+
):
|
| 88 |
+
prefix_extended = pa.chunked_array(
|
| 89 |
+
[pa.ListArray.from_arrays([0, len(prefix_array)], prefix_array)] * len(table)
|
| 90 |
+
)
|
| 91 |
+
suffix_extended = pa.chunked_array(
|
| 92 |
+
[pa.ListArray.from_arrays([0, len(suffix_array)], suffix_array)] * len(table)
|
| 93 |
+
)
|
| 94 |
+
target_dtype = table[column].type
|
| 95 |
+
if prefix_extended.type != target_dtype:
|
| 96 |
+
prefix_extended = prefix_extended.cast(target_dtype)
|
| 97 |
+
if suffix_extended.type != target_dtype:
|
| 98 |
+
suffix_extended = suffix_extended.cast(target_dtype)
|
| 99 |
+
|
| 100 |
+
new_array = hstack_pyarray_list(prefix_extended, table[column], suffix_extended)
|
| 101 |
+
return table.drop([column]).append_column(column, new_array)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def define_parquet_dataset(parquet_path: str, partition_filters) -> pq.ParquetDataset:
|
| 105 |
+
return pq.ParquetDataset(
|
| 106 |
+
parquet_path,
|
| 107 |
+
filters=partition_filters,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@lru_cache()
|
| 112 |
+
def default_sonar_pipeline() -> SonarTextBatchEmbedder:
|
| 113 |
+
local_sonar_config = SonarTextEmbedderConfig(
|
| 114 |
+
column_config=[
|
| 115 |
+
LangColumnConfig("input_text", lang_value="eng_Latn"),
|
| 116 |
+
],
|
| 117 |
+
batch_size=10,
|
| 118 |
+
device="cpu",
|
| 119 |
+
)
|
| 120 |
+
return SonarTextBatchEmbedder(local_sonar_config)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@lru_cache(2000)
|
| 124 |
+
def _get_embed_sentences(text: Optional[str]) -> pa.Array:
|
| 125 |
+
sentences_splitter = get_split_algo("eng_Latn", "default")
|
| 126 |
+
lstbe = default_sonar_pipeline()
|
| 127 |
+
sentences = pa.array(sentences_splitter(text) if text else [""])
|
| 128 |
+
input_table = pa.Table.from_pydict({"input_text": sentences})
|
| 129 |
+
vectors = pyarrow_column_to_array(lstbe(input_table)["input_text_sonar_emb"])
|
| 130 |
+
if not text:
|
| 131 |
+
# empty output of the right type
|
| 132 |
+
vectors = vectors.slice(0, 0)
|
| 133 |
+
sentences = sentences.slice(0, 0)
|
| 134 |
+
return vectors, sentences
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def prepare_suffix_prefix_embeddings(*args):
|
| 138 |
+
if all(xx is None for xx in args): # to avoid loading SonarModel
|
| 139 |
+
return [(None, None) for _ in args]
|
| 140 |
+
|
| 141 |
+
return [_get_embed_sentences(xx) for xx in args]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def from_pyarrow_to_torch_tensor(
|
| 145 |
+
arr: Union[pa.Array, pa.ChunkedArray], strict: bool = False
|
| 146 |
+
) -> NestedDictValue:
|
| 147 |
+
"""
|
| 148 |
+
struct_array = pa.Array.from_pandas([{"x": 4, "y": "RR"}] * 10)
|
| 149 |
+
nest_array = pa.Array.from_pandas([[{'a': 1}, {'a': 2}]])
|
| 150 |
+
"""
|
| 151 |
+
# for future ideas https://arrow.apache.org/docs/python/generated/pyarrow.Tensor.html
|
| 152 |
+
# for sparse matrix support https://github.com/apache/arrow/blob/main/python/pyarrow/tests/test_sparse_tensor.py
|
| 153 |
+
|
| 154 |
+
if arr.null_count != 0:
|
| 155 |
+
raise ValueError("to torch conversion does not support null values")
|
| 156 |
+
|
| 157 |
+
arr = pyarrow_column_to_array(arr)
|
| 158 |
+
|
| 159 |
+
arr_type = arr.type
|
| 160 |
+
if pa.types.is_primitive(arr_type):
|
| 161 |
+
try:
|
| 162 |
+
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
|
| 163 |
+
except Exception:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
|
| 168 |
+
except pa.ArrowInvalid:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
if pa.types.is_dictionary(arr_type):
|
| 172 |
+
return from_pyarrow_to_torch_tensor(arr.dictionary_decode())
|
| 173 |
+
|
| 174 |
+
if pa.types.is_string(arr_type):
|
| 175 |
+
return arr.to_pandas().tolist()
|
| 176 |
+
|
| 177 |
+
if pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type):
|
| 178 |
+
if pa.types.is_primitive(arr_type.value_type):
|
| 179 |
+
return arr.to_pandas().map(torch.from_numpy).tolist()
|
| 180 |
+
|
| 181 |
+
if pa.types.is_fixed_size_list(arr_type.value_type) and pa.types.is_primitive(
|
| 182 |
+
arr_type.value_type.value_type
|
| 183 |
+
):
|
| 184 |
+
return (
|
| 185 |
+
arr.to_pandas()
|
| 186 |
+
.map(
|
| 187 |
+
lambda x: torch.from_numpy(
|
| 188 |
+
np.vstack(x) if len(x) > 0 else np.array([], dtype=np.float32)
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
.tolist()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if pa.types.is_fixed_size_list(arr_type):
|
| 195 |
+
if pa.types.is_primitive(arr_type.value_type):
|
| 196 |
+
return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size)))
|
| 197 |
+
|
| 198 |
+
if pa.types.is_struct(arr_type):
|
| 199 |
+
return {
|
| 200 |
+
arr_type.field(i).name: from_pyarrow_to_torch_tensor(arr.field(i))
|
| 201 |
+
for i in range(arr_type.num_fields)
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
if pa.types.is_nested(arr_type):
|
| 205 |
+
# TODO: deal with arr = [[{'a': 1}, {'a': 2}]]
|
| 206 |
+
pass
|
| 207 |
+
|
| 208 |
+
if strict:
|
| 209 |
+
raise NotImplementedError(f"{arr_type} cannot be converted to torch.Tensor")
|
| 210 |
+
else:
|
| 211 |
+
return arr # keeping as in the orignal pyarrow form
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = False) -> NestedDict:
|
| 215 |
+
out = {}
|
| 216 |
+
for col in tt.column_names:
|
| 217 |
+
try:
|
| 218 |
+
out[col] = from_pyarrow_to_torch_tensor(tt[col], strict)
|
| 219 |
+
except ValueError as e:
|
| 220 |
+
logger.info(
|
| 221 |
+
f"Column {col} of type {tt[col].type} was not converted to torch as expected",
|
| 222 |
+
str(e),
|
| 223 |
+
)
|
| 224 |
+
out[col] = tt[col]
|
| 225 |
+
return out
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def add_fragments_trace(table: pa.Table, fragment: pa.dataset.Fragment) -> pa.Table:
|
| 229 |
+
table = table.append_column(
|
| 230 |
+
"__row_groups_ids",
|
| 231 |
+
len(table)
|
| 232 |
+
* [np.array([int(rg.id) for rg in fragment.row_groups], dtype=np.int32)],
|
| 233 |
+
)
|
| 234 |
+
table = table.append_column(
|
| 235 |
+
"__index_in_fragement", pa.array(np.arange(len(table), dtype=np.int32))
|
| 236 |
+
)
|
| 237 |
+
return table
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def shuffle_table(table: pa.Table, random_state: np.random.RandomState) -> pa.Table:
|
| 241 |
+
permutation = pa.array(random_state.permutation(len(table)))
|
| 242 |
+
return table.take(permutation)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class SafeFragment:
|
| 246 |
+
"""
|
| 247 |
+
Experimental :
|
| 248 |
+
Simple wrapper around `ParquetFileFragment` that allows to reinit the state of filesystem
|
| 249 |
+
if aws session token has expired.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
fragment: pa.dataset.ParquetFileFragment
|
| 253 |
+
|
| 254 |
+
def __init__(self, fragment: pa.dataset.ParquetFileFragment):
|
| 255 |
+
self.fragment = fragment
|
| 256 |
+
|
| 257 |
+
def __repr__(self) -> str:
|
| 258 |
+
out = ""
|
| 259 |
+
out += "SafeFragment \n"
|
| 260 |
+
out += "path = " + self.fragment.path + "\n"
|
| 261 |
+
out += f"row_groups = {[int(rg.id) for rg in self.fragment.row_groups]} \n"
|
| 262 |
+
out += f"physical_schema = \n {self.fragment.physical_schema} \n"
|
| 263 |
+
return out
|
| 264 |
+
|
| 265 |
+
@loading_retry
|
| 266 |
+
def load(self, columns: Optional[List[str]] = None) -> pa.Table:
|
| 267 |
+
if columns is not None:
|
| 268 |
+
fragment_columns = [
|
| 269 |
+
col for col in columns if col in self.fragment.physical_schema.names
|
| 270 |
+
]
|
| 271 |
+
else:
|
| 272 |
+
fragment_columns = self.fragment.physical_schema.names
|
| 273 |
+
# adding technical columns for tracking
|
| 274 |
+
fragment_columns = list(fragment_columns) + [
|
| 275 |
+
"__batch_index",
|
| 276 |
+
"__fragment_index",
|
| 277 |
+
"__filename",
|
| 278 |
+
]
|
| 279 |
+
try:
|
| 280 |
+
fragment_table = self.fragment.to_table(
|
| 281 |
+
columns=fragment_columns, use_threads=False
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
except OSError as e:
|
| 285 |
+
logger.info(
|
| 286 |
+
"could not load fragment, reinit the fragment state. Error: ", str(e)
|
| 287 |
+
)
|
| 288 |
+
self.fragment = loads(dumps(self.fragment))
|
| 289 |
+
fragment_table = self.fragment.to_table(
|
| 290 |
+
columns=fragment_columns, use_threads=False
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
fragment_table = add_partitioning_values(fragment_table, self.fragment, columns)
|
| 294 |
+
fragment_table = add_fragments_trace(fragment_table, self.fragment)
|
| 295 |
+
return fragment_table
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _parquet_fragments_to_pipeline_builder(
|
| 299 |
+
file_ds_fragments: List[pa.dataset.Fragment],
|
| 300 |
+
nb_epochs: int = 1,
|
| 301 |
+
shuffle: bool = True,
|
| 302 |
+
seed: Optional[int] = None,
|
| 303 |
+
) -> DataPipelineBuilder:
|
| 304 |
+
if shuffle:
|
| 305 |
+
if seed is None:
|
| 306 |
+
seed = int(torch.randint(0, 2**31, ()).item())
|
| 307 |
+
|
| 308 |
+
rsg = np.random.RandomState(seed)
|
| 309 |
+
ds_fragments_ = np.asarray(file_ds_fragments, dtype="O")
|
| 310 |
+
ds_fragments = np.concatenate(
|
| 311 |
+
[rsg.permutation(ds_fragments_) for _ in range(nb_epochs)]
|
| 312 |
+
).tolist()
|
| 313 |
+
else:
|
| 314 |
+
ds_fragments = file_ds_fragments * nb_epochs
|
| 315 |
+
|
| 316 |
+
pipeline_builder = read_sequence(ds_fragments)
|
| 317 |
+
pipeline_builder = pipeline_builder.map(SafeFragment)
|
| 318 |
+
return pipeline_builder
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def list_parquet_fragments(
|
| 322 |
+
parquet_ds: pq.ParquetDataset,
|
| 323 |
+
nb_epochs: int = 1,
|
| 324 |
+
split_to_row_groups: bool = True,
|
| 325 |
+
shuffle: bool = True,
|
| 326 |
+
seed: Optional[int] = None,
|
| 327 |
+
limit_options: Optional[ParquetDatasetLimitOptions] = None,
|
| 328 |
+
nb_jobs: int = 10,
|
| 329 |
+
) -> DataPipelineBuilder:
|
| 330 |
+
if limit_options is None:
|
| 331 |
+
limit_options = ParquetDatasetLimitOptions()
|
| 332 |
+
|
| 333 |
+
file_ds_fragments = get_dataset_fragments(parquet_ds, parquet_ds._filter_expression)
|
| 334 |
+
proxy_ds_path = "/".join(parquet_ds.files[0].split("=")[0].split("/")[:-1])
|
| 335 |
+
|
| 336 |
+
logger.info(f"{proxy_ds_path} : full number of files {len(file_ds_fragments)}")
|
| 337 |
+
if limit_options.fraction_of_files is not None:
|
| 338 |
+
file_ds_fragments = file_ds_fragments[
|
| 339 |
+
: max(
|
| 340 |
+
int(round(limit_options.fraction_of_files * len(file_ds_fragments))), 1
|
| 341 |
+
)
|
| 342 |
+
]
|
| 343 |
+
logger.info(
|
| 344 |
+
f"{proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of fraction_of_files={limit_options.fraction_of_files}"
|
| 345 |
+
)
|
| 346 |
+
if limit_options.nb_files is not None and limit_options.nb_files < len(
|
| 347 |
+
file_ds_fragments
|
| 348 |
+
):
|
| 349 |
+
file_ds_fragments = file_ds_fragments[: limit_options.nb_files]
|
| 350 |
+
logger.info(
|
| 351 |
+
f"{proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of nb_files={limit_options.nb_files}"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
output_fragments = []
|
| 355 |
+
total_nb_rows = 0
|
| 356 |
+
if split_to_row_groups:
|
| 357 |
+
logger.info(f"{proxy_ds_path} : starting split in row groups")
|
| 358 |
+
|
| 359 |
+
with Parallel(backend="threading", n_jobs=nb_jobs) as parallel:
|
| 360 |
+
total_nb_fragments = 0
|
| 361 |
+
early_stop = False
|
| 362 |
+
|
| 363 |
+
for batch_of_files in batched(file_ds_fragments, 20 * nb_jobs):
|
| 364 |
+
row_groups = parallel(
|
| 365 |
+
delayed(split_fragment_in_row_groups)(ff) for ff in batch_of_files
|
| 366 |
+
)
|
| 367 |
+
new_file_fragments = [x for y in row_groups for x in y]
|
| 368 |
+
if limit_options.nb_rows is not None:
|
| 369 |
+
new_file_fragments_stats = parallel(
|
| 370 |
+
delayed(lambda frag: frag.row_groups[0].num_rows)(ff)
|
| 371 |
+
for ff in new_file_fragments
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
new_file_fragments_stats = [0] * len(new_file_fragments)
|
| 375 |
+
|
| 376 |
+
for nb_row, frag in zip(new_file_fragments_stats, new_file_fragments):
|
| 377 |
+
output_fragments.append(frag)
|
| 378 |
+
total_nb_rows += nb_row
|
| 379 |
+
total_nb_fragments += 1
|
| 380 |
+
if (
|
| 381 |
+
limit_options.nb_fragments is not None
|
| 382 |
+
and total_nb_fragments >= limit_options.nb_fragments
|
| 383 |
+
):
|
| 384 |
+
early_stop = True
|
| 385 |
+
if limit_options.nb_rows is not None:
|
| 386 |
+
logger.info(
|
| 387 |
+
f"{proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached with around {total_nb_rows} rows"
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
logger.info(
|
| 391 |
+
f"{proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached"
|
| 392 |
+
)
|
| 393 |
+
break
|
| 394 |
+
if (
|
| 395 |
+
limit_options.nb_rows is not None
|
| 396 |
+
and total_nb_rows >= limit_options.nb_rows
|
| 397 |
+
):
|
| 398 |
+
early_stop = True
|
| 399 |
+
logger.info(
|
| 400 |
+
f"{proxy_ds_path} : nb_rows limit {limit_options.nb_rows} was reached with around {total_nb_fragments} fragments"
|
| 401 |
+
)
|
| 402 |
+
break
|
| 403 |
+
if early_stop:
|
| 404 |
+
break
|
| 405 |
+
else:
|
| 406 |
+
for frag in file_ds_fragments[: limit_options.nb_fragments]:
|
| 407 |
+
output_fragments.append(frag)
|
| 408 |
+
if limit_options.nb_rows is not None:
|
| 409 |
+
total_nb_rows += frag.count_rows()
|
| 410 |
+
if total_nb_rows >= limit_options.nb_rows:
|
| 411 |
+
break
|
| 412 |
+
|
| 413 |
+
logger.info(f"{proxy_ds_path} : finding fragments {len(output_fragments)}")
|
| 414 |
+
|
| 415 |
+
return _parquet_fragments_to_pipeline_builder(
|
| 416 |
+
output_fragments,
|
| 417 |
+
nb_epochs=nb_epochs,
|
| 418 |
+
shuffle=shuffle,
|
| 419 |
+
seed=seed,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def compute_length_splits(
|
| 424 |
+
length_col: NDArray[np.int32],
|
| 425 |
+
max_tokens: int,
|
| 426 |
+
order_by_length: bool = True,
|
| 427 |
+
drop_long_sample: bool = True,
|
| 428 |
+
) -> List[NDArray[np.int32]]:
|
| 429 |
+
"""split sequence of length_col in the chunks such that total length is ~ max_tokens
|
| 430 |
+
countint the padding to max length of elements in a chunk
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
length_col (np.ndarray):
|
| 434 |
+
max_tokens (int):
|
| 435 |
+
order_by_length (bool):
|
| 436 |
+
drop_long_sample (bool):
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
List[np.ndarray]: splits that contain indices over the original length_col
|
| 440 |
+
"""
|
| 441 |
+
argsort_ind = (
|
| 442 |
+
np.argsort(length_col)
|
| 443 |
+
if order_by_length
|
| 444 |
+
else np.arange(len(length_col), dtype=np.int32)
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
sorted_length_col = length_col[argsort_ind]
|
| 448 |
+
|
| 449 |
+
small_elements_masks = sorted_length_col <= max_tokens
|
| 450 |
+
big_elements_inds = argsort_ind[~small_elements_masks]
|
| 451 |
+
|
| 452 |
+
argsort_ind = argsort_ind[small_elements_masks]
|
| 453 |
+
sorted_length_col = sorted_length_col[small_elements_masks]
|
| 454 |
+
|
| 455 |
+
size = len(sorted_length_col)
|
| 456 |
+
splits = []
|
| 457 |
+
begin, end = 0, 0
|
| 458 |
+
while end < size:
|
| 459 |
+
current_max_len = sorted_length_col[begin]
|
| 460 |
+
begin = end
|
| 461 |
+
while end < size:
|
| 462 |
+
current_max_len = max(current_max_len, sorted_length_col[end])
|
| 463 |
+
if current_max_len * (end + 1 - begin) > max_tokens:
|
| 464 |
+
splits.append(argsort_ind[begin:end])
|
| 465 |
+
break
|
| 466 |
+
end += 1
|
| 467 |
+
else:
|
| 468 |
+
if begin < size:
|
| 469 |
+
splits.append(argsort_ind[begin:])
|
| 470 |
+
|
| 471 |
+
# adding big sample at the end one by one
|
| 472 |
+
if not drop_long_sample and len(big_elements_inds):
|
| 473 |
+
splits.extend(np.array_split(big_elements_inds, len(big_elements_inds)))
|
| 474 |
+
|
| 475 |
+
return splits
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def build_batching_loop_over_one_table(
|
| 479 |
+
table: pa.Table,
|
| 480 |
+
order_by_length: bool = False,
|
| 481 |
+
length_column: List[Optional[str]] = None,
|
| 482 |
+
batch_size: Optional[int] = None,
|
| 483 |
+
max_tokens: Optional[int] = None,
|
| 484 |
+
shuffle: bool = True,
|
| 485 |
+
seed: Optional[int] = None,
|
| 486 |
+
num_parallel_calls: int = 1,
|
| 487 |
+
) -> DataPipeline:
|
| 488 |
+
if max_tokens is not None:
|
| 489 |
+
assert length_column is not None, (
|
| 490 |
+
"Need to provide a column to compute the number of tokens"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
random_state = np.random.RandomState(seed)
|
| 494 |
+
if length_column is not None and len(length_column) > 0:
|
| 495 |
+
length_col = reduce(
|
| 496 |
+
np.add, (compute_rows_length(table[lc]) for lc in length_column)
|
| 497 |
+
)
|
| 498 |
+
else:
|
| 499 |
+
if shuffle:
|
| 500 |
+
length_col = random_state.randint(0, 2**23, len(table))
|
| 501 |
+
else:
|
| 502 |
+
length_col = np.zeros(len(table), dtype=np.int32)
|
| 503 |
+
|
| 504 |
+
if batch_size is not None:
|
| 505 |
+
if order_by_length:
|
| 506 |
+
sorting_ind = np.argsort(length_col, kind="stable")
|
| 507 |
+
else:
|
| 508 |
+
sorting_ind = np.arange(len(length_col), dtype=np.int32)
|
| 509 |
+
|
| 510 |
+
order_tt = pa.Table.from_arrays([pa.array(sorting_ind)], ["order"])
|
| 511 |
+
batches = [ind["order"] for ind in order_tt.to_batches(batch_size)]
|
| 512 |
+
elif max_tokens is not None:
|
| 513 |
+
batches = compute_length_splits(
|
| 514 |
+
length_col, max_tokens, order_by_length=order_by_length
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
raise ValueError("unknown batching method")
|
| 518 |
+
|
| 519 |
+
if shuffle:
|
| 520 |
+
batches = [batches[i] for i in random_state.permutation(len(batches))]
|
| 521 |
+
|
| 522 |
+
def _getter(ind):
|
| 523 |
+
try:
|
| 524 |
+
tt = table.take(ind)
|
| 525 |
+
return tt
|
| 526 |
+
except Exception as e:
|
| 527 |
+
logger.warn(f"Unexpected error : \n {str(e)} \n {table} \n {ind}")
|
| 528 |
+
return None
|
| 529 |
+
|
| 530 |
+
return (
|
| 531 |
+
read_sequence(batches)
|
| 532 |
+
.map(_getter, num_parallel_calls=num_parallel_calls)
|
| 533 |
+
.filter(lambda tt: bool(tt is not None))
|
| 534 |
+
.and_return(max_num_warnings=4)
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def filter_long_short_sentence_document(
|
| 539 |
+
batch: pa.Table,
|
| 540 |
+
column: str,
|
| 541 |
+
max_sentence_len: Optional[int],
|
| 542 |
+
min_sentence_len: Optional[int],
|
| 543 |
+
) -> pa.Table:
|
| 544 |
+
assert max_sentence_len is not None or min_sentence_len is not None
|
| 545 |
+
if min_sentence_len is None:
|
| 546 |
+
min_sentence_len = 0
|
| 547 |
+
|
| 548 |
+
if max_sentence_len is None:
|
| 549 |
+
max_sentence_len = 2**32
|
| 550 |
+
|
| 551 |
+
tt = pl.from_arrow(batch.select([column]), rechunk=False)
|
| 552 |
+
assert isinstance(tt, pl.DataFrame)
|
| 553 |
+
filter_ = tt.with_columns(
|
| 554 |
+
(
|
| 555 |
+
pl.col(column).list.eval(pl.col("").str.len_bytes()).list.max()
|
| 556 |
+
<= max_sentence_len
|
| 557 |
+
)
|
| 558 |
+
& (
|
| 559 |
+
pl.col(column).list.eval(pl.col("").str.len_bytes()).list.min()
|
| 560 |
+
<= max_sentence_len
|
| 561 |
+
)
|
| 562 |
+
)[column].to_arrow()
|
| 563 |
+
|
| 564 |
+
if pc.all(filter_).as_py():
|
| 565 |
+
return batch
|
| 566 |
+
return batch.filter(filter_)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def filter_document_by_quality(
|
| 570 |
+
batch: pa.Table,
|
| 571 |
+
column: str,
|
| 572 |
+
min_score=Optional[float],
|
| 573 |
+
max_score=Optional[float],
|
| 574 |
+
) -> pa.Table:
|
| 575 |
+
if min_score is None and max_score is None:
|
| 576 |
+
return batch
|
| 577 |
+
|
| 578 |
+
if min_score is None:
|
| 579 |
+
min_score = -float(np.inf)
|
| 580 |
+
if max_score is None:
|
| 581 |
+
max_score = float(np.inf)
|
| 582 |
+
|
| 583 |
+
tt = pl.from_arrow(batch.select([column]), rechunk=False)
|
| 584 |
+
assert isinstance(tt, pl.DataFrame)
|
| 585 |
+
filter_ = tt.with_columns(
|
| 586 |
+
(pl.col(column).list.max() <= max_score)
|
| 587 |
+
& (pl.col(column).list.min() >= min_score)
|
| 588 |
+
)[column].to_arrow()
|
| 589 |
+
if pc.all(filter_).as_py():
|
| 590 |
+
return batch
|
| 591 |
+
return batch.filter(filter_)
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def renaming(inp: NestedDict, mapper: dict, name: str) -> NestedDict:
|
| 595 |
+
renamed_name = ColumnsNames.dataset_name.value
|
| 596 |
+
if isinstance(inp, dict):
|
| 597 |
+
out_dict = {mapper.get(key, key): value for key, value in inp.items()}
|
| 598 |
+
out_dict[renamed_name] = name
|
| 599 |
+
res = out_dict
|
| 600 |
+
elif isinstance(inp, pd.DataFrame):
|
| 601 |
+
out_pd = inp.rename(mapper=mapper, axis=1)
|
| 602 |
+
out_pd[renamed_name] = name
|
| 603 |
+
res = out_pd
|
| 604 |
+
elif isinstance(inp, pa.Table):
|
| 605 |
+
out_pa: pa.Table = inp.rename_columns(
|
| 606 |
+
[mapper.get(key, key) for key in inp.column_names],
|
| 607 |
+
)
|
| 608 |
+
out_pa = out_pa.append_column(renamed_name, pa.array([name] * len(out_pa)))
|
| 609 |
+
res = out_pa
|
| 610 |
+
return res
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def materialize_sequence(
|
| 614 |
+
table: pa.Table,
|
| 615 |
+
column_sequence: List[SonarTextColumn],
|
| 616 |
+
vector_name: str,
|
| 617 |
+
text_name: str,
|
| 618 |
+
) -> pa.Table:
|
| 619 |
+
"""
|
| 620 |
+
Given `table`, it materializes `column_sequence`.
|
| 621 |
+
Different elements from `column_sequence` are concatenated sequentially.
|
| 622 |
+
Constant text elements will be sentencized and sonarized.
|
| 623 |
+
It also accepts columns with single text and embeddings values instead of list.
|
| 624 |
+
|
| 625 |
+
It returns a new table with two new columns with sequences of sentences and corresponding sequences of their embeddings.
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
table_len = len(table)
|
| 629 |
+
sentences_seq = []
|
| 630 |
+
vectors_seq = []
|
| 631 |
+
|
| 632 |
+
target_dtype = None
|
| 633 |
+
for col in column_sequence:
|
| 634 |
+
if col.sonar_column is not None:
|
| 635 |
+
target_dtype = table[col.sonar_column].type
|
| 636 |
+
break
|
| 637 |
+
|
| 638 |
+
for col in column_sequence:
|
| 639 |
+
if col.text_value is not None:
|
| 640 |
+
vectors, sentences = _get_embed_sentences(col.text_value)
|
| 641 |
+
vectors_extended = pa.chunked_array(
|
| 642 |
+
[pa.ListArray.from_arrays([0, len(vectors)], vectors)] * table_len
|
| 643 |
+
)
|
| 644 |
+
sentences_extended = pa.chunked_array(
|
| 645 |
+
[pa.ListArray.from_arrays([0, len(sentences)], sentences)] * table_len
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
assert (col.text_column is not None) and (col.sonar_column is not None)
|
| 649 |
+
vectors_extended = table[col.sonar_column]
|
| 650 |
+
sentences_extended = table[col.text_column]
|
| 651 |
+
if is_list_like(vectors_extended):
|
| 652 |
+
assert is_list_like(sentences_extended)
|
| 653 |
+
else:
|
| 654 |
+
vectors_extended = simple_array_to_nested(vectors_extended)
|
| 655 |
+
sentences_extended = simple_array_to_nested(sentences_extended)
|
| 656 |
+
|
| 657 |
+
if target_dtype and vectors_extended.type != target_dtype:
|
| 658 |
+
vectors_extended = vectors_extended.cast(target_dtype)
|
| 659 |
+
|
| 660 |
+
vectors_seq.append(vectors_extended)
|
| 661 |
+
sentences_seq.append(sentences_extended)
|
| 662 |
+
|
| 663 |
+
new_vectors_array = hstack_pyarray_list(*vectors_seq)
|
| 664 |
+
new_sentences_array = hstack_pyarray_list(*sentences_seq)
|
| 665 |
+
del vectors_seq, sentences_seq
|
| 666 |
+
table = table.append_column(vector_name, new_vectors_array)
|
| 667 |
+
table = table.append_column(text_name, new_sentences_array)
|
| 668 |
+
return table
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
@njit
|
| 672 |
+
def _get_hierarchical_indices_and_offsets(
|
| 673 |
+
pagaraphs_lengths: List[np.ndarray], max_seq_len: int
|
| 674 |
+
):
|
| 675 |
+
indices = []
|
| 676 |
+
new_lens = [0]
|
| 677 |
+
hierarchy_new_lens = [0]
|
| 678 |
+
|
| 679 |
+
for i, current_lens in enumerate(pagaraphs_lengths):
|
| 680 |
+
tmp_lens_sum = 0
|
| 681 |
+
nb_blocks = 0
|
| 682 |
+
for ll in current_lens:
|
| 683 |
+
if ll + tmp_lens_sum > max_seq_len:
|
| 684 |
+
indices.append(i)
|
| 685 |
+
new_lens.append(new_lens[-1] + tmp_lens_sum)
|
| 686 |
+
hierarchy_new_lens.append(hierarchy_new_lens[-1] + nb_blocks)
|
| 687 |
+
|
| 688 |
+
tmp_lens_sum = ll
|
| 689 |
+
nb_blocks = 0
|
| 690 |
+
else:
|
| 691 |
+
tmp_lens_sum += ll
|
| 692 |
+
|
| 693 |
+
nb_blocks += 1
|
| 694 |
+
|
| 695 |
+
if nb_blocks > 0:
|
| 696 |
+
indices.append(i)
|
| 697 |
+
new_lens.append(new_lens[-1] + tmp_lens_sum)
|
| 698 |
+
hierarchy_new_lens.append(hierarchy_new_lens[-1] + nb_blocks)
|
| 699 |
+
|
| 700 |
+
return (
|
| 701 |
+
np.array(indices, dtype=np.int32),
|
| 702 |
+
np.array(new_lens, dtype=np.int32),
|
| 703 |
+
np.array(hierarchy_new_lens, dtype=np.int32),
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def hierarchical_explode_table_with_max_length(
|
| 708 |
+
table: pa.Table,
|
| 709 |
+
columns: Union[str, List[str]],
|
| 710 |
+
max_seq_len: int,
|
| 711 |
+
page_len_column: str,
|
| 712 |
+
page_embs_columns: Optional[Union[str, List[str]]],
|
| 713 |
+
) -> pa.Table:
|
| 714 |
+
if isinstance(columns, str):
|
| 715 |
+
columns = [columns]
|
| 716 |
+
|
| 717 |
+
if isinstance(page_embs_columns, str):
|
| 718 |
+
page_embs_columns = [page_embs_columns]
|
| 719 |
+
elif page_embs_columns is None:
|
| 720 |
+
page_embs_columns = []
|
| 721 |
+
|
| 722 |
+
assert len(columns) > 0
|
| 723 |
+
|
| 724 |
+
cols = [pc.fill_null(table[columns[0]], [None])]
|
| 725 |
+
lengths = pc.list_value_length(cols[0]).to_numpy()
|
| 726 |
+
|
| 727 |
+
for name in columns[1:]:
|
| 728 |
+
col = pc.fill_null(table[name], [None])
|
| 729 |
+
# checking that all columns list structures are parallel
|
| 730 |
+
assert (lengths == pc.list_value_length(col).to_numpy()).all()
|
| 731 |
+
cols.append(col)
|
| 732 |
+
|
| 733 |
+
pagaraphs_lengths = table[page_len_column].to_pandas().to_list()
|
| 734 |
+
# assert [x.sum() for x pagaraphs_lengths] == lengths.tolist()
|
| 735 |
+
# next unroll with max_seq_len
|
| 736 |
+
indices, new_offests, hierarchy_offsets = _get_hierarchical_indices_and_offsets(
|
| 737 |
+
pagaraphs_lengths, max_seq_len
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
other_columns = list(table.schema.names)
|
| 741 |
+
for name in set(columns + [page_len_column] + page_embs_columns):
|
| 742 |
+
other_columns.remove(name)
|
| 743 |
+
|
| 744 |
+
remaining_table = table.select(other_columns).take(indices)
|
| 745 |
+
|
| 746 |
+
result_dict = {}
|
| 747 |
+
for name in other_columns:
|
| 748 |
+
result_dict[name] = remaining_table[name]
|
| 749 |
+
|
| 750 |
+
for name, col in zip(columns, cols):
|
| 751 |
+
rolled_array = pa.ListArray.from_arrays(
|
| 752 |
+
offsets=new_offests,
|
| 753 |
+
values=pyarrow_column_to_array(pc.list_flatten(col)),
|
| 754 |
+
)
|
| 755 |
+
result_dict[name] = rolled_array
|
| 756 |
+
|
| 757 |
+
for name in set([page_len_column] + page_embs_columns):
|
| 758 |
+
col = table[name]
|
| 759 |
+
rolled_array = pa.ListArray.from_arrays(
|
| 760 |
+
offsets=hierarchy_offsets,
|
| 761 |
+
values=pyarrow_column_to_array(pc.list_flatten(col)),
|
| 762 |
+
)
|
| 763 |
+
result_dict[name] = rolled_array
|
| 764 |
+
|
| 765 |
+
return pa.Table.from_pydict(result_dict, schema=table.schema)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def filter_table_with_different_lengths(
|
| 769 |
+
table: pa.Table, columns: List[str]
|
| 770 |
+
) -> pa.Table:
|
| 771 |
+
if len(columns) <= 1 or not all(is_list_like(table[col]) for col in columns):
|
| 772 |
+
return table
|
| 773 |
+
|
| 774 |
+
ref_lengths = pc.list_value_length(table[columns[0]])
|
| 775 |
+
for col in columns[1:]:
|
| 776 |
+
same_lens = pc.equal(pc.list_value_length(table[col]), ref_lengths)
|
| 777 |
+
if pc.all(same_lens).as_py():
|
| 778 |
+
continue
|
| 779 |
+
else:
|
| 780 |
+
logger.warn(
|
| 781 |
+
f"filtering table whose nb sentences and nb sonar vectors are aligned, keeping {pc.sum(same_lens).as_py()} rows out of{len(table)}"
|
| 782 |
+
)
|
| 783 |
+
table = table.filter(same_lens)
|
| 784 |
+
return table
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
@dataclass
|
| 788 |
+
class PFSState:
|
| 789 |
+
nb_fully_read_files: int = 0
|
| 790 |
+
nb_current_file_read_fragements: int = 0
|
| 791 |
+
total_nb_fragments: int = 0
|
| 792 |
+
total_nb_rows: int = 0
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class ParquetFragmentStreamer:
|
| 796 |
+
def __init__(
|
| 797 |
+
self,
|
| 798 |
+
parquet_ds: pq.ParquetDataset,
|
| 799 |
+
split_to_row_groups: bool = True,
|
| 800 |
+
limit_options: Optional[ParquetDatasetLimitOptions] = None,
|
| 801 |
+
read_state: Optional[PFSState] = None,
|
| 802 |
+
):
|
| 803 |
+
self.split_to_row_groups = split_to_row_groups
|
| 804 |
+
self.limit_options = limit_options or ParquetDatasetLimitOptions()
|
| 805 |
+
self.parquet_ds = parquet_ds
|
| 806 |
+
|
| 807 |
+
if read_state is not None:
|
| 808 |
+
self.state = read_state
|
| 809 |
+
else:
|
| 810 |
+
self.reset_state()
|
| 811 |
+
|
| 812 |
+
def reset_state(self):
|
| 813 |
+
self.state = PFSState()
|
| 814 |
+
|
| 815 |
+
def __reduce__(self):
|
| 816 |
+
return (
|
| 817 |
+
self.__class__,
|
| 818 |
+
(
|
| 819 |
+
self.parquet_ds,
|
| 820 |
+
self.split_to_row_groups,
|
| 821 |
+
self.limit_options,
|
| 822 |
+
self.state,
|
| 823 |
+
),
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
def truncate_files(
|
| 827 |
+
self,
|
| 828 |
+
parquet_ds: pq.ParquetDataset,
|
| 829 |
+
fraction_of_files: Optional[float],
|
| 830 |
+
nb_files: Optional[int],
|
| 831 |
+
) -> List[pa.dataset.Fragment]:
|
| 832 |
+
file_ds_fragments = get_dataset_fragments(
|
| 833 |
+
parquet_ds, parquet_ds._filter_expression
|
| 834 |
+
)
|
| 835 |
+
self.proxy_ds_path = "/".join(parquet_ds.files[0].split("=")[0].split("/")[:-1])
|
| 836 |
+
logger.info(
|
| 837 |
+
f"{self.proxy_ds_path} : full number of files {len(file_ds_fragments)}"
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
if fraction_of_files is not None:
|
| 841 |
+
file_ds_fragments = file_ds_fragments[
|
| 842 |
+
: max(
|
| 843 |
+
int(round(fraction_of_files * len(file_ds_fragments))),
|
| 844 |
+
1,
|
| 845 |
+
)
|
| 846 |
+
]
|
| 847 |
+
logger.info(
|
| 848 |
+
f"{self.proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of fraction_of_files={fraction_of_files}"
|
| 849 |
+
)
|
| 850 |
+
if nb_files is not None and nb_files < len(file_ds_fragments):
|
| 851 |
+
file_ds_fragments = file_ds_fragments[:nb_files]
|
| 852 |
+
logger.info(
|
| 853 |
+
f"{self.proxy_ds_path} : reducing number of files to {len(file_ds_fragments)} because of nb_files={nb_files}"
|
| 854 |
+
)
|
| 855 |
+
return file_ds_fragments
|
| 856 |
+
|
| 857 |
+
def __iter__(self):
|
| 858 |
+
limit_options = self.limit_options
|
| 859 |
+
|
| 860 |
+
file_ds_fragments = self.truncate_files(
|
| 861 |
+
self.parquet_ds,
|
| 862 |
+
limit_options.fraction_of_files,
|
| 863 |
+
limit_options.nb_files,
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
if not self.split_to_row_groups:
|
| 867 |
+
for frag in file_ds_fragments[
|
| 868 |
+
self.state.nb_fully_read_files : limit_options.nb_fragments
|
| 869 |
+
]:
|
| 870 |
+
self.state.nb_fully_read_files += 1
|
| 871 |
+
yield frag
|
| 872 |
+
|
| 873 |
+
if limit_options.nb_rows is not None:
|
| 874 |
+
self.state.total_nb_rows += frag.count_rows()
|
| 875 |
+
if self.state.total_nb_rows >= limit_options.nb_rows:
|
| 876 |
+
break
|
| 877 |
+
else:
|
| 878 |
+
early_stop = False
|
| 879 |
+
logger.info(f"{self.proxy_ds_path} : starting split in row groups")
|
| 880 |
+
|
| 881 |
+
for new_file in file_ds_fragments[self.state.nb_fully_read_files :]:
|
| 882 |
+
new_file_fragments = split_fragment_in_row_groups(new_file)
|
| 883 |
+
new_file_fragments = new_file_fragments[
|
| 884 |
+
self.state.nb_current_file_read_fragements :
|
| 885 |
+
]
|
| 886 |
+
if limit_options.nb_rows is not None:
|
| 887 |
+
new_file_fragments_stats = [
|
| 888 |
+
frag.row_groups[0].num_rows for frag in new_file_fragments
|
| 889 |
+
]
|
| 890 |
+
else:
|
| 891 |
+
new_file_fragments_stats = [0] * len(new_file_fragments)
|
| 892 |
+
|
| 893 |
+
for nb_row, frag in zip(new_file_fragments_stats, new_file_fragments):
|
| 894 |
+
self.state.total_nb_rows += nb_row
|
| 895 |
+
self.state.total_nb_fragments += 1
|
| 896 |
+
self.state.nb_current_file_read_fragements += (
|
| 897 |
+
1 # increate before yield
|
| 898 |
+
)
|
| 899 |
+
yield frag
|
| 900 |
+
|
| 901 |
+
if (
|
| 902 |
+
limit_options.nb_fragments is not None
|
| 903 |
+
and self.state.total_nb_fragments >= limit_options.nb_fragments
|
| 904 |
+
):
|
| 905 |
+
early_stop = True
|
| 906 |
+
if limit_options.nb_rows is not None:
|
| 907 |
+
logger.info(
|
| 908 |
+
f"{self.proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached with around {self.state.total_nb_rows} rows"
|
| 909 |
+
)
|
| 910 |
+
else:
|
| 911 |
+
logger.info(
|
| 912 |
+
f"{self.proxy_ds_path} : nb_fragments limit {limit_options.nb_fragments} was reached"
|
| 913 |
+
)
|
| 914 |
+
break
|
| 915 |
+
if (
|
| 916 |
+
limit_options.nb_rows is not None
|
| 917 |
+
and self.state.total_nb_rows >= limit_options.nb_rows
|
| 918 |
+
):
|
| 919 |
+
early_stop = True
|
| 920 |
+
logger.info(
|
| 921 |
+
f"{self.proxy_ds_path} : nb_rows limit {limit_options.nb_rows} was reached with around {self.state.total_nb_fragments} fragments"
|
| 922 |
+
)
|
| 923 |
+
break
|
| 924 |
+
if early_stop:
|
| 925 |
+
break
|
| 926 |
+
# only when full file is read we increament this
|
| 927 |
+
self.state.nb_fully_read_files += 1
|
| 928 |
+
self.state.nb_current_file_read_fragements = 0
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
@dataclass
|
| 932 |
+
class ShuffledIteratorState:
|
| 933 |
+
epoch_count: int
|
| 934 |
+
current_window: List[Any]
|
| 935 |
+
index: int
|
| 936 |
+
random_state: np.random.RandomState
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
class ShuffledIterator(Iterator[Any]):
|
| 940 |
+
def __init__(
|
| 941 |
+
self,
|
| 942 |
+
iterator,
|
| 943 |
+
window_size: int,
|
| 944 |
+
nb_epoch: int,
|
| 945 |
+
seed: Optional[int],
|
| 946 |
+
state: Optional[ShuffledIteratorState] = None,
|
| 947 |
+
):
|
| 948 |
+
self.base_iterator = iterator
|
| 949 |
+
self.window_size = window_size
|
| 950 |
+
self.seed = seed
|
| 951 |
+
self.nb_epoch = nb_epoch
|
| 952 |
+
|
| 953 |
+
if state is None:
|
| 954 |
+
state = ShuffledIteratorState(
|
| 955 |
+
random_state=np.random.RandomState(self.seed),
|
| 956 |
+
epoch_count=0,
|
| 957 |
+
current_window=[],
|
| 958 |
+
index=0,
|
| 959 |
+
)
|
| 960 |
+
self.state = state
|
| 961 |
+
self.window_iterator = None
|
| 962 |
+
|
| 963 |
+
def reset_state(self):
|
| 964 |
+
self.state.random_state = np.random.RandomState(self.seed)
|
| 965 |
+
self.state.epoch_count = 0
|
| 966 |
+
self._reset_inner()
|
| 967 |
+
|
| 968 |
+
def __reduce__(self):
|
| 969 |
+
return (
|
| 970 |
+
self.__class__,
|
| 971 |
+
(
|
| 972 |
+
self.base_iterator,
|
| 973 |
+
self.window_size,
|
| 974 |
+
self.nb_epoch,
|
| 975 |
+
self.seed,
|
| 976 |
+
self.state,
|
| 977 |
+
),
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
def _reset_inner(self):
|
| 981 |
+
self.base_iterator.reset_state()
|
| 982 |
+
self.state.index = 0
|
| 983 |
+
self.state.current_window = []
|
| 984 |
+
self.window_iterator = None
|
| 985 |
+
|
| 986 |
+
def __iter__(self):
|
| 987 |
+
return self
|
| 988 |
+
|
| 989 |
+
def __next__(self) -> Any:
|
| 990 |
+
if self.state.epoch_count >= self.nb_epoch:
|
| 991 |
+
raise StopIteration
|
| 992 |
+
|
| 993 |
+
# If current window is exhausted, fetch the next window
|
| 994 |
+
if self.window_iterator is None:
|
| 995 |
+
self.window_iterator = batched(self.base_iterator, self.window_size) # type: ignore
|
| 996 |
+
assert self.window_iterator is not None
|
| 997 |
+
|
| 998 |
+
if self.state.index >= len(self.state.current_window):
|
| 999 |
+
try:
|
| 1000 |
+
# Get the next window batch
|
| 1001 |
+
window = next(self.window_iterator)
|
| 1002 |
+
window = np.array(window, dtype="O")
|
| 1003 |
+
self.state.random_state.shuffle(window)
|
| 1004 |
+
self.state.current_window = window
|
| 1005 |
+
self.state.index = 0
|
| 1006 |
+
except StopIteration:
|
| 1007 |
+
# If no more batches, increment epoch count and reset iterator
|
| 1008 |
+
self.state.epoch_count += 1
|
| 1009 |
+
self._reset_inner()
|
| 1010 |
+
return self.__next__()
|
| 1011 |
+
|
| 1012 |
+
# Return the next element from the current window
|
| 1013 |
+
result = self.state.current_window[self.state.index]
|
| 1014 |
+
self.state.index += 1
|
| 1015 |
+
return result
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def stream_parquet_fragments(
|
| 1019 |
+
parquet_ds: pq.ParquetDataset,
|
| 1020 |
+
nb_epochs: int,
|
| 1021 |
+
split_to_row_groups: bool = True,
|
| 1022 |
+
shuffle: bool = True,
|
| 1023 |
+
seed: Optional[int] = None,
|
| 1024 |
+
limit_options: Optional[ParquetDatasetLimitOptions] = None,
|
| 1025 |
+
shuffling_window: int = 200,
|
| 1026 |
+
) -> DataPipelineBuilder:
|
| 1027 |
+
fragments_iterator = ParquetFragmentStreamer(
|
| 1028 |
+
parquet_ds=parquet_ds,
|
| 1029 |
+
split_to_row_groups=split_to_row_groups,
|
| 1030 |
+
limit_options=limit_options,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
def reset_fn(iterator):
|
| 1034 |
+
iterator.reset_state()
|
| 1035 |
+
return iterator
|
| 1036 |
+
|
| 1037 |
+
pipeline = read_iterator(
|
| 1038 |
+
ShuffledIterator(
|
| 1039 |
+
fragments_iterator,
|
| 1040 |
+
window_size=shuffling_window if shuffle else 1,
|
| 1041 |
+
nb_epoch=nb_epochs,
|
| 1042 |
+
seed=seed,
|
| 1043 |
+
),
|
| 1044 |
+
reset_fn,
|
| 1045 |
+
infinite=False,
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
return pipeline.map(SafeFragment)
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def get_row_group_level_metadata(
|
| 1052 |
+
dataset: pq.ParquetDataset,
|
| 1053 |
+
columns: Optional[List[str]] = None,
|
| 1054 |
+
nb_jobs: int = 40,
|
| 1055 |
+
max_fragments: int = -1,
|
| 1056 |
+
seed: int = 123,
|
| 1057 |
+
) -> pd.DataFrame:
|
| 1058 |
+
"""
|
| 1059 |
+
Parses row group level metadata from a Parquet dataset and returns it as a pandas DataFrame.
|
| 1060 |
+
It's similar to `get_parquet_dataset_metadata`
|
| 1061 |
+
but present a unnested view on row groups statistics for only a subset of columns.
|
| 1062 |
+
This function can be used for any kind of downstream analysis.
|
| 1063 |
+
|
| 1064 |
+
It uses joblib for parallel processing
|
| 1065 |
+
and tqdm for progress tracking, which are good practices for handling large datasets.
|
| 1066 |
+
|
| 1067 |
+
Parameters:
|
| 1068 |
+
- dataset (pq.ParquetDataset): The Parquet dataset to parse.
|
| 1069 |
+
- columns (list of str, optional): The columns to include in the output DataFrame. If not specified, all columns are included.
|
| 1070 |
+
For `columns=[]` no column-vise information will be profided (which is generally much faster).
|
| 1071 |
+
- nb_jobs (int, default=40): The number of parallel jobs to run.
|
| 1072 |
+
- max_fragments (int, default=-1): The maximum number of fragments to include. If -1, all fragments are included.
|
| 1073 |
+
- seed (int, default=123): The seed for the random number generator, used when selecting fragments.
|
| 1074 |
+
|
| 1075 |
+
Returns:
|
| 1076 |
+
- pd.DataFrame: A DataFrame containing the row group level metadata.
|
| 1077 |
+
Example:
|
| 1078 |
+
>>> import pyarrow as pa
|
| 1079 |
+
>>> import pyarrow.fs
|
| 1080 |
+
>>> import pyarrow.compute as pc
|
| 1081 |
+
>>> fs, parquet_uri = pa.fs.FileSystem.from_uri("s3://<bucket_name>/<dataset_name>/")
|
| 1082 |
+
>>> dataset = pq.ParquetDataset(parquet_uri, filesystem=fs, filters=pc.equal(pc.field("split"), "validation"))
|
| 1083 |
+
>>> df_stats = get_row_group_level_metadata(dataset, columns=["col1", "col2", ...])
|
| 1084 |
+
"""
|
| 1085 |
+
assert max_fragments >= -1
|
| 1086 |
+
fragments = list(dataset._dataset.get_fragments(filter=dataset._filter_expression))
|
| 1087 |
+
|
| 1088 |
+
if max_fragments != -1 and max_fragments < len(fragments):
|
| 1089 |
+
fragments = (
|
| 1090 |
+
np.random.RandomState(seed)
|
| 1091 |
+
.choice(np.array(fragments, dtype="O"), max_fragments, replace=False)
|
| 1092 |
+
.tolist()
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
physical_schema = fragments[0].physical_schema
|
| 1096 |
+
|
| 1097 |
+
columns = columns if columns is not None else physical_schema.names
|
| 1098 |
+
# taking only existing columns
|
| 1099 |
+
non_existing_columns = tuple(set(columns) - set(physical_schema.names))
|
| 1100 |
+
if non_existing_columns:
|
| 1101 |
+
print(
|
| 1102 |
+
"Following colums are not present in physical schema and will be ignored",
|
| 1103 |
+
non_existing_columns,
|
| 1104 |
+
)
|
| 1105 |
+
columns = [col for col in columns if col in physical_schema.names]
|
| 1106 |
+
|
| 1107 |
+
columns_index = [physical_schema.get_field_index(col) for col in columns]
|
| 1108 |
+
|
| 1109 |
+
columns_to_exclude = set(["row_group_id", "num_rows", "total_byte_size"]) & set(
|
| 1110 |
+
columns
|
| 1111 |
+
)
|
| 1112 |
+
assert len(columns_to_exclude) == 0, (
|
| 1113 |
+
f"names conflict, rename/remove : {columns_to_exclude}"
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
def get_one_row_group_stats(row_group):
|
| 1117 |
+
metadata = row_group.metadata
|
| 1118 |
+
info = {
|
| 1119 |
+
"row_group_id": row_group.id,
|
| 1120 |
+
"num_rows": metadata.num_rows,
|
| 1121 |
+
"total_byte_size": metadata.total_byte_size,
|
| 1122 |
+
}
|
| 1123 |
+
for col, ind in zip(columns, columns_index):
|
| 1124 |
+
info[col] = metadata.column(ind).to_dict()
|
| 1125 |
+
return info
|
| 1126 |
+
|
| 1127 |
+
def get_fragment_stats(frag):
|
| 1128 |
+
return {
|
| 1129 |
+
"rg_stats": list(map(get_one_row_group_stats, frag.row_groups)),
|
| 1130 |
+
"parquet_file_path": frag.path,
|
| 1131 |
+
**get_partition_keys(frag.partition_expression),
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
stats = Parallel(nb_jobs, backend="threading")(
|
| 1135 |
+
delayed(get_fragment_stats)(frag) for frag in tqdm(fragments)
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
stats = pd.DataFrame(stats).explode("rg_stats")
|
| 1139 |
+
flatten_row_df = pd.DataFrame(stats.pop("rg_stats").tolist(), index=stats.index)
|
| 1140 |
+
result_df = pd.concat([stats, flatten_row_df], axis=1)
|
| 1141 |
+
return result_df
|
lcm/datasets/sentence_splitter_pipeline.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import gc
|
| 9 |
+
import typing as tp
|
| 10 |
+
from builtins import enumerate
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
|
| 13 |
+
import numba
|
| 14 |
+
import numpy as np
|
| 15 |
+
import polars as pl
|
| 16 |
+
import pyarrow as pa
|
| 17 |
+
import pyarrow.compute as pc
|
| 18 |
+
import torch
|
| 19 |
+
from stopes.modules.partitioned_data_mapper import BatchMapper
|
| 20 |
+
from stopes.modules.preprocess.sonar_text_embedding import (
|
| 21 |
+
SonarTextBatchEmbedder,
|
| 22 |
+
SonarTextEmbedderConfig,
|
| 23 |
+
)
|
| 24 |
+
from stopes.utils.arrow_utils import (
|
| 25 |
+
apply_on_nested_array,
|
| 26 |
+
)
|
| 27 |
+
from wtpsplit import SaT, indices_to_sentences
|
| 28 |
+
|
| 29 |
+
from lcm.datasets.sentence_splitting import remove_emojis, resplit
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@numba.jit(nopython=True)
|
| 33 |
+
def insert_elements(arr, max_diff):
|
| 34 |
+
"""
|
| 35 |
+
Insert elements into an array to ensure no two consecutive elements have a difference greater than max_diff.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
arr (numpy array): The original array of integers.
|
| 39 |
+
max_diff (int): The maximum allowed difference between consecutive elements after insertion.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
numpy array: The modified array with additional elements inserted to satisfy the max_diff condition.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
result = []
|
| 46 |
+
for i in range(len(arr) - 1):
|
| 47 |
+
result.append(arr[i])
|
| 48 |
+
diff = arr[i + 1] - arr[i]
|
| 49 |
+
if diff > max_diff:
|
| 50 |
+
num_insert = int(diff // max_diff)
|
| 51 |
+
step_size = diff / (num_insert + 1)
|
| 52 |
+
last_val = arr[i]
|
| 53 |
+
for j in range(1, num_insert + 1):
|
| 54 |
+
val = round(last_val + step_size)
|
| 55 |
+
if val < arr[i + 1]:
|
| 56 |
+
result.append(val)
|
| 57 |
+
last_val = val
|
| 58 |
+
result.append(arr[-1])
|
| 59 |
+
return np.array(result, dtype=np.int32)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@numba.jit(nopython=True)
|
| 63 |
+
def merge_small_intervals(
|
| 64 |
+
lenghts: np.ndarray, min_merging_length: int = 2, max_merge_length: int = 15
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Merge small intervals in a list of lengths.
|
| 68 |
+
This function takes a list of lengths and merges any intervals that are smaller than or equal to `min_merging_length`
|
| 69 |
+
into larger intervals. The merged intervals are limited to a maximum length of `max_merge_length`.
|
| 70 |
+
Parameters:
|
| 71 |
+
lengths (np.ndarray): A list of lengths to be merged.
|
| 72 |
+
min_merging_length (int): The minimum length of an interval to be merged. Defaults to 2.
|
| 73 |
+
max_merge_length (int): The maximum length of a merged interval. Defaults to 15.
|
| 74 |
+
Returns:
|
| 75 |
+
list: A list of merged lengths.
|
| 76 |
+
|
| 77 |
+
Examples:
|
| 78 |
+
>>> merge_small_intervals(np.array([1, 2, 3, 4, 5]))
|
| 79 |
+
array([3, 3, 4, 5], dtype=int32)
|
| 80 |
+
>>> merge_small_intervals(np.array([1, 1, 1, 1, 1]))
|
| 81 |
+
array([5], dtype=int32)
|
| 82 |
+
>>> merge_small_intervals(np.array([1, 2, 3, 2, 2, 2, 4, 1, 1, 5]))
|
| 83 |
+
array([3, 3, 6, 4, 2, 5], dtype=int32)
|
| 84 |
+
"""
|
| 85 |
+
merge_arr = []
|
| 86 |
+
merge_len = 0
|
| 87 |
+
|
| 88 |
+
for curr_len in lenghts:
|
| 89 |
+
if curr_len <= min_merging_length and merge_len + curr_len <= max_merge_length:
|
| 90 |
+
merge_len += curr_len
|
| 91 |
+
else:
|
| 92 |
+
if merge_len > 0:
|
| 93 |
+
merge_arr.append(merge_len)
|
| 94 |
+
merge_len = 0
|
| 95 |
+
merge_arr.append(curr_len)
|
| 96 |
+
if merge_len > 0:
|
| 97 |
+
merge_arr.append(merge_len)
|
| 98 |
+
|
| 99 |
+
return np.array(merge_arr, dtype=np.int32)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@numba.jit(nopython=True)
|
| 103 |
+
def find_closest_indices(arr1, arr2):
|
| 104 |
+
"""
|
| 105 |
+
Find indices of the closest elements in arr2 for each element in arr1.
|
| 106 |
+
|
| 107 |
+
Parameters:
|
| 108 |
+
arr1 (numpy array): The array containing the elements for which we want to find the closest elements in arr2.
|
| 109 |
+
arr2 (numpy array): The array in which we want to find the closest elements.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
indices (numpy array): The indices of the closest elements in arr2 for each element in arr1.
|
| 113 |
+
"""
|
| 114 |
+
# Use searchsorted to find the indices where elements from arr1 should be inserted in arr2
|
| 115 |
+
indices = np.searchsorted(arr2, arr1, side="left")
|
| 116 |
+
|
| 117 |
+
indices_bis = np.clip(indices - 1, a_min=0, a_max=len(arr2) - 1)
|
| 118 |
+
dist_one = np.abs(arr2[indices] - arr1)
|
| 119 |
+
dist_bis = np.abs(arr2[indices_bis] - arr1)
|
| 120 |
+
|
| 121 |
+
return np.where(dist_one < dist_bis, indices, indices_bis)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class SentenceSplitterConfig:
|
| 126 |
+
columns: tp.List[str]
|
| 127 |
+
model_name: str = "sat-6l"
|
| 128 |
+
sentence_suffix: str = "_sentences"
|
| 129 |
+
sentence_threshold: float = 0.01
|
| 130 |
+
max_sentence_len: int = 256
|
| 131 |
+
min_text_length: int = 10
|
| 132 |
+
min_unique_chars: int = 0
|
| 133 |
+
fallback_separators: tp.List[str] = field(
|
| 134 |
+
default_factory=lambda: [
|
| 135 |
+
"...",
|
| 136 |
+
"\n",
|
| 137 |
+
"!",
|
| 138 |
+
"?",
|
| 139 |
+
";",
|
| 140 |
+
":",
|
| 141 |
+
".",
|
| 142 |
+
",",
|
| 143 |
+
"\t",
|
| 144 |
+
" ",
|
| 145 |
+
]
|
| 146 |
+
)
|
| 147 |
+
device: str = "cuda"
|
| 148 |
+
remove_whitespace_before_inference: bool = False
|
| 149 |
+
batch_size: int = 256
|
| 150 |
+
block_size: int = 256
|
| 151 |
+
stride: int = 256
|
| 152 |
+
outer_batch_size: int = 1024
|
| 153 |
+
verbose: bool = False
|
| 154 |
+
pad_last_batch: bool = False
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class SentenceSplitter(BatchMapper):
|
| 158 |
+
def __init__(self, config: SentenceSplitterConfig):
|
| 159 |
+
super().__init__(config)
|
| 160 |
+
self.columns = config.columns
|
| 161 |
+
device = torch.device(config.device if torch.cuda.is_available() else "cpu")
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
self.model = SaT(
|
| 165 |
+
self.config.model_name,
|
| 166 |
+
from_pretrained_kwargs={"local_files_only": True},
|
| 167 |
+
)
|
| 168 |
+
except Exception:
|
| 169 |
+
self.model = SaT(self.config.model_name)
|
| 170 |
+
|
| 171 |
+
if "cuda" in config.device:
|
| 172 |
+
self.model.half()
|
| 173 |
+
|
| 174 |
+
self.model.eval().to(device)
|
| 175 |
+
|
| 176 |
+
@torch.inference_mode()
|
| 177 |
+
def _resplit_long_sentences(self, col: pa.Array) -> pa.Array:
|
| 178 |
+
mask = pc.greater_equal(pc.utf8_length(col), self.config.max_sentence_len)
|
| 179 |
+
texts_to_resplit = col.filter(mask).to_pandas().to_list()
|
| 180 |
+
|
| 181 |
+
resplit_sentences = []
|
| 182 |
+
for text, probs in zip(
|
| 183 |
+
texts_to_resplit,
|
| 184 |
+
self.model.predict_proba(
|
| 185 |
+
texts_to_resplit,
|
| 186 |
+
stride=self.config.stride,
|
| 187 |
+
block_size=self.config.block_size,
|
| 188 |
+
batch_size=self.config.batch_size,
|
| 189 |
+
pad_last_batch=self.config.pad_last_batch,
|
| 190 |
+
remove_whitespace_before_inference=self.config.remove_whitespace_before_inference,
|
| 191 |
+
outer_batch_size=self.config.outer_batch_size,
|
| 192 |
+
verbose=self.config.verbose,
|
| 193 |
+
),
|
| 194 |
+
):
|
| 195 |
+
nb_split = round(len(probs) / self.config.max_sentence_len) + 1
|
| 196 |
+
sentence_threshold = np.partition(probs, -nb_split)[-nb_split]
|
| 197 |
+
sentences = indices_to_sentences(
|
| 198 |
+
text,
|
| 199 |
+
np.where(probs >= sentence_threshold)[0],
|
| 200 |
+
strip_whitespace=False,
|
| 201 |
+
)
|
| 202 |
+
resplit_sentences.append(sentences)
|
| 203 |
+
|
| 204 |
+
# if not, hard resplit with some separators
|
| 205 |
+
def _resplit(raw_sentences):
|
| 206 |
+
for separator in self.config.fallback_separators:
|
| 207 |
+
raw_sentences = [
|
| 208 |
+
subchunk.strip()
|
| 209 |
+
for sent in raw_sentences
|
| 210 |
+
for subchunk in resplit(
|
| 211 |
+
sent, max_length=self.config.max_sentence_len, sep=separator
|
| 212 |
+
)
|
| 213 |
+
]
|
| 214 |
+
return raw_sentences
|
| 215 |
+
|
| 216 |
+
np_mask = mask.to_pandas().to_numpy()
|
| 217 |
+
full_text = col.to_pandas().to_list()
|
| 218 |
+
|
| 219 |
+
output_sentences = []
|
| 220 |
+
j = 0
|
| 221 |
+
for i, text in enumerate(full_text):
|
| 222 |
+
if np_mask[i]:
|
| 223 |
+
output_sentences.append(_resplit(resplit_sentences[j]))
|
| 224 |
+
j += 1
|
| 225 |
+
else:
|
| 226 |
+
output_sentences.append([text])
|
| 227 |
+
|
| 228 |
+
return pa.array(output_sentences, type=pa.list_(pa.string()))
|
| 229 |
+
|
| 230 |
+
def resplit_long_sentences(self, col: pa.Array) -> pa.Array:
|
| 231 |
+
list_col = apply_on_nested_array(self._resplit_long_sentences, col)
|
| 232 |
+
reflatten_col = pl.from_arrow(list_col).list.eval(pl.element().explode()) # type: ignore
|
| 233 |
+
# remove single char repeated
|
| 234 |
+
if self.config.min_unique_chars > 0:
|
| 235 |
+
reflatten_col = reflatten_col.list.eval(
|
| 236 |
+
pl.when(
|
| 237 |
+
pl.element().str.split("").list.n_unique()
|
| 238 |
+
> self.config.min_unique_chars
|
| 239 |
+
)
|
| 240 |
+
.then(pl.element())
|
| 241 |
+
.drop_nulls()
|
| 242 |
+
)
|
| 243 |
+
return reflatten_col.to_arrow().cast(pa.list_(pa.string()))
|
| 244 |
+
|
| 245 |
+
@torch.inference_mode()
|
| 246 |
+
def basic_split_on_single_column(
|
| 247 |
+
self,
|
| 248 |
+
col: tp.Union[pa.Array, pa.ChunkedArray],
|
| 249 |
+
) -> tp.Union[pa.Array, pa.ChunkedArray]:
|
| 250 |
+
if not (pa.types.is_large_string(col.type) or pa.types.is_string(col.type)):
|
| 251 |
+
raise ValueError("Column must be of type string")
|
| 252 |
+
|
| 253 |
+
texts = col.to_pandas().to_list()
|
| 254 |
+
texts = list(map(remove_emojis, texts))
|
| 255 |
+
|
| 256 |
+
long_texts = [t for t in texts if len(t) > self.config.min_text_length]
|
| 257 |
+
keep_texts = [
|
| 258 |
+
(idx, t)
|
| 259 |
+
for idx, t in enumerate(texts)
|
| 260 |
+
if len(t) <= self.config.min_text_length
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
outputs = self.model.split(
|
| 264 |
+
long_texts,
|
| 265 |
+
threshold=self.config.sentence_threshold,
|
| 266 |
+
stride=self.config.stride,
|
| 267 |
+
block_size=self.config.block_size,
|
| 268 |
+
batch_size=self.config.batch_size,
|
| 269 |
+
pad_last_batch=self.config.pad_last_batch,
|
| 270 |
+
remove_whitespace_before_inference=self.config.remove_whitespace_before_inference,
|
| 271 |
+
outer_batch_size=self.config.outer_batch_size,
|
| 272 |
+
verbose=self.config.verbose,
|
| 273 |
+
)
|
| 274 |
+
sentences = []
|
| 275 |
+
for row in outputs:
|
| 276 |
+
sentences.append([s.strip() for s in row if s.strip()])
|
| 277 |
+
|
| 278 |
+
for idx, text in keep_texts:
|
| 279 |
+
sentences.insert(idx, text)
|
| 280 |
+
|
| 281 |
+
return pa.array(sentences, type=pa.list_(pa.string()))
|
| 282 |
+
|
| 283 |
+
def __call__(self, table: pa.Table) -> pa.Table:
|
| 284 |
+
for column in self.columns:
|
| 285 |
+
sentence_array = self.basic_split_on_single_column(table[column])
|
| 286 |
+
|
| 287 |
+
sentence_array = self.resplit_long_sentences(sentence_array)
|
| 288 |
+
|
| 289 |
+
table = table.append_column(
|
| 290 |
+
f"{column}{self.config.sentence_suffix}", sentence_array
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return table
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@dataclass
|
| 297 |
+
class FullPipelineConfig:
|
| 298 |
+
splitter_config: SentenceSplitterConfig
|
| 299 |
+
sonar_encoder_config: SonarTextEmbedderConfig
|
| 300 |
+
min_text_length: int = 10
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class FullPipeline(BatchMapper):
|
| 304 |
+
"""
|
| 305 |
+
Creating sonar vectors from scratch.
|
| 306 |
+
Making sentences splits.
|
| 307 |
+
Computing sonar embeddings.
|
| 308 |
+
|
| 309 |
+
Config example requires only one input column:
|
| 310 |
+
- `text`
|
| 311 |
+
|
| 312 |
+
Note also that text should not be empty!
|
| 313 |
+
|
| 314 |
+
Example of config:
|
| 315 |
+
|
| 316 |
+
splitter_config = SentenceSplitterConfig(
|
| 317 |
+
columns=["text"],
|
| 318 |
+
model_name="sat-3l",
|
| 319 |
+
verbose=True,
|
| 320 |
+
sentence_threshold=0.02,
|
| 321 |
+
max_sentence_len=256,
|
| 322 |
+
)
|
| 323 |
+
sonar_encoder_config = SonarTextEmbedderConfig(
|
| 324 |
+
column_config=[LangColumnConfig("text_sentences", lang_value="eng_Latn")],
|
| 325 |
+
device="cuda",
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
full_config = FullPipelineConfig(
|
| 329 |
+
splitter_config=splitter_config,
|
| 330 |
+
sonar_encoder_config=sonar_encoder_config,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
def __init__(self, config: FullPipelineConfig):
|
| 336 |
+
self.config = config
|
| 337 |
+
self.splitter = SentenceSplitter(self.config.splitter_config)
|
| 338 |
+
self.sonar_encoder = SonarTextBatchEmbedder(self.config.sonar_encoder_config)
|
| 339 |
+
|
| 340 |
+
def __call__(self, batch: pa.Table) -> pa.Table:
|
| 341 |
+
for col in self.config.splitter_config.columns:
|
| 342 |
+
batch = batch.filter(
|
| 343 |
+
pc.greater_equal(
|
| 344 |
+
pc.utf8_length(batch[col]), self.config.min_text_length
|
| 345 |
+
)
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
batch = self.splitter(batch)
|
| 349 |
+
batch = self.sonar_encoder(batch)
|
| 350 |
+
gc.collect()
|
| 351 |
+
return batch
|
lcm/datasets/sentence_splitting.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import codecs
|
| 8 |
+
import re
|
| 9 |
+
import typing as tp
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
|
| 12 |
+
import spacy
|
| 13 |
+
import torch
|
| 14 |
+
from sacremoses import MosesDetokenizer, MosesPunctNormalizer
|
| 15 |
+
from stopes.pipelines.monolingual.utils.sentence_split import get_split_algo
|
| 16 |
+
from stopes.utils.language_codes import language_code_to_short_code
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def remove_emojis(text: str) -> str:
|
| 20 |
+
emoji_pattern = re.compile(
|
| 21 |
+
"["
|
| 22 |
+
"\U0001f600-\U0001f64f" # emoticons
|
| 23 |
+
"\U0001f300-\U0001f5ff" # symbols & pictographs
|
| 24 |
+
"\U0001f680-\U0001f6ff" # transport & map symbols
|
| 25 |
+
"\U0001f1e0-\U0001f1ff" # flags (iOS)
|
| 26 |
+
"\U00002702-\U000027b0"
|
| 27 |
+
"\U000024c2-\U0001f251"
|
| 28 |
+
"\U0001f900-\U0001f9ff" # Supplemental Symbols and Pictographs
|
| 29 |
+
"\U0001f700-\U0001f77f" # Alchemical Symbols
|
| 30 |
+
"\U0001f780-\U0001f7ff" # Geometric Shapes Extended
|
| 31 |
+
"\U0001f800-\U0001f8ff" # Supplemental Arrows-C
|
| 32 |
+
"\U0001fa00-\U0001fa6f" # Chess Symbols
|
| 33 |
+
"\U0001fa70-\U0001faff" # Symbols and Pictographs Extended-A
|
| 34 |
+
"\U0001f6c0-\U0001f6cf" # Miscellaneous Symbols and Pictographs (part)
|
| 35 |
+
"\U0001f6d0-\U0001f6d5" # Miscellaneous Symbols and Pictographs (part)
|
| 36 |
+
"\U0001f6f0-\U0001f6fa" # Miscellaneous Symbols and Pictographs (part)
|
| 37 |
+
"]+",
|
| 38 |
+
flags=re.UNICODE,
|
| 39 |
+
)
|
| 40 |
+
return emoji_pattern.sub(r"", text)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def batched(inputs: tp.Iterable, batch_size=10000) -> tp.Iterable:
|
| 44 |
+
batch = []
|
| 45 |
+
for line in inputs:
|
| 46 |
+
batch.append(line)
|
| 47 |
+
if len(batch) == batch_size:
|
| 48 |
+
yield batch
|
| 49 |
+
batch = []
|
| 50 |
+
yield batch
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def filter_empty_string(text):
|
| 54 |
+
return not any(char.isalnum() for char in text)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def remove_non_printable_chars(string):
|
| 58 |
+
return re.sub(r"[^\x20-\x7E]", "", string)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def deescape_special_chars(string):
|
| 62 |
+
return codecs.decode(string, "unicode_escape")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def resplit(text: str, max_length: int, sep: str) -> tp.List[str]:
|
| 66 |
+
words = text.split(sep)
|
| 67 |
+
result = []
|
| 68 |
+
current_piece = ""
|
| 69 |
+
|
| 70 |
+
for i, word in enumerate(words[:-1]):
|
| 71 |
+
# Append separator back to each word except the last
|
| 72 |
+
word += sep
|
| 73 |
+
if len(current_piece) + len(word) <= max_length:
|
| 74 |
+
current_piece += word
|
| 75 |
+
else:
|
| 76 |
+
if current_piece:
|
| 77 |
+
result.append(current_piece)
|
| 78 |
+
current_piece = word
|
| 79 |
+
|
| 80 |
+
# Handle the last word separately to avoid adding an extra separator
|
| 81 |
+
last_word = words[-1]
|
| 82 |
+
if len(current_piece) + len(last_word) <= max_length:
|
| 83 |
+
current_piece += last_word
|
| 84 |
+
else:
|
| 85 |
+
if current_piece:
|
| 86 |
+
result.append(current_piece)
|
| 87 |
+
current_piece = last_word
|
| 88 |
+
|
| 89 |
+
if current_piece:
|
| 90 |
+
result.append(current_piece)
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@lru_cache
|
| 96 |
+
def get_moses_normalizers(lang):
|
| 97 |
+
moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
|
| 98 |
+
mpn = MosesPunctNormalizer(lang=moses_lang)
|
| 99 |
+
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
|
| 100 |
+
md = MosesDetokenizer(lang=moses_lang)
|
| 101 |
+
return mpn, md
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@lru_cache
|
| 105 |
+
def get_splitter(lang: str, model_name: str = None):
|
| 106 |
+
moses_lang = language_code_to_short_code(lang, try_replacing_with_macro=True)
|
| 107 |
+
if model_name is None:
|
| 108 |
+
model_name = (
|
| 109 |
+
f"{moses_lang}_core_web_sm"
|
| 110 |
+
if moses_lang == "en"
|
| 111 |
+
else f"{moses_lang}_core_news_sm"
|
| 112 |
+
)
|
| 113 |
+
try:
|
| 114 |
+
if torch.cuda.is_available():
|
| 115 |
+
spacy.require_gpu()
|
| 116 |
+
spacy_nlp = spacy.load(model_name, enable=["sentencizer"])
|
| 117 |
+
spacy_nlp.add_pipe("sentencizer")
|
| 118 |
+
|
| 119 |
+
def spacy_splitter(text):
|
| 120 |
+
for batch in batched(text, batch_size=999_000):
|
| 121 |
+
for sent in spacy_nlp("".join(batch)).sents:
|
| 122 |
+
yield str(sent)
|
| 123 |
+
|
| 124 |
+
return spacy_splitter
|
| 125 |
+
except ModuleNotFoundError:
|
| 126 |
+
print(
|
| 127 |
+
f"Spacy splitter not found for {lang}, switching to stopes implementation"
|
| 128 |
+
)
|
| 129 |
+
return get_split_algo(lang[:3], "default")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ResplitSentenceSplitter:
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
fallback_separators=(".", "!", "?", "...", "\n", ";", ",", ":", ">", " "),
|
| 136 |
+
):
|
| 137 |
+
self.fallback_separators = fallback_separators
|
| 138 |
+
|
| 139 |
+
def __call__(
|
| 140 |
+
self, document: str, lang: str = "eng_Latn", max_length: int = 200
|
| 141 |
+
) -> tp.List[str]:
|
| 142 |
+
mpn, md = get_moses_normalizers(lang)
|
| 143 |
+
# XXX: two below are not various language friendly
|
| 144 |
+
# document = deescape_special_chars(document)
|
| 145 |
+
# document = remove_non_printable_chars(document)
|
| 146 |
+
document = remove_emojis(document)
|
| 147 |
+
|
| 148 |
+
raw_sentences = get_splitter(lang)(document)
|
| 149 |
+
for separator in self.fallback_separators or []:
|
| 150 |
+
raw_sentences = [
|
| 151 |
+
subchunk.strip()
|
| 152 |
+
for sent in raw_sentences
|
| 153 |
+
for subchunk in resplit(sent, max_length=max_length, sep=separator)
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
return [
|
| 157 |
+
mpn.normalize(md.detokenize(sent.strip().split()))
|
| 158 |
+
for sent in raw_sentences
|
| 159 |
+
if len(sent) > 1 and not filter_empty_string(sent)
|
| 160 |
+
]
|
lcm/datasets/utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from fairseq2.models.sequence import SequenceBatch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def move_eos_to_the_end(
|
| 12 |
+
batch: SequenceBatch, pad_token_id: int = 0, eos_token_id: int = 3
|
| 13 |
+
) -> SequenceBatch:
|
| 14 |
+
"""
|
| 15 |
+
Convert a decoder-input batch (with the eos token in the beginning) to a decoder-output batch
|
| 16 |
+
(with eos in the end) of the same shape.
|
| 17 |
+
Note that this processing is missing two potentially critical issues:
|
| 18 |
+
1) If the sequence end has been truncated away, EOS token will be appended erroneously.
|
| 19 |
+
2) The language code token is still included in the loss computation (we may want to avoid it).
|
| 20 |
+
"""
|
| 21 |
+
# strip the EOS token prepended to the input and add an empty token in the end
|
| 22 |
+
seqs = torch.cat(
|
| 23 |
+
[
|
| 24 |
+
batch.seqs[:, 1:],
|
| 25 |
+
torch.zeros_like(batch.seqs[:, :1]) + pad_token_id,
|
| 26 |
+
],
|
| 27 |
+
dim=-1,
|
| 28 |
+
)
|
| 29 |
+
# fill the last real token in the batch with the eos value
|
| 30 |
+
if batch.padding_mask:
|
| 31 |
+
seqs[
|
| 32 |
+
torch.arange(seqs.shape[0], dtype=torch.int32),
|
| 33 |
+
batch.padding_mask.seq_lens - 1,
|
| 34 |
+
] = eos_token_id
|
| 35 |
+
else:
|
| 36 |
+
seqs[:, -1] = eos_token_id
|
| 37 |
+
|
| 38 |
+
result = SequenceBatch(
|
| 39 |
+
seqs=seqs,
|
| 40 |
+
padding_mask=batch.padding_mask,
|
| 41 |
+
)
|
| 42 |
+
return result
|
lcm/models/two_tower_diffusion_lcm/loader.py
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
|
| 7 |
from fairseq2.models.config_loader import StandardModelConfigLoader
|
| 8 |
from fairseq2.models.loader import StandardModelLoader, load_model
|
|
|
|
| 9 |
|
| 10 |
from lcm.models.base_lcm.loader import convert_lcm_checkpoint
|
| 11 |
from lcm.models.two_tower_diffusion_lcm.builder import (
|
|
@@ -23,11 +24,12 @@ load_two_tower_diffusion_lcm_config = StandardModelConfigLoader(
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
-
load_two_tower_diffusion_lcm_model = StandardModelLoader(
|
| 27 |
config_loader=load_two_tower_diffusion_lcm_config,
|
| 28 |
factory=create_two_tower_diffusion_lcm_model,
|
| 29 |
checkpoint_converter=convert_lcm_checkpoint,
|
| 30 |
restrict_checkpoints=False,
|
|
|
|
| 31 |
)
|
| 32 |
|
| 33 |
load_model.register(
|
|
|
|
| 6 |
|
| 7 |
from fairseq2.models.config_loader import StandardModelConfigLoader
|
| 8 |
from fairseq2.models.loader import StandardModelLoader, load_model
|
| 9 |
+
from Patches import Patch_TorchLoader
|
| 10 |
|
| 11 |
from lcm.models.base_lcm.loader import convert_lcm_checkpoint
|
| 12 |
from lcm.models.two_tower_diffusion_lcm.builder import (
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
|
| 27 |
+
load_two_tower_diffusion_lcm_model = StandardModelLoader(
|
| 28 |
config_loader=load_two_tower_diffusion_lcm_config,
|
| 29 |
factory=create_two_tower_diffusion_lcm_model,
|
| 30 |
checkpoint_converter=convert_lcm_checkpoint,
|
| 31 |
restrict_checkpoints=False,
|
| 32 |
+
tensor_loader=Patch_TorchLoader.load_tensors, # 🔥 the key patch
|
| 33 |
)
|
| 34 |
|
| 35 |
load_model.register(
|
lcm/train/__main__.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
import hydra
|
| 13 |
+
import submitit
|
| 14 |
+
from omegaconf import DictConfig, OmegaConf
|
| 15 |
+
from omegaconf.omegaconf import open_dict, read_write
|
| 16 |
+
from stopes.core import Requirements, StopesModule
|
| 17 |
+
|
| 18 |
+
from lcm.train.common import get_trainer
|
| 19 |
+
from lcm.utils.common import setup_conf
|
| 20 |
+
|
| 21 |
+
setup_conf()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TrainModule(StopesModule):
|
| 25 |
+
def requirements(self) -> Requirements:
|
| 26 |
+
return self.config.requirements
|
| 27 |
+
|
| 28 |
+
def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
|
| 29 |
+
# Add module.name to the config's log_folder
|
| 30 |
+
with read_write(self.config):
|
| 31 |
+
self.config.log_folder = Path(self.config.log_folder) / self.name()
|
| 32 |
+
|
| 33 |
+
trainer = get_trainer(self.config)
|
| 34 |
+
|
| 35 |
+
# trainer should have a run() method
|
| 36 |
+
trainer.run()
|
| 37 |
+
|
| 38 |
+
def should_retry(
|
| 39 |
+
self,
|
| 40 |
+
ex: Exception,
|
| 41 |
+
attempt: int,
|
| 42 |
+
iteration_value: Optional[Any] = None,
|
| 43 |
+
iteration_index: int = 0,
|
| 44 |
+
) -> bool:
|
| 45 |
+
# Before retrying the failed train run, clean the environment to make sure
|
| 46 |
+
# fs2 ProcessGroupGang can set up properly without raising error if the
|
| 47 |
+
# gang is not set up reliably
|
| 48 |
+
with submitit.helpers.clean_env():
|
| 49 |
+
return "ValueError" not in str(ex)
|
| 50 |
+
|
| 51 |
+
def name(self):
|
| 52 |
+
"""
|
| 53 |
+
implement this if you want to give a fancy name to your job
|
| 54 |
+
"""
|
| 55 |
+
name = self.config.get(
|
| 56 |
+
"experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}"
|
| 57 |
+
)
|
| 58 |
+
return name
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class TrainingConfig:
|
| 63 |
+
trainer: DictConfig
|
| 64 |
+
launcher: DictConfig
|
| 65 |
+
dry_run: bool = False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
async def run(config: TrainingConfig):
|
| 69 |
+
# dump the all config to the outputs config log
|
| 70 |
+
dump_dir = Path(config.launcher.config_dump_dir)
|
| 71 |
+
dump_dir.mkdir(parents=True, exist_ok=True)
|
| 72 |
+
OmegaConf.resolve(config) # type: ignore
|
| 73 |
+
# XXX: do we want to promote datasets configs from thier names to the final params
|
| 74 |
+
OmegaConf.save(
|
| 75 |
+
config=config,
|
| 76 |
+
f=str(dump_dir / "all_config.yaml"),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
train_config = config.trainer
|
| 80 |
+
|
| 81 |
+
# If launcher.cluster = debug set debug in the trainer to True
|
| 82 |
+
with open_dict(train_config):
|
| 83 |
+
if config.launcher.cluster == "debug":
|
| 84 |
+
train_config.debug = True
|
| 85 |
+
train_config.log_folder = config.launcher.log_folder
|
| 86 |
+
|
| 87 |
+
if getattr(config, "dry_run", False):
|
| 88 |
+
trainer = get_trainer(train_config)
|
| 89 |
+
print(f"Trainer: {trainer}")
|
| 90 |
+
print(f"Train config: {getattr(trainer, 'config')}")
|
| 91 |
+
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
launcher = hydra.utils.instantiate(config.launcher)
|
| 95 |
+
|
| 96 |
+
train_module = TrainModule(train_config)
|
| 97 |
+
wait_on = launcher.schedule(train_module)
|
| 98 |
+
|
| 99 |
+
await wait_on
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@hydra.main(
|
| 103 |
+
version_base="1.2",
|
| 104 |
+
config_path="../../recipes/train",
|
| 105 |
+
config_name="defaults.yaml",
|
| 106 |
+
)
|
| 107 |
+
def main(config: TrainingConfig) -> None:
|
| 108 |
+
"""
|
| 109 |
+
Launch a train module from CLI.
|
| 110 |
+
|
| 111 |
+
Example:
|
| 112 |
+
|
| 113 |
+
```sh
|
| 114 |
+
python -m lcm.train +pretrain=mse
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
in this example, `pretrain` is a folder under the `recipes` directory and `mse`
|
| 118 |
+
is a yaml file with the trainer configuration.
|
| 119 |
+
This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer`
|
| 120 |
+
hydra directive).
|
| 121 |
+
It must contain a `__trainer__` entry defining the constructor for the trainer.
|
| 122 |
+
|
| 123 |
+
You can use `-c job` to see the configuration without running anything. You can use
|
| 124 |
+
`dry_run=true` to initialize the trainer from the configuration and make sure it's correct
|
| 125 |
+
without running the actual training. To debug the jobs, you can use `launcher.cluster=debug`
|
| 126 |
+
"""
|
| 127 |
+
asyncio.run(run(config))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
lcm/train/common.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from inspect import signature
|
| 7 |
+
from typing import Any, Dict, Protocol, Union, runtime_checkable
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
from omegaconf import DictConfig, OmegaConf, read_write
|
| 11 |
+
|
| 12 |
+
from lcm.utils.common import promote_config
|
| 13 |
+
|
| 14 |
+
TRAINER_KEY = "_trainer_"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@runtime_checkable
|
| 18 |
+
class Trainer(Protocol):
|
| 19 |
+
"""Abstract trainer in LCM"""
|
| 20 |
+
|
| 21 |
+
def run(self) -> Any: ...
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _parse_training_config(train_config: DictConfig):
|
| 25 |
+
"""Return the TrainingConfig object from the omegaconf inputs"""
|
| 26 |
+
# The train_config should have 2 keys "_target_" and "_trainer_"
|
| 27 |
+
# the config is set to read-only within stopes module __init__
|
| 28 |
+
assert TRAINER_KEY in train_config, (
|
| 29 |
+
f"The trainer configuration is missing a {TRAINER_KEY} configuration, "
|
| 30 |
+
"you need to specify a Callable to initialize your config."
|
| 31 |
+
)
|
| 32 |
+
trainer_cls_or_func = train_config.get(TRAINER_KEY)
|
| 33 |
+
try:
|
| 34 |
+
trainer_obj = hydra.utils.get_object(trainer_cls_or_func)
|
| 35 |
+
sign = signature(trainer_obj)
|
| 36 |
+
assert len(sign.parameters) == 1 and "config" in sign.parameters, (
|
| 37 |
+
f'{trainer_cls_or_func} should take a single argument called "config"'
|
| 38 |
+
)
|
| 39 |
+
param_type = sign.parameters["config"].annotation
|
| 40 |
+
|
| 41 |
+
OmegaConf.resolve(train_config)
|
| 42 |
+
with read_write(train_config):
|
| 43 |
+
del train_config._trainer_
|
| 44 |
+
|
| 45 |
+
typed_config = promote_config(train_config, param_type)
|
| 46 |
+
return trainer_obj, typed_config
|
| 47 |
+
except Exception as ex:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"couldnt parse the train config: {train_config}.", str(ex)
|
| 50 |
+
) from ex
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_trainer(train_config: DictConfig) -> Trainer:
|
| 54 |
+
trainer_obj, typed_config = _parse_training_config(train_config)
|
| 55 |
+
return trainer_obj(typed_config)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _is_missing(config: Union[DictConfig, Dict], attr: str) -> bool:
|
| 59 |
+
if isinstance(config, Dict):
|
| 60 |
+
return attr in config and config[attr]
|
| 61 |
+
if OmegaConf.is_missing(config, attr):
|
| 62 |
+
return True
|
| 63 |
+
if not hasattr(config, attr) or not getattr(config, attr):
|
| 64 |
+
return True
|
| 65 |
+
return False
|
lcm/train/criterion.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Callable, Dict, List, Literal
|
| 9 |
+
|
| 10 |
+
from fairseq2.logging import get_log_writer
|
| 11 |
+
from omegaconf import MISSING
|
| 12 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
| 13 |
+
FullyShardedDataParallel as FSDP,
|
| 14 |
+
)
|
| 15 |
+
from torch.nn import Module
|
| 16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 17 |
+
|
| 18 |
+
from lcm.train.metrics import LossTerm
|
| 19 |
+
|
| 20 |
+
logger = get_log_writer(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class CriterionConfig:
|
| 25 |
+
"""A dataclass for criterion parameters"""
|
| 26 |
+
|
| 27 |
+
name: str = MISSING
|
| 28 |
+
"""Name of the criterion, a unique identifier used in the CriterionsFactory"""
|
| 29 |
+
|
| 30 |
+
reduction: Literal["sum", "mean"] = "sum"
|
| 31 |
+
"""How to reduce the loss across samples"""
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Criterion:
|
| 35 |
+
"""And abstract class for training criterions"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
config: CriterionConfig,
|
| 40 |
+
model: Module,
|
| 41 |
+
):
|
| 42 |
+
self.config = config
|
| 43 |
+
|
| 44 |
+
self.model = model
|
| 45 |
+
|
| 46 |
+
self.summands: List[str] = []
|
| 47 |
+
""" A list of loss term names to track during training.
|
| 48 |
+
This will create metric bags for each
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
self.reduction = config.reduction
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def throughput_metric_name(self) -> str:
|
| 55 |
+
return "num_target_elements"
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def base_model(self):
|
| 59 |
+
"""A pointer to the unwrapped model if training with FSDP/DDP"""
|
| 60 |
+
if isinstance(self.model, (DDP, FSDP)):
|
| 61 |
+
_model = self.model.module
|
| 62 |
+
else:
|
| 63 |
+
_model = self.model
|
| 64 |
+
return _model
|
| 65 |
+
|
| 66 |
+
@abstractmethod
|
| 67 |
+
def __call__(self, batch) -> LossTerm:
|
| 68 |
+
"""
|
| 69 |
+
Computes the loss given an input batch.
|
| 70 |
+
The model's forward pass is performed here
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CriterionsFactory:
|
| 75 |
+
"""Factory for LCM criterions"""
|
| 76 |
+
|
| 77 |
+
registry: Dict[str, Any] = {}
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def build_criterion(cls, name: str, **kwargs) -> Any:
|
| 81 |
+
"""build the criterion of choice from within the trainer"""
|
| 82 |
+
|
| 83 |
+
criterion_class = cls.registry[name]
|
| 84 |
+
|
| 85 |
+
criterion = criterion_class(**kwargs)
|
| 86 |
+
|
| 87 |
+
return criterion
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def register(cls, name: str) -> Callable:
|
| 91 |
+
"""decorator for adding criterions to the registry"""
|
| 92 |
+
|
| 93 |
+
def inner_wrapper(wrapped_class: Criterion) -> Callable:
|
| 94 |
+
assert name not in cls.registry, (
|
| 95 |
+
f"{name} is already register as a criterion"
|
| 96 |
+
)
|
| 97 |
+
cls.registry[name] = wrapped_class
|
| 98 |
+
return wrapped_class
|
| 99 |
+
|
| 100 |
+
return inner_wrapper
|
lcm/train/lcm/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
lcm/train/lcm/criterion.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from abc import abstractmethod
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
from lcm.datasets.batch import LCMInput, LCMStyle
|
| 15 |
+
from lcm.models.abstract_lcm import AbstractLCModel
|
| 16 |
+
from lcm.models.sonar_normalizer import SonarNormalizer
|
| 17 |
+
from lcm.train.criterion import Criterion, CriterionConfig
|
| 18 |
+
from lcm.train.metrics import LossTerm
|
| 19 |
+
|
| 20 |
+
logger = get_log_writer(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_standard_mse(
|
| 24 |
+
flattened_predictions: Tensor,
|
| 25 |
+
flattened_target: Tensor,
|
| 26 |
+
scales: Optional[Tensor] = None,
|
| 27 |
+
normalizer: Optional[SonarNormalizer] = None,
|
| 28 |
+
) -> Tuple[Tensor, Tensor]:
|
| 29 |
+
"""
|
| 30 |
+
Computes MSE loss between predictions and targets.
|
| 31 |
+
Note that, unlike regular MSE with mean/sum reduction, we first sum across channels
|
| 32 |
+
before later reducing in the criterion.
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
flattened_predictions (Tensor): The predictions in (N, C)
|
| 36 |
+
flattened_target (Tensor): The targets in (N, C)
|
| 37 |
+
scales (Optional[Tensor]): If not None, each channel will be weighted by the corresponding scale.
|
| 38 |
+
epsilon: A small epsilon to be added before taking the square root of the l2 distance
|
| 39 |
+
normalizer (Optional[SonarNormalizer]): If a normalizer is provided,
|
| 40 |
+
the predictions and targets will first be denormalized before computing the RMSE loss
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
mse (Tensor): the MSE loss with optional scaling
|
| 44 |
+
plain_mse (Tensor): The MSE loss without any scaling (for logging)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
assert flattened_predictions.dim() == 2, (
|
| 48 |
+
"Expecting two-dimensional predictions and targets. ",
|
| 49 |
+
f"Found targets in {flattened_target.size()} and ",
|
| 50 |
+
f"predictions in {flattened_predictions.size()}",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
assert flattened_predictions.shape == flattened_target.shape, (
|
| 54 |
+
"Expecting predictions and targets of the same shape ",
|
| 55 |
+
f"Received predictions {flattened_predictions.shape} and targets {flattened_target.shape}",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if scales is not None:
|
| 59 |
+
assert scales.dim() == 1, (
|
| 60 |
+
"Expecting a uni-dimensional tensor of scales ",
|
| 61 |
+
f"Found a tensor with dimension {scales.dim()}",
|
| 62 |
+
)
|
| 63 |
+
assert len(scales) == flattened_target.shape[-1], (
|
| 64 |
+
"The provided scales should have the same size as the target channels. ",
|
| 65 |
+
f"Found {len(scales)} expected {flattened_target.shape[-1]}",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if normalizer is not None:
|
| 69 |
+
assert hasattr(normalizer, "denormalize"), (
|
| 70 |
+
"The provided normalizer has not method `denormalize`"
|
| 71 |
+
)
|
| 72 |
+
flattened_predictions = normalizer.denormalize(flattened_predictions)
|
| 73 |
+
flattened_target = normalizer.denormalize(flattened_target)
|
| 74 |
+
|
| 75 |
+
full_mse = torch.nn.functional.mse_loss(
|
| 76 |
+
flattened_predictions, flattened_target, reduction="none"
|
| 77 |
+
)
|
| 78 |
+
plain_mse = full_mse.sum(dim=-1)
|
| 79 |
+
|
| 80 |
+
if scales is not None:
|
| 81 |
+
full_mse = full_mse * scales.unsqueeze(0)
|
| 82 |
+
mse = full_mse.sum(dim=-1)
|
| 83 |
+
else:
|
| 84 |
+
mse = plain_mse
|
| 85 |
+
return mse, plain_mse
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class LCMCriterionConfig(CriterionConfig):
|
| 90 |
+
compute_rmse: bool = True
|
| 91 |
+
"""If `True` take the square-root of MSE.
|
| 92 |
+
This is for now `True` by default for backward compatibility"""
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class LCMCriterion(Criterion):
|
| 96 |
+
"""And abstract class for the LCM's criterions"""
|
| 97 |
+
|
| 98 |
+
config: LCMCriterionConfig
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
config: LCMCriterionConfig,
|
| 103 |
+
model: AbstractLCModel,
|
| 104 |
+
style: LCMStyle = LCMStyle.UNSUPERVISED,
|
| 105 |
+
):
|
| 106 |
+
super().__init__(config, model)
|
| 107 |
+
|
| 108 |
+
self.style = style
|
| 109 |
+
|
| 110 |
+
# Summands for log/tb recorders
|
| 111 |
+
self.summands = ["mse_loss", "reconstruction_loss"]
|
| 112 |
+
|
| 113 |
+
self.normalize_in_criterion = (
|
| 114 |
+
self.base_model.config.sonar_normalizer_name is not None
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def sonar_normalizer(self) -> Optional[SonarNormalizer]:
|
| 119 |
+
if hasattr(self.base_model, "sonar_normalizer"):
|
| 120 |
+
return self.base_model.sonar_normalizer
|
| 121 |
+
|
| 122 |
+
elif hasattr(self.base_model, "frontend") and hasattr(
|
| 123 |
+
self.base_model.frontend, "sonar_normalizer"
|
| 124 |
+
):
|
| 125 |
+
return self.base_model.frontend.sonar_normalizer
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
logger.warning(
|
| 129 |
+
"Couldn't find the model's `sonar_normalizer`, defaulting to None"
|
| 130 |
+
)
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def throughput_metric_name(self) -> str:
|
| 135 |
+
return "num_target_elements"
|
| 136 |
+
|
| 137 |
+
@abstractmethod
|
| 138 |
+
def __call__(self, batch: LCMInput) -> LossTerm:
|
| 139 |
+
"""
|
| 140 |
+
Computes the loss given an input batch.
|
| 141 |
+
The model's forward pass is performed here
|
| 142 |
+
Input batch is LCMInput (see `lcm.datasets.batch`):
|
| 143 |
+
"""
|
lcm/train/lcm/trainer.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Dict, List, Mapping, Optional, Union
|
| 8 |
+
|
| 9 |
+
from fairseq2.assets import AssetCard
|
| 10 |
+
from fairseq2.checkpoint import FileCheckpointManager
|
| 11 |
+
from fairseq2.gang import Gang
|
| 12 |
+
from fairseq2.logging import get_log_writer
|
| 13 |
+
from fairseq2.metrics import MetricRecorder
|
| 14 |
+
from fairseq2.optim import DynamicLossScaler
|
| 15 |
+
from fairseq2.optim.lr_scheduler import AbstractLRScheduler
|
| 16 |
+
from fairseq2.utils.profiler import Profiler, Stopwatch
|
| 17 |
+
from fairseq2.utils.rng import RngBag
|
| 18 |
+
from omegaconf import MISSING
|
| 19 |
+
from stopes.core import Requirements
|
| 20 |
+
from torch.nn import Module
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from lcm.datasets.configs import ParquetDatasetConfig
|
| 25 |
+
from lcm.datasets.dataloader import LCMDataLoader
|
| 26 |
+
from lcm.datasets.dataloading import ds_name
|
| 27 |
+
from lcm.models.abstract_lcm import AbstractLCModelConfig
|
| 28 |
+
from lcm.models.base_lcm.loader import load_base_lcm_model
|
| 29 |
+
from lcm.train.criterion import CriterionsFactory
|
| 30 |
+
from lcm.train.metrics import LCMMetricBag
|
| 31 |
+
from lcm.train.mse_lcm.criterion import ReconstructionCriterionConfig
|
| 32 |
+
from lcm.train.trainer import Trainer, TrainerBuilder, TrainingConfig
|
| 33 |
+
from lcm.utils.card_utils import create_model_card
|
| 34 |
+
|
| 35 |
+
logger = get_log_writer(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class LCMTrainingConfig(TrainingConfig):
|
| 40 |
+
"""Holds the configuration of an LCM training job."""
|
| 41 |
+
|
| 42 |
+
training_data: List[ParquetDatasetConfig] = field(default_factory=list)
|
| 43 |
+
"""The datasets to train with.""" # TODO use dataset cards
|
| 44 |
+
|
| 45 |
+
validation_data: List[ParquetDatasetConfig] = field(default_factory=list)
|
| 46 |
+
"""The datasets to validate on.""" # TODO use dataset cards
|
| 47 |
+
|
| 48 |
+
model_config_or_name: Union[AbstractLCModelConfig, str, None] = None
|
| 49 |
+
"""The model configuration or name to train."""
|
| 50 |
+
|
| 51 |
+
requirements: Requirements = field(
|
| 52 |
+
default_factory=lambda: Requirements(
|
| 53 |
+
nodes=1,
|
| 54 |
+
tasks_per_node=8,
|
| 55 |
+
gpus_per_node=8,
|
| 56 |
+
cpus_per_task=8,
|
| 57 |
+
mem_gb=256,
|
| 58 |
+
timeout_min=3 * 24 * 60,
|
| 59 |
+
constraint="volta32gb",
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
"""The scheduling requirements for this trainer"""
|
| 63 |
+
|
| 64 |
+
criterion: ReconstructionCriterionConfig = MISSING
|
| 65 |
+
"""The MSE loss is the default base criterion used in either the `lcm` or `mse_lcm` trainers"""
|
| 66 |
+
|
| 67 |
+
max_subword_length: int = 512
|
| 68 |
+
""" Max subword length used to truncate seqs during sonar decoder backprop"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class LCMTrainer(Trainer):
|
| 72 |
+
config: LCMTrainingConfig
|
| 73 |
+
model: Module
|
| 74 |
+
training_data_loader: LCMDataLoader
|
| 75 |
+
validation_data_loader: Optional[LCMDataLoader]
|
| 76 |
+
gang: Gang
|
| 77 |
+
optimizer: Optimizer
|
| 78 |
+
loss_scaler: DynamicLossScaler
|
| 79 |
+
lr_scheduler: AbstractLRScheduler
|
| 80 |
+
rng_bag: RngBag
|
| 81 |
+
step_nr: int
|
| 82 |
+
train_metric_bag: LCMMetricBag
|
| 83 |
+
valid_metric_bag: Mapping[str, LCMMetricBag]
|
| 84 |
+
metric_recorders: List[MetricRecorder]
|
| 85 |
+
profiler: Profiler
|
| 86 |
+
stopwatch: Stopwatch
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
config: LCMTrainingConfig,
|
| 91 |
+
model: Module,
|
| 92 |
+
training_data_loader: LCMDataLoader,
|
| 93 |
+
validation_data_loader: Optional[LCMDataLoader],
|
| 94 |
+
gang: Gang,
|
| 95 |
+
checkpoint_manager: FileCheckpointManager,
|
| 96 |
+
rng_bag: RngBag,
|
| 97 |
+
stopwatch: Stopwatch,
|
| 98 |
+
card_metadata: Dict,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__(
|
| 101 |
+
config,
|
| 102 |
+
model,
|
| 103 |
+
training_data_loader,
|
| 104 |
+
validation_data_loader,
|
| 105 |
+
gang,
|
| 106 |
+
checkpoint_manager,
|
| 107 |
+
rng_bag,
|
| 108 |
+
stopwatch,
|
| 109 |
+
card_metadata=card_metadata,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def setup_criterion(self):
|
| 113 |
+
return CriterionsFactory.build_criterion(
|
| 114 |
+
name=self.config.criterion.name,
|
| 115 |
+
config=self.config.criterion,
|
| 116 |
+
model=self.model,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def setup_metric_bags(self):
|
| 120 |
+
self.train_metric_bag = LCMMetricBag(
|
| 121 |
+
self.gang,
|
| 122 |
+
loss_summands=self.criterion.summands,
|
| 123 |
+
reduction=self.criterion.reduction,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.register_non_stateful(
|
| 127 |
+
"valid_metric_bag",
|
| 128 |
+
{
|
| 129 |
+
ds_name(dataset): LCMMetricBag(
|
| 130 |
+
self.gang,
|
| 131 |
+
loss_summands=self.criterion.summands,
|
| 132 |
+
reduction=self.criterion.reduction,
|
| 133 |
+
)
|
| 134 |
+
for dataset in self.config.validation_data
|
| 135 |
+
},
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def create_model_card_for_last_checkpoint(
|
| 139 |
+
self, is_final: bool = True, **card_kwargs
|
| 140 |
+
) -> Optional[AssetCard]:
|
| 141 |
+
"""Create a model card based on the last saved
|
| 142 |
+
checkpoint and the model config."""
|
| 143 |
+
|
| 144 |
+
current_step_number: Optional[int] = None
|
| 145 |
+
if is_final:
|
| 146 |
+
steps = self.checkpoint_manager.get_step_numbers()
|
| 147 |
+
current_step_number = steps[-1] if len(steps) else None
|
| 148 |
+
else:
|
| 149 |
+
current_step_number = self.checkpoint_manager._get_checkpoint_step_nr()
|
| 150 |
+
|
| 151 |
+
if current_step_number is None:
|
| 152 |
+
logger.warning(
|
| 153 |
+
"No checkpoint was saved, the final model card wil not be created"
|
| 154 |
+
)
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
cp_fn = (
|
| 158 |
+
self.checkpoint_manager._checkpoint_dir
|
| 159 |
+
/ f"step_{current_step_number}"
|
| 160 |
+
/ "model.pt" # type: ignore
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
card = create_model_card(
|
| 164 |
+
checkpoint_path=cp_fn.absolute(),
|
| 165 |
+
model_arch=self.card_metadata["model_arch"],
|
| 166 |
+
model_config=self.card_metadata["model_config"],
|
| 167 |
+
model_type=self.card_metadata["model_type"],
|
| 168 |
+
**card_kwargs,
|
| 169 |
+
)
|
| 170 |
+
return card
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class LCMTrainerBuilder(TrainerBuilder):
|
| 174 |
+
config: LCMTrainingConfig
|
| 175 |
+
|
| 176 |
+
def __init__(self, config: LCMTrainingConfig):
|
| 177 |
+
super().__init__(config)
|
| 178 |
+
|
| 179 |
+
def load_data(self):
|
| 180 |
+
"""Load training and validation data"""
|
| 181 |
+
|
| 182 |
+
training_data_loader = LCMDataLoader(
|
| 183 |
+
data_config=self.config.data_loading_config,
|
| 184 |
+
datasets=self.config.training_data,
|
| 185 |
+
max_subword_length=self.config.max_subword_length,
|
| 186 |
+
dtype=self.dtype,
|
| 187 |
+
gang=self.gang,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
validation_data_loader = LCMDataLoader(
|
| 191 |
+
data_config=self.config.validation_data_loading_config,
|
| 192 |
+
datasets=self.config.validation_data,
|
| 193 |
+
max_subword_length=self.config.max_subword_length,
|
| 194 |
+
dtype=self.dtype,
|
| 195 |
+
gang=self.gang,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return training_data_loader, validation_data_loader
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def model_loader(self):
|
| 202 |
+
"""A fairseq2 ModelLoader"""
|
| 203 |
+
return load_base_lcm_model
|
| 204 |
+
|
| 205 |
+
def build_trainer(self):
|
| 206 |
+
"""Build the trainer by loading data and
|
| 207 |
+
setting up the model for training"""
|
| 208 |
+
|
| 209 |
+
training_data_loader, validation_data_loader = self.load_data()
|
| 210 |
+
|
| 211 |
+
checkpoint_manager = FileCheckpointManager(
|
| 212 |
+
self.config.output_dir.joinpath("checkpoints"),
|
| 213 |
+
self.gang,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.has_checkpoint = checkpoint_manager.has_checkpoint()
|
| 217 |
+
|
| 218 |
+
model = self.create_model()
|
| 219 |
+
|
| 220 |
+
# Force all model parameters to bfloat16 regardless of submodule defaults
|
| 221 |
+
model = model.to(dtype=torch.bfloat16)
|
| 222 |
+
|
| 223 |
+
model = self.maybe_load_model(model)
|
| 224 |
+
|
| 225 |
+
model = self.maybe_freeze_parameters(model)
|
| 226 |
+
|
| 227 |
+
# If using the META device, we need to move the model to gang.device
|
| 228 |
+
wrapped_model = None
|
| 229 |
+
|
| 230 |
+
if self.use_fsdp:
|
| 231 |
+
wrapped_model = self.wrap_model_with_fsdp(model)
|
| 232 |
+
elif self.use_ddp:
|
| 233 |
+
wrapped_model = self.wrap_model_with_ddp(model) # type: ignore
|
| 234 |
+
|
| 235 |
+
trainer = LCMTrainer(
|
| 236 |
+
self.config, # type: ignore
|
| 237 |
+
wrapped_model or model,
|
| 238 |
+
training_data_loader,
|
| 239 |
+
validation_data_loader,
|
| 240 |
+
self.gang,
|
| 241 |
+
checkpoint_manager,
|
| 242 |
+
self.rng_bag,
|
| 243 |
+
self.stopwatch,
|
| 244 |
+
card_metadata=self.card_metadata,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
trainer.setup()
|
| 248 |
+
|
| 249 |
+
if self.has_checkpoint:
|
| 250 |
+
trainer.restore()
|
| 251 |
+
|
| 252 |
+
return trainer
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def prepare_lcm_trainer(config: LCMTrainingConfig) -> LCMTrainer:
|
| 256 |
+
"""Create an LCM Trainer.
|
| 257 |
+
:param config: The training configuration.
|
| 258 |
+
"""
|
| 259 |
+
return LCMTrainerBuilder(config).build_trainer()
|
lcm/train/metrics.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from collections.abc import MutableMapping
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import (
|
| 11 |
+
Any,
|
| 12 |
+
Callable,
|
| 13 |
+
Dict,
|
| 14 |
+
List,
|
| 15 |
+
Mapping,
|
| 16 |
+
Optional,
|
| 17 |
+
Sequence,
|
| 18 |
+
Set,
|
| 19 |
+
Tuple,
|
| 20 |
+
Union,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from fairseq2.gang import Gang
|
| 25 |
+
from fairseq2.logging import get_log_writer
|
| 26 |
+
from fairseq2.metrics import (
|
| 27 |
+
MetricBag,
|
| 28 |
+
format_as_float,
|
| 29 |
+
format_as_int,
|
| 30 |
+
format_as_seconds,
|
| 31 |
+
)
|
| 32 |
+
from fairseq2.metrics.recorder import (
|
| 33 |
+
MetricRecorder,
|
| 34 |
+
_metric_formatters,
|
| 35 |
+
register_metric_formatter,
|
| 36 |
+
)
|
| 37 |
+
from fairseq2.typing import override
|
| 38 |
+
from torch import Tensor
|
| 39 |
+
from torch.cuda import _get_device_index
|
| 40 |
+
from torcheval.metrics import Max, Mean, Sum, Throughput
|
| 41 |
+
|
| 42 |
+
logger = get_log_writer(__name__)
|
| 43 |
+
|
| 44 |
+
format_as_percent = partial(format_as_int, postfix="%")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = ".") -> Dict:
|
| 48 |
+
"""
|
| 49 |
+
A helper function to flatten nested dictionaries
|
| 50 |
+
Example. With a training config like
|
| 51 |
+
config = {
|
| 52 |
+
'data': {
|
| 53 |
+
'training': {'batch_size': 10},
|
| 54 |
+
'validation': {'batch_size': 2}
|
| 55 |
+
},
|
| 56 |
+
'model': {'model_dim': 1024},
|
| 57 |
+
'use_fsdp': True
|
| 58 |
+
}
|
| 59 |
+
The flat config will be:
|
| 60 |
+
{
|
| 61 |
+
'data.training.batch_size': 10,
|
| 62 |
+
'data.validation.batch_size': 2,
|
| 63 |
+
'model.model_dim': 1024,
|
| 64 |
+
'use_fsdp': True
|
| 65 |
+
}
|
| 66 |
+
This helper is used to convert our nested training config into a flat
|
| 67 |
+
dictionary for Tensoarboard's HParams conusmption
|
| 68 |
+
|
| 69 |
+
"""
|
| 70 |
+
items: List = []
|
| 71 |
+
for k, v in d.items():
|
| 72 |
+
new_key = parent_key + sep + k if parent_key else k
|
| 73 |
+
if isinstance(v, MutableMapping):
|
| 74 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
| 75 |
+
else:
|
| 76 |
+
items.append((new_key, v))
|
| 77 |
+
return dict(items)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_allocated_gpu_memory(device):
|
| 81 |
+
"""
|
| 82 |
+
Get allocated memory in GiB for GPU devices
|
| 83 |
+
"""
|
| 84 |
+
if device.type == "cpu":
|
| 85 |
+
return 0, 0
|
| 86 |
+
device = _get_device_index(device, optional=True)
|
| 87 |
+
memory_stats = torch.cuda.memory_stats(device=device)
|
| 88 |
+
current_usage = memory_stats["allocated_bytes.all.current"] / (1024**3)
|
| 89 |
+
peak_usage = memory_stats["allocated_bytes.all.peak"] / (1024**3)
|
| 90 |
+
return current_usage, peak_usage
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class LossTerm:
|
| 95 |
+
"""Dataclass for a batch loss term"""
|
| 96 |
+
|
| 97 |
+
value: Tensor
|
| 98 |
+
"""The final loss to be optimized"""
|
| 99 |
+
|
| 100 |
+
batch_size: int
|
| 101 |
+
|
| 102 |
+
num_target_elements: Union[int, float]
|
| 103 |
+
|
| 104 |
+
summands: Dict[str, Tuple[Any, Any]] = field(default_factory=lambda: {})
|
| 105 |
+
"""A dictionary of loss terms to record. Each term is a tuple of (loss, number of elements)
|
| 106 |
+
The second term is optional; if None, we will use `num_target_elements` when aggregating"""
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LCMMetricBag(MetricBag):
|
| 110 |
+
"""Holds the common metrics of an LCM."""
|
| 111 |
+
|
| 112 |
+
loss: Mean
|
| 113 |
+
batch_size: Sum
|
| 114 |
+
elements_per_batch: Mean
|
| 115 |
+
elements_per_second: Throughput
|
| 116 |
+
num_target_elements: Sum
|
| 117 |
+
total_num_target_elements: Sum
|
| 118 |
+
|
| 119 |
+
grad_norm: Mean
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self, gang: Gang, loss_summands: Sequence[str] = [], reduction: str = "sum"
|
| 123 |
+
) -> None:
|
| 124 |
+
"""
|
| 125 |
+
:param gang:
|
| 126 |
+
The gang to sync metrics across all processes.
|
| 127 |
+
"""
|
| 128 |
+
super().__init__(gang)
|
| 129 |
+
|
| 130 |
+
# temporary fix:
|
| 131 |
+
|
| 132 |
+
self.reduction = reduction
|
| 133 |
+
|
| 134 |
+
d = gang.device
|
| 135 |
+
|
| 136 |
+
# A temporary solution to track as many loss terms as we explore
|
| 137 |
+
self.loss_summands = loss_summands
|
| 138 |
+
|
| 139 |
+
self.register_metric("loss", Mean(device=d), persistent=False)
|
| 140 |
+
|
| 141 |
+
# this is the effective batch size
|
| 142 |
+
self.register_metric("batch_size", Sum(device=d), persistent=False)
|
| 143 |
+
|
| 144 |
+
self.register_metric("elements_per_batch", Mean(device=d), persistent=False)
|
| 145 |
+
|
| 146 |
+
self.register_metric(
|
| 147 |
+
"elements_per_second", Throughput(device=d), persistent=False
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.register_metric("gpu_memory_usage", Max(device=d), persistent=False)
|
| 151 |
+
|
| 152 |
+
self.register_metric("gpu_peak_memory_usage", Max(device=d), persistent=False)
|
| 153 |
+
|
| 154 |
+
# self.register_metric("ram_percentage", Max(device=d), persistent=False)
|
| 155 |
+
|
| 156 |
+
# self.register_metric("cpu_percentage", Max(device=d), persistent=False)
|
| 157 |
+
|
| 158 |
+
for summand in self.loss_summands:
|
| 159 |
+
self.register_metric(summand, Mean(device=d), persistent=False)
|
| 160 |
+
|
| 161 |
+
# The number of target tokens in a parallel batch. Used for computing throughput
|
| 162 |
+
self.register_metric("num_target_elements", Sum(device=d), persistent=False)
|
| 163 |
+
|
| 164 |
+
# The total_num_target_elements is persistent and is supposed to track the
|
| 165 |
+
# total number of tokens consumed since training started
|
| 166 |
+
self.total_num_target_elements = Sum(device=d)
|
| 167 |
+
|
| 168 |
+
def register_adaln_metric(self, module_name: str):
|
| 169 |
+
for block in ["mha", "ffn"]:
|
| 170 |
+
for tensor in [
|
| 171 |
+
"shift",
|
| 172 |
+
"scale",
|
| 173 |
+
"gate",
|
| 174 |
+
]:
|
| 175 |
+
self.register_metric(
|
| 176 |
+
f"{module_name}_{block}_{tensor}_mean",
|
| 177 |
+
Mean(device=self._gang.device),
|
| 178 |
+
persistent=False,
|
| 179 |
+
)
|
| 180 |
+
self.register_metric(
|
| 181 |
+
f"{module_name}_{block}_{tensor}_std",
|
| 182 |
+
Mean(device=self._gang.device),
|
| 183 |
+
persistent=False,
|
| 184 |
+
)
|
| 185 |
+
# formatters
|
| 186 |
+
register_metric_formatter(
|
| 187 |
+
f"{module_name}_{block}_{tensor}_mean",
|
| 188 |
+
f"{module_name}_{block}_{tensor}_mean",
|
| 189 |
+
1000,
|
| 190 |
+
format_as_float,
|
| 191 |
+
)
|
| 192 |
+
register_metric_formatter(
|
| 193 |
+
f"{module_name}_{block}_{tensor}_std",
|
| 194 |
+
f"{module_name}_{block}_{tensor}_std",
|
| 195 |
+
1000,
|
| 196 |
+
format_as_float,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def register_module_metric(self, module_name: str):
|
| 200 |
+
for tensor in [
|
| 201 |
+
"input_gradient",
|
| 202 |
+
"output_gradient",
|
| 203 |
+
"input_activations",
|
| 204 |
+
"output_activations",
|
| 205 |
+
]:
|
| 206 |
+
self.register_metric(
|
| 207 |
+
f"{module_name}_{tensor}_mean",
|
| 208 |
+
Mean(device=self._gang.device),
|
| 209 |
+
persistent=False,
|
| 210 |
+
)
|
| 211 |
+
self.register_metric(
|
| 212 |
+
f"{module_name}_{tensor}_std",
|
| 213 |
+
Mean(device=self._gang.device),
|
| 214 |
+
persistent=False,
|
| 215 |
+
)
|
| 216 |
+
# formatters
|
| 217 |
+
register_metric_formatter(
|
| 218 |
+
f"{module_name}_{tensor}_mean",
|
| 219 |
+
f"{module_name}_{tensor}_mean",
|
| 220 |
+
1000,
|
| 221 |
+
format_as_float,
|
| 222 |
+
)
|
| 223 |
+
register_metric_formatter(
|
| 224 |
+
f"{module_name}_{tensor}_std",
|
| 225 |
+
f"{module_name}_{tensor}_std",
|
| 226 |
+
1000,
|
| 227 |
+
format_as_float,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
@torch.inference_mode()
|
| 231 |
+
def update(
|
| 232 |
+
self,
|
| 233 |
+
losses: Sequence[LossTerm],
|
| 234 |
+
) -> None:
|
| 235 |
+
"""Update the metrics.
|
| 236 |
+
|
| 237 |
+
:param output:
|
| 238 |
+
The losses generated by the model for each batch
|
| 239 |
+
:param elapsed_time:
|
| 240 |
+
The total elapsed time to read and process batches
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
loss = torch.zeros((), dtype=torch.float64)
|
| 244 |
+
|
| 245 |
+
loss_summands = {
|
| 246 |
+
s: torch.zeros((), dtype=torch.float64) for s in self.loss_summands
|
| 247 |
+
}
|
| 248 |
+
# Denominator to normalize the loss summands, if -1,
|
| 249 |
+
# we will default to normalizing with `num_target_elements`
|
| 250 |
+
loss_summands_numel = {
|
| 251 |
+
s: -torch.ones((), dtype=torch.long) for s in self.loss_summands
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
batch_size = torch.zeros((), dtype=torch.int64)
|
| 255 |
+
|
| 256 |
+
num_target_elements = torch.zeros((), dtype=torch.int64)
|
| 257 |
+
|
| 258 |
+
# Only in the case of using gradient accumulation that `losses` will be a non-singleton
|
| 259 |
+
for batch_loss in losses:
|
| 260 |
+
loss += float(batch_loss.value)
|
| 261 |
+
|
| 262 |
+
for s in self.loss_summands:
|
| 263 |
+
loss_term = batch_loss.summands.get(s, (0.0, None))
|
| 264 |
+
loss_summands[s] += float(loss_term[0])
|
| 265 |
+
if loss_term[1] is not None and not loss_term[1] == -1:
|
| 266 |
+
if loss_summands_numel[s] == -1:
|
| 267 |
+
loss_summands_numel[s] = torch.zeros((), dtype=torch.int64)
|
| 268 |
+
loss_summands_numel[s] += loss_term[1]
|
| 269 |
+
|
| 270 |
+
batch_size += batch_loss.batch_size
|
| 271 |
+
num_target_elements += batch_loss.num_target_elements
|
| 272 |
+
|
| 273 |
+
# Misleading normalization in the metric bag with reduction == "mean"
|
| 274 |
+
# Kept here for backward compatibility
|
| 275 |
+
# Any normalization here is only for reporting and doesn't impact optimization
|
| 276 |
+
if self.reduction == "sum":
|
| 277 |
+
loss /= num_target_elements
|
| 278 |
+
keys = list(loss_summands)
|
| 279 |
+
for k in keys:
|
| 280 |
+
denom = loss_summands_numel[k]
|
| 281 |
+
if denom == -1:
|
| 282 |
+
denom = num_target_elements
|
| 283 |
+
loss_summands[k] /= denom + 1e-6
|
| 284 |
+
|
| 285 |
+
self.loss.update(loss, weight=num_target_elements)
|
| 286 |
+
|
| 287 |
+
for s in loss_summands:
|
| 288 |
+
weight = loss_summands_numel[s]
|
| 289 |
+
if weight == -1:
|
| 290 |
+
weight = num_target_elements
|
| 291 |
+
getattr(self, s).update(loss_summands[s], weight=weight)
|
| 292 |
+
|
| 293 |
+
self.batch_size.update(batch_size)
|
| 294 |
+
|
| 295 |
+
self.elements_per_batch.update(num_target_elements)
|
| 296 |
+
|
| 297 |
+
self.num_target_elements.update(num_target_elements)
|
| 298 |
+
|
| 299 |
+
# update the cumulative metric
|
| 300 |
+
self.total_num_target_elements.update(num_target_elements)
|
| 301 |
+
|
| 302 |
+
# Get GPU memory usage
|
| 303 |
+
gpu_memory_usage, gpu_peak_memory_usage = get_allocated_gpu_memory(
|
| 304 |
+
self._gang.device
|
| 305 |
+
)
|
| 306 |
+
self.gpu_memory_usage.update(torch.tensor(gpu_memory_usage))
|
| 307 |
+
self.gpu_peak_memory_usage.update(torch.tensor(gpu_peak_memory_usage))
|
| 308 |
+
|
| 309 |
+
def reset_batch_metrics(self) -> None:
|
| 310 |
+
"""Reset the batch metrics to their initial state."""
|
| 311 |
+
self.loss.reset()
|
| 312 |
+
for s in self.loss_summands:
|
| 313 |
+
getattr(self, s).reset()
|
| 314 |
+
|
| 315 |
+
self.batch_size.reset()
|
| 316 |
+
self.elements_per_batch.reset()
|
| 317 |
+
self.elements_per_second.reset()
|
| 318 |
+
self.grad_norm.reset()
|
| 319 |
+
self.gpu_memory_usage.reset()
|
| 320 |
+
self.gpu_peak_memory_usage.reset()
|
| 321 |
+
# self.ram_percentage.reset()
|
| 322 |
+
# self.cpu_percentage.reset()
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
## Weight and Biases recorder
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
import wandb # type: ignore[import-not-found]
|
| 329 |
+
except ImportError:
|
| 330 |
+
has_wandb = False
|
| 331 |
+
else:
|
| 332 |
+
has_wandb = True
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class LCMWandBRecorder(MetricRecorder):
|
| 336 |
+
"""Records metric values to Weights & Biases."""
|
| 337 |
+
|
| 338 |
+
defined_runs: Set[str] = set()
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
project: Optional[str] = None,
|
| 343 |
+
name: Optional[str] = None,
|
| 344 |
+
output_dir: Optional[Path] = None,
|
| 345 |
+
config: Dict[str, Any] = {},
|
| 346 |
+
**kwargs,
|
| 347 |
+
) -> None:
|
| 348 |
+
"""
|
| 349 |
+
:param project: A project to organise this run with other experiments, if none, the run will go under `uncategorized`.
|
| 350 |
+
:param name: A unique name for your run, if none is given, a random name will be generated
|
| 351 |
+
:param output_dir: The base directory under which to store the W&B files. You don't have to provide this.
|
| 352 |
+
:param config: A dictionary of key-value pairs to be stored as the experiment's config. (akin to hparams in tb)
|
| 353 |
+
:param kwargs: Additional arguments to pass to wandb.init()
|
| 354 |
+
|
| 355 |
+
In order to use W&B, run `wandb login` from the command line and enter
|
| 356 |
+
the API key when prompted.
|
| 357 |
+
"""
|
| 358 |
+
if not has_wandb:
|
| 359 |
+
log = get_log_writer(__name__)
|
| 360 |
+
log.warning("wandb not found. Please install it with `pip install wandb`.") # fmt: skip
|
| 361 |
+
|
| 362 |
+
self._run = None
|
| 363 |
+
else:
|
| 364 |
+
if output_dir:
|
| 365 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 366 |
+
self._run = wandb.init( # type: ignore
|
| 367 |
+
project=project,
|
| 368 |
+
name=name,
|
| 369 |
+
dir=output_dir,
|
| 370 |
+
resume="allow",
|
| 371 |
+
config=config,
|
| 372 |
+
**kwargs,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
def _define_run(self, run: str):
|
| 376 |
+
if run in self.defined_runs:
|
| 377 |
+
return
|
| 378 |
+
# https://docs.wandb.ai/guides/track/log/customize-logging-axes/
|
| 379 |
+
wandb.define_metric(f"{run}/step")
|
| 380 |
+
wandb.define_metric(f"{run}/*", step_metric=f"{run}/step")
|
| 381 |
+
|
| 382 |
+
@override
|
| 383 |
+
def record_metrics(
|
| 384 |
+
self,
|
| 385 |
+
run: str,
|
| 386 |
+
values: Mapping[str, Any],
|
| 387 |
+
step_nr: Optional[int] = None,
|
| 388 |
+
*,
|
| 389 |
+
flush: bool = True,
|
| 390 |
+
) -> None:
|
| 391 |
+
if self._run is None:
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
self._define_run(run)
|
| 395 |
+
|
| 396 |
+
for name, value in values.items():
|
| 397 |
+
formatter = _metric_formatters.get(name)
|
| 398 |
+
if formatter is None:
|
| 399 |
+
display_name = name
|
| 400 |
+
else:
|
| 401 |
+
display_name = formatter.display_name
|
| 402 |
+
|
| 403 |
+
self._run.log({f"{run}/{display_name}": value, f"{run}/step": step_nr})
|
| 404 |
+
|
| 405 |
+
@override
|
| 406 |
+
def close(self) -> None:
|
| 407 |
+
if self._run is not None:
|
| 408 |
+
self._run.finish()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
lcm_metric_formatters: Dict[str, Tuple[str, int, Callable[[Any], str]]] = {
|
| 412 |
+
# fmt: off
|
| 413 |
+
"loss": ("Loss", 100, format_as_float),
|
| 414 |
+
"nll_loss": ("NLL Loss", 100, format_as_float),
|
| 415 |
+
"mse_loss": ("MSE Loss", 100, format_as_float),
|
| 416 |
+
"contrastive_loss": ("Contrastive Loss", 110, format_as_float),
|
| 417 |
+
"reconstruction_loss": ("Reconstruction loss", 110, format_as_float),
|
| 418 |
+
"unnormalized_reconstruction_loss": (
|
| 419 |
+
"Unnormalized Reconstruction Loss",
|
| 420 |
+
110,
|
| 421 |
+
format_as_float,
|
| 422 |
+
),
|
| 423 |
+
"kld": ("KLD loss", 110, format_as_float),
|
| 424 |
+
"encoder_mse_loss": ("Encoder MSE loss", 110, format_as_float),
|
| 425 |
+
"decoder_ce_loss": ("Decoder CE loss", 110, format_as_float),
|
| 426 |
+
"elapsed_time": ("Elapsed Time", 500, format_as_seconds),
|
| 427 |
+
"wall_time": ("Wall Time", 510, format_as_seconds),
|
| 428 |
+
"lr": ("Learning Rate", 800, format_as_float),
|
| 429 |
+
"loss_scale": ("Loss Scale", 810, format_as_float),
|
| 430 |
+
"grad_norm": ("Grad norm", 810, format_as_float),
|
| 431 |
+
"raw_grad_norm": ("Raw Grad norm", 815, format_as_float),
|
| 432 |
+
"encoder_mse_scale": ("Encoder MSE loss scale", 850, format_as_float),
|
| 433 |
+
"batch_size": ("Batch Size", 900, format_as_int),
|
| 434 |
+
"elements_per_batch": ("Elements per Batch", 900, format_as_int),
|
| 435 |
+
"elements_per_second": ("Elements per Second", 900, format_as_int),
|
| 436 |
+
"num_examples": ("Number of Examples", 900, format_as_int),
|
| 437 |
+
"num_source_elements": ("Number of Source Elements", 900, format_as_int),
|
| 438 |
+
"num_target_elements": ("Number of Target Elements", 900, format_as_int),
|
| 439 |
+
"total_num_target_elements": ("Accumulated Target Elements", 920, format_as_int),
|
| 440 |
+
"gpu_memory_usage": ("GPU memory usage (GiB)", 910, format_as_float),
|
| 441 |
+
"gpu_peak_memory_usage": ("GPU peak memory usage (GiB)", 910, format_as_float),
|
| 442 |
+
"ram_percentage": ("RAM usage", 920, format_as_percent),
|
| 443 |
+
"cpu_percentage": ("CPU usage", 920, format_as_percent),
|
| 444 |
+
"mean_predicted_embeddings": ("mean_predicted_embeddings", 920, format_as_float),
|
| 445 |
+
"std_predicted_embeddings": ("std_predicted_embeddings", 920, format_as_float),
|
| 446 |
+
# fmt: on
|
| 447 |
+
}
|
| 448 |
+
for key in lcm_metric_formatters:
|
| 449 |
+
register_metric_formatter(key, *lcm_metric_formatters[key], overwrite=True)
|
lcm/train/mse_lcm/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
lcm/train/mse_lcm/criterion.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from fairseq2.logging import get_log_writer
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
from lcm.datasets.batch import EmbeddingsBatch, LCMInput, LCMStyle
|
| 14 |
+
from lcm.models.abstract_lcm import AbstractLCModel
|
| 15 |
+
from lcm.train.criterion import CriterionsFactory
|
| 16 |
+
from lcm.train.lcm.criterion import (
|
| 17 |
+
LCMCriterion,
|
| 18 |
+
LCMCriterionConfig,
|
| 19 |
+
compute_standard_mse,
|
| 20 |
+
)
|
| 21 |
+
from lcm.train.metrics import LossTerm
|
| 22 |
+
|
| 23 |
+
logger = get_log_writer(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class ReconstructionCriterionConfig(LCMCriterionConfig):
|
| 28 |
+
min_context_size: int = 1
|
| 29 |
+
"""minimum context size for next sentence prediction"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@CriterionsFactory.register("next_sentence_mse")
|
| 33 |
+
class ReconstructionCriterion(LCMCriterion):
|
| 34 |
+
"""Computes the MSE reconstruction loss for next-sentence prediction"""
|
| 35 |
+
|
| 36 |
+
config: ReconstructionCriterionConfig
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
config: ReconstructionCriterionConfig,
|
| 41 |
+
model: AbstractLCModel,
|
| 42 |
+
style: LCMStyle = LCMStyle.UNSUPERVISED,
|
| 43 |
+
):
|
| 44 |
+
super().__init__(config, model, style)
|
| 45 |
+
|
| 46 |
+
if style is not LCMStyle.SUPERVISED:
|
| 47 |
+
assert (
|
| 48 |
+
config.min_context_size is not None and config.min_context_size > 0
|
| 49 |
+
), (
|
| 50 |
+
"For unsupervised pre-training, expecting a min_context_size of at least 1. "
|
| 51 |
+
f"Received min_context_size={config.min_context_size}. "
|
| 52 |
+
"Note that we need some context to predict the first position and "
|
| 53 |
+
"this context can come from a dummy `beginning of document (BOD)` vector. "
|
| 54 |
+
"With a minimum context size of 1 we ensure that we never ask the model to predict BOD"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.min_context_size = config.min_context_size
|
| 58 |
+
|
| 59 |
+
def prepare_input_and_mask(
|
| 60 |
+
self,
|
| 61 |
+
batch: LCMInput,
|
| 62 |
+
) -> Tuple[EmbeddingsBatch, torch.Tensor]:
|
| 63 |
+
"""
|
| 64 |
+
A method for preparing model inputs and mask for a batch.
|
| 65 |
+
It will be typically reused by the `__call__`
|
| 66 |
+
implementations of the subclasses.
|
| 67 |
+
"""
|
| 68 |
+
input_embeddings = batch.prepare_input(style=self.style)
|
| 69 |
+
|
| 70 |
+
target_mask = batch.prepare_target_mask(
|
| 71 |
+
input_embeddings,
|
| 72 |
+
style=self.style,
|
| 73 |
+
min_context_size=self.config.min_context_size,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return input_embeddings, target_mask
|
| 77 |
+
|
| 78 |
+
def __call__(self, batch: LCMInput) -> LossTerm:
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
batch is an LCMInput (see lcm.datasets.batch):
|
| 82 |
+
|
| 83 |
+
Returns a LossTerm
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
# prepare_input_and mask returns embeddings with seqs in B,T,C
|
| 87 |
+
# and a target mask in B,T,C. Note that the first position is never used as target
|
| 88 |
+
# (i.e. BOS vector or first sentence in the document) and will always be set to False
|
| 89 |
+
# in the target mask
|
| 90 |
+
input_embeddings, target_mask = self.prepare_input_and_mask(batch)
|
| 91 |
+
|
| 92 |
+
if self.normalize_in_criterion:
|
| 93 |
+
# the input to the model will be normalize and
|
| 94 |
+
# so is the target used for loss computation
|
| 95 |
+
input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
|
| 96 |
+
|
| 97 |
+
# Predict model outputs
|
| 98 |
+
output_embeddings = self.model(input_embeddings)
|
| 99 |
+
|
| 100 |
+
# Prepare predictions and targets:
|
| 101 |
+
# Shift the input to remove the first position.
|
| 102 |
+
# Shifted seqs from input_embeddings are used as ground truth target embeddings
|
| 103 |
+
target_seqs = input_embeddings.seqs[:, 1:].contiguous()
|
| 104 |
+
batch_size, _, sonar_dim = target_seqs.size()
|
| 105 |
+
|
| 106 |
+
# shift and flatten
|
| 107 |
+
target_mask = target_mask[:, 1:].reshape(-1)
|
| 108 |
+
# i.e. s2, s3, s4, s5
|
| 109 |
+
|
| 110 |
+
# Trim the last position.
|
| 111 |
+
# output_seqs represent contextualized embeddings / predictions for the next sentence
|
| 112 |
+
# This shifting/trimming allows us to predict `s_t` conditioned on `s_{<t}`
|
| 113 |
+
predicted_seqs = output_embeddings.seqs[:, :-1].contiguous()
|
| 114 |
+
# i.e. s<=1, s<=2, s<=3, s<=4
|
| 115 |
+
|
| 116 |
+
# only measure distance over `target_mask = True` positions
|
| 117 |
+
flattened_predictions = predicted_seqs.view(-1, sonar_dim)[target_mask]
|
| 118 |
+
flattened_target = target_seqs.view(-1, sonar_dim)[target_mask]
|
| 119 |
+
|
| 120 |
+
# Cast features to float32 before computing the loss:
|
| 121 |
+
reconstruction_loss, mse_loss = self.compute_loss(
|
| 122 |
+
flattened_predictions.float(), flattened_target.float()
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
num_target_elements = target_mask.sum()
|
| 126 |
+
|
| 127 |
+
if self.reduction == "sum" or num_target_elements == 0:
|
| 128 |
+
reduced_reconstruction_loss = reconstruction_loss.sum()
|
| 129 |
+
mse_loss = mse_loss.sum()
|
| 130 |
+
|
| 131 |
+
elif self.reduction == "mean":
|
| 132 |
+
reduced_reconstruction_loss = reconstruction_loss.mean()
|
| 133 |
+
mse_loss = mse_loss.mean()
|
| 134 |
+
|
| 135 |
+
final_loss = reduced_reconstruction_loss
|
| 136 |
+
|
| 137 |
+
# Loss summands for records
|
| 138 |
+
summands = {
|
| 139 |
+
"mse_loss": (mse_loss.item(), None),
|
| 140 |
+
"reconstruction_loss": (reduced_reconstruction_loss.item(), None),
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
return LossTerm(
|
| 144 |
+
value=final_loss,
|
| 145 |
+
batch_size=batch_size,
|
| 146 |
+
num_target_elements=num_target_elements.item(),
|
| 147 |
+
summands=summands,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def compute_loss(
|
| 151 |
+
self, flattened_predictions, flattened_target
|
| 152 |
+
) -> Tuple[Tensor, Tensor]:
|
| 153 |
+
"""
|
| 154 |
+
Computes the following loss terms:
|
| 155 |
+
1. The Reconstruction loss we want to optimize as well as:
|
| 156 |
+
2. RMSE loss (for tracking) (in this parent class, RMSE=Reconstruction loss)
|
| 157 |
+
Returns reconstruction_loss, mse_loss
|
| 158 |
+
"""
|
| 159 |
+
reconstruction_loss, _ = compute_standard_mse(
|
| 160 |
+
flattened_predictions, flattened_target
|
| 161 |
+
)
|
| 162 |
+
if self.config.compute_rmse:
|
| 163 |
+
epsilon = 1e-5
|
| 164 |
+
reconstruction_loss = torch.sqrt(reconstruction_loss + epsilon)
|
| 165 |
+
|
| 166 |
+
return reconstruction_loss, reconstruction_loss
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@CriterionsFactory.register("target_mse")
|
| 170 |
+
class TargetMSECriterion(ReconstructionCriterion):
|
| 171 |
+
"""Computes the LCM training objective given source/target pairs"""
|
| 172 |
+
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
config: ReconstructionCriterionConfig,
|
| 176 |
+
model: AbstractLCModel,
|
| 177 |
+
style: LCMStyle = LCMStyle.SUPERVISED,
|
| 178 |
+
):
|
| 179 |
+
super().__init__(config, model, style)
|
lcm/train/optim.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
from fairseq2.logging import get_log_writer
|
| 9 |
+
from fairseq2.optim.lr_scheduler import (
|
| 10 |
+
AbstractLRScheduler,
|
| 11 |
+
CosineAnnealingLR,
|
| 12 |
+
MyleLR,
|
| 13 |
+
NoopLR,
|
| 14 |
+
PolynomialDecayLR,
|
| 15 |
+
TriStageLR,
|
| 16 |
+
)
|
| 17 |
+
from torch.optim import Optimizer
|
| 18 |
+
|
| 19 |
+
logger = get_log_writer(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_lr_scheduler(
|
| 23 |
+
optimizer: Optimizer,
|
| 24 |
+
lr: float,
|
| 25 |
+
warmup_steps: int,
|
| 26 |
+
start_lr: float = 1e-7,
|
| 27 |
+
final_lr: float = 1e-5,
|
| 28 |
+
max_steps: int = 10_000,
|
| 29 |
+
stage_ratio: Tuple[float, ...] = (0.1, 0.4, 0.5),
|
| 30 |
+
schedule: str = "myle",
|
| 31 |
+
) -> AbstractLRScheduler:
|
| 32 |
+
assert schedule in [
|
| 33 |
+
"noop",
|
| 34 |
+
"myle",
|
| 35 |
+
"cosine",
|
| 36 |
+
"wsd",
|
| 37 |
+
"polynomial",
|
| 38 |
+
], (
|
| 39 |
+
f"Cannot recognize the learing rate schedule {schedule}, only noop, myle, cosine and wsd are supported"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
assert lr > 0, "The learning reate should be strictly positive"
|
| 43 |
+
|
| 44 |
+
lr_scheduler: AbstractLRScheduler
|
| 45 |
+
|
| 46 |
+
if schedule == "noop":
|
| 47 |
+
lr_scheduler = NoopLR(optimizer)
|
| 48 |
+
|
| 49 |
+
elif schedule == "myle":
|
| 50 |
+
lr_scheduler = MyleLR(
|
| 51 |
+
optimizer,
|
| 52 |
+
num_warmup_steps=warmup_steps,
|
| 53 |
+
start_lr=[start_lr],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
elif schedule == "cosine":
|
| 57 |
+
lr_scheduler = CosineAnnealingLR(
|
| 58 |
+
optimizer,
|
| 59 |
+
cycle_len=max_steps - warmup_steps + 1,
|
| 60 |
+
num_warmup_steps=warmup_steps,
|
| 61 |
+
start_lr=[start_lr],
|
| 62 |
+
final_lr=[final_lr],
|
| 63 |
+
cycle_mul=1.0,
|
| 64 |
+
lr_mul=1.0,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
elif schedule == "wsd":
|
| 68 |
+
assert lr > start_lr, (
|
| 69 |
+
f"the starting learning rate {start_lr} should be lesser than the main lr {lr}"
|
| 70 |
+
)
|
| 71 |
+
start_lr_scale = start_lr / lr
|
| 72 |
+
|
| 73 |
+
assert lr > final_lr, (
|
| 74 |
+
f"the final learning rate {final_lr} should be lesser than the main lr {lr}"
|
| 75 |
+
)
|
| 76 |
+
final_lr_scale = final_lr / lr
|
| 77 |
+
|
| 78 |
+
lr_scheduler = TriStageLR(
|
| 79 |
+
optimizer,
|
| 80 |
+
max_steps,
|
| 81 |
+
stage_ratio=stage_ratio, # type: ignore
|
| 82 |
+
start_lr_scale=start_lr_scale,
|
| 83 |
+
final_lr_scale=final_lr_scale,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
elif schedule == "polynomial":
|
| 87 |
+
lr_scheduler = PolynomialDecayLR(
|
| 88 |
+
optimizer,
|
| 89 |
+
max_steps,
|
| 90 |
+
warmup_steps,
|
| 91 |
+
power=200,
|
| 92 |
+
start_lr=start_lr,
|
| 93 |
+
final_lr=final_lr,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return lr_scheduler
|
lcm/train/step_sampler.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Literal, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributions as D
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
from lcm.nn.schedulers import DDIMScheduler
|
| 15 |
+
|
| 16 |
+
SUPPORTED_SAMPLERS = Literal["uniform", "beta"]
|
| 17 |
+
SUPPORTED_WEIGHTINGS = Literal["none", "clamp_snr"]
|
| 18 |
+
|
| 19 |
+
logger = get_log_writer(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def beta_function(a, b):
|
| 23 |
+
result = torch.exp(torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a + b))
|
| 24 |
+
return result
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class StepsSamplerConfig:
|
| 29 |
+
sampling: SUPPORTED_SAMPLERS = "uniform"
|
| 30 |
+
weighting: SUPPORTED_WEIGHTINGS = "none"
|
| 31 |
+
beta_a: float = 0.8
|
| 32 |
+
beta_b: float = 1
|
| 33 |
+
max_gamma: float = 5.0
|
| 34 |
+
min_gamma: float = 0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class StepsSampler(object):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
config: StepsSamplerConfig,
|
| 41 |
+
noise_scheduler: DDIMScheduler,
|
| 42 |
+
):
|
| 43 |
+
num_diffusion_train_steps = noise_scheduler.num_diffusion_train_steps
|
| 44 |
+
weights: Optional[Tensor] = None
|
| 45 |
+
|
| 46 |
+
if config.sampling == "uniform":
|
| 47 |
+
weights = torch.ones(
|
| 48 |
+
num_diffusion_train_steps,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
elif config.sampling == "beta":
|
| 52 |
+
# As motivated in https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/00328.pdf
|
| 53 |
+
a = torch.tensor([config.beta_a])
|
| 54 |
+
b = torch.tensor([config.beta_b])
|
| 55 |
+
# a=1, b=1 -> uniform
|
| 56 |
+
# The paper empirically chooses b=1, a=0.8 < 1
|
| 57 |
+
|
| 58 |
+
steps = (
|
| 59 |
+
torch.arange(1, num_diffusion_train_steps + 1)
|
| 60 |
+
/ num_diffusion_train_steps
|
| 61 |
+
)
|
| 62 |
+
weights = (
|
| 63 |
+
1 / beta_function(a, b) * (steps ** (a - 1)) * ((1 - steps) ** (b - 1))
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
assert weights is not None, "The sampling weights were not properly set!"
|
| 67 |
+
logger.info(f"Training with sampling weights={weights}")
|
| 68 |
+
|
| 69 |
+
self.distrib = D.Categorical(
|
| 70 |
+
probs=weights / weights.sum(),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# setup weights for scaling:
|
| 74 |
+
if config.weighting == "none":
|
| 75 |
+
self.gamma_per_step = None
|
| 76 |
+
|
| 77 |
+
elif config.weighting == "clamp_snr":
|
| 78 |
+
# Min-SNR scheme from
|
| 79 |
+
# https://arxiv.org/abs/2303.09556
|
| 80 |
+
snrs = noise_scheduler.get_snrs()
|
| 81 |
+
# gamma(t) = min(max_gamma, snr(t))
|
| 82 |
+
self.gamma_per_step = torch.clamp(
|
| 83 |
+
snrs, max=config.max_gamma, min=config.min_gamma
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
logger.info(f"Training with Gamma={self.gamma_per_step}")
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def _training_weights(self) -> Tensor:
|
| 90 |
+
return self.distrib.probs
|
| 91 |
+
|
| 92 |
+
def sample(self, size: torch.Size, device: torch.device):
|
| 93 |
+
samples = self.distrib.sample(size).to(device)
|
| 94 |
+
# print('Samples', samples)
|
| 95 |
+
# print('Counts:', torch.bincount(samples.flatten()))
|
| 96 |
+
return samples
|
| 97 |
+
|
| 98 |
+
def get_loss_scales(self, steps):
|
| 99 |
+
if self.gamma_per_step is None:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
# If we're using constant Gamma=1 (returning None), then the sum of
|
| 103 |
+
# the loss scales is steps.numel(), to match the total mass,
|
| 104 |
+
# we normalize the scales to sum to steps.numel()
|
| 105 |
+
gamma = self.gamma_per_step.to(steps.device)[steps]
|
| 106 |
+
gamma = gamma / gamma.sum() * steps.numel()
|
| 107 |
+
return gamma
|
lcm/train/trainer.py
ADDED
|
@@ -0,0 +1,1422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from abc import abstractmethod
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
+
from dataclasses import asdict, dataclass, field
|
| 13 |
+
from functools import cached_property
|
| 14 |
+
from itertools import count
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from pprint import pformat
|
| 17 |
+
from typing import (
|
| 18 |
+
Any,
|
| 19 |
+
ContextManager,
|
| 20 |
+
Dict,
|
| 21 |
+
Iterator,
|
| 22 |
+
List,
|
| 23 |
+
Literal,
|
| 24 |
+
Mapping,
|
| 25 |
+
Optional,
|
| 26 |
+
Tuple,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import yaml
|
| 31 |
+
from fairseq2.assets import AssetCard, AssetCardFieldNotFoundError
|
| 32 |
+
from fairseq2.checkpoint import FileCheckpointManager
|
| 33 |
+
from fairseq2.gang import FakeGang, Gang, ReduceOperation, all_sum
|
| 34 |
+
from fairseq2.logging import get_log_writer
|
| 35 |
+
from fairseq2.metrics import (
|
| 36 |
+
LogMetricRecorder,
|
| 37 |
+
MetricBag,
|
| 38 |
+
MetricRecorder,
|
| 39 |
+
TensorBoardRecorder,
|
| 40 |
+
record_metrics,
|
| 41 |
+
)
|
| 42 |
+
from fairseq2.nn.ddp import to_ddp
|
| 43 |
+
from fairseq2.nn.fsdp import to_fsdp
|
| 44 |
+
from fairseq2.nn.utils.gradient import (
|
| 45 |
+
check_gradient_norms,
|
| 46 |
+
clip_gradient_norm,
|
| 47 |
+
scale_gradients,
|
| 48 |
+
)
|
| 49 |
+
from fairseq2.nn.utils.module import (
|
| 50 |
+
_get_named_modules,
|
| 51 |
+
freeze_parameters,
|
| 52 |
+
to_device,
|
| 53 |
+
)
|
| 54 |
+
from fairseq2.optim import AdamW, DynamicLossScaler
|
| 55 |
+
from fairseq2.optim.lr_scheduler import AbstractLRScheduler, get_effective_lr
|
| 56 |
+
from fairseq2.recipes.utils.log import log_model
|
| 57 |
+
from fairseq2.utils.profiler import Profiler, Stopwatch
|
| 58 |
+
from fairseq2.utils.rng import RngBag
|
| 59 |
+
from fairseq2.utils.state import StatefulObjectBag
|
| 60 |
+
from omegaconf import MISSING
|
| 61 |
+
from stopes.core import Requirements
|
| 62 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
| 63 |
+
FullyShardedDataParallel as FSDP,
|
| 64 |
+
)
|
| 65 |
+
from torch.nn import Module
|
| 66 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 67 |
+
from torch.optim import Optimizer
|
| 68 |
+
from torch.profiler import record_function
|
| 69 |
+
from torcheval.metrics import Mean
|
| 70 |
+
|
| 71 |
+
from lcm.datasets.configs import DataLoadingConfig, ValidationDataLoadingConfig
|
| 72 |
+
from lcm.datasets.dataloading import ds_name
|
| 73 |
+
from lcm.train.metrics import (
|
| 74 |
+
LCMWandBRecorder,
|
| 75 |
+
flatten_dict,
|
| 76 |
+
)
|
| 77 |
+
from lcm.train.optim import build_lr_scheduler
|
| 78 |
+
from lcm.utils.data_utils import update_dataclass
|
| 79 |
+
from lcm.utils.distributed import (
|
| 80 |
+
SUPPORTED_FSDP_MEMORY_POLICIES,
|
| 81 |
+
SUPPORTED_FSDP_WRAP_POLICIES,
|
| 82 |
+
get_fsdp_memory_policy,
|
| 83 |
+
get_fsdp_wrap_policy,
|
| 84 |
+
init_process_group,
|
| 85 |
+
)
|
| 86 |
+
from lcm.utils.logging import (
|
| 87 |
+
log_env_variables,
|
| 88 |
+
setup_additional_logging,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
logger = get_log_writer(__name__)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class TrainingConfig:
|
| 96 |
+
"""Holds the configuration of a training job."""
|
| 97 |
+
|
| 98 |
+
training_data: Any = MISSING
|
| 99 |
+
"""The datasets to train with."""
|
| 100 |
+
|
| 101 |
+
validation_data: Any = MISSING
|
| 102 |
+
"""The datasets to validate on."""
|
| 103 |
+
|
| 104 |
+
model_arch: Optional[str] = None
|
| 105 |
+
"""Starting architecture for the model to train"""
|
| 106 |
+
|
| 107 |
+
model_arch_overrides: Optional[Dict] = None
|
| 108 |
+
"""Dict of parameters to overwrite in `model_arch`"""
|
| 109 |
+
|
| 110 |
+
model_config_or_name: Optional[Any] = None
|
| 111 |
+
"""The model configuration or name to train.
|
| 112 |
+
This option cannot be paired with model_arch + model_arch_overrides
|
| 113 |
+
If provided, this option supersedes model_arch + model_arch_overrides
|
| 114 |
+
"""
|
| 115 |
+
output_dir: Path = MISSING
|
| 116 |
+
"""The output directory to store checkpoints and logs."""
|
| 117 |
+
|
| 118 |
+
log_folder: Optional[Path] = None
|
| 119 |
+
"""The executor's log directory where stdout/stderr will be redirected.
|
| 120 |
+
We will use this directory to optionally enable ATEN and NCCL
|
| 121 |
+
logging (if debug is True) """
|
| 122 |
+
|
| 123 |
+
tb_dir: Optional[Path] = None
|
| 124 |
+
"""The output directory to store tensorbaord logs"""
|
| 125 |
+
|
| 126 |
+
# defaults to "uncategorized"
|
| 127 |
+
wandb_project: Optional[str] = None
|
| 128 |
+
wandb_run_name: Optional[str] = None
|
| 129 |
+
wandb_entity: Optional[str] = None
|
| 130 |
+
|
| 131 |
+
requirements: Requirements = field(
|
| 132 |
+
default_factory=lambda: Requirements(
|
| 133 |
+
nodes=1,
|
| 134 |
+
tasks_per_node=8,
|
| 135 |
+
gpus_per_node=8,
|
| 136 |
+
cpus_per_task=8,
|
| 137 |
+
mem_gb=256,
|
| 138 |
+
timeout_min=3 * 24 * 60,
|
| 139 |
+
constraint="volta32gb",
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
"""The scheduling requirements for this trainer"""
|
| 143 |
+
|
| 144 |
+
data_loading_config: DataLoadingConfig = MISSING
|
| 145 |
+
|
| 146 |
+
validation_data_loading_config: ValidationDataLoadingConfig = field(
|
| 147 |
+
default_factory=lambda: ValidationDataLoadingConfig()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
criterion: Any = MISSING
|
| 151 |
+
|
| 152 |
+
dtype: str = "torch.float32"
|
| 153 |
+
"""The data type of the model."""
|
| 154 |
+
|
| 155 |
+
lr_schedule: str = "myle"
|
| 156 |
+
"""The learning rate schedule out of
|
| 157 |
+
`noop`: no learning rate schedule, just use the initial learning rate,
|
| 158 |
+
`myle`: inv-sqrt as implemented in Fairseq,
|
| 159 |
+
`cosine` cosine annealing schedule,
|
| 160 |
+
`wsd` for Warmup-Stable-Decay (WSD) or tri-stage """
|
| 161 |
+
|
| 162 |
+
lr: float = 0.004
|
| 163 |
+
"""The initial (post-warm-up) learning rate for AdamW."""
|
| 164 |
+
|
| 165 |
+
start_lr: float = 1e-7
|
| 166 |
+
"""The initial warmup learning rate."""
|
| 167 |
+
|
| 168 |
+
final_lr: float = 1e-5
|
| 169 |
+
"""The final learning rate."""
|
| 170 |
+
|
| 171 |
+
lr_stage_ratios: List[float] = field(default_factory=lambda: [0.1, 0.4, 0.5])
|
| 172 |
+
"""The ratios of the wsd (tri-stage) learning rate scheduler."""
|
| 173 |
+
|
| 174 |
+
num_lr_warmup_steps: int = 800
|
| 175 |
+
"""The number of warm-up steps for the learning rate."""
|
| 176 |
+
|
| 177 |
+
weight_decay: float = 0.1
|
| 178 |
+
"""The weight decay coefficient of AdamW (PyTorch default: 1e-2, Fs2 default: 0.0)."""
|
| 179 |
+
|
| 180 |
+
adam_betas: List[float] = field(default_factory=lambda: [0.9, 0.98])
|
| 181 |
+
"""The beta coefficients of AdamW used for computing running averages of gradient and its square."""
|
| 182 |
+
|
| 183 |
+
adam_eps: float = 1e-6
|
| 184 |
+
"""The term added to the denominator in AdamW to improve numerical stability.
|
| 185 |
+
Default in FS2 and PyTorch is 1e-8. Previous hard coded value in our trainer is 1e-6"""
|
| 186 |
+
|
| 187 |
+
use_optimizer_in_fp32: bool = True
|
| 188 |
+
"""if True, the optimizer (AdamW) will be initialized with `use_fp32 = True`
|
| 189 |
+
i.e. we will store the optimizer state in single precision and convert
|
| 190 |
+
gradients on-the-fly to single precision for numerical stability"""
|
| 191 |
+
|
| 192 |
+
max_steps: int = 10_000
|
| 193 |
+
"""The maximum number of training steps."""
|
| 194 |
+
|
| 195 |
+
max_grad_norm: float = 1000
|
| 196 |
+
"""Maximal gradient norm, for gradient clipping.
|
| 197 |
+
gradients are multiplied by `torch.clamp(max_norm / (total_norm + 1e-6), max=1.0)`
|
| 198 |
+
if max_norm is arbitrarily large, then we'll only report gradients norm
|
| 199 |
+
"""
|
| 200 |
+
turn_off_grad_normalization: bool = False
|
| 201 |
+
"""If ``True``, Turn off gradient normalization"""
|
| 202 |
+
|
| 203 |
+
gradient_accumulation: int = 1
|
| 204 |
+
"""The number of steps to accumulate gradients before an optimizer update."""
|
| 205 |
+
|
| 206 |
+
validate_every_n_steps: int = 5000
|
| 207 |
+
"""The number of steps after which to validate the model."""
|
| 208 |
+
|
| 209 |
+
checkpoint_every_n_steps: int = 5000
|
| 210 |
+
"""The number of steps after which to checkpoint."""
|
| 211 |
+
|
| 212 |
+
keep_last_n_checkpoints: int = -1
|
| 213 |
+
"""The number of checkpoints to keep on disk."""
|
| 214 |
+
|
| 215 |
+
save_model_every_n_steps: int = 5000
|
| 216 |
+
"""The number of steps after which to save a consolidated version of the model."""
|
| 217 |
+
|
| 218 |
+
preserve_consolidated_models: bool = False
|
| 219 |
+
"""If `True`, only pt files excluding ones starting with `mdoel` will be deleted from the step checkpoint directory."""
|
| 220 |
+
|
| 221 |
+
publish_metrics_every_n_steps: int = 1
|
| 222 |
+
"""The number of steps after which to publish training metrics."""
|
| 223 |
+
|
| 224 |
+
gc_every_n_steps: int = 1000
|
| 225 |
+
"""The frequency of steps at which we collect garbage with `gc.collect()`."""
|
| 226 |
+
|
| 227 |
+
seed: int = 2
|
| 228 |
+
"""The RNG seed to use while starting the job."""
|
| 229 |
+
|
| 230 |
+
debug: bool = False
|
| 231 |
+
"""If ``True``, runs the trainer in debug mode"""
|
| 232 |
+
|
| 233 |
+
profile: bool = False
|
| 234 |
+
"""If ``True``, runs the PyTorch profiler at the beginning of the training."""
|
| 235 |
+
|
| 236 |
+
profiler_skip_first: int = 200
|
| 237 |
+
|
| 238 |
+
profiler_active: int = 3
|
| 239 |
+
"""If profiling (``profile = True``), The profiler will skip the first ``skip_first`` steps, then do the active recording for the next ``active`` steps
|
| 240 |
+
If planning to visualize the trace with tensorbaord, then ``active`` should be small (less than 10 steps), otherwise tb won't load!
|
| 241 |
+
"""
|
| 242 |
+
loss_scaler_init_scale: float = 2.0**15
|
| 243 |
+
"""The initial scale for the gradient scaler, fairseq2's default is 2.0**15"""
|
| 244 |
+
|
| 245 |
+
loss_scaler_scale_window: Optional[int] = None
|
| 246 |
+
"""The number of consecutive optimizer steps without inf/NaN gradients that must occur for the scale to be updated"""
|
| 247 |
+
|
| 248 |
+
use_fsdp: bool = True
|
| 249 |
+
"""If ``True``, uses FSDP instead of DDP."""
|
| 250 |
+
|
| 251 |
+
use_autocast: bool = False
|
| 252 |
+
"""If ``True``, wrap the forward pass in AMP autocast context.
|
| 253 |
+
autocast is only needed if training with mixed precision.
|
| 254 |
+
If training fails without it, check if some module with its weights is not properly cast
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
fsdp_wrap_granularity: SUPPORTED_FSDP_WRAP_POLICIES = "model"
|
| 258 |
+
"""The granularity at which to wrap the model."""
|
| 259 |
+
|
| 260 |
+
fsdp_memory_policy: SUPPORTED_FSDP_MEMORY_POLICIES = "standard"
|
| 261 |
+
"""The FSDP memory policy."""
|
| 262 |
+
|
| 263 |
+
fsdp_fp32_reduce: bool = False
|
| 264 |
+
""" If ``True``, the gradients will be reduced in full precision even when dtype is `torch.float16`"""
|
| 265 |
+
|
| 266 |
+
use_submitit: bool = True
|
| 267 |
+
"""If ``True``, setup the environment ti use submitit."""
|
| 268 |
+
|
| 269 |
+
fake_gang_device: Optional[str] = None
|
| 270 |
+
"""If non-empty, the trainer will be set locally on a device, instead of distributed training."""
|
| 271 |
+
|
| 272 |
+
experiment_name: Optional[str] = None
|
| 273 |
+
"""experiment name for job trackin, if None default to StopesModule naming"""
|
| 274 |
+
|
| 275 |
+
raise_oom: bool = False
|
| 276 |
+
"""If ``True``, raise OOM errors when they occur, if ``False`` give it another try."""
|
| 277 |
+
|
| 278 |
+
raise_nan_or_inf: bool = False
|
| 279 |
+
"""If ``True``, raise FloatingPointError with Nan/Inf losses, if ``False`` give it another try."""
|
| 280 |
+
|
| 281 |
+
max_ooms: int = 10
|
| 282 |
+
"""If ```raise_oom`` is False, how many OOMs we can tolerate per rank before raising an error."""
|
| 283 |
+
|
| 284 |
+
max_nans_or_infs: int = 10
|
| 285 |
+
"""If ```raise_nan_or_inf`` is False, how many Nan/Infs we can tolerate per rank before raising an error."""
|
| 286 |
+
|
| 287 |
+
freeze_modules: Optional[List[str]] = None
|
| 288 |
+
"""Name of modules in the model to be frozen when training/finetuning"""
|
| 289 |
+
|
| 290 |
+
freezing_strategy: Literal["none", "modules", "ffn", "ffn-adaln", "adaln"] = "none"
|
| 291 |
+
"""
|
| 292 |
+
Freezing strategy to follow. Options are:
|
| 293 |
+
1. none: Nothing will be frozen (default)
|
| 294 |
+
2. modules: A list of modules to freeze will be read from `freeze_modules`
|
| 295 |
+
3. ffn: All ffn sub-modules will be frozen
|
| 296 |
+
4. ffn-adaln: all FFN and Adaln sub-modules will be frozen.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class Trainer(StatefulObjectBag):
|
| 301 |
+
config: TrainingConfig
|
| 302 |
+
model: Module
|
| 303 |
+
training_data_loader: Any
|
| 304 |
+
validation_data_loader: Optional[Any]
|
| 305 |
+
gang: Gang
|
| 306 |
+
optimizer: Optimizer
|
| 307 |
+
loss_scaler: DynamicLossScaler
|
| 308 |
+
lr_scheduler: AbstractLRScheduler
|
| 309 |
+
rng_bag: RngBag
|
| 310 |
+
step_nr: int
|
| 311 |
+
train_metric_bag: MetricBag
|
| 312 |
+
valid_metric_bag: Mapping[str, MetricBag]
|
| 313 |
+
metric_recorders: List[MetricRecorder]
|
| 314 |
+
profiler: Profiler
|
| 315 |
+
stopwatch: Stopwatch
|
| 316 |
+
criterion: Any
|
| 317 |
+
card_metdata: Dict
|
| 318 |
+
_train_step_time: float
|
| 319 |
+
_valid_step_time: float
|
| 320 |
+
|
| 321 |
+
def __init__(
|
| 322 |
+
self,
|
| 323 |
+
config: TrainingConfig,
|
| 324 |
+
model: Module,
|
| 325 |
+
training_data_loader: Any,
|
| 326 |
+
validation_data_loader: Optional[Any],
|
| 327 |
+
gang: Gang,
|
| 328 |
+
checkpoint_manager: FileCheckpointManager,
|
| 329 |
+
rng_bag: RngBag,
|
| 330 |
+
stopwatch: Stopwatch,
|
| 331 |
+
card_metadata: Dict,
|
| 332 |
+
) -> None:
|
| 333 |
+
super().__init__()
|
| 334 |
+
|
| 335 |
+
self.config = config
|
| 336 |
+
|
| 337 |
+
if self.config.debug:
|
| 338 |
+
logger._logger.setLevel(logging.DEBUG)
|
| 339 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 340 |
+
|
| 341 |
+
self.card_metadata = card_metadata
|
| 342 |
+
|
| 343 |
+
self.dtype = eval(config.dtype)
|
| 344 |
+
|
| 345 |
+
self.model = model
|
| 346 |
+
|
| 347 |
+
self.training_data_loader = training_data_loader
|
| 348 |
+
|
| 349 |
+
# Skip saving and loading the state of validation dataloader
|
| 350 |
+
self.register_non_stateful("validation_data_loader", validation_data_loader)
|
| 351 |
+
|
| 352 |
+
self.gang = gang
|
| 353 |
+
|
| 354 |
+
self.rng_bag = rng_bag
|
| 355 |
+
|
| 356 |
+
self.step_nr = 1
|
| 357 |
+
|
| 358 |
+
self.current_run_steps = 0
|
| 359 |
+
|
| 360 |
+
self.checkpoint_manager = checkpoint_manager
|
| 361 |
+
|
| 362 |
+
tb_dir = config.tb_dir or config.output_dir.joinpath("tb")
|
| 363 |
+
|
| 364 |
+
self.metric_recorders = [LogMetricRecorder(logger)]
|
| 365 |
+
|
| 366 |
+
if gang.rank == 0:
|
| 367 |
+
self.metric_recorders.append(TensorBoardRecorder(tb_dir))
|
| 368 |
+
self.metric_recorders.append(
|
| 369 |
+
LCMWandBRecorder(
|
| 370 |
+
name=config.wandb_run_name,
|
| 371 |
+
project=config.wandb_project or "uncategorized",
|
| 372 |
+
output_dir=config.output_dir / "wandb",
|
| 373 |
+
config=self._tb_flat_config,
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
self.profiler = Profiler(
|
| 378 |
+
skip_first=config.profiler_skip_first,
|
| 379 |
+
active=config.profiler_active,
|
| 380 |
+
log_dir=tb_dir,
|
| 381 |
+
gang=gang,
|
| 382 |
+
enabled=config.profile,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
self.stopwatch = stopwatch
|
| 386 |
+
self._train_step_time = 0.0
|
| 387 |
+
self._valid_step_time = 0.0
|
| 388 |
+
|
| 389 |
+
self.criterion = None # type: ignore
|
| 390 |
+
|
| 391 |
+
self.loss_scaler = None # type: ignore
|
| 392 |
+
|
| 393 |
+
@property
|
| 394 |
+
def is_fsdp(self) -> bool:
|
| 395 |
+
return isinstance(self.model, FSDP)
|
| 396 |
+
|
| 397 |
+
@property
|
| 398 |
+
def is_ddp(self) -> bool:
|
| 399 |
+
return isinstance(self.model, DDP)
|
| 400 |
+
|
| 401 |
+
def setup(self) -> None:
|
| 402 |
+
self.criterion = self.setup_criterion()
|
| 403 |
+
|
| 404 |
+
self.setup_metric_bags()
|
| 405 |
+
|
| 406 |
+
# Add the grad_norm metric to the training metric bag
|
| 407 |
+
self.train_metric_bag.register_metric(
|
| 408 |
+
"grad_norm", Mean(device=self.gang.device), persistent=False
|
| 409 |
+
)
|
| 410 |
+
self.train_metric_bag.register_metric(
|
| 411 |
+
"raw_grad_norm", Mean(device=self.gang.device), persistent=False
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
self.setup_optimizer_and_lr_schedule()
|
| 415 |
+
|
| 416 |
+
def setup_optimizer_and_lr_schedule(self):
|
| 417 |
+
optimizer = AdamW(
|
| 418 |
+
self.model.parameters(),
|
| 419 |
+
lr=self.config.lr,
|
| 420 |
+
betas=tuple(self.config.adam_betas), # type: ignore
|
| 421 |
+
eps=self.config.adam_eps,
|
| 422 |
+
use_fp32=self.config.use_optimizer_in_fp32,
|
| 423 |
+
weight_decay=self.config.weight_decay,
|
| 424 |
+
)
|
| 425 |
+
logger.info(
|
| 426 |
+
(
|
| 427 |
+
f"Setting up AdamW optimizer with betas={self.config.adam_betas}, "
|
| 428 |
+
f"base lr={self.config.lr} and weight decay={self.config.weight_decay} "
|
| 429 |
+
f"and use_fp32={self.config.use_optimizer_in_fp32}"
|
| 430 |
+
)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
self.register_stateful("optimizer", optimizer)
|
| 434 |
+
|
| 435 |
+
self.loss_scaler = DynamicLossScaler(
|
| 436 |
+
optimizer,
|
| 437 |
+
gang=self.gang,
|
| 438 |
+
init_scale=self.config.loss_scaler_init_scale,
|
| 439 |
+
min_scale=0.0001,
|
| 440 |
+
scale_window=self.config.loss_scaler_scale_window,
|
| 441 |
+
enabled=self.dtype == torch.float16,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
if self.loss_scaler.is_enabled:
|
| 445 |
+
logger.info(
|
| 446 |
+
f"Initializing DynamicLossScaler with init_scale={self.config.loss_scaler_init_scale}"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
lr_scheduler = build_lr_scheduler(
|
| 450 |
+
optimizer=self.optimizer,
|
| 451 |
+
schedule=self.config.lr_schedule,
|
| 452 |
+
lr=self.config.lr,
|
| 453 |
+
warmup_steps=self.config.num_lr_warmup_steps,
|
| 454 |
+
start_lr=self.config.start_lr,
|
| 455 |
+
final_lr=self.config.final_lr,
|
| 456 |
+
max_steps=self.config.max_steps,
|
| 457 |
+
stage_ratio=tuple(self.config.lr_stage_ratios),
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Saving the lr_scheduler as well to properly resume training
|
| 461 |
+
self.register_stateful("lr_scheduler", lr_scheduler)
|
| 462 |
+
|
| 463 |
+
@abstractmethod
|
| 464 |
+
def setup_criterion(self):
|
| 465 |
+
"""Define a criterion (loss / objective function to optimize)"""
|
| 466 |
+
|
| 467 |
+
def setup_metric_bags(self):
|
| 468 |
+
"""Setup metric bags for tracking"""
|
| 469 |
+
|
| 470 |
+
self.train_metric_bag = MetricBag(self.gang)
|
| 471 |
+
|
| 472 |
+
self.register_non_stateful(
|
| 473 |
+
"valid_metric_bag",
|
| 474 |
+
{
|
| 475 |
+
ds_name(dataset): MetricBag(self.gang)
|
| 476 |
+
for dataset in self.config.validation_data
|
| 477 |
+
},
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def checkpoint_and_raise(self, exc) -> None:
|
| 481 |
+
# Checkpoint before exiting
|
| 482 |
+
if torch.cuda.is_available():
|
| 483 |
+
torch.cuda.synchronize()
|
| 484 |
+
logger.warning(f"R{self.gang.rank} checkpoint_and_raise - error={exc}")
|
| 485 |
+
if self.current_run_steps > 100:
|
| 486 |
+
# avoid checkpoining for early failures
|
| 487 |
+
self._checkpoint(crash=exc)
|
| 488 |
+
raise exc
|
| 489 |
+
|
| 490 |
+
@cached_property
|
| 491 |
+
def _tb_flat_config(self):
|
| 492 |
+
"""
|
| 493 |
+
Prepare the flat config that will be used as HParams
|
| 494 |
+
to record training metadata, namely config and environment hashes.
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
dict_config = flatten_dict(asdict(self.config))
|
| 498 |
+
|
| 499 |
+
# Merge the data lists:
|
| 500 |
+
def get_data_signature(dataset):
|
| 501 |
+
return ":".join(
|
| 502 |
+
map(str, (dataset["name"], dataset["weight"], dataset["filters"]))
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
dict_config["training_data"] = "+".join(
|
| 506 |
+
get_data_signature(dataset) for dataset in dict_config["training_data"]
|
| 507 |
+
)
|
| 508 |
+
dict_config["validation_data"] = "+".join(
|
| 509 |
+
get_data_signature(dataset) for dataset in dict_config["validation_data"]
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# value should be one of int, float, str, bool, or torch.Tensor
|
| 513 |
+
allowed_types = (int, float, str, bool, torch.Tensor)
|
| 514 |
+
config_keys = list(dict_config)
|
| 515 |
+
for k in config_keys:
|
| 516 |
+
if not isinstance(dict_config[k], allowed_types):
|
| 517 |
+
del dict_config[k]
|
| 518 |
+
|
| 519 |
+
return dict_config
|
| 520 |
+
|
| 521 |
+
def run(self) -> None:
|
| 522 |
+
"""Run the trainer for up to `max_steps`"""
|
| 523 |
+
|
| 524 |
+
logger.info(f"Running training on {self.gang.size} device(s).")
|
| 525 |
+
|
| 526 |
+
data_iter = self.training_data_loader.iterate_batches()
|
| 527 |
+
|
| 528 |
+
logger.info(
|
| 529 |
+
f"R{self.gang.rank} - waiting for all ranks to prepare a data iterator!"
|
| 530 |
+
)
|
| 531 |
+
self.gang.barrier()
|
| 532 |
+
|
| 533 |
+
# These counters are rank-specific
|
| 534 |
+
ooms, nans_or_infs = 0, 0
|
| 535 |
+
|
| 536 |
+
# TODO: validate before training
|
| 537 |
+
# logger.info(f"Starting with validation at step={self.step_nr}")
|
| 538 |
+
# self._validate()
|
| 539 |
+
|
| 540 |
+
with self.profiler:
|
| 541 |
+
while self.step_nr <= self.config.max_steps:
|
| 542 |
+
with record_function(f"step_{self.step_nr}"):
|
| 543 |
+
try:
|
| 544 |
+
# Main training step: forward -> backward -> optimizer.step -> log
|
| 545 |
+
stepped = self._train_step(data_iter)
|
| 546 |
+
|
| 547 |
+
except RuntimeError as e:
|
| 548 |
+
if "out of memory" in str(e):
|
| 549 |
+
self._log_oom(e)
|
| 550 |
+
ooms += 1
|
| 551 |
+
if self.config.raise_oom or ooms > self.config.max_ooms:
|
| 552 |
+
# Previous behaviour, no retries but still checkpointing
|
| 553 |
+
self.checkpoint_and_raise(e)
|
| 554 |
+
|
| 555 |
+
logger.warning(
|
| 556 |
+
f"Attempting to recover from OOM on R{self.gang.rank} (OOMS={ooms})"
|
| 557 |
+
)
|
| 558 |
+
stepped = True
|
| 559 |
+
# reset optimizer
|
| 560 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 561 |
+
|
| 562 |
+
# rollback updates
|
| 563 |
+
self.train_metric_bag.rollback_updates()
|
| 564 |
+
|
| 565 |
+
# Empty CUDA cache before trying again
|
| 566 |
+
if torch.cuda.is_available():
|
| 567 |
+
torch.cuda.empty_cache()
|
| 568 |
+
|
| 569 |
+
else:
|
| 570 |
+
# Other RuntimeErrors
|
| 571 |
+
self.checkpoint_and_raise(e)
|
| 572 |
+
|
| 573 |
+
except FloatingPointError as e:
|
| 574 |
+
if "Losses are Nan/Inf" in str(e):
|
| 575 |
+
self._log_nan_loss(e)
|
| 576 |
+
nans_or_infs += 1
|
| 577 |
+
if (
|
| 578 |
+
self.config.raise_nan_or_inf
|
| 579 |
+
or nans_or_infs > self.config.max_nans_or_infs
|
| 580 |
+
):
|
| 581 |
+
self.checkpoint_and_raise(e)
|
| 582 |
+
|
| 583 |
+
logger.warning(
|
| 584 |
+
f"Attempting to recover from NaN/Inf loss on R{self.gang.rank} (NaNs/Infs={nans_or_infs})"
|
| 585 |
+
)
|
| 586 |
+
stepped = True
|
| 587 |
+
# reset optimizer
|
| 588 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 589 |
+
|
| 590 |
+
# rollback updates
|
| 591 |
+
self.train_metric_bag.rollback_updates()
|
| 592 |
+
|
| 593 |
+
else:
|
| 594 |
+
# Other FloatingPointErrors
|
| 595 |
+
self.checkpoint_and_raise(e)
|
| 596 |
+
|
| 597 |
+
except Exception as e:
|
| 598 |
+
self.checkpoint_and_raise(e)
|
| 599 |
+
|
| 600 |
+
if stepped:
|
| 601 |
+
if self._should_publish_train_metrics():
|
| 602 |
+
self._publish_train_metrics()
|
| 603 |
+
|
| 604 |
+
if self._should_checkpoint():
|
| 605 |
+
self._checkpoint()
|
| 606 |
+
|
| 607 |
+
if self._should_validate():
|
| 608 |
+
self._validate()
|
| 609 |
+
|
| 610 |
+
if self._should_collect_garbage():
|
| 611 |
+
self._collect_garbage()
|
| 612 |
+
|
| 613 |
+
self.profiler.step()
|
| 614 |
+
|
| 615 |
+
self.step_nr += 1
|
| 616 |
+
self.current_run_steps += 1
|
| 617 |
+
|
| 618 |
+
else:
|
| 619 |
+
logger.info(f"R{self.gang.rank} - Resetting the datapipeline")
|
| 620 |
+
self.training_data_loader.pipeline.reset()
|
| 621 |
+
|
| 622 |
+
logger.info(f"R{self.gang.rank} - Done resetting the datapipeline")
|
| 623 |
+
data_iter = self.training_data_loader.iterate_batches()
|
| 624 |
+
|
| 625 |
+
self._save_model_card_for_last_checkpoint(to_checkpoint_dir=False)
|
| 626 |
+
logger.info(f"Finished training after {self.step_nr - 1} step(s).")
|
| 627 |
+
|
| 628 |
+
self.gang.close()
|
| 629 |
+
|
| 630 |
+
def restore(self) -> None:
|
| 631 |
+
logger.info("Attempting to load last checkpoint.")
|
| 632 |
+
|
| 633 |
+
step_nr, checkpoint = self.checkpoint_manager.load_last_checkpoint()
|
| 634 |
+
|
| 635 |
+
logger.info(f"Checkpoint loaded, restoring training from step {step_nr}.")
|
| 636 |
+
|
| 637 |
+
self.load_state_dict(checkpoint)
|
| 638 |
+
|
| 639 |
+
self.gang.barrier()
|
| 640 |
+
|
| 641 |
+
logger.info("Training restored, resuming.")
|
| 642 |
+
|
| 643 |
+
self.step_nr = step_nr + 1
|
| 644 |
+
|
| 645 |
+
def _maybe_with_autocast(self) -> ContextManager[None]:
|
| 646 |
+
# autocast is only needed if training with mixed precision.
|
| 647 |
+
# If training fails without it, check if some module with its weights
|
| 648 |
+
# is not properly cast
|
| 649 |
+
if self.config.use_autocast:
|
| 650 |
+
return torch.autocast(device_type="cuda", dtype=self.dtype)
|
| 651 |
+
else:
|
| 652 |
+
return nullcontext()
|
| 653 |
+
|
| 654 |
+
def _train_step(self, data_iter: Iterator) -> bool:
|
| 655 |
+
step_nr = self.step_nr
|
| 656 |
+
|
| 657 |
+
step_stopwatch = Stopwatch(start=True, device=self.gang.device)
|
| 658 |
+
|
| 659 |
+
stepped = False
|
| 660 |
+
|
| 661 |
+
# We have to retry the step in case of a gradient overflow.
|
| 662 |
+
while not stepped:
|
| 663 |
+
batches = []
|
| 664 |
+
|
| 665 |
+
# Collect batches.
|
| 666 |
+
with record_function(f"step_{step_nr}_data_load"):
|
| 667 |
+
for _ in range(self.config.gradient_accumulation):
|
| 668 |
+
try:
|
| 669 |
+
batches.append(next(data_iter))
|
| 670 |
+
except StopIteration:
|
| 671 |
+
break
|
| 672 |
+
|
| 673 |
+
if len(batches) != self.config.gradient_accumulation:
|
| 674 |
+
logger.info(
|
| 675 |
+
f"R{self.gang.rank} -End of data reached at training step {step_nr}."
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
return False
|
| 679 |
+
|
| 680 |
+
# create a copy of the current metrics
|
| 681 |
+
# any update to the metrics from this point will either be committed with `commit_updates`
|
| 682 |
+
# or ignored with `rollback_updates`
|
| 683 |
+
self.train_metric_bag.begin_updates()
|
| 684 |
+
|
| 685 |
+
num_targets = 0
|
| 686 |
+
|
| 687 |
+
# Accumulate gradients.
|
| 688 |
+
for batch_nr, batch in enumerate(batches):
|
| 689 |
+
with self._maybe_no_sync(batch_nr, len(batches)):
|
| 690 |
+
with record_function(f"step_{step_nr}_{batch_nr}_forward"):
|
| 691 |
+
# autocast should wrap only the forward pass(es)
|
| 692 |
+
# of your network, including the loss computation(s).
|
| 693 |
+
# Backward passes under autocast are not recommended.
|
| 694 |
+
with self._maybe_with_autocast():
|
| 695 |
+
loss = self.criterion(batch)
|
| 696 |
+
|
| 697 |
+
if not (
|
| 698 |
+
torch.isfinite(loss.value).all() or self.loss_scaler.is_enabled
|
| 699 |
+
):
|
| 700 |
+
raise FloatingPointError("Losses are Nan/Inf.")
|
| 701 |
+
|
| 702 |
+
# update metrics
|
| 703 |
+
self.train_metric_bag.update([loss])
|
| 704 |
+
|
| 705 |
+
with record_function(f"step_{step_nr}_{batch_nr}_backward"):
|
| 706 |
+
self.loss_scaler.backward(loss.value)
|
| 707 |
+
|
| 708 |
+
num_targets += loss.num_target_elements
|
| 709 |
+
|
| 710 |
+
# Record and clip gradient norm
|
| 711 |
+
grad_norm, raw_grad_norm = self.process_gradients(step_nr, num_targets)
|
| 712 |
+
|
| 713 |
+
# Update parameters.
|
| 714 |
+
with record_function(f"step_{step_nr}_optimizer"):
|
| 715 |
+
# scale_result: LossScaleResult(old_scale: float, new_scale: float, overflow: bool, min_reached: bool)
|
| 716 |
+
_, scale_result = self.loss_scaler.run_optimizer_step(step_nr)
|
| 717 |
+
|
| 718 |
+
if scale_result.overflow:
|
| 719 |
+
# Walk back the metrics update:
|
| 720 |
+
self.train_metric_bag.rollback_updates()
|
| 721 |
+
logger.debug(
|
| 722 |
+
f"R{self.gang.rank} rolled back update {self.train_metric_bag._original_metrics is None}"
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if scale_result.min_reached:
|
| 726 |
+
logger.error(f"Loss has started exploding at step {step_nr}. Stopping training.") # fmt: skip
|
| 727 |
+
|
| 728 |
+
raise FloatingPointError("The training loss has exploded.")
|
| 729 |
+
|
| 730 |
+
logger.debug(f"Repeating training step {step_nr}.")
|
| 731 |
+
|
| 732 |
+
else:
|
| 733 |
+
self.lr_scheduler.step()
|
| 734 |
+
|
| 735 |
+
stepped = True
|
| 736 |
+
|
| 737 |
+
# Reset.
|
| 738 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 739 |
+
|
| 740 |
+
# Stepped = True:
|
| 741 |
+
with record_function(f"step_{step_nr}_metrics"):
|
| 742 |
+
# do something with losses and grad_norm
|
| 743 |
+
|
| 744 |
+
self.train_metric_bag.commit_updates()
|
| 745 |
+
|
| 746 |
+
# gradient norm is common to workers
|
| 747 |
+
self.train_metric_bag.grad_norm.update(grad_norm)
|
| 748 |
+
self.train_metric_bag.raw_grad_norm.update(raw_grad_norm)
|
| 749 |
+
|
| 750 |
+
if self.gang.rank == 0:
|
| 751 |
+
# update elapsed time once
|
| 752 |
+
self._train_step_time += step_stopwatch.get_elapsed_time()
|
| 753 |
+
|
| 754 |
+
del batches
|
| 755 |
+
return stepped
|
| 756 |
+
|
| 757 |
+
def _maybe_no_sync(self, batch_nr: int, num_batches: int) -> ContextManager[None]:
|
| 758 |
+
if batch_nr < num_batches - 1 and self.gang.size > 1:
|
| 759 |
+
return self.model.no_sync()
|
| 760 |
+
return nullcontext()
|
| 761 |
+
|
| 762 |
+
def normalize_gradients(self, num_targets: int) -> None:
|
| 763 |
+
"""
|
| 764 |
+
:param num_target:
|
| 765 |
+
The number of targets used in loss computation in this process.
|
| 766 |
+
|
| 767 |
+
If reduction = sum:
|
| 768 |
+
similar to fairseq2's `normalize_gradients`, will normalize the gradients of the model by ``world_size/num_targets``.
|
| 769 |
+
If reduction = mean:
|
| 770 |
+
will simply multiply by world size i.e undo DDP/FSDP's default normalization
|
| 771 |
+
"""
|
| 772 |
+
reduction = self.criterion.reduction
|
| 773 |
+
if reduction == "sum":
|
| 774 |
+
total_num_targets = torch.tensor(
|
| 775 |
+
num_targets, device=self.gang.device, dtype=torch.int64
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
self.gang.all_reduce(total_num_targets, ReduceOperation.SUM)
|
| 779 |
+
|
| 780 |
+
# Both DDP and FSDP divide gradients by the world size which we also undo.
|
| 781 |
+
if total_num_targets > 0:
|
| 782 |
+
grad_scale = self.gang.size / total_num_targets
|
| 783 |
+
else:
|
| 784 |
+
# If total_num_targets == 0, gradients will be zeroes anyway
|
| 785 |
+
grad_scale = self.gang.size
|
| 786 |
+
|
| 787 |
+
else:
|
| 788 |
+
grad_scale = self.gang.size
|
| 789 |
+
|
| 790 |
+
scale_gradients(self.model, grad_scale)
|
| 791 |
+
|
| 792 |
+
def process_gradients(
|
| 793 |
+
self, step_nr: int, num_targets: int
|
| 794 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 795 |
+
with record_function(f"step_{self.step_nr}_process_grads"):
|
| 796 |
+
# Normalize gradients
|
| 797 |
+
"""
|
| 798 |
+
Normalize and clip the gradients
|
| 799 |
+
"""
|
| 800 |
+
# this raw grad norm is only used for debugging
|
| 801 |
+
raw_grad_norm = clip_gradient_norm(
|
| 802 |
+
self.model,
|
| 803 |
+
max_norm=None,
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
if not self.config.turn_off_grad_normalization:
|
| 807 |
+
self.normalize_gradients(num_targets=num_targets)
|
| 808 |
+
|
| 809 |
+
# undo the GradScaler's scaling before clipping
|
| 810 |
+
self.loss_scaler.unscale_gradients_()
|
| 811 |
+
|
| 812 |
+
# Clip gradients
|
| 813 |
+
# If DDP, we use torch.nn.utils.clip_grad_norm_, if FSDP,
|
| 814 |
+
# we use torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
|
| 815 |
+
# this method handles the fact that gradients might be sharded across ranks.
|
| 816 |
+
grad_norm = clip_gradient_norm(
|
| 817 |
+
self.model,
|
| 818 |
+
max_norm=self.config.max_grad_norm,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
# Check for gradient consistency across workers:
|
| 822 |
+
if not check_gradient_norms(grad_norm, self.gang, step_nr):
|
| 823 |
+
raise FloatingPointError(
|
| 824 |
+
f"The gradients are inconsistent between processes at step {step_nr}. Training cannot continue."
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
return grad_norm, raw_grad_norm
|
| 828 |
+
|
| 829 |
+
def _should_validate(self) -> bool:
|
| 830 |
+
return self._should_do(self.config.validate_every_n_steps)
|
| 831 |
+
|
| 832 |
+
def _should_collect_garbage(self) -> bool:
|
| 833 |
+
return self._should_do(self.config.gc_every_n_steps)
|
| 834 |
+
|
| 835 |
+
def _collect_garbage(self):
|
| 836 |
+
logger.info("Collecting garbage...")
|
| 837 |
+
gc.collect()
|
| 838 |
+
|
| 839 |
+
@torch.inference_mode()
|
| 840 |
+
def _validate(self) -> None:
|
| 841 |
+
gc.collect()
|
| 842 |
+
torch.cuda.empty_cache()
|
| 843 |
+
|
| 844 |
+
if self.validation_data_loader is None:
|
| 845 |
+
logger.info("Skip validation as the data loader is empty")
|
| 846 |
+
return
|
| 847 |
+
|
| 848 |
+
self.model.eval()
|
| 849 |
+
|
| 850 |
+
logger.info(f"Starting validation after step {self.step_nr}.")
|
| 851 |
+
|
| 852 |
+
self.validation_data_loader.pipeline.reset()
|
| 853 |
+
|
| 854 |
+
data_iter = self.validation_data_loader.iterate_batches()
|
| 855 |
+
data_dummy_iter = self.validation_data_loader.iterate_dummy_batches()
|
| 856 |
+
|
| 857 |
+
logger.info(f"R{self.gang.rank} done creating the validation data iterator")
|
| 858 |
+
|
| 859 |
+
for step_nr in count(start=1):
|
| 860 |
+
step_stopwatch = Stopwatch(start=True, device=self.gang.device)
|
| 861 |
+
|
| 862 |
+
try:
|
| 863 |
+
batch = next(data_iter)
|
| 864 |
+
true_batch = 1
|
| 865 |
+
except StopIteration:
|
| 866 |
+
batch = next(data_dummy_iter)
|
| 867 |
+
true_batch = 0
|
| 868 |
+
|
| 869 |
+
total_nb_batches = all_sum(self.gang, true_batch)
|
| 870 |
+
|
| 871 |
+
if bool(total_nb_batches == 0):
|
| 872 |
+
break
|
| 873 |
+
# we apply model for all workers to avoid process groups sync issues
|
| 874 |
+
loss = self.criterion(batch)
|
| 875 |
+
|
| 876 |
+
if true_batch:
|
| 877 |
+
self._valid_step_time += step_stopwatch.get_elapsed_time()
|
| 878 |
+
self.valid_metric_bag[batch.name].update([loss])
|
| 879 |
+
|
| 880 |
+
self._publish_validation_metrics()
|
| 881 |
+
|
| 882 |
+
logger.info(
|
| 883 |
+
f"R{self.gang.rank} Validation complete in {step_nr} steps, resuming training."
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
self.model.train()
|
| 887 |
+
|
| 888 |
+
def _should_publish_train_metrics(self) -> bool:
|
| 889 |
+
return self._should_do(self.config.publish_metrics_every_n_steps)
|
| 890 |
+
|
| 891 |
+
def _set_elements_per_second(
|
| 892 |
+
self, metric_values: Dict[str, Any], elapsed_time: float
|
| 893 |
+
) -> None:
|
| 894 |
+
try:
|
| 895 |
+
num_elements = metric_values[self.criterion.throughput_metric_name]
|
| 896 |
+
except KeyError:
|
| 897 |
+
return
|
| 898 |
+
|
| 899 |
+
if not isinstance(num_elements, (int, float, torch.Tensor)):
|
| 900 |
+
return
|
| 901 |
+
|
| 902 |
+
if elapsed_time == 0.0:
|
| 903 |
+
metric_values["elements_per_second"] = 0.0
|
| 904 |
+
else:
|
| 905 |
+
metric_values["elements_per_second"] = num_elements / elapsed_time
|
| 906 |
+
|
| 907 |
+
def _publish_train_metrics(self) -> None:
|
| 908 |
+
values = self.train_metric_bag.sync_and_compute_metrics()
|
| 909 |
+
|
| 910 |
+
self.train_metric_bag.reset_non_persistent_metrics()
|
| 911 |
+
|
| 912 |
+
# Only rank-0 to record and publish
|
| 913 |
+
# since sync_and_compute_metrics's recipient rank is 0
|
| 914 |
+
if self.gang.rank != 0:
|
| 915 |
+
return
|
| 916 |
+
|
| 917 |
+
assert values is not None
|
| 918 |
+
|
| 919 |
+
values["lr"] = get_effective_lr(self.lr_scheduler)
|
| 920 |
+
|
| 921 |
+
self._set_elements_per_second(values, self._train_step_time)
|
| 922 |
+
|
| 923 |
+
if self.loss_scaler.is_enabled:
|
| 924 |
+
values["grad_scale"] = self.loss_scaler.get_scale()
|
| 925 |
+
|
| 926 |
+
values["wall_time"] = self.stopwatch.get_elapsed_time()
|
| 927 |
+
values["elapsed_time"] = self._train_step_time
|
| 928 |
+
|
| 929 |
+
record_metrics(self.metric_recorders, "Train", values, self.step_nr)
|
| 930 |
+
|
| 931 |
+
self._train_step_time = 0.0
|
| 932 |
+
|
| 933 |
+
def _publish_validation_metrics(self) -> None:
|
| 934 |
+
values = {}
|
| 935 |
+
for name, metric_bag in self.valid_metric_bag.items():
|
| 936 |
+
values[name] = metric_bag.sync_and_compute_metrics()
|
| 937 |
+
metric_bag.reset_non_persistent_metrics()
|
| 938 |
+
|
| 939 |
+
# Only rank-0 to record and publish
|
| 940 |
+
if self.gang.rank != 0:
|
| 941 |
+
return
|
| 942 |
+
|
| 943 |
+
for name, val in values.items():
|
| 944 |
+
assert val is not None
|
| 945 |
+
self._set_elements_per_second(val, self._valid_step_time)
|
| 946 |
+
val["elapsed_time"] = self._valid_step_time
|
| 947 |
+
val["wall_time"] = self.stopwatch.get_elapsed_time()
|
| 948 |
+
valid_name = f"Valid | {name}"
|
| 949 |
+
record_metrics(self.metric_recorders, valid_name, val, self.step_nr)
|
| 950 |
+
|
| 951 |
+
# reset timers
|
| 952 |
+
self._valid_step_time = 0.0
|
| 953 |
+
|
| 954 |
+
def _should_checkpoint(self) -> bool:
|
| 955 |
+
return self._should_do(self.config.checkpoint_every_n_steps)
|
| 956 |
+
|
| 957 |
+
def _should_save_consolidated_model(self) -> bool:
|
| 958 |
+
return self.is_fsdp and self._should_do(self.config.save_model_every_n_steps)
|
| 959 |
+
|
| 960 |
+
def _checkpoint(self, crash=None) -> None:
|
| 961 |
+
logger.info(f"Saving checkpoint at step {self.step_nr}")
|
| 962 |
+
checkpoint = self.state_dict()
|
| 963 |
+
|
| 964 |
+
metadata = {
|
| 965 |
+
"config": self.config,
|
| 966 |
+
"crash": crash,
|
| 967 |
+
}
|
| 968 |
+
|
| 969 |
+
self.checkpoint_manager.begin_checkpoint(self.step_nr)
|
| 970 |
+
|
| 971 |
+
if self.is_fsdp:
|
| 972 |
+
replicated_keys = None
|
| 973 |
+
elif self.is_ddp:
|
| 974 |
+
# If we do not shard, save the model and the optimizer only on rank 0.
|
| 975 |
+
replicated_keys = {"model", "optimizer"}
|
| 976 |
+
else:
|
| 977 |
+
replicated_keys = {"*"}
|
| 978 |
+
|
| 979 |
+
self.checkpoint_manager.save_state(checkpoint, replicated_keys=replicated_keys)
|
| 980 |
+
|
| 981 |
+
self.checkpoint_manager.save_metadata(metadata)
|
| 982 |
+
|
| 983 |
+
if self._should_save_consolidated_model():
|
| 984 |
+
self._save_consolidated_model()
|
| 985 |
+
|
| 986 |
+
# Create a model card only after creating model.pt
|
| 987 |
+
# i.e., regular checkpointing with DDP or after consolidation with FSDP
|
| 988 |
+
if not self.is_fsdp:
|
| 989 |
+
self._save_model_card_for_last_checkpoint(to_checkpoint_dir=True)
|
| 990 |
+
|
| 991 |
+
self.checkpoint_manager.commit_checkpoint()
|
| 992 |
+
|
| 993 |
+
# Note that this logic looks at saved directories regardless of
|
| 994 |
+
# the nature of the checkpointing, consolidated or not
|
| 995 |
+
if self.config.keep_last_n_checkpoints != -1:
|
| 996 |
+
self.checkpoint_manager.keep_last_n_checkpoints(
|
| 997 |
+
self.config.keep_last_n_checkpoints,
|
| 998 |
+
preserve_model=self.config.preserve_consolidated_models,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
logger.info(f"Checkpoint saved by worker @rank={self.gang.rank}")
|
| 1002 |
+
|
| 1003 |
+
def _save_consolidated_model(self) -> None:
|
| 1004 |
+
logger.info(f"Saving consolidated model at step {self.step_nr}.")
|
| 1005 |
+
self.checkpoint_manager.save_consolidated_fsdp_model(self.model)
|
| 1006 |
+
self._save_model_card_for_last_checkpoint(to_checkpoint_dir=True)
|
| 1007 |
+
logger.info("Consolidated model saved.")
|
| 1008 |
+
|
| 1009 |
+
def _should_do(self, n_step: int) -> bool:
|
| 1010 |
+
return self.step_nr % n_step == 0
|
| 1011 |
+
|
| 1012 |
+
def create_model_card_for_last_checkpoint(
|
| 1013 |
+
self, is_final: bool = False, **card_kwargs
|
| 1014 |
+
) -> Optional[AssetCard]:
|
| 1015 |
+
"""Create a model card based on the last saved checkpoint and the model config."""
|
| 1016 |
+
logger.warning(
|
| 1017 |
+
"Could not create a model card with a generic trainer. Please use a model-specific one."
|
| 1018 |
+
)
|
| 1019 |
+
return None
|
| 1020 |
+
|
| 1021 |
+
def _save_model_card_for_last_checkpoint(
|
| 1022 |
+
self, to_checkpoint_dir: bool = False
|
| 1023 |
+
) -> None:
|
| 1024 |
+
"""Save the model card for the last checkpoint to the checkpoint directory or the core output directory."""
|
| 1025 |
+
if self.gang.rank != 0:
|
| 1026 |
+
return
|
| 1027 |
+
|
| 1028 |
+
if to_checkpoint_dir:
|
| 1029 |
+
current_step_nr = self.checkpoint_manager._checkpoint_step_nr
|
| 1030 |
+
output_dir = self.checkpoint_manager._checkpoint_dir.joinpath(
|
| 1031 |
+
f"step_{current_step_nr}.tmp"
|
| 1032 |
+
)
|
| 1033 |
+
else:
|
| 1034 |
+
output_dir = self.config.output_dir
|
| 1035 |
+
|
| 1036 |
+
card = self.create_model_card_for_last_checkpoint(
|
| 1037 |
+
is_final=not to_checkpoint_dir
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
if card is not None:
|
| 1041 |
+
card_data = card._metadata # TODO: use the exposed attribute when available
|
| 1042 |
+
with open(output_dir / "model_card.yaml", "w", encoding="utf-8") as outfile:
|
| 1043 |
+
yaml.dump(card_data, outfile, default_flow_style=False)
|
| 1044 |
+
logger.info(f"Model card saved in {output_dir}")
|
| 1045 |
+
|
| 1046 |
+
def _log_oom(self, exc):
|
| 1047 |
+
logger.warning(
|
| 1048 |
+
f"OOM: Ran out of memory on R{self.gang.rank} with exception: {exc}"
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
if torch.cuda.is_available():
|
| 1052 |
+
for device_idx in range(torch.cuda.device_count()):
|
| 1053 |
+
logger.warning(torch.cuda.memory_summary(device=device_idx))
|
| 1054 |
+
|
| 1055 |
+
sys.stderr.flush()
|
| 1056 |
+
|
| 1057 |
+
def _log_nan_loss(self, exc):
|
| 1058 |
+
logger.warning(f"We hit a Nan/Inf Loss: raised with exception: {exc}")
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
class TrainerBuilder:
|
| 1062 |
+
def __init__(self, config: TrainingConfig):
|
| 1063 |
+
assert config.save_model_every_n_steps % config.checkpoint_every_n_steps == 0, (
|
| 1064 |
+
f"save_model_every_n_steps={config.save_model_every_n_steps} for saving consolidated models should be a multiplier of checkpoint_every_n_steps={config.checkpoint_every_n_steps}"
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
self.config = config
|
| 1068 |
+
|
| 1069 |
+
self.stopwatch = Stopwatch(start=True)
|
| 1070 |
+
|
| 1071 |
+
# In case we train on Ampere or later, use TF32.
|
| 1072 |
+
torch.set_float32_matmul_precision("high")
|
| 1073 |
+
|
| 1074 |
+
if self.config.fake_gang_device is None:
|
| 1075 |
+
# By default, we work with a process group
|
| 1076 |
+
self.gang = init_process_group(config, logger=logger._logger)
|
| 1077 |
+
else:
|
| 1078 |
+
# For testing purposes, we use a fake gang on the chosen device
|
| 1079 |
+
self.gang = FakeGang(device=torch.device(self.config.fake_gang_device))
|
| 1080 |
+
|
| 1081 |
+
self.gang_rank = self.gang.rank if self.gang else 0
|
| 1082 |
+
|
| 1083 |
+
if self.gang.device.type == "cuda":
|
| 1084 |
+
# Setup ATEN and NCCL logging if in debug mode
|
| 1085 |
+
self._setup_additional_logging()
|
| 1086 |
+
|
| 1087 |
+
# Dump environment variables:
|
| 1088 |
+
log_env_variables(self.gang.device)
|
| 1089 |
+
|
| 1090 |
+
# A variable to carry fields necessary to build concise model cards
|
| 1091 |
+
self.card_metdata: Dict = {}
|
| 1092 |
+
|
| 1093 |
+
if self.gang_rank == 0:
|
| 1094 |
+
logger.info(f"Job Config\n{pformat(config)}")
|
| 1095 |
+
|
| 1096 |
+
self.device = self.gang.device
|
| 1097 |
+
|
| 1098 |
+
rng_bag = RngBag.from_device_defaults(self.device)
|
| 1099 |
+
|
| 1100 |
+
# Ensure that each run has deterministic behavior.
|
| 1101 |
+
rng_bag.manual_seed(config.seed)
|
| 1102 |
+
|
| 1103 |
+
self.rng_bag = rng_bag
|
| 1104 |
+
|
| 1105 |
+
self.dtype = eval(config.dtype)
|
| 1106 |
+
|
| 1107 |
+
self.finetune: bool = False
|
| 1108 |
+
|
| 1109 |
+
self.has_checkpoint: bool = False
|
| 1110 |
+
|
| 1111 |
+
@property
|
| 1112 |
+
@abstractmethod
|
| 1113 |
+
def model_loader(self):
|
| 1114 |
+
"""A fairseq2 ModelLoader"""
|
| 1115 |
+
|
| 1116 |
+
@property
|
| 1117 |
+
def model_config_loader(self):
|
| 1118 |
+
"""A fairseq2 ConfigLoader"""
|
| 1119 |
+
return self.model_loader._config_loader
|
| 1120 |
+
|
| 1121 |
+
@abstractmethod
|
| 1122 |
+
def load_data(self):
|
| 1123 |
+
"""Load training and validation data
|
| 1124 |
+
Returns one loader for training data and one for validation data
|
| 1125 |
+
"""
|
| 1126 |
+
|
| 1127 |
+
def create_model_config(self, set_finetune_flag: bool = False):
|
| 1128 |
+
"""
|
| 1129 |
+
Given `model_config_or_name`, `model_arch` and `model_arch_overrides`
|
| 1130 |
+
create the model config dict
|
| 1131 |
+
if `set_finetune_flag` is `True` then the trainer's finetune flag will be set
|
| 1132 |
+
here inferred from the use of `model_config_or_name`
|
| 1133 |
+
"""
|
| 1134 |
+
if self.config.model_config_or_name is not None:
|
| 1135 |
+
assert self.config.model_arch is None, (
|
| 1136 |
+
"We cannot set both `model_config_or_name` and `model_arch`"
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
if isinstance(self.config.model_config_or_name, str):
|
| 1140 |
+
# The config of a registered model i.e. we're finetuning
|
| 1141 |
+
logger.info(
|
| 1142 |
+
f"Loading pretrained model from {self.config.model_config_or_name}"
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
model_config = self.model_config_loader(
|
| 1146 |
+
self.config.model_config_or_name
|
| 1147 |
+
)
|
| 1148 |
+
finetune = True
|
| 1149 |
+
|
| 1150 |
+
# Metadata for card creation
|
| 1151 |
+
source_card = self.model_config_loader._asset_store.retrieve_card(
|
| 1152 |
+
self.config.model_config_or_name
|
| 1153 |
+
)
|
| 1154 |
+
try:
|
| 1155 |
+
arch = source_card.field("model_arch").as_(str)
|
| 1156 |
+
except AssetCardFieldNotFoundError:
|
| 1157 |
+
arch = None
|
| 1158 |
+
|
| 1159 |
+
self.card_metadata = {
|
| 1160 |
+
"model_config": model_config if arch is None else None,
|
| 1161 |
+
"model_type": model_config.model_type,
|
| 1162 |
+
"model_arch": arch,
|
| 1163 |
+
}
|
| 1164 |
+
|
| 1165 |
+
else:
|
| 1166 |
+
# model_config_or_name is a dataclass
|
| 1167 |
+
logger.info(
|
| 1168 |
+
"Creating a model from the provided config in model_config_or_name"
|
| 1169 |
+
)
|
| 1170 |
+
model_config = self.config.model_config_or_name
|
| 1171 |
+
|
| 1172 |
+
self.card_metadata = {
|
| 1173 |
+
"model_config": model_config,
|
| 1174 |
+
"model_type": model_config.model_type,
|
| 1175 |
+
"model_arch": None,
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
finetune = False
|
| 1179 |
+
|
| 1180 |
+
elif self.config.model_arch is not None:
|
| 1181 |
+
assert (
|
| 1182 |
+
self.config.model_arch in self.model_config_loader._arch_configs.names()
|
| 1183 |
+
), (
|
| 1184 |
+
f"Could not recognise {self.config.model_arch} as a registered architecture "
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
logger.info(
|
| 1188 |
+
f"Creating a model from registered arch {self.config.model_arch}"
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
finetune = False
|
| 1192 |
+
model_config = self.model_config_loader._arch_configs.get(
|
| 1193 |
+
self.config.model_arch
|
| 1194 |
+
)
|
| 1195 |
+
self.card_metadata = {
|
| 1196 |
+
"model_config": None,
|
| 1197 |
+
"model_type": model_config.model_type,
|
| 1198 |
+
"model_arch": self.config.model_arch,
|
| 1199 |
+
}
|
| 1200 |
+
|
| 1201 |
+
# In all setups we can override some config parameters
|
| 1202 |
+
if self.config.model_arch_overrides is not None:
|
| 1203 |
+
try:
|
| 1204 |
+
update_dataclass(model_config, self.config.model_arch_overrides)
|
| 1205 |
+
|
| 1206 |
+
except (TypeError, ValueError) as ex:
|
| 1207 |
+
raise ValueError(
|
| 1208 |
+
"The model_arch_overrides contain one or more invalid keys"
|
| 1209 |
+
) from ex
|
| 1210 |
+
|
| 1211 |
+
self.card_metadata["model_arch"] = None
|
| 1212 |
+
self.card_metadata["model_config"] = model_config
|
| 1213 |
+
|
| 1214 |
+
logger.info(
|
| 1215 |
+
f"Overwriting model config parameters with {self.config.model_arch_overrides}"
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
if set_finetune_flag:
|
| 1219 |
+
self.finetune = finetune
|
| 1220 |
+
|
| 1221 |
+
return model_config
|
| 1222 |
+
|
| 1223 |
+
def create_model(self):
|
| 1224 |
+
"""
|
| 1225 |
+
Load the model to be trained.
|
| 1226 |
+
In case other models are developed following a different paradigm, we can create
|
| 1227 |
+
corresponding trainers by overriding `create_model`
|
| 1228 |
+
"""
|
| 1229 |
+
logger.info("Initializing model.")
|
| 1230 |
+
|
| 1231 |
+
model_config = self.create_model_config(set_finetune_flag=True)
|
| 1232 |
+
|
| 1233 |
+
if self.gang_rank == 0:
|
| 1234 |
+
logger.info(f"Final model config:\n{pformat(model_config)}")
|
| 1235 |
+
|
| 1236 |
+
model = self.model_loader._factory(
|
| 1237 |
+
model_config,
|
| 1238 |
+
device=self.device,
|
| 1239 |
+
dtype=self.dtype,
|
| 1240 |
+
)
|
| 1241 |
+
# log model before any wrapping:
|
| 1242 |
+
log_model(model, logger)
|
| 1243 |
+
|
| 1244 |
+
return model
|
| 1245 |
+
|
| 1246 |
+
def wrap_model_with_ddp(self, model) -> DDP:
|
| 1247 |
+
"""Wrap the model with DDP"""
|
| 1248 |
+
|
| 1249 |
+
try:
|
| 1250 |
+
ddp_model = to_ddp(
|
| 1251 |
+
model,
|
| 1252 |
+
self.gang,
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
except ValueError:
|
| 1256 |
+
logger.warning(
|
| 1257 |
+
"Using pytorch DDP instead of fairseq's `to_ddp`\
|
| 1258 |
+
- please check fairseq2 after a3de79dcc6a4ea34cde644e15b4056f1a808a6a8"
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
ddp_model = DDP(model)
|
| 1262 |
+
|
| 1263 |
+
if self.gang_rank == 0:
|
| 1264 |
+
log_model(ddp_model, logger)
|
| 1265 |
+
|
| 1266 |
+
return ddp_model
|
| 1267 |
+
|
| 1268 |
+
def wrap_model_with_fsdp(self, model) -> FSDP:
|
| 1269 |
+
"""Wrap the model with FSDP."""
|
| 1270 |
+
|
| 1271 |
+
wrap_policy, ignored_modules = get_fsdp_wrap_policy(
|
| 1272 |
+
model, wrap_granularity=self.config.fsdp_wrap_granularity
|
| 1273 |
+
)
|
| 1274 |
+
memory_policy = get_fsdp_memory_policy(policy=self.config.fsdp_memory_policy)
|
| 1275 |
+
|
| 1276 |
+
if self.dtype == torch.float32:
|
| 1277 |
+
mixed_precision_dtype = None
|
| 1278 |
+
else:
|
| 1279 |
+
mixed_precision_dtype = self.dtype
|
| 1280 |
+
|
| 1281 |
+
skip_init = False
|
| 1282 |
+
broadcast_state = self.finetune and not self.has_checkpoint
|
| 1283 |
+
fp32_reduce = self.config.fsdp_fp32_reduce
|
| 1284 |
+
|
| 1285 |
+
if self.gang.rank == 0:
|
| 1286 |
+
logger.info(
|
| 1287 |
+
(
|
| 1288 |
+
f"FSDP init with: \n--- ignored_modules={ignored_modules}"
|
| 1289 |
+
f"\n--- wrap_policy={wrap_policy}"
|
| 1290 |
+
f"\n--- mixed_precision_dtype={mixed_precision_dtype}"
|
| 1291 |
+
f"\n--- skip_init={skip_init}"
|
| 1292 |
+
f"\n--- broadcast_state (FSDP's sync_module_states)={broadcast_state}"
|
| 1293 |
+
f"\n--- fp32_reduce={fp32_reduce}"
|
| 1294 |
+
f"\n--- memory_policy={memory_policy}"
|
| 1295 |
+
)
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
fsdp_model = to_fsdp(
|
| 1299 |
+
model,
|
| 1300 |
+
self.gang,
|
| 1301 |
+
wrap_policy,
|
| 1302 |
+
mixed_precision_dtype=mixed_precision_dtype,
|
| 1303 |
+
ignored_modules=ignored_modules,
|
| 1304 |
+
fp32_reduce=fp32_reduce,
|
| 1305 |
+
skip_init=skip_init,
|
| 1306 |
+
broadcast_state=broadcast_state,
|
| 1307 |
+
memory_policy=memory_policy,
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
if self.gang_rank == 0:
|
| 1311 |
+
log_model(fsdp_model, logger)
|
| 1312 |
+
|
| 1313 |
+
return fsdp_model
|
| 1314 |
+
|
| 1315 |
+
def maybe_load_model(self, model):
|
| 1316 |
+
"""
|
| 1317 |
+
If we are finetuning and we don't have a checkpoint,
|
| 1318 |
+
load the pre-trained model and broadcast it to
|
| 1319 |
+
all gang processes from rank 0.
|
| 1320 |
+
"""
|
| 1321 |
+
if not self.has_checkpoint and self.finetune:
|
| 1322 |
+
logger.info(f"Loading for finetuning: {self.config.model_config_or_name}")
|
| 1323 |
+
|
| 1324 |
+
if self.gang_rank == 0:
|
| 1325 |
+
pretrained_model = self.model_loader(
|
| 1326 |
+
model_name_or_card=self.config.model_config_or_name,
|
| 1327 |
+
device=self.gang.device,
|
| 1328 |
+
dtype=self.dtype,
|
| 1329 |
+
) # type: ignore[arg-type]
|
| 1330 |
+
|
| 1331 |
+
try:
|
| 1332 |
+
model.load_state_dict(
|
| 1333 |
+
pretrained_model.state_dict(),
|
| 1334 |
+
strict=True,
|
| 1335 |
+
assign=False,
|
| 1336 |
+
)
|
| 1337 |
+
except (KeyError, ValueError) as ex:
|
| 1338 |
+
raise ValueError(
|
| 1339 |
+
f"The model state form {self.config.model_config_or_name} "
|
| 1340 |
+
"cannot be loaded. See nested exception for details."
|
| 1341 |
+
) from ex
|
| 1342 |
+
|
| 1343 |
+
self.gang.barrier()
|
| 1344 |
+
|
| 1345 |
+
to_device(model, self.gang.device)
|
| 1346 |
+
|
| 1347 |
+
logger.info(
|
| 1348 |
+
f"Done loading model for finetuning: {self.config.model_config_or_name}"
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
return model
|
| 1352 |
+
|
| 1353 |
+
def maybe_freeze_parameters(self, model):
|
| 1354 |
+
assert (self.config.freezing_strategy == "modules") == (
|
| 1355 |
+
self.config.freeze_modules is not None
|
| 1356 |
+
), (
|
| 1357 |
+
"For the `modules` freezing_strategy, we need a list of `freeze_modules`. "
|
| 1358 |
+
"If `freeze_modules` is provided, make sure to use freezing_strategy=modules"
|
| 1359 |
+
)
|
| 1360 |
+
|
| 1361 |
+
if self.config.freezing_strategy == "none":
|
| 1362 |
+
return model
|
| 1363 |
+
|
| 1364 |
+
if self.config.freezing_strategy == "modules":
|
| 1365 |
+
# Optionally freeze the parameters of sub-modules:
|
| 1366 |
+
if self.config.freeze_modules is not None:
|
| 1367 |
+
for module in self.config.freeze_modules:
|
| 1368 |
+
logger.info(f"... Freezing module={module}")
|
| 1369 |
+
freeze_parameters(getattr(model, module))
|
| 1370 |
+
return model
|
| 1371 |
+
|
| 1372 |
+
if self.config.freezing_strategy == "ffn":
|
| 1373 |
+
for name, m in _get_named_modules(model):
|
| 1374 |
+
if "ffn" in name:
|
| 1375 |
+
logger.info(f"... Freezing module={name}")
|
| 1376 |
+
freeze_parameters(m)
|
| 1377 |
+
return model
|
| 1378 |
+
|
| 1379 |
+
if self.config.freezing_strategy == "adaln":
|
| 1380 |
+
for name, m in _get_named_modules(model):
|
| 1381 |
+
if "modulator" in name:
|
| 1382 |
+
logger.info(f"... Freezing module={name}")
|
| 1383 |
+
freeze_parameters(m)
|
| 1384 |
+
return model
|
| 1385 |
+
|
| 1386 |
+
if self.config.freezing_strategy == "ffn-adaln":
|
| 1387 |
+
for name, m in _get_named_modules(model):
|
| 1388 |
+
if "modulator" in name or "ffn" in name:
|
| 1389 |
+
logger.info(f"... Freezing module={name}")
|
| 1390 |
+
freeze_parameters(m)
|
| 1391 |
+
return model
|
| 1392 |
+
|
| 1393 |
+
raise ValueError(f"Unknown freezing stratgey {self.config.freezing_strategy}")
|
| 1394 |
+
|
| 1395 |
+
def _setup_additional_logging(self):
|
| 1396 |
+
if self.config.debug:
|
| 1397 |
+
assert self.config.log_folder is not None, (
|
| 1398 |
+
"Missing log_folder, \
|
| 1399 |
+
make sure the log_folder is properly set in the training config"
|
| 1400 |
+
)
|
| 1401 |
+
setup_additional_logging(log_folder=self.config.log_folder)
|
| 1402 |
+
|
| 1403 |
+
@property
|
| 1404 |
+
def use_fsdp(self) -> bool:
|
| 1405 |
+
return self.config.use_fsdp
|
| 1406 |
+
|
| 1407 |
+
@property
|
| 1408 |
+
def use_ddp(self) -> bool:
|
| 1409 |
+
"""
|
| 1410 |
+
Whether DDP should be used.
|
| 1411 |
+
if selg.gang.size == 1: single worker, no parallelism
|
| 1412 |
+
if use_fsdp: use FSDP instead
|
| 1413 |
+
"""
|
| 1414 |
+
return not (self.gang.size == 1 or self.use_fsdp)
|
| 1415 |
+
|
| 1416 |
+
@abstractmethod
|
| 1417 |
+
def build_trainer(self):
|
| 1418 |
+
"""Build the trainer by loading data and
|
| 1419 |
+
setting up the model for training
|
| 1420 |
+
|
| 1421 |
+
Returns trainer
|
| 1422 |
+
"""
|
lcm/train/two_tower_diffusion_lcm/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
lcm/train/two_tower_diffusion_lcm/criterion.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq2.logging import get_log_writer
|
| 12 |
+
from fairseq2.nn.padding import pad_seqs
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from lcm.datasets.batch import EmbeddingsBatch, LCMInput, LCMStyle
|
| 16 |
+
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModel
|
| 17 |
+
from lcm.train.criterion import CriterionsFactory
|
| 18 |
+
from lcm.train.lcm.criterion import (
|
| 19 |
+
LCMCriterion,
|
| 20 |
+
LCMCriterionConfig,
|
| 21 |
+
compute_standard_mse,
|
| 22 |
+
)
|
| 23 |
+
from lcm.train.metrics import LossTerm, format_as_float, register_metric_formatter
|
| 24 |
+
from lcm.train.step_sampler import StepsSampler, StepsSamplerConfig
|
| 25 |
+
|
| 26 |
+
logger = get_log_writer(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class TowerDiffusionLCMCriterionConfig(LCMCriterionConfig):
|
| 31 |
+
cf_guidance_probability: float = 0.0
|
| 32 |
+
"""Probability to use classifier-free guidance by dropping conditioning.
|
| 33 |
+
Note that this requires the model to be set with
|
| 34 |
+
`trained_with_cf_guidance = True`!
|
| 35 |
+
"""
|
| 36 |
+
step_sampling: StepsSamplerConfig = field(
|
| 37 |
+
default_factory=lambda: StepsSamplerConfig()
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
log_losses_per_timestep_bucket: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@CriterionsFactory.register("two_tower_diffusion_next_sent")
|
| 44 |
+
class TwoTowerDiffusionCriterion(LCMCriterion):
|
| 45 |
+
"""Computes the LCM training objective for next-sentence prediction with diffusion"""
|
| 46 |
+
|
| 47 |
+
config: TowerDiffusionLCMCriterionConfig
|
| 48 |
+
model: TwoTowerDiffusionLCModel
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
config: TowerDiffusionLCMCriterionConfig,
|
| 53 |
+
model: TwoTowerDiffusionLCModel,
|
| 54 |
+
style: LCMStyle = LCMStyle.UNSUPERVISED,
|
| 55 |
+
):
|
| 56 |
+
super().__init__(config, model, style)
|
| 57 |
+
assert hasattr(self.base_model, "noise_scheduler"), (
|
| 58 |
+
"Expecting the diffusion model to have a `noise_scheduler`"
|
| 59 |
+
)
|
| 60 |
+
self.noise_scheduler = self.base_model.noise_scheduler
|
| 61 |
+
|
| 62 |
+
self.prediction_type = self.noise_scheduler.prediction_type
|
| 63 |
+
|
| 64 |
+
self.trained_with_cf_guidance = self.base_model.config.trained_with_cf_guidance
|
| 65 |
+
|
| 66 |
+
self.cf_guidance_probability = config.cf_guidance_probability
|
| 67 |
+
|
| 68 |
+
assert (
|
| 69 |
+
bool(self.cf_guidance_probability > 0) == self.trained_with_cf_guidance
|
| 70 |
+
), (
|
| 71 |
+
"Expecting the config's cf_guidance_probabilitya to align with the model's `trained_with_cf_guidance` ",
|
| 72 |
+
f"Found cf_guidance_probability={config.cf_guidance_probability} and "
|
| 73 |
+
f"trained_with_cf_guidance={self.trained_with_cf_guidance}",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
assert self.normalize_in_criterion, (
|
| 77 |
+
"We only support `normalize_in_criterion = True` in the diffusion criterions"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.summands.append("unnormalized_reconstruction_loss")
|
| 81 |
+
|
| 82 |
+
if self.config.log_losses_per_timestep_bucket:
|
| 83 |
+
# customize if needed
|
| 84 |
+
self.step_bucketing_boundaries = torch.linspace(
|
| 85 |
+
0, self.noise_scheduler.num_diffusion_train_steps, 11
|
| 86 |
+
)
|
| 87 |
+
self.step_bucketing_labels: List[str] = []
|
| 88 |
+
for e in range(len(self.step_bucketing_boundaries) - 1):
|
| 89 |
+
bucket_left = self.step_bucketing_boundaries[e]
|
| 90 |
+
bucket_right = self.step_bucketing_boundaries[e + 1]
|
| 91 |
+
self.step_bucketing_labels.append(
|
| 92 |
+
f"reconstruction_loss_t{bucket_left:.0f}-{bucket_right:.0f}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.summands.extend(self.step_bucketing_labels)
|
| 96 |
+
for label in self.step_bucketing_labels:
|
| 97 |
+
register_metric_formatter(
|
| 98 |
+
label, label, 1000, format_as_float, overwrite=True
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Step sampler + loss weighter
|
| 102 |
+
self.step_sampler = StepsSampler(
|
| 103 |
+
config.step_sampling,
|
| 104 |
+
noise_scheduler=self.noise_scheduler,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def prepare_input_and_mask(
|
| 108 |
+
self,
|
| 109 |
+
batch: LCMInput,
|
| 110 |
+
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
|
| 111 |
+
"""
|
| 112 |
+
A method for preparing model inputs and mask for a batch.
|
| 113 |
+
It will be typically reused by the `__call__`
|
| 114 |
+
implementations of the subclasses.
|
| 115 |
+
Returns:
|
| 116 |
+
- input_batch: context
|
| 117 |
+
- target_batch: denoiser input
|
| 118 |
+
- target_mask mask of positions to compute the loss over
|
| 119 |
+
|
| 120 |
+
"""
|
| 121 |
+
# Prepare the input as in MSE LCM: each sequence is (src, tgt)
|
| 122 |
+
input_embeddings = batch.prepare_input(style=self.style)
|
| 123 |
+
|
| 124 |
+
# Normalize the embeddings
|
| 125 |
+
if self.normalize_in_criterion:
|
| 126 |
+
input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
|
| 127 |
+
|
| 128 |
+
target_mask = torch.ones(
|
| 129 |
+
size=input_embeddings.seqs.shape[:-1],
|
| 130 |
+
dtype=torch.bool,
|
| 131 |
+
device=input_embeddings.seqs.device,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Factor in padded positions:
|
| 135 |
+
if input_embeddings.padding_mask is not None:
|
| 136 |
+
target_mask &= input_embeddings.padding_mask.materialize()
|
| 137 |
+
|
| 138 |
+
return input_embeddings, input_embeddings.clone(), target_mask
|
| 139 |
+
|
| 140 |
+
def sample_noisy_input_and_targets(self, input_batch, target_mask):
|
| 141 |
+
"""
|
| 142 |
+
(1)
|
| 143 |
+
Prepares the noised inputs (latents) by sampling diffusion timesteps and calling
|
| 144 |
+
on the model's noise_scheduler to add noise accordingly
|
| 145 |
+
(2) Given the scheduler prediction type, prepares the target that the model will be
|
| 146 |
+
trained to predict.
|
| 147 |
+
|
| 148 |
+
:param input_bach: EmbeddingsBatch of the ground truth embeddings with seqs in (B, T, C)
|
| 149 |
+
:param target_mask: Bool tensor in (B, T) where `True` signals that the
|
| 150 |
+
model will be asked to predict the position
|
| 151 |
+
"""
|
| 152 |
+
input_seqs, padding_mask = input_batch.seqs, input_batch.padding_mask
|
| 153 |
+
|
| 154 |
+
timesteps = self.step_sampler.sample(
|
| 155 |
+
size=input_seqs[..., 0].size(), device=input_seqs.device
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Sample noise
|
| 159 |
+
noise_seqs = torch.randn_like(input_seqs)
|
| 160 |
+
|
| 161 |
+
# Define target in (B*T, C)
|
| 162 |
+
sonar_dim = input_seqs.size(-1)
|
| 163 |
+
if self.prediction_type == "sample":
|
| 164 |
+
"""Predict the clean ground truth embeddings. Default mode"""
|
| 165 |
+
target = input_seqs.view(-1, sonar_dim)
|
| 166 |
+
|
| 167 |
+
elif self.prediction_type == "epsilon":
|
| 168 |
+
"""Predict the added noise"""
|
| 169 |
+
target = noise_seqs.view(-1, sonar_dim)
|
| 170 |
+
|
| 171 |
+
elif self.prediction_type == "v_prediction":
|
| 172 |
+
"""Predict an interpolation of the ground truth clean
|
| 173 |
+
embeddings and the added noise.
|
| 174 |
+
As introduced in https://arxiv.org/pdf/2305.08891
|
| 175 |
+
"""
|
| 176 |
+
target = self.noise_scheduler.get_velocity(
|
| 177 |
+
input_seqs.view(-1, sonar_dim),
|
| 178 |
+
noise_seqs.view(-1, sonar_dim),
|
| 179 |
+
timesteps.view(-1),
|
| 180 |
+
).clone()
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
"Prediction type should be either: sample, epsilon, v_prediction"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Add noise
|
| 187 |
+
# Reshape inputs and noise into in (B*T , C) -> add noise -> reshape back as (B, T, C)
|
| 188 |
+
noisy_input_seqs = self.noise_scheduler.add_noise(
|
| 189 |
+
input_seqs.view(-1, sonar_dim),
|
| 190 |
+
noise_seqs.view(-1, sonar_dim),
|
| 191 |
+
timesteps.view(-1),
|
| 192 |
+
).view(input_seqs.size())
|
| 193 |
+
|
| 194 |
+
# Create sequence batch with diffusion timesteps
|
| 195 |
+
noisy_input_batch = EmbeddingsBatch(
|
| 196 |
+
noisy_input_seqs,
|
| 197 |
+
padding_mask,
|
| 198 |
+
diffusion_timesteps=timesteps,
|
| 199 |
+
)
|
| 200 |
+
return noisy_input_batch, target, target_mask
|
| 201 |
+
|
| 202 |
+
def compute_loss(
|
| 203 |
+
self, flattened_predictions, flattened_target
|
| 204 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 205 |
+
"""
|
| 206 |
+
Parameters:
|
| 207 |
+
flattened_predictions (Tensor): The predictions in (N, C)
|
| 208 |
+
flattened_target (Tensor): The targets in (N, C)
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
reconstruction_loss (Tensor): The Reconstruction loss we want to optimize (RMSE, SmoothL1, Huber etc.).
|
| 212 |
+
plain_reconstruction_loss (Tensor): plain RMSE loss.
|
| 213 |
+
unnormalized_reconstruction_loss (Tensor): plain RMSE loss between unnormalized features.
|
| 214 |
+
"""
|
| 215 |
+
reconstruction_loss, plain_reconstruction_loss = compute_standard_mse(
|
| 216 |
+
flattened_predictions,
|
| 217 |
+
flattened_target,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
unnormalized_reconstruction_loss, _ = compute_standard_mse(
|
| 221 |
+
flattened_predictions,
|
| 222 |
+
flattened_target,
|
| 223 |
+
normalizer=self.sonar_normalizer,
|
| 224 |
+
)
|
| 225 |
+
# For backward compatibility with ongoing runs, take the sqrt
|
| 226 |
+
if self.config.compute_rmse:
|
| 227 |
+
epsilon = 1e-5
|
| 228 |
+
reconstruction_loss = torch.sqrt(reconstruction_loss + epsilon)
|
| 229 |
+
plain_reconstruction_loss = torch.sqrt(plain_reconstruction_loss + epsilon)
|
| 230 |
+
unnormalized_reconstruction_loss = torch.sqrt(
|
| 231 |
+
unnormalized_reconstruction_loss + epsilon
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return (
|
| 235 |
+
reconstruction_loss,
|
| 236 |
+
plain_reconstruction_loss,
|
| 237 |
+
unnormalized_reconstruction_loss,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def _log_losses_per_step(self, batch_steps, reconstruction_loss):
|
| 242 |
+
# Aggregate loss terms based on their bucket of diffusion steps for tracking
|
| 243 |
+
summands = {}
|
| 244 |
+
if self.config.log_losses_per_timestep_bucket:
|
| 245 |
+
# Reconstruction_loss in BT,
|
| 246 |
+
# batch_steps in BT,
|
| 247 |
+
bucket_index = torch.bucketize(
|
| 248 |
+
batch_steps, self.step_bucketing_boundaries.to(batch_steps.device)
|
| 249 |
+
)
|
| 250 |
+
onehot = F.one_hot(
|
| 251 |
+
bucket_index,
|
| 252 |
+
num_classes=self.step_bucketing_boundaries.numel(),
|
| 253 |
+
)
|
| 254 |
+
loss_per_step = torch.matmul(onehot.t().float(), reconstruction_loss)
|
| 255 |
+
count_steps = onehot.sum(dim=0) + 1e-6
|
| 256 |
+
if self.reduction == "mean":
|
| 257 |
+
loss_per_step /= count_steps
|
| 258 |
+
|
| 259 |
+
for e, label in enumerate(self.step_bucketing_labels):
|
| 260 |
+
summands[label] = (
|
| 261 |
+
loss_per_step[e].item(),
|
| 262 |
+
count_steps[e].long().item(),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return summands
|
| 266 |
+
|
| 267 |
+
def __call__(self, batch: LCMInput) -> LossTerm:
|
| 268 |
+
"""
|
| 269 |
+
Input batch is LCMInput with:
|
| 270 |
+
source: List[Tensor]
|
| 271 |
+
target: Union[None, List[Tensor]]
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
# Prepare the clean inputs and target mask:
|
| 275 |
+
input_batch, target_batch, target_mask = self.prepare_input_and_mask(batch)
|
| 276 |
+
|
| 277 |
+
noisy_target_batch, target, target_mask = self.sample_noisy_input_and_targets(
|
| 278 |
+
target_batch, target_mask
|
| 279 |
+
)
|
| 280 |
+
# Encode the context and diffuse:
|
| 281 |
+
output_batch = self.model(
|
| 282 |
+
input_batch,
|
| 283 |
+
noisy_target_batch,
|
| 284 |
+
cf_guidance_prob=self.cf_guidance_probability,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Shape B, T, C
|
| 288 |
+
output_seqs = output_batch.seqs
|
| 289 |
+
|
| 290 |
+
sonar_dim = output_seqs.size(-1)
|
| 291 |
+
|
| 292 |
+
# only measure distance over `target_mask = True` positions
|
| 293 |
+
target_mask = target_mask.reshape(-1)
|
| 294 |
+
|
| 295 |
+
# The target is basically the doubled ground truth sequence before noising
|
| 296 |
+
# (with some modification to adjust for the denoiser's prediction type)
|
| 297 |
+
|
| 298 |
+
# contextualized latents (noised inputs preceding the target) e_1, e_2, ...
|
| 299 |
+
flattened_predictions = output_seqs.view(-1, sonar_dim)[target_mask]
|
| 300 |
+
|
| 301 |
+
# x1, x2, ..., xT
|
| 302 |
+
# Target is already in B*T, C
|
| 303 |
+
flattened_target = target[target_mask]
|
| 304 |
+
|
| 305 |
+
# Cast features to float32 before computing the loss:
|
| 306 |
+
(
|
| 307 |
+
reconstruction_loss,
|
| 308 |
+
mse_loss,
|
| 309 |
+
unnormalized_reconstruction_loss,
|
| 310 |
+
) = self.compute_loss(flattened_predictions.float(), flattened_target.float())
|
| 311 |
+
|
| 312 |
+
num_target_elements = target_mask.sum()
|
| 313 |
+
|
| 314 |
+
batch_steps = noisy_target_batch.diffusion_timesteps.view(-1)[target_mask]
|
| 315 |
+
|
| 316 |
+
summands = self._log_losses_per_step(batch_steps, reconstruction_loss)
|
| 317 |
+
|
| 318 |
+
# Get loss scales per timestep (gamma)
|
| 319 |
+
gammas = self.step_sampler.get_loss_scales(batch_steps)
|
| 320 |
+
# Weight the loss terms
|
| 321 |
+
if gammas is not None:
|
| 322 |
+
reconstruction_loss = torch.mul(reconstruction_loss, gammas)
|
| 323 |
+
|
| 324 |
+
if self.reduction == "sum" or num_target_elements == 0:
|
| 325 |
+
reduced_reconstruction_loss = reconstruction_loss.sum()
|
| 326 |
+
mse_loss = mse_loss.sum()
|
| 327 |
+
unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.sum()
|
| 328 |
+
|
| 329 |
+
elif self.reduction == "mean":
|
| 330 |
+
reduced_reconstruction_loss = reconstruction_loss.mean()
|
| 331 |
+
mse_loss = mse_loss.mean()
|
| 332 |
+
unnormalized_reconstruction_loss = unnormalized_reconstruction_loss.mean()
|
| 333 |
+
|
| 334 |
+
final_loss = reduced_reconstruction_loss
|
| 335 |
+
|
| 336 |
+
# Loss summands for records
|
| 337 |
+
summands.update(
|
| 338 |
+
{
|
| 339 |
+
"mse_loss": (mse_loss.item(), -1),
|
| 340 |
+
"reconstruction_loss": (reduced_reconstruction_loss.item(), -1),
|
| 341 |
+
"unnormalized_reconstruction_loss": (
|
| 342 |
+
unnormalized_reconstruction_loss.item(),
|
| 343 |
+
-1,
|
| 344 |
+
),
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return LossTerm(
|
| 349 |
+
value=final_loss,
|
| 350 |
+
batch_size=output_seqs.size(0),
|
| 351 |
+
num_target_elements=num_target_elements.item(),
|
| 352 |
+
summands=summands,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@CriterionsFactory.register("two_tower_diffusion_next_sent_finetuning")
|
| 357 |
+
class DiffusionNextSentFinetuningCriterion(TwoTowerDiffusionCriterion):
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
config: TowerDiffusionLCMCriterionConfig,
|
| 361 |
+
model: TwoTowerDiffusionLCModel,
|
| 362 |
+
):
|
| 363 |
+
super().__init__(config, model, LCMStyle.SUPERVISED)
|
| 364 |
+
|
| 365 |
+
def prepare_input_and_mask(
|
| 366 |
+
self,
|
| 367 |
+
batch: LCMInput,
|
| 368 |
+
) -> Tuple[EmbeddingsBatch, EmbeddingsBatch, torch.Tensor]:
|
| 369 |
+
"""
|
| 370 |
+
A method for preparing model inputs and mask for a batch.
|
| 371 |
+
It will be typically reused by the `__call__`
|
| 372 |
+
implementations of the subclasses.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
- input_batch: context
|
| 376 |
+
- target_batch: denoiser input
|
| 377 |
+
- target_mask mask of positions to compute the loss over
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
# Prepare the input as in MSE LCM
|
| 381 |
+
input_embeddings = batch.prepare_input(style=self.style)
|
| 382 |
+
|
| 383 |
+
assert input_embeddings.source_lengths is not None, (
|
| 384 |
+
"Missing source lengths needed for the two-tower supervised fintuning"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
target_embeddings = EmbeddingsBatch(*pad_seqs(batch.target)) # type: ignore
|
| 388 |
+
|
| 389 |
+
# Normalize the embeddings
|
| 390 |
+
if self.normalize_in_criterion:
|
| 391 |
+
input_embeddings = input_embeddings.normalize_seqs(self.sonar_normalizer)
|
| 392 |
+
target_embeddings = target_embeddings.normalize_seqs(self.sonar_normalizer)
|
| 393 |
+
|
| 394 |
+
target_mask = torch.ones(
|
| 395 |
+
size=target_embeddings.shape[:-1],
|
| 396 |
+
dtype=torch.bool,
|
| 397 |
+
device=input_embeddings.seqs.device,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Factor in padded positions:
|
| 401 |
+
if target_embeddings.padding_mask is not None:
|
| 402 |
+
target_mask &= target_embeddings.padding_mask.materialize()
|
| 403 |
+
|
| 404 |
+
return input_embeddings, target_embeddings, target_mask
|
lcm/train/two_tower_diffusion_lcm/trainer.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from lcm.models.two_tower_diffusion_lcm.builder import TwoTowerDiffusionLCModelConfig
|
| 10 |
+
from lcm.models.two_tower_diffusion_lcm.loader import (
|
| 11 |
+
load_two_tower_diffusion_lcm_model,
|
| 12 |
+
)
|
| 13 |
+
from lcm.train.lcm.trainer import LCMTrainer, LCMTrainerBuilder, LCMTrainingConfig
|
| 14 |
+
from lcm.train.two_tower_diffusion_lcm.criterion import (
|
| 15 |
+
TowerDiffusionLCMCriterionConfig,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TwoTowerDiffusionLCMTrainingConfig(LCMTrainingConfig):
|
| 21 |
+
model_config_or_name: Union[TwoTowerDiffusionLCModelConfig, str, None] = None
|
| 22 |
+
"""The model configuration or name to train."""
|
| 23 |
+
|
| 24 |
+
criterion: TowerDiffusionLCMCriterionConfig = field( # type: ignore
|
| 25 |
+
default_factory=lambda: TowerDiffusionLCMCriterionConfig()
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DiffusionLCMTrainerBuilder(LCMTrainerBuilder):
|
| 30 |
+
config: TwoTowerDiffusionLCMTrainingConfig
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: TwoTowerDiffusionLCMTrainingConfig):
|
| 33 |
+
super().__init__(config)
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def model_loader(self):
|
| 37 |
+
"""A fairseq2 ModelLoader"""
|
| 38 |
+
return load_two_tower_diffusion_lcm_model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def prepare_two_tower_diffusion_lcm_trainer(
|
| 42 |
+
config: TwoTowerDiffusionLCMTrainingConfig,
|
| 43 |
+
) -> LCMTrainer:
|
| 44 |
+
"""Create an LCM Trainer.
|
| 45 |
+
:param config: The training configuration.
|
| 46 |
+
"""
|
| 47 |
+
return DiffusionLCMTrainerBuilder(config).build_trainer()
|
pyproject.toml
CHANGED
|
@@ -13,6 +13,7 @@ dependencies = [
|
|
| 13 |
"polars>=1.16.0",
|
| 14 |
"pyarrow>=16.1.0",
|
| 15 |
"retrying>=1.3.4",
|
|
|
|
| 16 |
"sentence-splitter>=1.4",
|
| 17 |
"sonar-space>=0.3.2",
|
| 18 |
"stopes[mono]>=2.2.0",
|
|
|
|
| 13 |
"polars>=1.16.0",
|
| 14 |
"pyarrow>=16.1.0",
|
| 15 |
"retrying>=1.3.4",
|
| 16 |
+
"safetensors>=0.5.3",
|
| 17 |
"sentence-splitter>=1.4",
|
| 18 |
"sonar-space>=0.3.2",
|
| 19 |
"stopes[mono]>=2.2.0",
|
scripts/CovertToST.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from safetensors.torch import save_file
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Define the location and files to process
|
| 6 |
+
location = "_LexaLCM_Pre0/Checkpoints/LCM_TwoTower_Pre0/checkpoints/step_250000"
|
| 7 |
+
files = ["model", "rank_0", "metadata"]
|
| 8 |
+
|
| 9 |
+
for file in files:
|
| 10 |
+
pt_path = os.path.join(location, f"{file}.pt")
|
| 11 |
+
st_path = os.path.join(location, f"{file}.safetensors")
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
# Attempt to load the checkpoint with weights_only=True
|
| 15 |
+
checkpoint = torch.load(pt_path, weights_only=True)
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"Warning: Failed to load {pt_path} with weights_only=True due to {e}")
|
| 18 |
+
print("Attempting to load with weights_only=False (ensure the source is trusted).")
|
| 19 |
+
try:
|
| 20 |
+
checkpoint = torch.load(pt_path, weights_only=False)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Error: Failed to load {pt_path} with weights_only=False due to {e}")
|
| 23 |
+
continue # Skip to the next file
|
| 24 |
+
|
| 25 |
+
# Determine the state_dict
|
| 26 |
+
state_dict = checkpoint.get('model', checkpoint)
|
| 27 |
+
|
| 28 |
+
# Filter out non-tensor entries
|
| 29 |
+
tensor_state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
| 30 |
+
|
| 31 |
+
# Save the filtered state_dict to a .safetensors file
|
| 32 |
+
save_file(tensor_state_dict, st_path)
|
| 33 |
+
print(f"Successfully converted {pt_path} to {st_path}")
|