|
|
import os |
|
|
import pickle |
|
|
import torch |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from datasets import load_from_disk |
|
|
|
|
|
from .collators import DataCollatorForMultitaskCellClassification |
|
|
|
|
|
|
|
|
class StreamingMultiTaskDataset(Dataset): |
|
|
|
|
|
def __init__(self, dataset_path, config, is_test=False, dataset_type=""): |
|
|
"""Initialize the streaming dataset.""" |
|
|
self.dataset = load_from_disk(dataset_path) |
|
|
self.config = config |
|
|
self.is_test = is_test |
|
|
self.dataset_type = dataset_type |
|
|
self.cell_id_mapping = {} |
|
|
|
|
|
|
|
|
self.task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] |
|
|
self.task_to_column = dict(zip(self.task_names, config["task_columns"])) |
|
|
config["task_names"] = self.task_names |
|
|
|
|
|
|
|
|
self.has_unique_cell_ids = "unique_cell_id" in self.dataset.column_names |
|
|
print(f"{'Found' if self.has_unique_cell_ids else 'No'} unique_cell_id column in {dataset_type} dataset") |
|
|
|
|
|
|
|
|
self.label_mappings_path = os.path.join( |
|
|
config["results_dir"], |
|
|
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl" |
|
|
) |
|
|
|
|
|
if not is_test: |
|
|
self._validate_columns() |
|
|
self.task_label_mappings, self.num_labels_list = self._create_label_mappings() |
|
|
self._save_label_mappings() |
|
|
else: |
|
|
|
|
|
self.task_label_mappings = self._load_label_mappings() |
|
|
self.num_labels_list = [len(mapping) for mapping in self.task_label_mappings.values()] |
|
|
|
|
|
def _validate_columns(self): |
|
|
"""Ensures required columns are present in the dataset.""" |
|
|
missing_columns = [col for col in self.task_to_column.values() |
|
|
if col not in self.dataset.column_names] |
|
|
if missing_columns: |
|
|
raise KeyError( |
|
|
f"Missing columns in {self.dataset_type} dataset: {missing_columns}. " |
|
|
f"Available columns: {self.dataset.column_names}" |
|
|
) |
|
|
|
|
|
def _create_label_mappings(self): |
|
|
"""Creates label mappings for the dataset.""" |
|
|
task_label_mappings = {} |
|
|
num_labels_list = [] |
|
|
|
|
|
for task, column in self.task_to_column.items(): |
|
|
unique_values = sorted(set(self.dataset[column])) |
|
|
mapping = {label: idx for idx, label in enumerate(unique_values)} |
|
|
task_label_mappings[task] = mapping |
|
|
num_labels_list.append(len(unique_values)) |
|
|
|
|
|
return task_label_mappings, num_labels_list |
|
|
|
|
|
def _save_label_mappings(self): |
|
|
"""Saves label mappings to a pickle file.""" |
|
|
with open(self.label_mappings_path, "wb") as f: |
|
|
pickle.dump(self.task_label_mappings, f) |
|
|
|
|
|
def _load_label_mappings(self): |
|
|
"""Loads label mappings from a pickle file.""" |
|
|
with open(self.label_mappings_path, "rb") as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
record = self.dataset[idx] |
|
|
|
|
|
|
|
|
if self.has_unique_cell_ids: |
|
|
unique_cell_id = record["unique_cell_id"] |
|
|
self.cell_id_mapping[idx] = unique_cell_id |
|
|
else: |
|
|
self.cell_id_mapping[idx] = f"cell_{idx}" |
|
|
|
|
|
|
|
|
transformed_record = { |
|
|
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long), |
|
|
"cell_id": idx, |
|
|
} |
|
|
|
|
|
|
|
|
if not self.is_test: |
|
|
label_dict = { |
|
|
task: self.task_label_mappings[task][record[column]] |
|
|
for task, column in self.task_to_column.items() |
|
|
} |
|
|
else: |
|
|
label_dict = {task: -1 for task in self.config["task_names"]} |
|
|
|
|
|
transformed_record["label"] = label_dict |
|
|
|
|
|
return transformed_record |
|
|
|
|
|
|
|
|
def get_data_loader(dataset, batch_size, sampler=None, shuffle=True): |
|
|
"""Create a DataLoader with the given dataset and parameters.""" |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
sampler=sampler, |
|
|
shuffle=shuffle if sampler is None else False, |
|
|
num_workers=0, |
|
|
pin_memory=True, |
|
|
collate_fn=DataCollatorForMultitaskCellClassification(), |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_data_loaders(config, include_test=False): |
|
|
"""Prepare data loaders for training, validation, and optionally test.""" |
|
|
result = {} |
|
|
|
|
|
|
|
|
train_dataset = StreamingMultiTaskDataset( |
|
|
config["train_path"], |
|
|
config, |
|
|
dataset_type="train" |
|
|
) |
|
|
result["train_loader"] = get_data_loader(train_dataset, config["batch_size"]) |
|
|
|
|
|
|
|
|
result["train_cell_mapping"] = {k: v for k, v in train_dataset.cell_id_mapping.items()} |
|
|
print(f"Collected {len(result['train_cell_mapping'])} cell IDs from training dataset") |
|
|
|
|
|
result["num_labels_list"] = train_dataset.num_labels_list |
|
|
|
|
|
|
|
|
val_dataset = StreamingMultiTaskDataset( |
|
|
config["val_path"], |
|
|
config, |
|
|
dataset_type="validation" |
|
|
) |
|
|
result["val_loader"] = get_data_loader(val_dataset, config["batch_size"]) |
|
|
|
|
|
|
|
|
for idx in range(len(val_dataset)): |
|
|
_ = val_dataset[idx] |
|
|
|
|
|
result["val_cell_mapping"] = {k: v for k, v in val_dataset.cell_id_mapping.items()} |
|
|
print(f"Collected {len(result['val_cell_mapping'])} cell IDs from validation dataset") |
|
|
|
|
|
|
|
|
validate_label_mappings(config) |
|
|
|
|
|
|
|
|
if include_test and "test_path" in config: |
|
|
test_dataset = StreamingMultiTaskDataset( |
|
|
config["test_path"], |
|
|
config, |
|
|
is_test=True, |
|
|
dataset_type="test" |
|
|
) |
|
|
result["test_loader"] = get_data_loader(test_dataset, config["batch_size"]) |
|
|
|
|
|
for idx in range(len(test_dataset)): |
|
|
_ = test_dataset[idx] |
|
|
|
|
|
result["test_cell_mapping"] = {k: v for k, v in test_dataset.cell_id_mapping.items()} |
|
|
print(f"Collected {len(result['test_cell_mapping'])} cell IDs from test dataset") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def validate_label_mappings(config): |
|
|
"""Ensures train and validation label mappings are consistent.""" |
|
|
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl") |
|
|
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl") |
|
|
|
|
|
with open(train_mappings_path, "rb") as f: |
|
|
train_mappings = pickle.load(f) |
|
|
|
|
|
with open(val_mappings_path, "rb") as f: |
|
|
val_mappings = pickle.load(f) |
|
|
|
|
|
for task_name in config["task_names"]: |
|
|
if train_mappings[task_name] != val_mappings[task_name]: |
|
|
raise ValueError( |
|
|
f"Mismatch in label mappings for task '{task_name}'.\n" |
|
|
f"Train Mapping: {train_mappings[task_name]}\n" |
|
|
f"Validation Mapping: {val_mappings[task_name]}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def preload_and_process_data(config): |
|
|
"""Preloads and preprocesses train and validation datasets.""" |
|
|
data = prepare_data_loaders(config) |
|
|
|
|
|
return ( |
|
|
data["train_loader"].dataset, |
|
|
data["train_cell_mapping"], |
|
|
data["val_loader"].dataset, |
|
|
data["val_cell_mapping"], |
|
|
data["num_labels_list"] |
|
|
) |
|
|
|
|
|
|
|
|
def preload_data(config): |
|
|
"""Preprocesses train and validation data for trials.""" |
|
|
data = prepare_data_loaders(config) |
|
|
return data["train_loader"], data["val_loader"] |
|
|
|
|
|
|
|
|
def load_and_preprocess_test_data(config): |
|
|
"""Loads and preprocesses test data.""" |
|
|
test_dataset = StreamingMultiTaskDataset( |
|
|
config["test_path"], |
|
|
config, |
|
|
is_test=True, |
|
|
dataset_type="test" |
|
|
) |
|
|
|
|
|
return ( |
|
|
test_dataset, |
|
|
test_dataset.cell_id_mapping, |
|
|
test_dataset.num_labels_list |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_test_loader(config): |
|
|
"""Prepares DataLoader for test data.""" |
|
|
data = prepare_data_loaders(config, include_test=True) |
|
|
return data["test_loader"], data["test_cell_mapping"], data["num_labels_list"] |