| | """ |
| | Data preparation script that adds proper <|endofex|> markers to training data. |
| | |
| | This script processes the existing dataset and wraps expressions with end-of-expression |
| | markers so the model learns to stop generation correctly. |
| | |
| | Usage: |
| | python scripts/data/prepare_training_data_fixed.py \ |
| | --dataset_repo_id augustocsc/sintetico_natural \ |
| | --data_dir 700K \ |
| | --data_column i_prompt_n \ |
| | --output_dir ./data/processed/700K_fixed \ |
| | --validate |
| | """ |
| |
|
| | import argparse |
| | import logging |
| | import os |
| | import sys |
| | from pathlib import Path |
| | from typing import Dict, Tuple |
| |
|
| | from datasets import load_dataset, Dataset, DatasetDict |
| | import pandas as pd |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s' |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def add_end_markers(example: Dict) -> Dict: |
| | """ |
| | Add end-of-expression markers to training data. |
| | |
| | This function: |
| | 1. Locates the expression in the text (after 'expr:') |
| | 2. Finds the natural end boundary (before 'vars:', newlines, etc.) |
| | 3. Inserts <|endofex|> marker at the end |
| | 4. Preserves any remaining content after the marker |
| | |
| | Args: |
| | example: Dictionary containing 'text' field with training data |
| | |
| | Returns: |
| | Dictionary with modified 'text' field containing end markers |
| | """ |
| | text = example['text'] |
| |
|
| | |
| | if 'expr:' not in text: |
| | logger.warning(f"No 'expr:' found in text: {text[:100]}...") |
| | return {'text': text} |
| |
|
| | |
| | parts = text.split('expr:', 1) |
| | if len(parts) != 2: |
| | logger.warning(f"Unexpected format in text: {text[:100]}...") |
| | return {'text': text} |
| |
|
| | prefix = parts[0] |
| | expression_part = parts[1] |
| |
|
| | |
| | if '<|endofex|>' in expression_part: |
| | logger.debug("Marker already present, skipping") |
| | return {'text': text} |
| |
|
| | |
| | end_idx = len(expression_part) |
| | boundaries = ['\nvars:', '\nVariables:', '\n\n', '\nvar:', '\nVariable:'] |
| |
|
| | for boundary in boundaries: |
| | idx = expression_part.find(boundary) |
| | if idx != -1 and idx < end_idx: |
| | end_idx = idx |
| |
|
| | |
| | clean_expr = expression_part[:end_idx].strip() |
| | remaining = expression_part[end_idx:] |
| |
|
| | |
| | new_text = f"{prefix}expr: {clean_expr}<|endofex|>{remaining}" |
| |
|
| | return {'text': new_text} |
| |
|
| |
|
| | def validate_markers(example: Dict) -> Dict: |
| | """ |
| | Validate that markers are properly present in the text. |
| | |
| | Args: |
| | example: Dictionary containing 'text' field |
| | |
| | Returns: |
| | Dictionary with validation metadata |
| | """ |
| | text = example['text'] |
| | start_count = text.count('<|startofex|>') |
| | end_count = text.count('<|endofex|>') |
| |
|
| | |
| | |
| | valid = end_count > 0 |
| |
|
| | return { |
| | 'valid': valid, |
| | 'start_count': start_count, |
| | 'end_count': end_count, |
| | 'text': text |
| | } |
| |
|
| |
|
| | def process_dataset( |
| | dataset_repo_id: str, |
| | data_dir: str, |
| | data_column: str, |
| | output_dir: Path, |
| | validate: bool = True |
| | ) -> Tuple[DatasetDict, Dict]: |
| | """ |
| | Process the dataset by adding end markers to all splits. |
| | |
| | Args: |
| | dataset_repo_id: HuggingFace dataset repository ID |
| | data_dir: Subdirectory within the dataset (e.g., '700K') |
| | data_column: Column to use for training data |
| | output_dir: Directory to save processed dataset |
| | validate: Whether to run validation after processing |
| | |
| | Returns: |
| | Tuple of (processed_dataset, statistics) |
| | """ |
| | logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...") |
| |
|
| | try: |
| | |
| | dataset = load_dataset( |
| | dataset_repo_id, |
| | data_dir=data_dir, |
| | split=None |
| | ) |
| |
|
| | if not isinstance(dataset, dict): |
| | |
| | dataset = {'train': dataset} |
| |
|
| | logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}") |
| |
|
| | |
| | if 'train' in dataset and len(dataset['train']) > 0: |
| | logger.info(f"\nSample BEFORE processing:") |
| | logger.info(f"{dataset['train'][0][data_column][:200]}...") |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to load dataset: {e}") |
| | raise |
| |
|
| | |
| | processed_dataset = {} |
| | statistics = { |
| | 'total_examples': 0, |
| | 'processed_examples': 0, |
| | 'already_marked': 0, |
| | 'splits': {} |
| | } |
| |
|
| | for split_name, split_data in dataset.items(): |
| | logger.info(f"\nProcessing {split_name} split ({len(split_data)} examples)...") |
| |
|
| | |
| | if data_column != 'text': |
| | split_data = split_data.rename_column(data_column, 'text') |
| |
|
| | |
| | already_marked = sum(1 for ex in split_data if '<|endofex|>' in ex['text']) |
| | statistics['already_marked'] += already_marked |
| |
|
| | if already_marked > 0: |
| | logger.info(f"Found {already_marked} examples already with markers") |
| |
|
| | |
| | processed_split = split_data.map( |
| | add_end_markers, |
| | desc=f"Adding markers to {split_name}" |
| | ) |
| |
|
| | processed_dataset[split_name] = processed_split |
| |
|
| | |
| | split_stats = { |
| | 'total': len(split_data), |
| | 'processed': len(processed_split), |
| | 'already_marked': already_marked |
| | } |
| | statistics['splits'][split_name] = split_stats |
| | statistics['total_examples'] += len(split_data) |
| | statistics['processed_examples'] += len(processed_split) |
| |
|
| | |
| | if len(processed_split) > 0: |
| | logger.info(f"\nSample AFTER processing:") |
| | logger.info(f"{processed_split[0]['text'][:200]}...") |
| |
|
| | |
| | if validate: |
| | logger.info("\n" + "="*60) |
| | logger.info("VALIDATION") |
| | logger.info("="*60) |
| |
|
| | for split_name, split_data in processed_dataset.items(): |
| | logger.info(f"\nValidating {split_name} split...") |
| |
|
| | |
| | validated = split_data.map(validate_markers) |
| |
|
| | |
| | valid_count = sum(validated['valid']) |
| | invalid_count = len(validated) - valid_count |
| |
|
| | valid_rate = valid_count / len(validated) * 100 |
| |
|
| | logger.info(f"Valid examples: {valid_count}/{len(validated)} ({valid_rate:.1f}%)") |
| |
|
| | if invalid_count > 0: |
| | logger.warning(f"Found {invalid_count} invalid examples!") |
| |
|
| | |
| | invalid_examples = [ |
| | ex for ex in validated if not ex['valid'] |
| | ][:3] |
| |
|
| | for i, ex in enumerate(invalid_examples): |
| | logger.warning(f"\nInvalid example {i+1}:") |
| | logger.warning(f"Start markers: {ex['start_count']}") |
| | logger.warning(f"End markers: {ex['end_count']}") |
| | logger.warning(f"Text: {ex['text'][:200]}...") |
| |
|
| | |
| | statistics['splits'][split_name]['valid'] = valid_count |
| | statistics['splits'][split_name]['invalid'] = invalid_count |
| | statistics['splits'][split_name]['valid_rate'] = valid_rate |
| |
|
| | |
| | processed_dataset = DatasetDict(processed_dataset) |
| |
|
| | return processed_dataset, statistics |
| |
|
| |
|
| | def save_dataset(dataset: DatasetDict, output_dir: Path, data_dir: str): |
| | """ |
| | Save processed dataset to local directory. |
| | |
| | Args: |
| | dataset: Processed dataset to save |
| | output_dir: Directory to save to |
| | data_dir: Original data directory name (for filename) |
| | """ |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | logger.info(f"\nSaving processed dataset to {output_dir}...") |
| |
|
| | for split_name, split_data in dataset.items(): |
| | |
| | output_file = output_dir / f"{split_name}_{data_dir}.csv" |
| |
|
| | |
| | df = split_data.to_pandas() |
| | df.to_csv(output_file, index=False) |
| |
|
| | logger.info(f"Saved {split_name} split: {output_file} ({len(df)} examples)") |
| |
|
| | logger.info("Dataset saved successfully!") |
| |
|
| |
|
| | def print_statistics(statistics: Dict): |
| | """ |
| | Print processing statistics in a formatted table. |
| | |
| | Args: |
| | statistics: Dictionary containing processing statistics |
| | """ |
| | logger.info("\n" + "="*60) |
| | logger.info("PROCESSING STATISTICS") |
| | logger.info("="*60) |
| |
|
| | logger.info(f"\nTotal examples: {statistics['total_examples']}") |
| | logger.info(f"Processed examples: {statistics['processed_examples']}") |
| | logger.info(f"Already marked: {statistics['already_marked']}") |
| |
|
| | logger.info("\nPer-split statistics:") |
| | logger.info("-"*60) |
| |
|
| | for split_name, split_stats in statistics['splits'].items(): |
| | logger.info(f"\n{split_name.upper()}:") |
| | logger.info(f" Total: {split_stats['total']}") |
| | logger.info(f" Processed: {split_stats['processed']}") |
| | logger.info(f" Already marked: {split_stats.get('already_marked', 0)}") |
| |
|
| | if 'valid' in split_stats: |
| | logger.info(f" Valid: {split_stats['valid']}") |
| | logger.info(f" Invalid: {split_stats['invalid']}") |
| | logger.info(f" Valid rate: {split_stats['valid_rate']:.1f}%") |
| |
|
| | logger.info("="*60) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="Prepare training data with proper end-of-expression markers" |
| | ) |
| | parser.add_argument( |
| | "--dataset_repo_id", |
| | type=str, |
| | required=True, |
| | help="HuggingFace dataset repository ID" |
| | ) |
| | parser.add_argument( |
| | "--data_dir", |
| | type=str, |
| | required=True, |
| | help="Subdirectory within the dataset (e.g., '700K')" |
| | ) |
| | parser.add_argument( |
| | "--data_column", |
| | type=str, |
| | required=True, |
| | help="Column to use for training data (e.g., 'i_prompt_n')" |
| | ) |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | required=True, |
| | help="Directory to save processed dataset" |
| | ) |
| | parser.add_argument( |
| | "--validate", |
| | action="store_true", |
| | help="Run validation after processing" |
| | ) |
| | parser.add_argument( |
| | "--push_to_hub", |
| | action="store_true", |
| | help="Push processed dataset to HuggingFace Hub" |
| | ) |
| | parser.add_argument( |
| | "--hub_repo_id", |
| | type=str, |
| | default=None, |
| | help="HuggingFace repository ID for pushing (if --push_to_hub)" |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | output_dir = Path(args.output_dir) |
| |
|
| | |
| | try: |
| | processed_dataset, statistics = process_dataset( |
| | dataset_repo_id=args.dataset_repo_id, |
| | data_dir=args.data_dir, |
| | data_column=args.data_column, |
| | output_dir=output_dir, |
| | validate=args.validate |
| | ) |
| |
|
| | |
| | print_statistics(statistics) |
| |
|
| | |
| | save_dataset(processed_dataset, output_dir, args.data_dir) |
| |
|
| | |
| | if args.push_to_hub: |
| | if not args.hub_repo_id: |
| | logger.error("--hub_repo_id required when using --push_to_hub") |
| | sys.exit(1) |
| |
|
| | logger.info(f"\nPushing to HuggingFace Hub: {args.hub_repo_id}") |
| | processed_dataset.push_to_hub(args.hub_repo_id) |
| | logger.info("Successfully pushed to Hub!") |
| |
|
| | |
| | if args.validate: |
| | all_valid = all( |
| | split_stats.get('invalid', 0) == 0 |
| | for split_stats in statistics['splits'].values() |
| | ) |
| |
|
| | if not all_valid: |
| | logger.error("\n⚠️ Some examples failed validation!") |
| | sys.exit(1) |
| | else: |
| | logger.info("\n✅ All examples validated successfully!") |
| |
|
| | logger.info("\n✅ Data preparation complete!") |
| |
|
| | except Exception as e: |
| | logger.error(f"\n❌ Error during processing: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|