Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from typing import Any, Generator | |
| import torch | |
| from pydantic import BaseModel, ConfigDict | |
| from bytelatent.data.data_types import BltExample | |
| from bytelatent.data.iterators.abstract_iterator import ( | |
| PydanticIteratorState, | |
| StatefulIterator, | |
| ) | |
| from bytelatent.data.iterators.arrow_iterator import ( | |
| ArrowFileIterator, | |
| ArrowFileIteratorState, | |
| ) | |
| from bytelatent.data.iterators.limit_iterator import LimitIterator, LimitIteratorState | |
| from bytelatent.data.iterators.looping_iterator import ( | |
| LoopingIterator, | |
| LoopingIteratorState, | |
| ) | |
| from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum | |
| from bytelatent.tokenizers.blt_tokenizer import BltTokenizer | |
| from bytelatent.tokenizers.build_tokenizer import TokenizerArgs | |
| class PreprocessIteratorState(PydanticIteratorState): | |
| model_config = ConfigDict(extra="forbid") | |
| arrow_file_iterator_state: ( | |
| ArrowFileIteratorState | LoopingIteratorState | LimitIteratorState | |
| ) | |
| add_tokens: bool | |
| add_patches: bool | |
| tokenizer_args: TokenizerArgs | |
| patcher_args: PatcherArgs | |
| def build(self): | |
| arrow_iterator = self.arrow_file_iterator_state.build() | |
| return PreprocessIterator( | |
| arrow_iterator, | |
| patcher_args=self.patcher_args, | |
| tokenizer_args=self.tokenizer_args, | |
| add_tokens=self.add_tokens, | |
| add_patches=self.add_patches, | |
| ) | |
| class PreprocessIterator(StatefulIterator): | |
| """ | |
| Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require | |
| preprocessing like tokenization and patching | |
| """ | |
| def __init__( | |
| self, | |
| arrow_iterator: ArrowFileIterator | LoopingIterator | LimitIterator, | |
| *, | |
| patcher_args: PatcherArgs, | |
| tokenizer_args: TokenizerArgs, | |
| add_tokens: bool = True, | |
| add_patches: bool = True, | |
| ): | |
| self.arrow_iterator = arrow_iterator | |
| self.tokenizer_args = tokenizer_args | |
| self.patcher_args = patcher_args | |
| self.add_tokens = add_tokens | |
| self.add_patches = add_patches | |
| self.tokenizer: BltTokenizer | None = None | |
| self.patcher: Patcher | None = None | |
| def get_state(self) -> PreprocessIteratorState: | |
| """ | |
| The only state to maintain here is from arrow, there | |
| isn't any internal state on this iterator. | |
| """ | |
| return PreprocessIteratorState( | |
| arrow_file_iterator_state=self.arrow_iterator.get_state(), | |
| tokenizer_args=self.tokenizer_args, | |
| patcher_args=self.patcher_args, | |
| add_tokens=self.add_tokens, | |
| add_patches=self.add_patches, | |
| ) | |
| def create_iter(self) -> Generator[BltExample, Any, None]: | |
| if self.tokenizer is None and self.add_tokens: | |
| self.tokenizer = self.tokenizer_args.build() | |
| if self.patcher is None and self.add_patches: | |
| self.patcher = self.patcher_args.build() | |
| example_iter = self.arrow_iterator.create_iter() | |
| for example in example_iter: | |
| if self.add_tokens: | |
| tokens = self.tokenizer.encode(example.text) | |
| else: | |
| tokens = example.tokens | |
| if ( | |
| self.patcher is not None | |
| and self.patcher.patching_mode == PatchingModeEnum.entropy | |
| ): | |
| assert ( | |
| example.entropies is not None | |
| ), "For patching, entropies cannot be None" | |
| entropies = torch.tensor(example.entropies).unsqueeze(0) | |
| else: | |
| entropies = None | |
| if self.patcher is None: | |
| patch_lengths = None | |
| else: | |
| patch_lengths = self.patcher.patch( | |
| torch.tensor(tokens).unsqueeze(0), | |
| include_next_token=False, | |
| entropies=entropies, | |
| )[0][0].tolist() | |
| yield BltExample( | |
| sample_id=example.sample_id, | |
| text=example.text, | |
| tokens=tokens, | |
| mask=[True] * len(tokens), | |
| patch_lengths=patch_lengths, | |
| entropies=example.entropies, | |
| ) | |