lsn-analysis / load_data_wiki.py
tvkain's picture
Upload folder using huggingface_hub
fed1832 verified
#!/usr/bin/env python3
"""
Preprocess multilingual Wikipedia data for model training.
This script performs the following steps:
1. Downloads the Wikimedia Wikipedia dataset for the specified languages
(https://huggingface.co/datasets/wikimedia/wikipedia).
2. Tokenizes the dataset using a Hugging Face tokenizer corresponding
to the specified model.
3. Aggregates token IDs up to a target number of tokens per language.
4. Saves the tokenized data as a PyTorch tensor for later use.
Usage example:
python load_datas.py \
--languages en zh vi \
--model-id meta-llama/Llama-3.1-8B-Instruct \
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
--output-dir train-data
"""
#!/usr/bin/env python3
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
import os
from tqdm import tqdm
import multiprocessing
from functools import partial
import argparse
NUM_PROC_BASE = max(1, os.cpu_count() // 2 if os.cpu_count() else 1)
TARGET_TOKENS_PER_LANGUAGE = 100_000_000
DATE_SNAPSHOT = "20231101" # fixed date
def tokenize_function(examples, tokenizer):
output = tokenizer(
examples["text"],
add_special_tokens=False,
truncation=False,
padding=False,
)
return {"input_ids": output.input_ids}
def build_and_save(
lang,
model_id,
tokenizer_name,
output_dir,
num_proc_map=NUM_PROC_BASE
):
print(f"Starting data processing for language: {lang}")
train_filename_base = f"id.{lang}.train.{model_id.replace('/', '_')}"
train_output_path = os.path.join(output_dir, train_filename_base)
try:
ds = load_dataset("wikimedia/wikipedia", f"{DATE_SNAPSHOT}.{lang}", split="train", trust_remote_code=True)
if len(ds) == 0:
print(f"Warning: Dataset for {lang} is empty. Skipping.")
return
except Exception as e:
print(f"Error loading dataset for {lang}: {e}")
raise
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
use_fast=True,
trust_remote_code=True,
)
except Exception as e:
print(f"Error loading tokenizer '{tokenizer_name}': {e}")
raise
tokenization_func_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer)
tokenized_ds = ds.map(
tokenization_func_with_tokenizer,
batched=True,
num_proc=num_proc_map,
remove_columns=ds.column_names,
desc=f"Tokenizing {lang}"
)
all_document_token_lists = []
for processed_example in tqdm(tokenized_ds, desc=f"Collecting token lists for {lang}"):
token_list_for_one_doc = processed_example['input_ids']
if isinstance(token_list_for_one_doc, list):
all_document_token_lists.append(token_list_for_one_doc)
if not all_document_token_lists:
print(f"Warning: No token sequences found for {lang} after tokenization. Skipping.")
return
final_token_ids = []
collected_tokens_count = 0
for doc_tokens_list in tqdm(all_document_token_lists, desc=f"Aggregating tokens for {lang}"):
if not doc_tokens_list:
continue
current_doc_token_count = len(doc_tokens_list)
if collected_tokens_count + current_doc_token_count <= TARGET_TOKENS_PER_LANGUAGE:
final_token_ids.extend(doc_tokens_list)
collected_tokens_count += current_doc_token_count
else:
remaining_needed = TARGET_TOKENS_PER_LANGUAGE - collected_tokens_count
final_token_ids.extend(doc_tokens_list[:remaining_needed])
collected_tokens_count += remaining_needed
break
if collected_tokens_count >= TARGET_TOKENS_PER_LANGUAGE:
break
del all_document_token_lists
del tokenized_ds
del ds
if collected_tokens_count == 0:
print(f"Warning: Zero tokens collected for {lang}. Skipping save.")
return
if collected_tokens_count < TARGET_TOKENS_PER_LANGUAGE:
print(f"Warning: Language {lang} has only {collected_tokens_count:,} tokens, "
f"which is less than the target of {TARGET_TOKENS_PER_LANGUAGE:,}.")
full_tensor = torch.tensor(final_token_ids, dtype=torch.long)
del final_token_ids
os.makedirs(output_dir, exist_ok=True)
torch.save(full_tensor, train_output_path)
print(f"Saved {full_tensor.numel():,} tokens for {lang}.")
del full_tensor
def run_job(args):
lang, model_id, tokenizer_name, output_dir, num_proc_map = args
print(f"Processing language: {lang} (PID: {os.getpid()})")
try:
build_and_save(
lang=lang,
model_id=model_id,
tokenizer_name=tokenizer_name,
output_dir=output_dir,
num_proc_map=num_proc_map
)
return lang, True, None
except Exception as e:
import traceback
traceback.print_exc()
return lang, False, str(e)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess Wikipedia data for multiple languages.")
parser.add_argument(
"--languages", type=str, default='en,zh,eu,ga',
help="Comma-separated list of languages to process, e.g., 'en,zh,fr'"
)
parser.add_argument("--model-id", type=str, required=True, help="Model identifier (used for file naming).")
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path.")
parser.add_argument("--output-dir", type=str, default="train-data", help="Where to store tokenized tensors.")
parser.add_argument("--max-concurrent", type=int, default=6, help="Max concurrent processes.")
args = parser.parse_args()
args.languages = [lang.strip() for lang in args.languages.split(',') if lang.strip()]
MAX_CONCURRENT_LANGUAGES = args.max_concurrent
NUM_MAP_PROC_PER_LANG = max(1, NUM_PROC_BASE // MAX_CONCURRENT_LANGUAGES if MAX_CONCURRENT_LANGUAGES > 0 else NUM_PROC_BASE)
print(f"Starting batch processing for {len(args.languages)} languages.")
job_args_list = [
(lang, args.model_id, args.tokenizer, args.output_dir, NUM_MAP_PROC_PER_LANG)
for lang in args.languages
]
successful_langs = []
failed_langs_with_errors = {}
with multiprocessing.Pool(processes=MAX_CONCURRENT_LANGUAGES) as pool:
results_iterable = pool.imap_unordered(run_job, job_args_list)
for result in tqdm(results_iterable, total=len(args.languages), desc="Overall Language Progress"):
lang_processed, success, error_msg = result
if success:
successful_langs.append(lang_processed)
else:
failed_langs_with_errors[lang_processed] = error_msg
print("Batch processing finished.")
print(f"Successfully processed: {', '.join(sorted(successful_langs))}")
if failed_langs_with_errors:
print(f"Failed to process: {', '.join(sorted(failed_langs_with_errors.keys()))}")
for lang_failed, err in failed_langs_with_errors.items():
print(f" - {lang_failed}: {err}")