|
|
|
|
|
|
|
|
import sys |
|
|
from typing import Dict, Any, Optional, List |
|
|
from datasets import load_dataset, DatasetDict, Dataset |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
from .utils import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_and_prepare_dataset( |
|
|
dataset_repo_id: str, |
|
|
data_dir: Optional[str], |
|
|
source_column: str, |
|
|
target_column: str, |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
block_size: int, |
|
|
eval_strategy: str |
|
|
) -> DatasetDict: |
|
|
"""Loads dataset, renames column, tokenizes, and optionally groups texts.""" |
|
|
logger.info(f"Loading dataset from Hub: {dataset_repo_id} (data_dir: {data_dir})") |
|
|
try: |
|
|
raw_datasets = load_dataset(dataset_repo_id, data_dir=data_dir) |
|
|
logger.info(f"Dataset loaded: {raw_datasets}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Renaming column '{source_column}' to '{target_column}' and removing others.") |
|
|
try: |
|
|
def rename_and_keep_column(example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
if source_column not in example: |
|
|
raise KeyError(f"Source column '{source_column}' not found in example: {list(example.keys())}") |
|
|
return {target_column: example[source_column]} |
|
|
|
|
|
column_names_to_remove = {} |
|
|
for split in raw_datasets.keys(): |
|
|
column_names_to_remove[split] = [name for name in raw_datasets[split].column_names if name != source_column] |
|
|
|
|
|
if source_column in column_names_to_remove[split]: |
|
|
column_names_to_remove[split].remove(source_column) |
|
|
|
|
|
|
|
|
processed_datasets = DatasetDict() |
|
|
for split, original_cols in raw_datasets.items(): |
|
|
cols_to_remove = [col for col in original_cols.column_names if col != source_column] |
|
|
processed_datasets[split] = raw_datasets[split].map( |
|
|
rename_and_keep_column, |
|
|
batched=False, |
|
|
remove_columns=cols_to_remove |
|
|
) |
|
|
logger.info(f"Dataset after column renaming: {processed_datasets}") |
|
|
|
|
|
except KeyError as e: |
|
|
logger.error(f"Error during column renaming: {e}. Ensure '{source_column}' exists.", exc_info=True) |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
logger.error(f"An unexpected error occurred during column renaming/cleanup: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
logger.info("Tokenizing dataset...") |
|
|
def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[Any]]: |
|
|
|
|
|
return tokenizer(examples[target_column], truncation=True, max_length=block_size if block_size else None) |
|
|
|
|
|
|
|
|
try: |
|
|
tokenized_datasets = processed_datasets.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=processed_datasets["train"].column_names, |
|
|
desc="Running tokenizer on dataset", |
|
|
) |
|
|
logger.info("Tokenization complete.") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during tokenization: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Processed dataset structure (tokenized only): {tokenized_datasets}") |
|
|
return tokenized_datasets |