File size: 5,025 Bytes
5faf2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# finetuning/data_loader.py

import sys
from typing import Dict, Any, Optional, List
from datasets import load_dataset, DatasetDict, Dataset
from transformers import PreTrainedTokenizerBase
from .utils import logger # Import logger

# Note: The original script had a commented-out section for group_texts.
# I've kept it commented out here as well, returning tokenized_datasets directly.
# If text grouping is needed, uncomment the relevant parts.

def load_and_prepare_dataset(
    dataset_repo_id: str,
    data_dir: Optional[str],
    source_column: str,
    target_column: str,
    tokenizer: PreTrainedTokenizerBase,
    block_size: int,
    eval_strategy: str # Keep for potential future use or warnings
) -> DatasetDict:
    """Loads dataset, renames column, tokenizes, and optionally groups texts."""
    logger.info(f"Loading dataset from Hub: {dataset_repo_id} (data_dir: {data_dir})")
    try:
        raw_datasets = load_dataset(dataset_repo_id, data_dir=data_dir)
        logger.info(f"Dataset loaded: {raw_datasets}")
    except Exception as e:
        logger.error(f"Failed to load dataset: {e}", exc_info=True)
        sys.exit(1)

    # --- Preprocessing Steps ---
    # 1. Rename source column to target column (e.g., 'text')
    logger.info(f"Renaming column '{source_column}' to '{target_column}' and removing others.")
    try:
        def rename_and_keep_column(example: Dict[str, Any]) -> Dict[str, Any]:
            if source_column not in example:
                raise KeyError(f"Source column '{source_column}' not found in example: {list(example.keys())}")
            return {target_column: example[source_column]}

        column_names_to_remove = {}
        for split in raw_datasets.keys():
            column_names_to_remove[split] = [name for name in raw_datasets[split].column_names if name != source_column]
            # Ensure target_column is not accidentally removed if it's the same as source_column initially
            if source_column in column_names_to_remove[split]: # Should not happen if logic is correct
                column_names_to_remove[split].remove(source_column)


        processed_datasets = DatasetDict()
        for split, original_cols in raw_datasets.items():
            cols_to_remove = [col for col in original_cols.column_names if col != source_column]
            processed_datasets[split] = raw_datasets[split].map(
                rename_and_keep_column,
                batched=False,
                remove_columns=cols_to_remove
            )
        logger.info(f"Dataset after column renaming: {processed_datasets}")

    except KeyError as e:
        logger.error(f"Error during column renaming: {e}. Ensure '{source_column}' exists.", exc_info=True)
        sys.exit(1)
    except Exception as e:
        logger.error(f"An unexpected error occurred during column renaming/cleanup: {e}", exc_info=True)
        sys.exit(1)

    # 2. Tokenize
    logger.info("Tokenizing dataset...")
    def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[Any]]:
        # Ensure tokenizer handles truncation as per original intention
        return tokenizer(examples[target_column], truncation=True, max_length=block_size if block_size else None)


    try:
        tokenized_datasets = processed_datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=processed_datasets["train"].column_names, # Removes the 'text' column
            desc="Running tokenizer on dataset",
        )
        logger.info("Tokenization complete.")
    except Exception as e:
        logger.error(f"Error during tokenization: {e}", exc_info=True)
        sys.exit(1)


    # 3. Group texts into blocks (Currently commented out in original script logic)
    # logger.info(f"Grouping texts into blocks of size: {block_size}")
    # def group_texts(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
    #     concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    #     total_length = len(concatenated["input_ids"])
    #     if total_length >= block_size:
    #         total_length = (total_length // block_size) * block_size
    #     else:
    #         logger.warning(
    #             f"Total length ({total_length}) < block_size ({block_size}), might return empty batches."
    #         )
    #     result = {
    #         k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
    #         for k, t in concatenated.items()
    #     }
    #     result["labels"] = [list(x) for x in result["input_ids"]] # Deep copy for labels
    #     return result

    # lm_datasets = tokenized_datasets.map(
    #     group_texts,
    #     batched=True,
    #     desc=f"Grouping texts into chunks of {block_size}",
    # )
    # logger.info("Grouping complete.")
    # logger.info(f"Processed dataset structure after grouping: {lm_datasets}")
    # return lm_datasets

    logger.info(f"Processed dataset structure (tokenized only): {tokenized_datasets}")
    return tokenized_datasets