gpt2_medium_prefix_682k / scripts /data /prepare_training_data_fixed.py
augustocsc's picture
GPT-2 Medium trained on prefix dataset (682K)
3742716 verified
"""
Data preparation script that adds proper <|endofex|> markers to training data.
This script processes the existing dataset and wraps expressions with end-of-expression
markers so the model learns to stop generation correctly.
Usage:
python scripts/data/prepare_training_data_fixed.py \
--dataset_repo_id augustocsc/sintetico_natural \
--data_dir 700K \
--data_column i_prompt_n \
--output_dir ./data/processed/700K_fixed \
--validate
"""
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Dict, Tuple
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def add_end_markers(example: Dict) -> Dict:
"""
Add end-of-expression markers to training data.
This function:
1. Locates the expression in the text (after 'expr:')
2. Finds the natural end boundary (before 'vars:', newlines, etc.)
3. Inserts <|endofex|> marker at the end
4. Preserves any remaining content after the marker
Args:
example: Dictionary containing 'text' field with training data
Returns:
Dictionary with modified 'text' field containing end markers
"""
text = example['text']
# Check if expression part exists
if 'expr:' not in text:
logger.warning(f"No 'expr:' found in text: {text[:100]}...")
return {'text': text}
# Split at expr: and add marker after expression
parts = text.split('expr:', 1)
if len(parts) != 2:
logger.warning(f"Unexpected format in text: {text[:100]}...")
return {'text': text}
prefix = parts[0]
expression_part = parts[1]
# Check if marker already exists
if '<|endofex|>' in expression_part:
logger.debug("Marker already present, skipping")
return {'text': text}
# Find natural end of expression (before vars:, newline, etc)
end_idx = len(expression_part)
boundaries = ['\nvars:', '\nVariables:', '\n\n', '\nvar:', '\nVariable:']
for boundary in boundaries:
idx = expression_part.find(boundary)
if idx != -1 and idx < end_idx:
end_idx = idx
# Insert marker
clean_expr = expression_part[:end_idx].strip()
remaining = expression_part[end_idx:]
# Reconstruct text with marker
new_text = f"{prefix}expr: {clean_expr}<|endofex|>{remaining}"
return {'text': new_text}
def validate_markers(example: Dict) -> Dict:
"""
Validate that markers are properly present in the text.
Args:
example: Dictionary containing 'text' field
Returns:
Dictionary with validation metadata
"""
text = example['text']
start_count = text.count('<|startofex|>')
end_count = text.count('<|endofex|>')
# Valid if we have at least one end marker
# (start marker is optional depending on format)
valid = end_count > 0
return {
'valid': valid,
'start_count': start_count,
'end_count': end_count,
'text': text
}
def process_dataset(
dataset_repo_id: str,
data_dir: str,
data_column: str,
output_dir: Path,
validate: bool = True
) -> Tuple[DatasetDict, Dict]:
"""
Process the dataset by adding end markers to all splits.
Args:
dataset_repo_id: HuggingFace dataset repository ID
data_dir: Subdirectory within the dataset (e.g., '700K')
data_column: Column to use for training data
output_dir: Directory to save processed dataset
validate: Whether to run validation after processing
Returns:
Tuple of (processed_dataset, statistics)
"""
logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...")
try:
# Load dataset from HuggingFace Hub
dataset = load_dataset(
dataset_repo_id,
data_dir=data_dir,
split=None # Load all splits
)
if not isinstance(dataset, dict):
# If single split, convert to dict
dataset = {'train': dataset}
logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}")
# Show sample before processing
if 'train' in dataset and len(dataset['train']) > 0:
logger.info(f"\nSample BEFORE processing:")
logger.info(f"{dataset['train'][0][data_column][:200]}...")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
# Process each split
processed_dataset = {}
statistics = {
'total_examples': 0,
'processed_examples': 0,
'already_marked': 0,
'splits': {}
}
for split_name, split_data in dataset.items():
logger.info(f"\nProcessing {split_name} split ({len(split_data)} examples)...")
# Rename column to 'text' if needed
if data_column != 'text':
split_data = split_data.rename_column(data_column, 'text')
# Count examples that already have markers
already_marked = sum(1 for ex in split_data if '<|endofex|>' in ex['text'])
statistics['already_marked'] += already_marked
if already_marked > 0:
logger.info(f"Found {already_marked} examples already with markers")
# Apply marker addition
processed_split = split_data.map(
add_end_markers,
desc=f"Adding markers to {split_name}"
)
processed_dataset[split_name] = processed_split
# Update statistics
split_stats = {
'total': len(split_data),
'processed': len(processed_split),
'already_marked': already_marked
}
statistics['splits'][split_name] = split_stats
statistics['total_examples'] += len(split_data)
statistics['processed_examples'] += len(processed_split)
# Show sample after processing
if len(processed_split) > 0:
logger.info(f"\nSample AFTER processing:")
logger.info(f"{processed_split[0]['text'][:200]}...")
# Validate if requested
if validate:
logger.info("\n" + "="*60)
logger.info("VALIDATION")
logger.info("="*60)
for split_name, split_data in processed_dataset.items():
logger.info(f"\nValidating {split_name} split...")
# Apply validation
validated = split_data.map(validate_markers)
# Count valid examples
valid_count = sum(validated['valid'])
invalid_count = len(validated) - valid_count
valid_rate = valid_count / len(validated) * 100
logger.info(f"Valid examples: {valid_count}/{len(validated)} ({valid_rate:.1f}%)")
if invalid_count > 0:
logger.warning(f"Found {invalid_count} invalid examples!")
# Show first few invalid examples
invalid_examples = [
ex for ex in validated if not ex['valid']
][:3]
for i, ex in enumerate(invalid_examples):
logger.warning(f"\nInvalid example {i+1}:")
logger.warning(f"Start markers: {ex['start_count']}")
logger.warning(f"End markers: {ex['end_count']}")
logger.warning(f"Text: {ex['text'][:200]}...")
# Update statistics
statistics['splits'][split_name]['valid'] = valid_count
statistics['splits'][split_name]['invalid'] = invalid_count
statistics['splits'][split_name]['valid_rate'] = valid_rate
# Convert back to DatasetDict
processed_dataset = DatasetDict(processed_dataset)
return processed_dataset, statistics
def save_dataset(dataset: DatasetDict, output_dir: Path, data_dir: str):
"""
Save processed dataset to local directory.
Args:
dataset: Processed dataset to save
output_dir: Directory to save to
data_dir: Original data directory name (for filename)
"""
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"\nSaving processed dataset to {output_dir}...")
for split_name, split_data in dataset.items():
# Save as CSV
output_file = output_dir / f"{split_name}_{data_dir}.csv"
# Convert to pandas and save
df = split_data.to_pandas()
df.to_csv(output_file, index=False)
logger.info(f"Saved {split_name} split: {output_file} ({len(df)} examples)")
logger.info("Dataset saved successfully!")
def print_statistics(statistics: Dict):
"""
Print processing statistics in a formatted table.
Args:
statistics: Dictionary containing processing statistics
"""
logger.info("\n" + "="*60)
logger.info("PROCESSING STATISTICS")
logger.info("="*60)
logger.info(f"\nTotal examples: {statistics['total_examples']}")
logger.info(f"Processed examples: {statistics['processed_examples']}")
logger.info(f"Already marked: {statistics['already_marked']}")
logger.info("\nPer-split statistics:")
logger.info("-"*60)
for split_name, split_stats in statistics['splits'].items():
logger.info(f"\n{split_name.upper()}:")
logger.info(f" Total: {split_stats['total']}")
logger.info(f" Processed: {split_stats['processed']}")
logger.info(f" Already marked: {split_stats.get('already_marked', 0)}")
if 'valid' in split_stats:
logger.info(f" Valid: {split_stats['valid']}")
logger.info(f" Invalid: {split_stats['invalid']}")
logger.info(f" Valid rate: {split_stats['valid_rate']:.1f}%")
logger.info("="*60)
def main():
parser = argparse.ArgumentParser(
description="Prepare training data with proper end-of-expression markers"
)
parser.add_argument(
"--dataset_repo_id",
type=str,
required=True,
help="HuggingFace dataset repository ID"
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="Subdirectory within the dataset (e.g., '700K')"
)
parser.add_argument(
"--data_column",
type=str,
required=True,
help="Column to use for training data (e.g., 'i_prompt_n')"
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory to save processed dataset"
)
parser.add_argument(
"--validate",
action="store_true",
help="Run validation after processing"
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push processed dataset to HuggingFace Hub"
)
parser.add_argument(
"--hub_repo_id",
type=str,
default=None,
help="HuggingFace repository ID for pushing (if --push_to_hub)"
)
args = parser.parse_args()
# Convert output_dir to Path
output_dir = Path(args.output_dir)
# Process dataset
try:
processed_dataset, statistics = process_dataset(
dataset_repo_id=args.dataset_repo_id,
data_dir=args.data_dir,
data_column=args.data_column,
output_dir=output_dir,
validate=args.validate
)
# Print statistics
print_statistics(statistics)
# Save to local directory
save_dataset(processed_dataset, output_dir, args.data_dir)
# Push to Hub if requested
if args.push_to_hub:
if not args.hub_repo_id:
logger.error("--hub_repo_id required when using --push_to_hub")
sys.exit(1)
logger.info(f"\nPushing to HuggingFace Hub: {args.hub_repo_id}")
processed_dataset.push_to_hub(args.hub_repo_id)
logger.info("Successfully pushed to Hub!")
# Check if any validation failed
if args.validate:
all_valid = all(
split_stats.get('invalid', 0) == 0
for split_stats in statistics['splits'].values()
)
if not all_valid:
logger.error("\n⚠️ Some examples failed validation!")
sys.exit(1)
else:
logger.info("\n✅ All examples validated successfully!")
logger.info("\n✅ Data preparation complete!")
except Exception as e:
logger.error(f"\n❌ Error during processing: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()