craffel's picture
download
raw
13.6 kB
"""
Script to count tokens in JSONL files in parallel by chunk and segment. Each process will handle a specific segment of a chunk file, read only its assigned lines, tokenize the text, and count tokens and bytes. Results are aggregated by chunk and optionally by source.
The script uses a multiprocessing Pool to efficiently utilize CPU cores and can handle large files by processing them in segments without loading entire files into memory.
"""
# PYTHONPATH=/home/gsa/lingua:$PYTHONPATH python scripts/count_tokens.py --tokenizer_name huggingface --tokenizer_path gpt2 --add_bos --add_eos --chunks 1 --workers=64 --segments_per_chunk=16
# PYTHONPATH=/home/gsa/lingua:$PYTHONPATH python scripts/count_tokens.py --tokenizer_name huggingface --tokenizer_path meta-llama/Llama-3.2-1B --add_bos --add_eos --chunks 1 --workers=64 --segments_per_chunk=16
# PYTHONPATH=/home/gsa/lingua:$PYTHONPATH python scripts/count_tokens.py --tokenizer_name huggingface --tokenizer_path google/gemma-2-2b --add_bos --add_eos --chunks 1 --workers=64 --segments_per_chunk=16
# PYTHONPATH=$PROJECT/lingua:$PYTHONPATH python scripts/count_tokens.py config=count_tokens_hq_.yaml glob_pattern="arb_Arab*.jsonl"
## on vulcan
# PYTHONPATH=/home/gsa/lingua:$PYTHONPATH python scripts/count_tokens.py --tokenizer_name huggingface --tokenizer_path meta-llama/Llama-3.2-1B --add_bos --add_eos --chunks 1 --workers=64 --segments_per_chunk=16 --input_path=/scratch/gsa/controlled_data
import sys
from pathlib import Path
sys.path.insert(0, Path(__file__).parents[1].as_posix())
print(sys.path)
import argparse
import functools
import json
import logging
import os
import time
from dataclasses import dataclass
from multiprocessing import Pool, cpu_count
from typing import Any, List, Optional
import yaml
from omegaconf import OmegaConf
from lingua.tokenizer import TokenizerArgs, build_tokenizer
# Configure logging to handle multiprocessing output cleanly
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(processName)s - %(message)s"
)
logger = logging.getLogger(__name__)
def get_file_line_offsets(file_path, num_segments):
"""
Split a file into roughly equal segments by counting lines.
Returns list of (start_line, end_line) tuples for each segment.
"""
# First pass: count total lines
with open(file_path, "r") as f:
total_lines = sum(1 for _ in f)
if total_lines == 0:
return []
lines_per_segment = max(1, total_lines // num_segments)
segments = []
for i in range(num_segments):
start_line = i * lines_per_segment
# Last segment gets all remaining lines
end_line = total_lines if i == num_segments - 1 else (i + 1) * lines_per_segment
if start_line < total_lines:
segments.append((start_line, end_line))
return segments
def process_file_segment(
chunk_id, segment_id, start_line, end_line, file_path, args=None, tokenizer_args=None
):
"""
Worker function to process a segment of a single chunk file.
Each worker reads only its assigned line range.
"""
start_time = time.time()
# Re-build tokenizer inside the process
tokenizer = build_tokenizer(
tokenizer_args.name, tokenizer_args.path, tokenizer_args.tokenizers
)
logger.info(f"Tokenizer built, processing {file_path} segment {segment_id}")
local_count = 0
total_bytes = 0
lines_processed = 0
token_count_per_source = {}
try:
with open(file_path, "r") as f_in:
for i, line in enumerate(f_in):
# Skip lines outside our segment
if i < start_line:
continue
if end_line != float("inf") and i >= end_line:
break
if not line.strip():
continue
# Parse JSONL
data = json.loads(line)
text = data.get("text", "").encode("utf-8").decode("utf-8")
# Tokenize
tokens = tokenizer.encode(
text, add_bos=args.add_bos, add_eos=args.add_eos
)
local_count += len(tokens)
total_bytes += len(text.encode("utf-8"))
lines_processed += 1
if "source" in data:
source = data["source"]
if source not in token_count_per_source:
token_count_per_source[source] = 0
token_count_per_source[source] += len(tokens)
# Progress logging
if lines_processed % 50000 == 0:
elapsed = time.time() - start_time
rate = lines_processed / elapsed if elapsed > 0 else 0
logger.info(
f"Chunk {chunk_id} Segment {segment_id}: {lines_processed} lines, {local_count} tokens {total_bytes / 1e9:.2f} GB ({rate:.0f} lines/sec)"
)
elapsed = time.time() - start_time
logger.info(
f"Chunk {chunk_id} Segment {segment_id} finished: {lines_processed} lines, {local_count} tokens {total_bytes / 1e9:.2f} GB in {elapsed:.1f}s"
)
return (
chunk_id,
segment_id,
local_count,
total_bytes,
token_count_per_source,
file_path,
)
except Exception as e:
logger.error(f"Error processing chunk {chunk_id} segment {segment_id}: {e}")
return (chunk_id, segment_id, 0, 0, {}, file_path)
import re
def extract_chunk_sequence_number(filename):
"""
Extracts the chunk or sequence ID from three different filename formats.
Returns the extracted number as an integer, or None if no match is found.
"""
# 1. Pattern for ".chunk.XXX.jsonl" (Handles the first two formats)
# Searches for ".chunk." followed by one or more digits (\d+),
# and captures the digits.
match_chunk = re.search(r"\.chunk\.(\d+)\.jsonl$", filename)
if match_chunk:
# Returns the captured digits (Group 1)
return int(match_chunk.group(1))
# 2. Pattern for "train_data_{chunk}-{seq}-8.jsonl"
# Searches for "train_data_" followed by any characters, then a hyphen,
# then "0" followed by the sequence number (\d), and captures the sequence number.
# Note: This is specifically targeting the 'seq' number after the last hyphen before '-8.jsonl'.
match_seq = re.search(r"train_data_.*-0(\d+)-8\.jsonl$", filename)
if match_seq:
# Returns the captured digits (Group 1)
return int(match_seq.group(1))
# If neither pattern is found
return None
def create_work_items(chunks, input_path, glob_pattern, segments_per_chunk):
"""
Create a list of work items (chunk_id, segment_id, start_line, end_line, file_path).
Each chunk is split into multiple segments.
"""
work_items = []
files_to_process = [f for f in input_path.glob(glob_pattern) if "val" not in f.name]
for file_path in files_to_process:
chunk_id = extract_chunk_sequence_number(file_path.name)
if chunk_id is None:
logger.warning(f"Could not extract chunk ID from filename: {file_path.name}, will use None as key")
if segments_per_chunk == 1:
# Process entire file as single segment - skip line counting
logger.info(
f"Processing {file_path.name} as single segment (skipping line count)"
)
work_items.append((chunk_id, 0, 0, float("inf"), file_path))
else:
# Get line segments for this file
logger.info(f"Analyzing {file_path.name} for segmentation...")
segments = get_file_line_offsets(file_path, segments_per_chunk)
logger.info(f"{file_path.name}: Split into {len(segments)} segments")
for seg_id, (start_line, end_line) in enumerate(segments):
work_items.append((chunk_id, seg_id, start_line, end_line, file_path))
return work_items
def auto_detect_chunks(input_path):
"""Automatically detect available chunk files."""
chunks = []
for f in sorted(input_path.glob("train_data.*.jsonl")):
try:
chunk_num = int(f.stem.split(".")[-1])
chunks.append(chunk_num)
except ValueError:
continue
return chunks
@dataclass
class Args:
tokenizer_name: Optional[str] = None # "huggingface"
tokenizer_path: Optional[str] = None # "meta-llama/Llama-3.2-1B"
tokenizer: Optional[TokenizerArgs] = None
input_path: str = "/scratch/gsa/data/"
add_bos: bool = False
add_eos: bool = False
workers: Optional[int] = None
segments_per_chunk: int = 1
glob_pattern: str = "*.jsonl"
chunks: Optional[List[int]] = None
code: Optional[str] = None
output_file: Optional[str] = None
config: Any = None
def parse_and_merge_config():
cli_args = OmegaConf.from_cli()
# file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
# del cli_args.config
default_cfg = OmegaConf.structured(Args())
cfg = OmegaConf.merge(default_cfg, cli_args)
# cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
return cfg
def main():
args = parse_and_merge_config()
input_path = Path(args.input_path)
# Determine optimal number of workers
if args.workers is None:
args.workers = min(cpu_count(), 64)
logger.info(f"Auto-selected {args.workers} workers (CPUs: {cpu_count()})")
else:
args.workers = min(args.workers, 64)
logger.info(f"Configuration:")
logger.info(f" Workers: {args.workers}")
# logger.info(f" Chunks to process: {len(args.chunks)}")
logger.info(f" Segments per chunk: {args.segments_per_chunk}")
# logger.info(f" Total work items: ~{len(args.chunks) * args.segments_per_chunk}")
if args.tokenizer_name:
tokenizer_args = TokenizerArgs(
name=args.tokenizer_name,
path=args.tokenizer_path,
tokenizers=None,
)
else:
tokenizer_args = args.tokenizer
tokenizer = build_tokenizer(
tokenizer_args.name, tokenizer_args.path, tokenizer_args.tokenizers
)
# exit(0)
# Create all work items (chunk_id, segment_id, start_line, end_line)
logger.info("Creating work items...")
work_items = create_work_items(
args.chunks, input_path, args.glob_pattern, args.segments_per_chunk
)
if not work_items:
logger.error("No valid work items to process!")
exit(1)
logger.info(f"Created {len(work_items)} work items")
# Create partial function with constant arguments
worker_func = functools.partial(
process_file_segment,
tokenizer_args=tokenizer_args,
args=args,
)
# Run in parallel
logger.info(f"Starting processing with {args.workers} workers...")
logger.info(
f"Tokenizer: {args.tokenizer_name}, add_bos={args.add_bos}, add_eos={args.add_eos}"
)
start_time = time.time()
with Pool(processes=args.workers) as pool:
# Unpack work items into function arguments
results = pool.starmap(worker_func, work_items)
elapsed_time = time.time() - start_time
# Aggregate results by chunk
chunk_totals = {}
chunk_bytes = {}
token_count_per_source_aggregate = {}
for (
chunk_id,
segment_id,
token_count,
total_bytes,
token_count_per_source,
file_path,
) in results:
chunk_totals[chunk_id] = chunk_totals.get(chunk_id, 0) + token_count
chunk_bytes[chunk_id] = chunk_bytes.get(chunk_id, 0) + total_bytes
for source, count in token_count_per_source.items():
token_count_per_source_aggregate[source] = (
token_count_per_source_aggregate.get(source, 0) + count
)
total_tokens = sum(chunk_totals.values())
total_bytes = sum(chunk_bytes.values())
# Summary statistics
logger.info("=" * 80)
logger.info("PROCESSING COMPLETE")
logger.info("=" * 80)
logger.info(f"Total time: {elapsed_time:.1f}s")
logger.info(f"Workers used: {args.workers}")
logger.info(f"Segments per chunk: {args.segments_per_chunk}")
logger.info(f"Total work items: {len(work_items)}")
logger.info(f"Total tokens: {total_tokens:,}")
logger.info(f"Total data processed: {total_bytes / 1e9:.2f} GB")
logger.info(f"Total tokens per source: {token_count_per_source_aggregate}")
logger.info(f"Throughput: {total_tokens / elapsed_time:,.0f} tokens/sec")
logger.info("\nToken counts per chunk:")
for chunk_id in sorted(chunk_totals.keys(), key=lambda x: (x is None, x or 0)):
logger.info(f" Chunk {str(chunk_id):>4}: {chunk_totals[chunk_id]:,} tokens")
logger.info(f"Total tokens: {total_tokens:,}")
logger.info("=" * 80)
# with open("/scratch/gsa/token_counts_per_source.yaml", "a") as f_out:
if args.output_file:
output_path = Path(args.output_file)
with open(output_path, "a") as f_out:
key = args.code if args.code else f"{args.tokenizer_path}_{args.input_path.name}"
dct = {key: {"total_tokens": total_tokens, "total_bytes": total_bytes, "chunk_counts": chunk_totals, "token_count_per_source": token_count_per_source_aggregate,}}
yaml.dump(dct, f_out)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
13.6 kB
·
Xet hash:
cfbda28d04cf1b1e36c3e289f53fa55dcc8f76bbfcc0e4540727ca8c66b0c8a6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.