lsn-analysis / load_data_oscar.py
tvkain's picture
Upload folder using huggingface_hub
fed1832 verified
#!/usr/bin/env python3
"""
Preprocess multilingual OSCAR data for model training.
This script performs:
1. Downloads the OSCAR dataset for the specified languages:
https://huggingface.co/datasets/oscar-corpus
2. Tokenizes text using a Hugging Face tokenizer
3. Aggregates up to TARGET_TOKENS_PER_LANGUAGE
4. Saves tokenized data as PyTorch tensors
"""
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
import os
from tqdm import tqdm
import multiprocessing
from functools import partial
import argparse
NUM_PROC_BASE = max(1, os.cpu_count() // 2 if os.cpu_count() else 1)
TARGET_TOKENS_PER_LANGUAGE = 100_000_000
def tokenize_function(examples, tokenizer):
texts = []
for example_texts in examples["text"]:
for text in example_texts:
texts.append(text["text"])
if not texts:
return {"input_ids": []}
output = tokenizer(
texts,
add_special_tokens=False,
truncation=False,
padding=False,
)
return {"input_ids": output.input_ids}
def build_and_save(
lang, model_id, tokenizer_name, output_dir, num_proc_map=NUM_PROC_BASE
):
print(f"Starting data processing for language: {lang}")
train_filename_base = f"oscar.{lang}.train.{model_id.replace('/', '_')}"
train_output_path = os.path.join(output_dir, train_filename_base)
try:
ds = load_dataset(
"oscar-corpus/mOSCAR",
# data_files=f"data/{SNAPSHOT}/{lang}_meta/*.jsonl.zst",
lang,
streaming=False,
)['train']
if len(ds) == 0:
print(f"Warning: Dataset for {lang} is empty. Skipping.")
return
except Exception as e:
print(f"Error loading dataset for {lang}: {e}")
raise
print(f"Dataset length is {len(ds)}")
limit = 2_000_000
if len(ds) > limit:
ds = ds.select(range(limit))
print(f"Dataset length is {len(ds)}")
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
use_fast=True,
trust_remote_code=True,
)
except Exception as e:
print(f"Error loading tokenizer '{tokenizer_name}': {e}")
raise
tokenization_func_with_tokenizer = partial(tokenize_function, tokenizer=tokenizer)
tokenized_ds = ds.map(
tokenization_func_with_tokenizer,
batched=True,
num_proc=num_proc_map,
remove_columns=ds.column_names,
desc=f"Tokenizing {lang}",
)
all_document_token_lists = []
for processed_example in tqdm(
tokenized_ds, desc=f"Collecting token lists for {lang}"
):
token_list_for_one_doc = processed_example["input_ids"]
if isinstance(token_list_for_one_doc, list):
all_document_token_lists.append(token_list_for_one_doc)
if not all_document_token_lists:
print(f"Warning: No token sequences found for {lang}. Skipping.")
return
final_token_ids = []
collected_tokens_count = 0
for doc_tokens_list in tqdm(
all_document_token_lists, desc=f"Aggregating tokens for {lang}"
):
if not doc_tokens_list:
continue
current_doc_token_count = len(doc_tokens_list)
if (
collected_tokens_count + current_doc_token_count
<= TARGET_TOKENS_PER_LANGUAGE
):
final_token_ids.extend(doc_tokens_list)
collected_tokens_count += current_doc_token_count
else:
remaining_needed = TARGET_TOKENS_PER_LANGUAGE - collected_tokens_count
final_token_ids.extend(doc_tokens_list[:remaining_needed])
collected_tokens_count += remaining_needed
break
if collected_tokens_count >= TARGET_TOKENS_PER_LANGUAGE:
break
del all_document_token_lists
del tokenized_ds
del ds
if collected_tokens_count == 0:
print(f"Warning: Zero tokens collected for {lang}. Skipping save.")
return
if collected_tokens_count < TARGET_TOKENS_PER_LANGUAGE:
print(
f"Warning: Language {lang} has only {collected_tokens_count:,} tokens, "
f"which is less than the target of {TARGET_TOKENS_PER_LANGUAGE:,}."
)
full_tensor = torch.tensor(final_token_ids, dtype=torch.long)
del final_token_ids
os.makedirs(output_dir, exist_ok=True)
torch.save(full_tensor, train_output_path)
print(f"Saved {full_tensor.numel():,} tokens for {lang}.")
del full_tensor
def run_job(args):
lang, model_id, tokenizer_name, output_dir, num_proc_map = args
print(f"Processing language: {lang} (PID: {os.getpid()})")
try:
build_and_save(lang, model_id, tokenizer_name, output_dir, num_proc_map)
return lang, True, None
except Exception as e:
import traceback
traceback.print_exc()
return lang, False, str(e)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Preprocess OSCAR data for multiple languages."
)
parser.add_argument(
"--languages",
type=str,
default="en,zh,fr",
help="Comma-separated list of languages to process, e.g., 'en,zh,fr'",
)
parser.add_argument(
"--model-id",
type=str,
required=True,
help="Model identifier (used for file naming).",
)
parser.add_argument(
"--tokenizer", type=str, required=True, help="Tokenizer name or path."
)
parser.add_argument(
"--output-dir",
type=str,
default="train-data",
help="Where to store tokenized tensors.",
)
parser.add_argument(
"--max-concurrent", type=int, default=6, help="Max concurrent processes."
)
args = parser.parse_args()
args.languages = [
lang.strip() for lang in args.languages.split(",") if lang.strip()
]
MAX_CONCURRENT_LANGUAGES = args.max_concurrent
NUM_MAP_PROC_PER_LANG = max(
1,
(
NUM_PROC_BASE // MAX_CONCURRENT_LANGUAGES
if MAX_CONCURRENT_LANGUAGES > 0
else NUM_PROC_BASE
),
)
print(f"Starting batch processing for {len(args.languages)} languages.")
job_args_list = [
(lang, args.model_id, args.tokenizer, args.output_dir, NUM_MAP_PROC_PER_LANG)
for lang in args.languages
]
successful_langs = []
failed_langs_with_errors = {}
with multiprocessing.Pool(processes=MAX_CONCURRENT_LANGUAGES) as pool:
results_iterable = pool.imap_unordered(run_job, job_args_list)
for result in tqdm(
results_iterable,
total=len(args.languages),
desc="Overall Language Progress",
):
lang_processed, success, error_msg = result
if success:
successful_langs.append(lang_processed)
else:
failed_langs_with_errors[lang_processed] = error_msg
print("Batch processing finished.")
print(f"Successfully processed: {', '.join(sorted(successful_langs))}")
if failed_langs_with_errors:
print(
f"Failed to process: {', '.join(sorted(failed_langs_with_errors.keys()))}"
)
for lang_failed, err in failed_langs_with_errors.items():
print(f" - {lang_failed}: {err}")