gpt2_base_prefix_682k / scripts /data /prepare_experiment_data.py
augustocsc's picture
GPT-2 Base trained on prefix dataset (682K)
5faf2eb verified
#!/usr/bin/env python3
"""
Data preparation script for training experiments.
Prepares data in two formats:
- EXP-A: JSON structured format
- EXP-B: EOS token format (GPT-2's <|endoftext|>)
Usage:
python scripts/data/prepare_experiment_data.py \
--dataset_repo_id augustocsc/sintetico_natural \
--data_dir 700K \
--data_column i_prompt_n \
--output_base_dir ./data/experiments
"""
import argparse
import json
import logging
import re
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def parse_original_format(text: str) -> Optional[Dict]:
"""
Parse the original format into components.
Original format:
vars: x_1, x_2
oper: *, +, sin
cons: C
expr: C*sin(x_1) + x_2
Returns:
Dictionary with vars, ops, cons, expr or None if parsing fails
"""
result = {
'vars': [],
'ops': [],
'cons': None,
'expr': None,
'raw_text': text
}
lines = text.strip().split('\n')
for line in lines:
line = line.strip()
if not line:
continue
if line.startswith('vars:') or line.startswith('Variables:'):
# Extract variables
var_part = line.split(':', 1)[1].strip()
vars_list = [v.strip() for v in var_part.split(',') if v.strip()]
result['vars'] = vars_list
elif line.startswith('oper:') or line.startswith('Operators:'):
# Extract operators
op_part = line.split(':', 1)[1].strip()
ops_list = [o.strip() for o in op_part.split(',') if o.strip()]
result['ops'] = ops_list
elif line.startswith('cons:') or line.startswith('Constants:'):
# Extract constants
cons_part = line.split(':', 1)[1].strip()
result['cons'] = cons_part if cons_part else None
elif line.startswith('expr:'):
# Extract expression - everything after 'expr:'
expr_part = line.split(':', 1)[1].strip()
# Clean expression: remove any markers or trailing content
expr_part = expr_part.split('<|')[0].strip() # Remove any existing markers
expr_part = expr_part.split('\n')[0].strip() # Remove newlines
result['expr'] = expr_part
# Validate we got the essential parts
if not result['expr']:
return None
return result
def convert_to_json_format(parsed: Dict) -> str:
"""
Convert parsed data to JSON format (EXP-A).
Output format:
{"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "C*sin(x_1) + x_2"}
"""
json_obj = {
'vars': parsed['vars'],
'ops': parsed['ops'],
}
if parsed['cons']:
json_obj['cons'] = parsed['cons']
json_obj['expr'] = parsed['expr']
return json.dumps(json_obj, ensure_ascii=False)
def convert_to_eos_format(parsed: Dict) -> str:
"""
Convert parsed data to EOS token format (EXP-B).
Output format:
vars: x_1, x_2
oper: *, +, sin
cons: C
expr: C*sin(x_1) + x_2<|endoftext|>
"""
lines = []
if parsed['vars']:
lines.append(f"vars: {', '.join(parsed['vars'])}")
if parsed['ops']:
lines.append(f"oper: {', '.join(parsed['ops'])}")
if parsed['cons']:
lines.append(f"cons: {parsed['cons']}")
# Add expression with EOS token
lines.append(f"expr: {parsed['expr']}<|endoftext|>")
return '\n'.join(lines)
def process_example_json(example: Dict) -> Dict:
"""Process a single example into JSON format."""
text = example['text']
parsed = parse_original_format(text)
if parsed is None:
logger.warning(f"Failed to parse: {text[:100]}...")
return {'text': '', 'valid': False}
json_text = convert_to_json_format(parsed)
return {'text': json_text, 'valid': True}
def process_example_eos(example: Dict) -> Dict:
"""Process a single example into EOS format."""
text = example['text']
parsed = parse_original_format(text)
if parsed is None:
logger.warning(f"Failed to parse: {text[:100]}...")
return {'text': '', 'valid': False}
eos_text = convert_to_eos_format(parsed)
return {'text': eos_text, 'valid': True}
def validate_json_format(text: str) -> bool:
"""Validate JSON format is correct."""
try:
obj = json.loads(text)
return 'expr' in obj and 'vars' in obj and 'ops' in obj
except:
return False
def validate_eos_format(text: str) -> bool:
"""Validate EOS format is correct."""
return '<|endoftext|>' in text and 'expr:' in text
def process_dataset(
dataset_repo_id: str,
data_dir: str,
data_column: str,
output_base_dir: Path,
max_samples: Optional[int] = None
) -> Dict:
"""
Process the dataset into both formats.
Args:
dataset_repo_id: HuggingFace dataset repository ID
data_dir: Subdirectory within the dataset
data_column: Column containing the text data
output_base_dir: Base directory for output
max_samples: Optional limit on number of samples (for testing)
Returns:
Dictionary with processing statistics
"""
logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...")
# Load dataset
dataset = load_dataset(
dataset_repo_id,
data_dir=data_dir,
split=None
)
if not isinstance(dataset, dict):
dataset = {'train': dataset}
logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}")
# Show sample
if 'train' in dataset:
sample = dataset['train'][0][data_column]
logger.info(f"\nSample ORIGINAL format:\n{sample}\n")
# Create output directories
output_json = output_base_dir / 'exp_a_json'
output_eos = output_base_dir / 'exp_b_eos'
output_json.mkdir(parents=True, exist_ok=True)
output_eos.mkdir(parents=True, exist_ok=True)
statistics = {
'total': 0,
'json_valid': 0,
'eos_valid': 0,
'json_invalid': 0,
'eos_invalid': 0,
'splits': {}
}
for split_name, split_data in dataset.items():
logger.info(f"\n{'='*60}")
logger.info(f"Processing {split_name} split ({len(split_data)} examples)")
logger.info('='*60)
# Rename column if needed
if data_column != 'text':
split_data = split_data.rename_column(data_column, 'text')
# Limit samples if specified
if max_samples and len(split_data) > max_samples:
logger.info(f"Limiting to {max_samples} samples for testing")
split_data = split_data.select(range(max_samples))
statistics['total'] += len(split_data)
# Process to JSON format
logger.info("\nConverting to JSON format (EXP-A)...")
json_data = split_data.map(
process_example_json,
desc=f"JSON format ({split_name})"
)
# Filter valid examples
json_valid = json_data.filter(lambda x: x['valid'])
json_invalid_count = len(json_data) - len(json_valid)
logger.info(f"JSON format: {len(json_valid)}/{len(json_data)} valid")
if len(json_valid) > 0:
logger.info(f"\nSample JSON format:\n{json_valid[0]['text']}\n")
# Process to EOS format
logger.info("\nConverting to EOS format (EXP-B)...")
eos_data = split_data.map(
process_example_eos,
desc=f"EOS format ({split_name})"
)
# Filter valid examples
eos_valid = eos_data.filter(lambda x: x['valid'])
eos_invalid_count = len(eos_data) - len(eos_valid)
logger.info(f"EOS format: {len(eos_valid)}/{len(eos_data)} valid")
if len(eos_valid) > 0:
logger.info(f"\nSample EOS format:\n{eos_valid[0]['text']}\n")
# Update statistics
statistics['json_valid'] += len(json_valid)
statistics['json_invalid'] += json_invalid_count
statistics['eos_valid'] += len(eos_valid)
statistics['eos_invalid'] += eos_invalid_count
statistics['splits'][split_name] = {
'total': len(split_data),
'json_valid': len(json_valid),
'eos_valid': len(eos_valid)
}
# Save JSON format
json_df = pd.DataFrame({'text': [ex['text'] for ex in json_valid]})
json_file = output_json / f'{split_name}.csv'
json_df.to_csv(json_file, index=False)
logger.info(f"Saved JSON: {json_file} ({len(json_df)} examples)")
# Save EOS format
eos_df = pd.DataFrame({'text': [ex['text'] for ex in eos_valid]})
eos_file = output_eos / f'{split_name}.csv'
eos_df.to_csv(eos_file, index=False)
logger.info(f"Saved EOS: {eos_file} ({len(eos_df)} examples)")
return statistics
def validate_output_files(output_base_dir: Path) -> Dict:
"""
Validate the generated output files.
Returns:
Validation results dictionary
"""
logger.info("\n" + "="*60)
logger.info("VALIDATION OF OUTPUT FILES")
logger.info("="*60)
results = {
'exp_a_json': {'valid': True, 'issues': []},
'exp_b_eos': {'valid': True, 'issues': []}
}
# Validate JSON format (EXP-A)
json_dir = output_base_dir / 'exp_a_json'
for csv_file in json_dir.glob('*.csv'):
logger.info(f"\nValidating {csv_file.name}...")
df = pd.read_csv(csv_file)
valid_count = 0
invalid_samples = []
for idx, row in df.iterrows():
text = row['text']
if validate_json_format(text):
valid_count += 1
else:
if len(invalid_samples) < 3:
invalid_samples.append(text[:100])
rate = valid_count / len(df) * 100 if len(df) > 0 else 0
logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)")
if invalid_samples:
results['exp_a_json']['valid'] = False
results['exp_a_json']['issues'].extend(invalid_samples)
# Validate EOS format (EXP-B)
eos_dir = output_base_dir / 'exp_b_eos'
for csv_file in eos_dir.glob('*.csv'):
logger.info(f"\nValidating {csv_file.name}...")
df = pd.read_csv(csv_file)
valid_count = 0
invalid_samples = []
for idx, row in df.iterrows():
text = row['text']
if validate_eos_format(text):
valid_count += 1
else:
if len(invalid_samples) < 3:
invalid_samples.append(text[:100])
rate = valid_count / len(df) * 100 if len(df) > 0 else 0
logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)")
if invalid_samples:
results['exp_b_eos']['valid'] = False
results['exp_b_eos']['issues'].extend(invalid_samples)
return results
def print_final_report(statistics: Dict, validation: Dict):
"""Print final processing report."""
logger.info("\n" + "="*60)
logger.info("FINAL REPORT")
logger.info("="*60)
logger.info(f"\nTotal examples processed: {statistics['total']}")
logger.info("\nEXP-A (JSON Format):")
logger.info(f" Valid: {statistics['json_valid']}")
logger.info(f" Invalid: {statistics['json_invalid']}")
json_rate = statistics['json_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0
logger.info(f" Success rate: {json_rate:.1f}%")
logger.info(f" Validation: {'PASS' if validation['exp_a_json']['valid'] else 'FAIL'}")
logger.info("\nEXP-B (EOS Format):")
logger.info(f" Valid: {statistics['eos_valid']}")
logger.info(f" Invalid: {statistics['eos_invalid']}")
eos_rate = statistics['eos_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0
logger.info(f" Success rate: {eos_rate:.1f}%")
logger.info(f" Validation: {'PASS' if validation['exp_b_eos']['valid'] else 'FAIL'}")
logger.info("\nPer-split breakdown:")
for split_name, split_stats in statistics['splits'].items():
logger.info(f"\n {split_name.upper()}:")
logger.info(f" Total: {split_stats['total']}")
logger.info(f" JSON valid: {split_stats['json_valid']}")
logger.info(f" EOS valid: {split_stats['eos_valid']}")
logger.info("\n" + "="*60)
all_valid = validation['exp_a_json']['valid'] and validation['exp_b_eos']['valid']
if all_valid:
logger.info("STATUS: ALL VALIDATIONS PASSED")
else:
logger.info("STATUS: SOME VALIDATIONS FAILED")
logger.info("="*60)
return all_valid
def main():
parser = argparse.ArgumentParser(
description="Prepare experiment data in JSON and EOS formats"
)
parser.add_argument(
"--dataset_repo_id",
type=str,
default="augustocsc/sintetico_natural",
help="HuggingFace dataset repository ID"
)
parser.add_argument(
"--data_dir",
type=str,
default="700K",
help="Subdirectory within the dataset"
)
parser.add_argument(
"--data_column",
type=str,
default="i_prompt_n",
help="Column containing text data"
)
parser.add_argument(
"--output_base_dir",
type=str,
default="./data/experiments",
help="Base directory for output"
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Maximum samples per split (for testing)"
)
parser.add_argument(
"--skip_validation",
action="store_true",
help="Skip output file validation"
)
args = parser.parse_args()
output_base_dir = Path(args.output_base_dir)
logger.info("="*60)
logger.info("EXPERIMENT DATA PREPARATION")
logger.info("="*60)
logger.info(f"Dataset: {args.dataset_repo_id}/{args.data_dir}")
logger.info(f"Column: {args.data_column}")
logger.info(f"Output: {output_base_dir}")
if args.max_samples:
logger.info(f"Max samples: {args.max_samples}")
logger.info("="*60)
try:
# Process dataset
statistics = process_dataset(
dataset_repo_id=args.dataset_repo_id,
data_dir=args.data_dir,
data_column=args.data_column,
output_base_dir=output_base_dir,
max_samples=args.max_samples
)
# Validate output
if not args.skip_validation:
validation = validate_output_files(output_base_dir)
else:
validation = {
'exp_a_json': {'valid': True, 'issues': []},
'exp_b_eos': {'valid': True, 'issues': []}
}
# Print report
all_valid = print_final_report(statistics, validation)
if all_valid:
logger.info("\nData preparation completed successfully!")
logger.info(f"\nOutput directories:")
logger.info(f" EXP-A (JSON): {output_base_dir / 'exp_a_json'}")
logger.info(f" EXP-B (EOS): {output_base_dir / 'exp_b_eos'}")
sys.exit(0)
else:
logger.error("\nData preparation completed with validation errors!")
sys.exit(1)
except Exception as e:
logger.error(f"\nFailed to prepare data: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()