| |
|
|
| import sys |
| from typing import Dict, Any, Optional, List |
| from datasets import load_dataset, DatasetDict, Dataset |
| from transformers import PreTrainedTokenizerBase |
| from .utils import logger |
|
|
| |
| |
| |
|
|
| def load_and_prepare_dataset( |
| dataset_repo_id: str, |
| data_dir: Optional[str], |
| source_column: str, |
| target_column: str, |
| tokenizer: PreTrainedTokenizerBase, |
| block_size: int, |
| eval_strategy: str |
| ) -> DatasetDict: |
| """Loads dataset, renames column, tokenizes, and optionally groups texts.""" |
| logger.info(f"Loading dataset from Hub: {dataset_repo_id} (data_dir: {data_dir})") |
| try: |
| raw_datasets = load_dataset(dataset_repo_id, data_dir=data_dir) |
| logger.info(f"Dataset loaded: {raw_datasets}") |
| except Exception as e: |
| logger.error(f"Failed to load dataset: {e}", exc_info=True) |
| sys.exit(1) |
|
|
| |
| |
| logger.info(f"Renaming column '{source_column}' to '{target_column}' and removing others.") |
| try: |
| def rename_and_keep_column(example: Dict[str, Any]) -> Dict[str, Any]: |
| if source_column not in example: |
| raise KeyError(f"Source column '{source_column}' not found in example: {list(example.keys())}") |
| return {target_column: example[source_column]} |
|
|
| column_names_to_remove = {} |
| for split in raw_datasets.keys(): |
| column_names_to_remove[split] = [name for name in raw_datasets[split].column_names if name != source_column] |
| |
| if source_column in column_names_to_remove[split]: |
| column_names_to_remove[split].remove(source_column) |
|
|
|
|
| processed_datasets = DatasetDict() |
| for split, original_cols in raw_datasets.items(): |
| cols_to_remove = [col for col in original_cols.column_names if col != source_column] |
| processed_datasets[split] = raw_datasets[split].map( |
| rename_and_keep_column, |
| batched=False, |
| remove_columns=cols_to_remove |
| ) |
| logger.info(f"Dataset after column renaming: {processed_datasets}") |
|
|
| except KeyError as e: |
| logger.error(f"Error during column renaming: {e}. Ensure '{source_column}' exists.", exc_info=True) |
| sys.exit(1) |
| except Exception as e: |
| logger.error(f"An unexpected error occurred during column renaming/cleanup: {e}", exc_info=True) |
| sys.exit(1) |
|
|
| |
| logger.info("Tokenizing dataset...") |
| def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[Any]]: |
| |
| return tokenizer(examples[target_column], truncation=True, max_length=block_size if block_size else None) |
|
|
|
|
| try: |
| tokenized_datasets = processed_datasets.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=processed_datasets["train"].column_names, |
| desc="Running tokenizer on dataset", |
| ) |
| logger.info("Tokenization complete.") |
| except Exception as e: |
| logger.error(f"Error during tokenization: {e}", exc_info=True) |
| sys.exit(1) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| logger.info(f"Processed dataset structure (tokenized only): {tokenized_datasets}") |
| return tokenized_datasets |