GTP-on-Reddit / src /data-extraction.py
saumilyajj's picture
Upload folder using huggingface_hub
47df44d verified
"""
OpenWebText Data Extraction Pipeline
====================================
This module processes compressed OpenWebText dataset files (.xz) in parallel,
extracting text content and building character vocabularies for GPT training.
Features:
- Parallel processing using ProcessPoolExecutor
- 90/10 train/validation split
- Character-level vocabulary extraction
- Windows multiprocessing support
Author: Your Name
Date: September 2025
"""
import os
import lzma
from tqdm import tqdm
from multiprocessing import Pool, cpu_count, freeze_support
import concurrent.futures
def process_file(args):
"""
Process a single .xz compressed file and extract text content.
Args:
args (tuple): Contains (directory, filename, output_file, vocab)
- directory (str): Path to the directory containing the file
- filename (str): Name of the .xz file to process
- output_file (str): Path to output file for appending text
- vocab (set): Character vocabulary set (for consistency)
Returns:
set: Set of unique characters found in the processed file
"""
directory, filename, output_file, vocab = args
file_path = os.path.join(directory, filename)
with lzma.open(file_path, "rt", encoding="utf-8") as infile:
text = infile.read()
with open(output_file, "a", encoding="utf-8") as outfile:
outfile.write(text)
characters = set(text)
return characters
def xz_files_in_dir(directory):
"""
Get all .xz files in the specified directory.
Args:
directory (str): Path to directory to scan
Returns:
list: List of .xz filenames in the directory
"""
return [filename for filename in os.listdir(directory) if filename.endswith(".xz") and os.path.isfile(os.path.join(directory, filename))]
def process_files_in_parallel(files, folder_path, output_file):
"""
Process multiple .xz files in parallel using ProcessPoolExecutor.
Args:
files (list): List of filenames to process
folder_path (str): Directory containing the files
output_file (str): Output file path for extracted text
Returns:
set: Combined character vocabulary from all processed files
"""
vocab = set()
with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_count()) as executor:
args = [(folder_path, filename, output_file, vocab) for filename in files]
for characters in tqdm(executor.map(process_file, args), total=len(files)):
vocab.update(characters)
return vocab
def main():
"""
Main execution function for OpenWebText data extraction.
Process flow:
1. Scan for .xz files in 'openwebtext' directory
2. Split files into 90% training, 10% validation
3. Process files in parallel
4. Extract and combine character vocabularies
5. Save vocabulary to vocab.txt
Output files:
- output_train.txt: Training text data
- output_val.txt: Validation text data
- vocab.txt: Character vocabulary (one char per line)
"""
folder_path = "openwebtext"
output_file_train = "output_train.txt"
output_file_val = "output_val.txt"
vocab_file = "vocab.txt"
files = xz_files_in_dir(folder_path)
total_files = len(files)
split_index = int(total_files * 0.9) # 90% for training
files_train = files[:split_index]
files_val = files[split_index:]
# Ensure output files are empty before appending
open(output_file_train, 'w').close()
open(output_file_val, 'w').close()
# Process the training files
vocab_train = process_files_in_parallel(files_train, folder_path, output_file_train)
# Process the validation files
vocab_val = process_files_in_parallel(files_val, folder_path, output_file_val)
# Combine vocabularies (if needed) and write to vocab.txt
vocab = vocab_train.union(vocab_val)
with open(vocab_file, "w", encoding="utf-8") as vfile:
for char in sorted(vocab):
vfile.write(char + '\n')
if __name__ == '__main__':
freeze_support() # For Windows support
main()