File size: 8,484 Bytes
35c7764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import pickle
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_from_disk

from .collators import DataCollatorForMultitaskCellClassification


class StreamingMultiTaskDataset(Dataset):
    
    def __init__(self, dataset_path, config, is_test=False, dataset_type=""):
        """Initialize the streaming dataset."""
        self.dataset = load_from_disk(dataset_path)
        self.config = config
        self.is_test = is_test
        self.dataset_type = dataset_type
        self.cell_id_mapping = {}
        
        # Setup task and column mappings
        self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
        self.task_to_column = dict(zip(self.task_names, config["task_columns"]))
        config["task_names"] = self.task_names
        
        # Check if unique_cell_id column exists in the dataset
        self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names
        print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset")
        
        # Setup label mappings
        self.label_mappings_path = os.path.join(
            config["results_dir"],
            f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
        )
        
        if not is_test:
            self._validate_columns()
            self.task_label_mappings, self.num_labels_list = self._create_label_mappings()
            self._save_label_mappings()
        else:
            # Load existing mappings for test data
            self.task_label_mappings = self._load_label_mappings()
            self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()]
    
    def _validate_columns(self):
        """Ensures required columns are present in the dataset."""
        missing_columns = [col for col in self.task_to_column.values() 
                          if col not in self.dataset.column_names]
        if missing_columns:
            raise KeyError(
                f"Missing columns in {self.dataset_type} dataset: {missing_columns}. "
                f"Available columns: {self.dataset.column_names}"
            )
    
    def _create_label_mappings(self):
        """Creates label mappings for the dataset."""
        task_label_mappings = {}
        num_labels_list = []
        
        for task, column in self.task_to_column.items():
            unique_values = sorted(set(self.dataset[column]))
            mapping = {label: idx for idx, label in enumerate(unique_values)}
            task_label_mappings[task] = mapping
            num_labels_list.append(len(unique_values))
        
        return task_label_mappings, num_labels_list
    
    def _save_label_mappings(self):
        """Saves label mappings to a pickle file."""
        with open(self.label_mappings_path, "wb") as f:
            pickle.dump(self.task_label_mappings, f)
    
    def _load_label_mappings(self):
        """Loads label mappings from a pickle file."""
        with open(self.label_mappings_path, "rb") as f:
            return pickle.load(f)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        record = self.dataset[idx]
        
        # Store cell ID mapping
        if self.has_unique_cell_ids:
            unique_cell_id = record["unique_cell_id"]
            self.cell_id_mapping[idx] = unique_cell_id
        else:
            self.cell_id_mapping[idx] = f"cell_{idx}"
        
        # Create transformed record
        transformed_record = {
            "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
            "cell_id": idx,
        }
        
        # Add labels
        if not self.is_test:
            label_dict = {
                task: self.task_label_mappings[task][record[column]]
                for task, column in self.task_to_column.items()
            }
        else:
            label_dict = {task: -1 for task in self.config["task_names"]}
        
        transformed_record["label"] = label_dict
        
        return transformed_record


def get_data_loader(dataset, batch_size, sampler=None, shuffle=True):
    """Create a DataLoader with the given dataset and parameters."""
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle if sampler is None else False,
        num_workers=0,
        pin_memory=True,
        collate_fn=DataCollatorForMultitaskCellClassification(),
    )


def prepare_data_loaders(config, include_test=False):
    """Prepare data loaders for training, validation, and optionally test."""
    result = {}
    
    # Process train data
    train_dataset = StreamingMultiTaskDataset(
        config["train_path"], 
        config, 
        dataset_type="train"
    )
    result["train_loader"] = get_data_loader(train_dataset, config["batch_size"])
    
    # Store the cell ID mapping from the dataset
    result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()}
    print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset")
    
    result["num_labels_list"] = train_dataset.num_labels_list
    
    # Process validation data
    val_dataset = StreamingMultiTaskDataset(
        config["val_path"], 
        config, 
        dataset_type="validation"
    )
    result["val_loader"] = get_data_loader(val_dataset, config["batch_size"])
    
    # Store the complete cell ID mapping for validation
    for idx in range(len(val_dataset)):
        _ = val_dataset[idx]
    
    result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()}
    print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset")
    
    # Validate label mappings
    validate_label_mappings(config)
    
    # Process test data if requested
    if include_test and "test_path" in config:
        test_dataset = StreamingMultiTaskDataset(
            config["test_path"], 
            config, 
            is_test=True,
            dataset_type="test"
        )
        result["test_loader"] = get_data_loader(test_dataset, config["batch_size"])
        
        for idx in range(len(test_dataset)):
            _ = test_dataset[idx]
        
        result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()}
        print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset")
    
    return result


def validate_label_mappings(config):
    """Ensures train and validation label mappings are consistent."""
    train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
    val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
    
    with open(train_mappings_path, "rb") as f:
        train_mappings = pickle.load(f)
    
    with open(val_mappings_path, "rb") as f:
        val_mappings = pickle.load(f)

    for task_name in config["task_names"]:
        if train_mappings[task_name] != val_mappings[task_name]:
            raise ValueError(
                f"Mismatch in label mappings for task '{task_name}'.\n"
                f"Train Mapping: {train_mappings[task_name]}\n"
                f"Validation Mapping: {val_mappings[task_name]}"
            )


# Legacy functions for backward compatibility
def preload_and_process_data(config):
    """Preloads and preprocesses train and validation datasets."""
    data = prepare_data_loaders(config)
    
    return (
        data["train_loader"].dataset,
        data["train_cell_mapping"],
        data["val_loader"].dataset,
        data["val_cell_mapping"],
        data["num_labels_list"]
    )


def preload_data(config):
    """Preprocesses train and validation data for trials."""
    data = prepare_data_loaders(config)
    return data["train_loader"], data["val_loader"]


def load_and_preprocess_test_data(config):
    """Loads and preprocesses test data."""
    test_dataset = StreamingMultiTaskDataset(
        config["test_path"], 
        config, 
        is_test=True,
        dataset_type="test"
    )
    
    return (
        test_dataset,
        test_dataset.cell_id_mapping,
        test_dataset.num_labels_list
    )


def prepare_test_loader(config):
    """Prepares DataLoader for test data."""
    data = prepare_data_loaders(config, include_test=True)
    return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"]