|
|
"""
|
|
|
OpenWebText Data Extraction Pipeline
|
|
|
====================================
|
|
|
|
|
|
This module processes compressed OpenWebText dataset files (.xz) in parallel,
|
|
|
extracting text content and building character vocabularies for GPT training.
|
|
|
|
|
|
Features:
|
|
|
- Parallel processing using ProcessPoolExecutor
|
|
|
- 90/10 train/validation split
|
|
|
- Character-level vocabulary extraction
|
|
|
- Windows multiprocessing support
|
|
|
|
|
|
Author: Your Name
|
|
|
Date: September 2025
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import lzma
|
|
|
from tqdm import tqdm
|
|
|
from multiprocessing import Pool, cpu_count, freeze_support
|
|
|
import concurrent.futures
|
|
|
|
|
|
def process_file(args):
|
|
|
"""
|
|
|
Process a single .xz compressed file and extract text content.
|
|
|
|
|
|
Args:
|
|
|
args (tuple): Contains (directory, filename, output_file, vocab)
|
|
|
- directory (str): Path to the directory containing the file
|
|
|
- filename (str): Name of the .xz file to process
|
|
|
- output_file (str): Path to output file for appending text
|
|
|
- vocab (set): Character vocabulary set (for consistency)
|
|
|
|
|
|
Returns:
|
|
|
set: Set of unique characters found in the processed file
|
|
|
"""
|
|
|
directory, filename, output_file, vocab = args
|
|
|
file_path = os.path.join(directory, filename)
|
|
|
with lzma.open(file_path, "rt", encoding="utf-8") as infile:
|
|
|
text = infile.read()
|
|
|
with open(output_file, "a", encoding="utf-8") as outfile:
|
|
|
outfile.write(text)
|
|
|
characters = set(text)
|
|
|
return characters
|
|
|
|
|
|
def xz_files_in_dir(directory):
|
|
|
"""
|
|
|
Get all .xz files in the specified directory.
|
|
|
|
|
|
Args:
|
|
|
directory (str): Path to directory to scan
|
|
|
|
|
|
Returns:
|
|
|
list: List of .xz filenames in the directory
|
|
|
"""
|
|
|
return [filename for filename in os.listdir(directory) if filename.endswith(".xz") and os.path.isfile(os.path.join(directory, filename))]
|
|
|
|
|
|
def process_files_in_parallel(files, folder_path, output_file):
|
|
|
"""
|
|
|
Process multiple .xz files in parallel using ProcessPoolExecutor.
|
|
|
|
|
|
Args:
|
|
|
files (list): List of filenames to process
|
|
|
folder_path (str): Directory containing the files
|
|
|
output_file (str): Output file path for extracted text
|
|
|
|
|
|
Returns:
|
|
|
set: Combined character vocabulary from all processed files
|
|
|
"""
|
|
|
vocab = set()
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_count()) as executor:
|
|
|
args = [(folder_path, filename, output_file, vocab) for filename in files]
|
|
|
for characters in tqdm(executor.map(process_file, args), total=len(files)):
|
|
|
vocab.update(characters)
|
|
|
return vocab
|
|
|
|
|
|
def main():
|
|
|
"""
|
|
|
Main execution function for OpenWebText data extraction.
|
|
|
|
|
|
Process flow:
|
|
|
1. Scan for .xz files in 'openwebtext' directory
|
|
|
2. Split files into 90% training, 10% validation
|
|
|
3. Process files in parallel
|
|
|
4. Extract and combine character vocabularies
|
|
|
5. Save vocabulary to vocab.txt
|
|
|
|
|
|
Output files:
|
|
|
- output_train.txt: Training text data
|
|
|
- output_val.txt: Validation text data
|
|
|
- vocab.txt: Character vocabulary (one char per line)
|
|
|
"""
|
|
|
folder_path = "openwebtext"
|
|
|
output_file_train = "output_train.txt"
|
|
|
output_file_val = "output_val.txt"
|
|
|
vocab_file = "vocab.txt"
|
|
|
|
|
|
files = xz_files_in_dir(folder_path)
|
|
|
total_files = len(files)
|
|
|
split_index = int(total_files * 0.9)
|
|
|
files_train = files[:split_index]
|
|
|
files_val = files[split_index:]
|
|
|
|
|
|
|
|
|
open(output_file_train, 'w').close()
|
|
|
open(output_file_val, 'w').close()
|
|
|
|
|
|
|
|
|
vocab_train = process_files_in_parallel(files_train, folder_path, output_file_train)
|
|
|
|
|
|
|
|
|
vocab_val = process_files_in_parallel(files_val, folder_path, output_file_val)
|
|
|
|
|
|
|
|
|
vocab = vocab_train.union(vocab_val)
|
|
|
with open(vocab_file, "w", encoding="utf-8") as vfile:
|
|
|
for char in sorted(vocab):
|
|
|
vfile.write(char + '\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
freeze_support()
|
|
|
main() |