| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import logging |
| | from itertools import chain |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| |
|
| |
|
| | def tokenize( |
| | examples: Dict[str, List[Any]], |
| | tokenizer: AutoTokenizer, |
| | context_length: int |
| | ) -> Dict[str, List[List[int]]]: |
| | """ |
| | Tokenize the input text and split into chunks of specified context length. |
| | |
| | Args: |
| | examples: |
| | Dictionary containing the input text. |
| | tokenizer: |
| | Initialized tokenizer. |
| | context_length: |
| | Length of each context chunk. |
| | |
| | Returns: |
| | Dictionary containing tokenized and chunked input ids |
| | """ |
| | text = examples['text'] |
| | input_ids = tokenizer(text)['input_ids'] |
| | input_ids = list(chain(*input_ids)) |
| | total_length = len(input_ids) |
| | total_length = (total_length // context_length) * context_length |
| | |
| | return {'input_ids': [input_ids[i:i+context_length] for i in range(0, total_length, context_length)]} |
| |
|
| |
|
| | def preprocess( |
| | dataset: str, |
| | name: Optional[str] = None, |
| | split: str = 'train', |
| | output: str = 'data', |
| | model: str = 'mistralai/Mistral-7B-v0.1', |
| | num_proc: int = 64, |
| | context_length: int = 8192 |
| | ) -> None: |
| | """ |
| | Load, tokenize, and save the processed dataset. |
| | |
| | Args: |
| | dataset: |
| | Path or name of the dataset. |
| | name: |
| | Name of the dataset configuration. |
| | split: |
| | Dataset split to process. |
| | output: |
| | Output directory. |
| | model: |
| | Model name for tokenizer. |
| | num_proc: |
| | Number of processes for parallel processing. |
| | context_length: |
| | Context length for tokenization. |
| | """ |
| | tokenized_path = f'{output}/{dataset}/{name}/{split}' if name is not None else f'{output}/{dataset}/{split}' |
| |
|
| | logging.info(f'Initializing tokenizer of {model}') |
| | tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) |
| | logging.info(f'Tokenizer initialized: {tokenizer}') |
| |
|
| | logging.info(f'Loading dataset: {dataset}') |
| | dataset = load_dataset(dataset, name=name, split=split) |
| |
|
| | remove_columns = list(next(iter(dataset)).keys()) |
| | logging.info('Tokenizing and processing dataset') |
| | dataset = dataset.map( |
| | lambda examples: tokenize(examples, tokenizer, context_length), |
| | batched=True, |
| | remove_columns=remove_columns, |
| | num_proc=num_proc, |
| | desc="Running tokenizer on dataset" |
| | ) |
| |
|
| | logging.info(f'Saving processed dataset to {tokenized_path}') |
| | dataset.save_to_disk(tokenized_path, num_proc=num_proc) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Preprocess and tokenize dataset") |
| | parser.add_argument("--dataset", default="HuggingFaceFW/fineweb-edu", help="Path or name of the dataset") |
| | parser.add_argument("--name", default=None, help="Name of the dataset configuration") |
| | parser.add_argument("--split", default="train", help="Dataset split to process") |
| | parser.add_argument("--output", default="data", help="Output directory") |
| | parser.add_argument("--model", default="mistralai/Mistral-7B-v0.1", help="Model name for tokenizer") |
| | parser.add_argument("--num_proc", type=int, default=64, help="Number of processes for parallel processing") |
| | parser.add_argument("--context_length", type=int, default=8192, help="Context length for tokenization") |
| | args = parser.parse_args() |
| |
|
| | preprocess( |
| | dataset=args.dataset, |
| | name=args.name, |
| | split=args.split, |
| | output=args.output, |
| | model=args.model, |
| | num_proc=args.num_proc, |
| | context_length=args.context_length |
| | ) |
| |
|