| """ |
| Data preparation script that adds proper <|endofex|> markers to training data. |
| |
| This script processes the existing dataset and wraps expressions with end-of-expression |
| markers so the model learns to stop generation correctly. |
| |
| Usage: |
| python scripts/data/prepare_training_data_fixed.py \ |
| --dataset_repo_id augustocsc/sintetico_natural \ |
| --data_dir 700K \ |
| --data_column i_prompt_n \ |
| --output_dir ./data/processed/700K_fixed \ |
| --validate |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| from datasets import load_dataset, Dataset, DatasetDict |
| import pandas as pd |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def add_end_markers(example: Dict) -> Dict: |
| """ |
| Add end-of-expression markers to training data. |
| |
| This function: |
| 1. Locates the expression in the text (after 'expr:') |
| 2. Finds the natural end boundary (before 'vars:', newlines, etc.) |
| 3. Inserts <|endofex|> marker at the end |
| 4. Preserves any remaining content after the marker |
| |
| Args: |
| example: Dictionary containing 'text' field with training data |
| |
| Returns: |
| Dictionary with modified 'text' field containing end markers |
| """ |
| text = example['text'] |
|
|
| |
| if 'expr:' not in text: |
| logger.warning(f"No 'expr:' found in text: {text[:100]}...") |
| return {'text': text} |
|
|
| |
| parts = text.split('expr:', 1) |
| if len(parts) != 2: |
| logger.warning(f"Unexpected format in text: {text[:100]}...") |
| return {'text': text} |
|
|
| prefix = parts[0] |
| expression_part = parts[1] |
|
|
| |
| if '<|endofex|>' in expression_part: |
| logger.debug("Marker already present, skipping") |
| return {'text': text} |
|
|
| |
| end_idx = len(expression_part) |
| boundaries = ['\nvars:', '\nVariables:', '\n\n', '\nvar:', '\nVariable:'] |
|
|
| for boundary in boundaries: |
| idx = expression_part.find(boundary) |
| if idx != -1 and idx < end_idx: |
| end_idx = idx |
|
|
| |
| clean_expr = expression_part[:end_idx].strip() |
| remaining = expression_part[end_idx:] |
|
|
| |
| new_text = f"{prefix}expr: {clean_expr}<|endofex|>{remaining}" |
|
|
| return {'text': new_text} |
|
|
|
|
| def validate_markers(example: Dict) -> Dict: |
| """ |
| Validate that markers are properly present in the text. |
| |
| Args: |
| example: Dictionary containing 'text' field |
| |
| Returns: |
| Dictionary with validation metadata |
| """ |
| text = example['text'] |
| start_count = text.count('<|startofex|>') |
| end_count = text.count('<|endofex|>') |
|
|
| |
| |
| valid = end_count > 0 |
|
|
| return { |
| 'valid': valid, |
| 'start_count': start_count, |
| 'end_count': end_count, |
| 'text': text |
| } |
|
|
|
|
| def process_dataset( |
| dataset_repo_id: str, |
| data_dir: str, |
| data_column: str, |
| output_dir: Path, |
| validate: bool = True |
| ) -> Tuple[DatasetDict, Dict]: |
| """ |
| Process the dataset by adding end markers to all splits. |
| |
| Args: |
| dataset_repo_id: HuggingFace dataset repository ID |
| data_dir: Subdirectory within the dataset (e.g., '700K') |
| data_column: Column to use for training data |
| output_dir: Directory to save processed dataset |
| validate: Whether to run validation after processing |
| |
| Returns: |
| Tuple of (processed_dataset, statistics) |
| """ |
| logger.info(f"Loading dataset from {dataset_repo_id}/{data_dir}...") |
|
|
| try: |
| |
| dataset = load_dataset( |
| dataset_repo_id, |
| data_dir=data_dir, |
| split=None |
| ) |
|
|
| if not isinstance(dataset, dict): |
| |
| dataset = {'train': dataset} |
|
|
| logger.info(f"Loaded {len(dataset)} split(s): {list(dataset.keys())}") |
|
|
| |
| if 'train' in dataset and len(dataset['train']) > 0: |
| logger.info(f"\nSample BEFORE processing:") |
| logger.info(f"{dataset['train'][0][data_column][:200]}...") |
|
|
| except Exception as e: |
| logger.error(f"Failed to load dataset: {e}") |
| raise |
|
|
| |
| processed_dataset = {} |
| statistics = { |
| 'total_examples': 0, |
| 'processed_examples': 0, |
| 'already_marked': 0, |
| 'splits': {} |
| } |
|
|
| for split_name, split_data in dataset.items(): |
| logger.info(f"\nProcessing {split_name} split ({len(split_data)} examples)...") |
|
|
| |
| if data_column != 'text': |
| split_data = split_data.rename_column(data_column, 'text') |
|
|
| |
| already_marked = sum(1 for ex in split_data if '<|endofex|>' in ex['text']) |
| statistics['already_marked'] += already_marked |
|
|
| if already_marked > 0: |
| logger.info(f"Found {already_marked} examples already with markers") |
|
|
| |
| processed_split = split_data.map( |
| add_end_markers, |
| desc=f"Adding markers to {split_name}" |
| ) |
|
|
| processed_dataset[split_name] = processed_split |
|
|
| |
| split_stats = { |
| 'total': len(split_data), |
| 'processed': len(processed_split), |
| 'already_marked': already_marked |
| } |
| statistics['splits'][split_name] = split_stats |
| statistics['total_examples'] += len(split_data) |
| statistics['processed_examples'] += len(processed_split) |
|
|
| |
| if len(processed_split) > 0: |
| logger.info(f"\nSample AFTER processing:") |
| logger.info(f"{processed_split[0]['text'][:200]}...") |
|
|
| |
| if validate: |
| logger.info("\n" + "="*60) |
| logger.info("VALIDATION") |
| logger.info("="*60) |
|
|
| for split_name, split_data in processed_dataset.items(): |
| logger.info(f"\nValidating {split_name} split...") |
|
|
| |
| validated = split_data.map(validate_markers) |
|
|
| |
| valid_count = sum(validated['valid']) |
| invalid_count = len(validated) - valid_count |
|
|
| valid_rate = valid_count / len(validated) * 100 |
|
|
| logger.info(f"Valid examples: {valid_count}/{len(validated)} ({valid_rate:.1f}%)") |
|
|
| if invalid_count > 0: |
| logger.warning(f"Found {invalid_count} invalid examples!") |
|
|
| |
| invalid_examples = [ |
| ex for ex in validated if not ex['valid'] |
| ][:3] |
|
|
| for i, ex in enumerate(invalid_examples): |
| logger.warning(f"\nInvalid example {i+1}:") |
| logger.warning(f"Start markers: {ex['start_count']}") |
| logger.warning(f"End markers: {ex['end_count']}") |
| logger.warning(f"Text: {ex['text'][:200]}...") |
|
|
| |
| statistics['splits'][split_name]['valid'] = valid_count |
| statistics['splits'][split_name]['invalid'] = invalid_count |
| statistics['splits'][split_name]['valid_rate'] = valid_rate |
|
|
| |
| processed_dataset = DatasetDict(processed_dataset) |
|
|
| return processed_dataset, statistics |
|
|
|
|
| def save_dataset(dataset: DatasetDict, output_dir: Path, data_dir: str): |
| """ |
| Save processed dataset to local directory. |
| |
| Args: |
| dataset: Processed dataset to save |
| output_dir: Directory to save to |
| data_dir: Original data directory name (for filename) |
| """ |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| logger.info(f"\nSaving processed dataset to {output_dir}...") |
|
|
| for split_name, split_data in dataset.items(): |
| |
| output_file = output_dir / f"{split_name}_{data_dir}.csv" |
|
|
| |
| df = split_data.to_pandas() |
| df.to_csv(output_file, index=False) |
|
|
| logger.info(f"Saved {split_name} split: {output_file} ({len(df)} examples)") |
|
|
| logger.info("Dataset saved successfully!") |
|
|
|
|
| def print_statistics(statistics: Dict): |
| """ |
| Print processing statistics in a formatted table. |
| |
| Args: |
| statistics: Dictionary containing processing statistics |
| """ |
| logger.info("\n" + "="*60) |
| logger.info("PROCESSING STATISTICS") |
| logger.info("="*60) |
|
|
| logger.info(f"\nTotal examples: {statistics['total_examples']}") |
| logger.info(f"Processed examples: {statistics['processed_examples']}") |
| logger.info(f"Already marked: {statistics['already_marked']}") |
|
|
| logger.info("\nPer-split statistics:") |
| logger.info("-"*60) |
|
|
| for split_name, split_stats in statistics['splits'].items(): |
| logger.info(f"\n{split_name.upper()}:") |
| logger.info(f" Total: {split_stats['total']}") |
| logger.info(f" Processed: {split_stats['processed']}") |
| logger.info(f" Already marked: {split_stats.get('already_marked', 0)}") |
|
|
| if 'valid' in split_stats: |
| logger.info(f" Valid: {split_stats['valid']}") |
| logger.info(f" Invalid: {split_stats['invalid']}") |
| logger.info(f" Valid rate: {split_stats['valid_rate']:.1f}%") |
|
|
| logger.info("="*60) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Prepare training data with proper end-of-expression markers" |
| ) |
| parser.add_argument( |
| "--dataset_repo_id", |
| type=str, |
| required=True, |
| help="HuggingFace dataset repository ID" |
| ) |
| parser.add_argument( |
| "--data_dir", |
| type=str, |
| required=True, |
| help="Subdirectory within the dataset (e.g., '700K')" |
| ) |
| parser.add_argument( |
| "--data_column", |
| type=str, |
| required=True, |
| help="Column to use for training data (e.g., 'i_prompt_n')" |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| required=True, |
| help="Directory to save processed dataset" |
| ) |
| parser.add_argument( |
| "--validate", |
| action="store_true", |
| help="Run validation after processing" |
| ) |
| parser.add_argument( |
| "--push_to_hub", |
| action="store_true", |
| help="Push processed dataset to HuggingFace Hub" |
| ) |
| parser.add_argument( |
| "--hub_repo_id", |
| type=str, |
| default=None, |
| help="HuggingFace repository ID for pushing (if --push_to_hub)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| output_dir = Path(args.output_dir) |
|
|
| |
| try: |
| processed_dataset, statistics = process_dataset( |
| dataset_repo_id=args.dataset_repo_id, |
| data_dir=args.data_dir, |
| data_column=args.data_column, |
| output_dir=output_dir, |
| validate=args.validate |
| ) |
|
|
| |
| print_statistics(statistics) |
|
|
| |
| save_dataset(processed_dataset, output_dir, args.data_dir) |
|
|
| |
| if args.push_to_hub: |
| if not args.hub_repo_id: |
| logger.error("--hub_repo_id required when using --push_to_hub") |
| sys.exit(1) |
|
|
| logger.info(f"\nPushing to HuggingFace Hub: {args.hub_repo_id}") |
| processed_dataset.push_to_hub(args.hub_repo_id) |
| logger.info("Successfully pushed to Hub!") |
|
|
| |
| if args.validate: |
| all_valid = all( |
| split_stats.get('invalid', 0) == 0 |
| for split_stats in statistics['splits'].values() |
| ) |
|
|
| if not all_valid: |
| logger.error("\n⚠️ Some examples failed validation!") |
| sys.exit(1) |
| else: |
| logger.info("\n✅ All examples validated successfully!") |
|
|
| logger.info("\n✅ Data preparation complete!") |
|
|
| except Exception as e: |
| logger.error(f"\n❌ Error during processing: {e}") |
| import traceback |
| traceback.print_exc() |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|