Hamilton / flame /utils /preprocess.py
ZhenbinWang's picture
Upload 50 files
805d830 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import argparse
from typing import Any, Dict, List
from transformers import AutoTokenizer, PreTrainedTokenizer
from flame.data import build_dataset
from torchtitan.tools.logging import init_logger, logger
def tokenize(
examples: Dict[str, List[Any]],
tokenizer: PreTrainedTokenizer,
) -> Dict:
if 'text' in examples:
samples = examples['text']
elif 'content' in examples:
samples = examples['content']
else:
raise ValueError(f'No "text" or "content" field found in examples:\n{examples}')
input_ids = tokenizer(samples)['input_ids']
bits_per_token = [len(sample.encode(encoding='utf-8')) * 8 / len(input_ids[i]) for i, sample in enumerate(samples)]
return {'input_ids': input_ids, 'bits_per_token': bits_per_token}
if __name__ == '__main__':
init_logger()
parser = argparse.ArgumentParser(description='Preprocess the dataset.')
parser.add_argument(
'--dataset',
default='HuggingFaceFW/fineweb-edu',
help='Dataset to use, with comma separated values',
)
parser.add_argument(
'--dataset_name',
default='sample-100BT',
help='The name of the dataset config, with comma separated values if provided',
)
parser.add_argument(
'--dataset_split',
default='train',
help='Dataset split to use, with comma separated values if provided',
)
parser.add_argument(
'--data_dir',
default=None,
help='Data dirs to use, with comma separated values if provided',
)
parser.add_argument(
'--data_files',
default=None,
help='Data files to use, with comma separated values if provided',
)
parser.add_argument(
'--data_probs',
default=None,
help='Data sampling probabilities, with comma separated values if provided',
)
parser.add_argument(
'--streaming',
action='store_true',
help='Whether to use streaming mode',
)
parser.add_argument(
'--num_workers',
type=int,
default=64,
help='Number of workers to use for preprocessing',
)
parser.add_argument(
'--seed',
type=int,
default=42,
help='Random seed for preprocessing',
)
parser.add_argument(
'--path',
default='data',
help='Path to save the preprocessed dataset',
)
parser.add_argument(
'--tokenizer',
default='fla-hub/transformer-1.3B-100B',
help='Tokenizer to use',
)
parser.add_argument(
"--batch_size",
type=int,
default=2048,
help="Batch size for processing"
)
args = parser.parse_args()
logger.info(f'Loading tokenizer {args.tokenizer}')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
logger.info(f'{tokenizer}')
logger.info(f'Loading dataset {args.dataset} {args.dataset_name} {args.dataset_split}')
dataset = build_dataset(
dataset=args.dataset,
dataset_name=args.dataset_name,
dataset_split=args.dataset_split,
data_dir=args.data_dir,
data_files=args.data_files,
data_probs=args.data_probs,
streaming=args.streaming,
num_workers=args.num_workers,
seed=args.seed,
)
logger.info(f'Tokenizing and processing the dataset with batch size {args.batch_size}')
dataset = dataset.map(
lambda examples: tokenize(examples, tokenizer),
batched=True,
batch_size=args.batch_size,
remove_columns=list(next(iter(dataset)).keys()),
num_proc=args.num_workers,
desc="Running tokenizer on dataset"
)
logger.info(f'{dataset}')
logger.info(f'Saving tokenized dataset to {args.path}')
dataset.save_to_disk(args.path)