| |
| """ |
| Data preparation script for training experiments. |
| |
| Prepares data in two formats: |
| - EXP-A: JSON structured format |
| - EXP-B: EOS token format (GPT-2's <|endoftext|>) |
| |
| Usage: |
| python scripts/data/prepare_experiment_data.py \ |
| --dataset_repo_id augustocsc/sintetico_natural \ |
| --data_dir 700K \ |
| --data_column i_prompt_n \ |
| --output_base_dir ./data/experiments |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import re |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| from datasets import load_dataset, Dataset, DatasetDict |
| import pandas as pd |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_original_format(text: str) -> Optional[Dict]: |
| """ |
| Parse the original format into components. |
| |
| Original format: |
| vars: x_1, x_2 |
| oper: *, +, sin |
| cons: C |
| expr: C*sin(x_1) + x_2 |
| |
| Returns: |
| Dictionary with vars, ops, cons, expr or None if parsing fails |
| """ |
| result = { |
| 'vars': [], |
| 'ops': [], |
| 'cons': None, |
| 'expr': None, |
| 'raw_text': text |
| } |
|
|
| lines = text.strip().split('\n') |
|
|
| for line in lines: |
| line = line.strip() |
| if not line: |
| continue |
|
|
| if line.startswith('vars:') or line.startswith('Variables:'): |
| |
| var_part = line.split(':', 1)[1].strip() |
| vars_list = [v.strip() for v in var_part.split(',') if v.strip()] |
| result['vars'] = vars_list |
|
|
| elif line.startswith('oper:') or line.startswith('Operators:'): |
| |
| op_part = line.split(':', 1)[1].strip() |
| ops_list = [o.strip() for o in op_part.split(',') if o.strip()] |
| result['ops'] = ops_list |
|
|
| elif line.startswith('cons:') or line.startswith('Constants:'): |
| |
| cons_part = line.split(':', 1)[1].strip() |
| result['cons'] = cons_part if cons_part else None |
|
|
| elif line.startswith('expr:'): |
| |
| expr_part = line.split(':', 1)[1].strip() |
| |
| expr_part = expr_part.split('<|')[0].strip() |
| expr_part = expr_part.split('\n')[0].strip() |
| result['expr'] = expr_part |
|
|
| |
| if not result['expr']: |
| return None |
|
|
| return result |
|
|
|
|
| def convert_to_json_format(parsed: Dict) -> str: |
| """ |
| Convert parsed data to JSON format (EXP-A). |
| |
| Output format: |
| {"vars": ["x_1", "x_2"], "ops": ["*", "+", "sin"], "cons": "C", "expr": "C*sin(x_1) + x_2"} |
| """ |
| json_obj = { |
| 'vars': parsed['vars'], |
| 'ops': parsed['ops'], |
| } |
|
|
| if parsed['cons']: |
| json_obj['cons'] = parsed['cons'] |
|
|
| json_obj['expr'] = parsed['expr'] |
|
|
| return json.dumps(json_obj, ensure_ascii=False) |
|
|
|
|
| def convert_to_eos_format(parsed: Dict) -> str: |
| """ |
| Convert parsed data to EOS token format (EXP-B). |
| |
| Output format: |
| vars: x_1, x_2 |
| oper: *, +, sin |
| cons: C |
| expr: C*sin(x_1) + x_2<|endoftext|> |
| """ |
| lines = [] |
|
|
| if parsed['vars']: |
| lines.append(f"vars: {', '.join(parsed['vars'])}") |
|
|
| if parsed['ops']: |
| lines.append(f"oper: {', '.join(parsed['ops'])}") |
|
|
| if parsed['cons']: |
| lines.append(f"cons: {parsed['cons']}") |
|
|
| |
| lines.append(f"expr: {parsed['expr']}<|endoftext|>") |
|
|
| return '\n'.join(lines) |
|
|
|
|
| def process_example_json(example: Dict) -> Dict: |
| """Process a single example into JSON format.""" |
| text = example['text'] |
| parsed = parse_original_format(text) |
|
|
| if parsed is None: |
| logger.warning(f"Failed to parse: {text[:100]}...") |
| return {'text': '', 'valid': False} |
|
|
| json_text = convert_to_json_format(parsed) |
| return {'text': json_text, 'valid': True} |
|
|
|
|
| def process_example_eos(example: Dict) -> Dict: |
| """Process a single example into EOS format.""" |
| text = example['text'] |
| parsed = parse_original_format(text) |
|
|
| if parsed is None: |
| logger.warning(f"Failed to parse: {text[:100]}...") |
| return {'text': '', 'valid': False} |
|
|
| eos_text = convert_to_eos_format(parsed) |
| return {'text': eos_text, 'valid': True} |
|
|
|
|
| def validate_json_format(text: str) -> bool: |
| """Validate JSON format is correct.""" |
| try: |
| obj = json.loads(text) |
| return 'expr' in obj and 'vars' in obj and 'ops' in obj |
| except: |
| return False |
|
|
|
|
| def validate_eos_format(text: str) -> bool: |
| """Validate EOS format is correct.""" |
| return '<|endoftext|>' in text and 'expr:' in text |
|
|
|
|
| def process_dataset( |
| dataset_repo_id: str, |
| data_dir: str, |
| data_column: str, |
| output_base_dir: Path, |
| max_samples: Optional[int] = None |
| ) -> Dict: |
| """ |
| Process the dataset into both formats. |
| |
| Args: |
| dataset_repo_id: HuggingFace dataset repository ID |
| data_dir: Subdirectory within the dataset |
| data_column: Column containing the text data |
| output_base_dir: Base directory for output |
| max_samples: Optional limit on number of samples (for testing) |
| |
| Returns: |
| Dictionary with processing statistics |
| """ |
| logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...") |
|
|
| |
| dataset = load_dataset( |
| dataset_repo_id, |
| data_dir=data_dir, |
| split=None |
| ) |
|
|
| if not isinstance(dataset, dict): |
| dataset = {'train': dataset} |
|
|
| logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}") |
|
|
| |
| if 'train' in dataset: |
| sample = dataset['train'][0][data_column] |
| logger.info(f"\nSample ORIGINAL format:\n{sample}\n") |
|
|
| |
| output_json = output_base_dir / 'exp_a_json' |
| output_eos = output_base_dir / 'exp_b_eos' |
| output_json.mkdir(parents=True, exist_ok=True) |
| output_eos.mkdir(parents=True, exist_ok=True) |
|
|
| statistics = { |
| 'total': 0, |
| 'json_valid': 0, |
| 'eos_valid': 0, |
| 'json_invalid': 0, |
| 'eos_invalid': 0, |
| 'splits': {} |
| } |
|
|
| for split_name, split_data in dataset.items(): |
| logger.info(f"\n{'='*60}") |
| logger.info(f"Processing {split_name} split ({len(split_data)} examples)") |
| logger.info('='*60) |
|
|
| |
| if data_column != 'text': |
| split_data = split_data.rename_column(data_column, 'text') |
|
|
| |
| if max_samples and len(split_data) > max_samples: |
| logger.info(f"Limiting to {max_samples} samples for testing") |
| split_data = split_data.select(range(max_samples)) |
|
|
| statistics['total'] += len(split_data) |
|
|
| |
| logger.info("\nConverting to JSON format (EXP-A)...") |
| json_data = split_data.map( |
| process_example_json, |
| desc=f"JSON format ({split_name})" |
| ) |
|
|
| |
| json_valid = json_data.filter(lambda x: x['valid']) |
| json_invalid_count = len(json_data) - len(json_valid) |
|
|
| logger.info(f"JSON format: {len(json_valid)}/{len(json_data)} valid") |
|
|
| if len(json_valid) > 0: |
| logger.info(f"\nSample JSON format:\n{json_valid[0]['text']}\n") |
|
|
| |
| logger.info("\nConverting to EOS format (EXP-B)...") |
| eos_data = split_data.map( |
| process_example_eos, |
| desc=f"EOS format ({split_name})" |
| ) |
|
|
| |
| eos_valid = eos_data.filter(lambda x: x['valid']) |
| eos_invalid_count = len(eos_data) - len(eos_valid) |
|
|
| logger.info(f"EOS format: {len(eos_valid)}/{len(eos_data)} valid") |
|
|
| if len(eos_valid) > 0: |
| logger.info(f"\nSample EOS format:\n{eos_valid[0]['text']}\n") |
|
|
| |
| statistics['json_valid'] += len(json_valid) |
| statistics['json_invalid'] += json_invalid_count |
| statistics['eos_valid'] += len(eos_valid) |
| statistics['eos_invalid'] += eos_invalid_count |
| statistics['splits'][split_name] = { |
| 'total': len(split_data), |
| 'json_valid': len(json_valid), |
| 'eos_valid': len(eos_valid) |
| } |
|
|
| |
| json_df = pd.DataFrame({'text': [ex['text'] for ex in json_valid]}) |
| json_file = output_json / f'{split_name}.csv' |
| json_df.to_csv(json_file, index=False) |
| logger.info(f"Saved JSON: {json_file} ({len(json_df)} examples)") |
|
|
| |
| eos_df = pd.DataFrame({'text': [ex['text'] for ex in eos_valid]}) |
| eos_file = output_eos / f'{split_name}.csv' |
| eos_df.to_csv(eos_file, index=False) |
| logger.info(f"Saved EOS: {eos_file} ({len(eos_df)} examples)") |
|
|
| return statistics |
|
|
|
|
| def validate_output_files(output_base_dir: Path) -> Dict: |
| """ |
| Validate the generated output files. |
| |
| Returns: |
| Validation results dictionary |
| """ |
| logger.info("\n" + "="*60) |
| logger.info("VALIDATION OF OUTPUT FILES") |
| logger.info("="*60) |
|
|
| results = { |
| 'exp_a_json': {'valid': True, 'issues': []}, |
| 'exp_b_eos': {'valid': True, 'issues': []} |
| } |
|
|
| |
| json_dir = output_base_dir / 'exp_a_json' |
| for csv_file in json_dir.glob('*.csv'): |
| logger.info(f"\nValidating {csv_file.name}...") |
| df = pd.read_csv(csv_file) |
|
|
| valid_count = 0 |
| invalid_samples = [] |
|
|
| for idx, row in df.iterrows(): |
| text = row['text'] |
| if validate_json_format(text): |
| valid_count += 1 |
| else: |
| if len(invalid_samples) < 3: |
| invalid_samples.append(text[:100]) |
|
|
| rate = valid_count / len(df) * 100 if len(df) > 0 else 0 |
| logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)") |
|
|
| if invalid_samples: |
| results['exp_a_json']['valid'] = False |
| results['exp_a_json']['issues'].extend(invalid_samples) |
|
|
| |
| eos_dir = output_base_dir / 'exp_b_eos' |
| for csv_file in eos_dir.glob('*.csv'): |
| logger.info(f"\nValidating {csv_file.name}...") |
| df = pd.read_csv(csv_file) |
|
|
| valid_count = 0 |
| invalid_samples = [] |
|
|
| for idx, row in df.iterrows(): |
| text = row['text'] |
| if validate_eos_format(text): |
| valid_count += 1 |
| else: |
| if len(invalid_samples) < 3: |
| invalid_samples.append(text[:100]) |
|
|
| rate = valid_count / len(df) * 100 if len(df) > 0 else 0 |
| logger.info(f" Valid: {valid_count}/{len(df)} ({rate:.1f}%)") |
|
|
| if invalid_samples: |
| results['exp_b_eos']['valid'] = False |
| results['exp_b_eos']['issues'].extend(invalid_samples) |
|
|
| return results |
|
|
|
|
| def print_final_report(statistics: Dict, validation: Dict): |
| """Print final processing report.""" |
| logger.info("\n" + "="*60) |
| logger.info("FINAL REPORT") |
| logger.info("="*60) |
|
|
| logger.info(f"\nTotal examples processed: {statistics['total']}") |
|
|
| logger.info("\nEXP-A (JSON Format):") |
| logger.info(f" Valid: {statistics['json_valid']}") |
| logger.info(f" Invalid: {statistics['json_invalid']}") |
| json_rate = statistics['json_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0 |
| logger.info(f" Success rate: {json_rate:.1f}%") |
| logger.info(f" Validation: {'PASS' if validation['exp_a_json']['valid'] else 'FAIL'}") |
|
|
| logger.info("\nEXP-B (EOS Format):") |
| logger.info(f" Valid: {statistics['eos_valid']}") |
| logger.info(f" Invalid: {statistics['eos_invalid']}") |
| eos_rate = statistics['eos_valid'] / statistics['total'] * 100 if statistics['total'] > 0 else 0 |
| logger.info(f" Success rate: {eos_rate:.1f}%") |
| logger.info(f" Validation: {'PASS' if validation['exp_b_eos']['valid'] else 'FAIL'}") |
|
|
| logger.info("\nPer-split breakdown:") |
| for split_name, split_stats in statistics['splits'].items(): |
| logger.info(f"\n {split_name.upper()}:") |
| logger.info(f" Total: {split_stats['total']}") |
| logger.info(f" JSON valid: {split_stats['json_valid']}") |
| logger.info(f" EOS valid: {split_stats['eos_valid']}") |
|
|
| logger.info("\n" + "="*60) |
|
|
| all_valid = validation['exp_a_json']['valid'] and validation['exp_b_eos']['valid'] |
| if all_valid: |
| logger.info("STATUS: ALL VALIDATIONS PASSED") |
| else: |
| logger.info("STATUS: SOME VALIDATIONS FAILED") |
|
|
| logger.info("="*60) |
|
|
| return all_valid |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Prepare experiment data in JSON and EOS formats" |
| ) |
| parser.add_argument( |
| "--dataset_repo_id", |
| type=str, |
| default="augustocsc/sintetico_natural", |
| help="HuggingFace dataset repository ID" |
| ) |
| parser.add_argument( |
| "--data_dir", |
| type=str, |
| default="700K", |
| help="Subdirectory within the dataset" |
| ) |
| parser.add_argument( |
| "--data_column", |
| type=str, |
| default="i_prompt_n", |
| help="Column containing text data" |
| ) |
| parser.add_argument( |
| "--output_base_dir", |
| type=str, |
| default="./data/experiments", |
| help="Base directory for output" |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=None, |
| help="Maximum samples per split (for testing)" |
| ) |
| parser.add_argument( |
| "--skip_validation", |
| action="store_true", |
| help="Skip output file validation" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| output_base_dir = Path(args.output_base_dir) |
|
|
| logger.info("="*60) |
| logger.info("EXPERIMENT DATA PREPARATION") |
| logger.info("="*60) |
| logger.info(f"Dataset: {args.dataset_repo_id}/{args.data_dir}") |
| logger.info(f"Column: {args.data_column}") |
| logger.info(f"Output: {output_base_dir}") |
| if args.max_samples: |
| logger.info(f"Max samples: {args.max_samples}") |
| logger.info("="*60) |
|
|
| try: |
| |
| statistics = process_dataset( |
| dataset_repo_id=args.dataset_repo_id, |
| data_dir=args.data_dir, |
| data_column=args.data_column, |
| output_base_dir=output_base_dir, |
| max_samples=args.max_samples |
| ) |
|
|
| |
| if not args.skip_validation: |
| validation = validate_output_files(output_base_dir) |
| else: |
| validation = { |
| 'exp_a_json': {'valid': True, 'issues': []}, |
| 'exp_b_eos': {'valid': True, 'issues': []} |
| } |
|
|
| |
| all_valid = print_final_report(statistics, validation) |
|
|
| if all_valid: |
| logger.info("\nData preparation completed successfully!") |
| logger.info(f"\nOutput directories:") |
| logger.info(f" EXP-A (JSON): {output_base_dir / 'exp_a_json'}") |
| logger.info(f" EXP-B (EOS): {output_base_dir / 'exp_b_eos'}") |
| sys.exit(0) |
| else: |
| logger.error("\nData preparation completed with validation errors!") |
| sys.exit(1) |
|
|
| except Exception as e: |
| logger.error(f"\nFailed to prepare data: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|