|
|
|
|
|
""" |
|
|
Preprocess multilingual Wikipedia data for model training. |
|
|
|
|
|
This script performs the following steps: |
|
|
1. Downloads the Wikimedia Wikipedia dataset for the specified languages |
|
|
(https://huggingface.co/datasets/wikimedia/wikipedia). |
|
|
2. Tokenizes the dataset using a Hugging Face tokenizer corresponding |
|
|
to the specified model. |
|
|
3. Aggregates token IDs up to a target number of tokens per language. |
|
|
4. Saves the tokenized data as a PyTorch tensor for later use. |
|
|
|
|
|
Usage example: |
|
|
python load_datas.py \ |
|
|
--languages en zh vi \ |
|
|
--model-id meta-llama/Llama-3.1-8B-Instruct \ |
|
|
--tokenizer meta-llama/Llama-3.1-8B-Instruct \ |
|
|
--output-dir train-data |
|
|
""" |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
import multiprocessing |
|
|
from functools import partial |
|
|
import argparse |
|
|
|
|
|
NUM_PROC_BASE = max(1, os.cpu_count() // 2 if os.cpu_count() else 1) |
|
|
TARGET_TOKENS_PER_LANGUAGE = 100_000_000 |
|
|
DATE_SNAPSHOT = "20231101" |
|
|
|
|
|
def tokenize_function(examples, tokenizer): |
|
|
output = tokenizer( |
|
|
examples["text"], |
|
|
add_special_tokens=False, |
|
|
truncation=False, |
|
|
padding=False, |
|
|
) |
|
|
return {"input_ids": output.input_ids} |
|
|
|
|
|
def build_and_save( |
|
|
lang, |
|
|
model_id, |
|
|
tokenizer_name, |
|
|
output_dir, |
|
|
num_proc_map=NUM_PROC_BASE |
|
|
): |
|
|
print(f"Starting data processing for language: {lang}") |
|
|
|
|
|
train_filename_base = f"id.{lang}.train.{model_id.replace('/', '_')}" |
|
|
train_output_path = os.path.join(output_dir, train_filename_base) |
|
|
|
|
|
try: |
|
|
ds = load_dataset("wikimedia/wikipedia", f"{DATE_SNAPSHOT}.{lang}", split="train", trust_remote_code=True) |
|
|
if len(ds) == 0: |
|
|
print(f"Warning: Dataset for {lang} is empty. Skipping.") |
|
|
return |
|
|
except Exception as e: |
|
|
print(f"Error loading dataset for {lang}: {e}") |
|
|
raise |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer_name, |
|
|
use_fast=True, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error loading tokenizer '{tokenizer_name}': {e}") |
|
|
raise |
|
|
|
|
|
tokenization_func_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer) |
|
|
|
|
|
tokenized_ds = ds.map( |
|
|
tokenization_func_with_tokenizer, |
|
|
batched=True, |
|
|
num_proc=num_proc_map, |
|
|
remove_columns=ds.column_names, |
|
|
desc=f"Tokenizing {lang}" |
|
|
) |
|
|
|
|
|
all_document_token_lists = [] |
|
|
for processed_example in tqdm(tokenized_ds, desc=f"Collecting token lists for {lang}"): |
|
|
token_list_for_one_doc = processed_example['input_ids'] |
|
|
if isinstance(token_list_for_one_doc, list): |
|
|
all_document_token_lists.append(token_list_for_one_doc) |
|
|
|
|
|
if not all_document_token_lists: |
|
|
print(f"Warning: No token sequences found for {lang} after tokenization. Skipping.") |
|
|
return |
|
|
|
|
|
final_token_ids = [] |
|
|
collected_tokens_count = 0 |
|
|
for doc_tokens_list in tqdm(all_document_token_lists, desc=f"Aggregating tokens for {lang}"): |
|
|
if not doc_tokens_list: |
|
|
continue |
|
|
|
|
|
current_doc_token_count = len(doc_tokens_list) |
|
|
|
|
|
if collected_tokens_count + current_doc_token_count <= TARGET_TOKENS_PER_LANGUAGE: |
|
|
final_token_ids.extend(doc_tokens_list) |
|
|
collected_tokens_count += current_doc_token_count |
|
|
else: |
|
|
remaining_needed = TARGET_TOKENS_PER_LANGUAGE - collected_tokens_count |
|
|
final_token_ids.extend(doc_tokens_list[:remaining_needed]) |
|
|
collected_tokens_count += remaining_needed |
|
|
break |
|
|
|
|
|
if collected_tokens_count >= TARGET_TOKENS_PER_LANGUAGE: |
|
|
break |
|
|
|
|
|
del all_document_token_lists |
|
|
del tokenized_ds |
|
|
del ds |
|
|
|
|
|
if collected_tokens_count == 0: |
|
|
print(f"Warning: Zero tokens collected for {lang}. Skipping save.") |
|
|
return |
|
|
|
|
|
if collected_tokens_count < TARGET_TOKENS_PER_LANGUAGE: |
|
|
print(f"Warning: Language {lang} has only {collected_tokens_count:,} tokens, " |
|
|
f"which is less than the target of {TARGET_TOKENS_PER_LANGUAGE:,}.") |
|
|
|
|
|
full_tensor = torch.tensor(final_token_ids, dtype=torch.long) |
|
|
del final_token_ids |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
torch.save(full_tensor, train_output_path) |
|
|
print(f"Saved {full_tensor.numel():,} tokens for {lang}.") |
|
|
del full_tensor |
|
|
|
|
|
def run_job(args): |
|
|
lang, model_id, tokenizer_name, output_dir, num_proc_map = args |
|
|
print(f"Processing language: {lang} (PID: {os.getpid()})") |
|
|
try: |
|
|
build_and_save( |
|
|
lang=lang, |
|
|
model_id=model_id, |
|
|
tokenizer_name=tokenizer_name, |
|
|
output_dir=output_dir, |
|
|
num_proc_map=num_proc_map |
|
|
) |
|
|
return lang, True, None |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return lang, False, str(e) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Preprocess Wikipedia data for multiple languages.") |
|
|
parser.add_argument( |
|
|
"--languages", type=str, default='en,zh,eu,ga', |
|
|
help="Comma-separated list of languages to process, e.g., 'en,zh,fr'" |
|
|
) |
|
|
parser.add_argument("--model-id", type=str, required=True, help="Model identifier (used for file naming).") |
|
|
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path.") |
|
|
parser.add_argument("--output-dir", type=str, default="train-data", help="Where to store tokenized tensors.") |
|
|
parser.add_argument("--max-concurrent", type=int, default=6, help="Max concurrent processes.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
args.languages = [lang.strip() for lang in args.languages.split(',') if lang.strip()] |
|
|
|
|
|
MAX_CONCURRENT_LANGUAGES = args.max_concurrent |
|
|
NUM_MAP_PROC_PER_LANG = max(1, NUM_PROC_BASE // MAX_CONCURRENT_LANGUAGES if MAX_CONCURRENT_LANGUAGES > 0 else NUM_PROC_BASE) |
|
|
|
|
|
print(f"Starting batch processing for {len(args.languages)} languages.") |
|
|
|
|
|
job_args_list = [ |
|
|
(lang, args.model_id, args.tokenizer, args.output_dir, NUM_MAP_PROC_PER_LANG) |
|
|
for lang in args.languages |
|
|
] |
|
|
|
|
|
successful_langs = [] |
|
|
failed_langs_with_errors = {} |
|
|
|
|
|
with multiprocessing.Pool(processes=MAX_CONCURRENT_LANGUAGES) as pool: |
|
|
results_iterable = pool.imap_unordered(run_job, job_args_list) |
|
|
for result in tqdm(results_iterable, total=len(args.languages), desc="Overall Language Progress"): |
|
|
lang_processed, success, error_msg = result |
|
|
if success: |
|
|
successful_langs.append(lang_processed) |
|
|
else: |
|
|
failed_langs_with_errors[lang_processed] = error_msg |
|
|
|
|
|
print("Batch processing finished.") |
|
|
print(f"Successfully processed: {', '.join(sorted(successful_langs))}") |
|
|
if failed_langs_with_errors: |
|
|
print(f"Failed to process: {', '.join(sorted(failed_langs_with_errors.keys()))}") |
|
|
for lang_failed, err in failed_langs_with_errors.items(): |
|
|
print(f" - {lang_failed}: {err}") |
|
|
|