|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader, IterableDataset |
|
|
from datasets import load_dataset, concatenate_datasets, interleave_datasets |
|
|
from typing import Dict, List, Optional, Any, Union |
|
|
import random |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import warnings |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from torchvision import transforms |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
from data_config import ( |
|
|
PRETRAIN_DATASETS, |
|
|
POSTTRAIN_DATASETS, |
|
|
TEST_DATASETS, |
|
|
PRETRAIN_MIX, |
|
|
POSTTRAIN_MIX, |
|
|
PREPROCESSING_CONFIG, |
|
|
DATASET_CACHE_DIR, |
|
|
HF_CACHE_DIR |
|
|
) |
|
|
|
|
|
|
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
class PreTrainDataset(IterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
mix_name: str = 'default', |
|
|
tokenizer=None, |
|
|
max_length: int = 2048, |
|
|
streaming: bool = True, |
|
|
seed: int = 42, |
|
|
max_samples: Optional[int] = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if tokenizer is None: |
|
|
raise ValueError("tokenizer cannot be None") |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.streaming = streaming |
|
|
self.seed = seed |
|
|
self.max_samples = max_samples |
|
|
self.samples_generated = 0 |
|
|
|
|
|
|
|
|
if mix_name not in PRETRAIN_MIX: |
|
|
raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}") |
|
|
|
|
|
mix_config = PRETRAIN_MIX[mix_name] |
|
|
dataset_names = mix_config.get('datasets', []) |
|
|
weights = mix_config.get('weights', []) |
|
|
|
|
|
if not dataset_names: |
|
|
raise ValueError(f"No datasets found in mix: {mix_name}") |
|
|
|
|
|
logger.info(f"Loading pretrain mix: {mix_name}") |
|
|
logger.info(f" Datasets: {dataset_names}") |
|
|
logger.info(f" Weights: {weights}") |
|
|
|
|
|
|
|
|
self.datasets = [] |
|
|
self.probabilities = [] |
|
|
|
|
|
for name, weight in zip(dataset_names, weights): |
|
|
if name not in PRETRAIN_DATASETS: |
|
|
logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping") |
|
|
continue |
|
|
|
|
|
config = PRETRAIN_DATASETS[name] |
|
|
try: |
|
|
ds = self._load_dataset(config) |
|
|
if ds is not None: |
|
|
self.datasets.append((name, ds, config)) |
|
|
self.probabilities.append(weight) |
|
|
logger.info(f" Successfully loaded {name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading {name}: {e}") |
|
|
continue |
|
|
|
|
|
if not self.datasets: |
|
|
raise ValueError("No datasets loaded successfully") |
|
|
|
|
|
|
|
|
total = sum(self.probabilities) |
|
|
self.probabilities = [p / total for p in self.probabilities] |
|
|
|
|
|
logger.info(f"Successfully loaded {len(self.datasets)} datasets") |
|
|
|
|
|
def _load_dataset(self, config: Dict): |
|
|
try: |
|
|
load_kwargs = { |
|
|
'path': config['hf_path'], |
|
|
'split': config.get('split', 'train'), |
|
|
'streaming': config.get('streaming', self.streaming), |
|
|
'cache_dir': HF_CACHE_DIR, |
|
|
} |
|
|
|
|
|
if 'config' in config: |
|
|
load_kwargs['name'] = config['config'] |
|
|
|
|
|
ds = load_dataset(**load_kwargs) |
|
|
return ds |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}") |
|
|
return None |
|
|
|
|
|
def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]: |
|
|
try: |
|
|
text_field = config.get('text_field', 'text') |
|
|
text = sample.get(text_field, '') |
|
|
|
|
|
if not text or not isinstance(text, str): |
|
|
return None |
|
|
|
|
|
text = text.strip() |
|
|
if len(text) < 10: |
|
|
return None |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].squeeze(0), |
|
|
'attention_mask': encoding['attention_mask'].squeeze(0), |
|
|
'type': 'text' |
|
|
} |
|
|
except Exception as e: |
|
|
logger.debug(f"Error processing text sample: {e}") |
|
|
return None |
|
|
|
|
|
def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]: |
|
|
try: |
|
|
text_field = config.get('text_field', 'caption') |
|
|
image_field = config.get('image_field', 'image') |
|
|
|
|
|
text = sample.get(text_field, '') |
|
|
image = sample.get(image_field) |
|
|
|
|
|
if not text or image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
if isinstance(image, str): |
|
|
try: |
|
|
response = requests.get(image, timeout=5) |
|
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
|
except Exception as img_error: |
|
|
logger.debug(f"Failed to load image from URL: {img_error}") |
|
|
return None |
|
|
elif isinstance(image, Image.Image): |
|
|
image = image.convert('RGB') |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
image_tensor = image_transform(image) |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].squeeze(0), |
|
|
'attention_mask': encoding['attention_mask'].squeeze(0), |
|
|
'image': image_tensor, |
|
|
'type': 'image_text' |
|
|
} |
|
|
except Exception as e: |
|
|
logger.debug(f"Error processing image-text sample: {e}") |
|
|
return None |
|
|
|
|
|
def __iter__(self): |
|
|
"""迭代器""" |
|
|
worker_info = torch.utils.data.get_worker_info() |
|
|
if worker_info is not None: |
|
|
|
|
|
random.seed(self.seed + worker_info.id) |
|
|
np.random.seed(self.seed + worker_info.id) |
|
|
else: |
|
|
random.seed(self.seed) |
|
|
np.random.seed(self.seed) |
|
|
|
|
|
|
|
|
iterators = [iter(ds) for _, ds, _ in self.datasets] |
|
|
self.samples_generated = 0 |
|
|
|
|
|
while True: |
|
|
|
|
|
if self.max_samples and self.samples_generated >= self.max_samples: |
|
|
break |
|
|
|
|
|
try: |
|
|
|
|
|
idx = np.random.choice(len(self.datasets), p=self.probabilities) |
|
|
name, _, config = self.datasets[idx] |
|
|
|
|
|
|
|
|
sample = next(iterators[idx]) |
|
|
|
|
|
|
|
|
processed = None |
|
|
if config.get('type') in ['text', 'code']: |
|
|
processed = self._process_text_sample(sample, config) |
|
|
elif config.get('type') == 'image_text': |
|
|
processed = self._process_image_text_sample(sample, config) |
|
|
else: |
|
|
logger.debug(f"Unknown type: {config.get('type')}") |
|
|
continue |
|
|
|
|
|
if processed is not None: |
|
|
self.samples_generated += 1 |
|
|
yield processed |
|
|
|
|
|
except StopIteration: |
|
|
|
|
|
try: |
|
|
iterators[idx] = iter(self.datasets[idx][1]) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to recreate iterator for dataset {idx}: {e}") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.debug(f"Error in iterator: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
class PostTrainDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
mix_name: str = 'default', |
|
|
tokenizer=None, |
|
|
max_length: int = 2048, |
|
|
max_samples: Optional[int] = None, |
|
|
split: str = 'train' |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if tokenizer is None: |
|
|
raise ValueError("tokenizer cannot be None") |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.split = split |
|
|
|
|
|
|
|
|
if mix_name not in POSTTRAIN_MIX: |
|
|
raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}") |
|
|
|
|
|
mix_config = POSTTRAIN_MIX[mix_name] |
|
|
dataset_names = mix_config.get('datasets', []) |
|
|
weights = mix_config.get('weights', []) |
|
|
|
|
|
if not dataset_names: |
|
|
raise ValueError(f"No datasets found in mix: {mix_name}") |
|
|
|
|
|
logger.info(f"Loading posttrain mix: {mix_name}") |
|
|
logger.info(f" Datasets: {dataset_names}") |
|
|
|
|
|
|
|
|
all_datasets = [] |
|
|
|
|
|
for name in dataset_names: |
|
|
if name not in POSTTRAIN_DATASETS: |
|
|
logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS") |
|
|
continue |
|
|
|
|
|
config = POSTTRAIN_DATASETS[name] |
|
|
try: |
|
|
load_kwargs = { |
|
|
'path': config['hf_path'], |
|
|
'split': split, |
|
|
'streaming': config.get('streaming', False), |
|
|
'cache_dir': HF_CACHE_DIR, |
|
|
} |
|
|
if 'data_files' in config: |
|
|
load_kwargs['data_files'] = config['data_files'] |
|
|
if 'config' in config: |
|
|
load_kwargs['name'] = config['config'] |
|
|
|
|
|
ds = load_dataset(**load_kwargs) |
|
|
|
|
|
|
|
|
if config.get('max_samples'): |
|
|
if hasattr(ds, 'take'): |
|
|
ds = ds.take(config['max_samples']) |
|
|
elif hasattr(ds, 'select'): |
|
|
ds = ds.select(range(min(len(ds), config['max_samples']))) |
|
|
|
|
|
|
|
|
def add_source(example): |
|
|
example['_source'] = name |
|
|
example['_config'] = config |
|
|
return example |
|
|
|
|
|
ds = ds.map(add_source) |
|
|
all_datasets.append(ds) |
|
|
|
|
|
ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming' |
|
|
logger.info(f" Loaded {name}: {ds_len} samples") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading {name}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if not all_datasets: |
|
|
raise ValueError("No datasets loaded successfully") |
|
|
|
|
|
if len(all_datasets) == 1: |
|
|
self.dataset = all_datasets[0] |
|
|
else: |
|
|
|
|
|
probabilities = [w / sum(weights[:len(all_datasets)]) |
|
|
for w in weights[:len(all_datasets)]] |
|
|
self.dataset = interleave_datasets( |
|
|
all_datasets, |
|
|
probabilities=probabilities, |
|
|
seed=42, |
|
|
stopping_strategy='all_exhausted' |
|
|
) |
|
|
|
|
|
|
|
|
if max_samples and hasattr(self.dataset, '__len__'): |
|
|
actual_len = min(len(self.dataset), max_samples) |
|
|
self.dataset = self.dataset.select(range(actual_len)) |
|
|
|
|
|
dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming' |
|
|
logger.info(f"Total samples: {dataset_len}") |
|
|
|
|
|
def _format_instruction(self, sample: Dict, config: Dict) -> str: |
|
|
"""格式化instruction""" |
|
|
try: |
|
|
data_type = config.get('type', 'instruction') |
|
|
|
|
|
if data_type == 'instruction': |
|
|
instruction_field = config.get('instruction_field', 'instruction') |
|
|
input_field = config.get('input_field', 'input') |
|
|
context_field = config.get('context_field', None) |
|
|
|
|
|
instruction = sample.get(instruction_field, '') |
|
|
input_text = sample.get(input_field, '') |
|
|
context = sample.get(context_field, '') if context_field else '' |
|
|
|
|
|
|
|
|
prompt_parts = [f"Instruction: {instruction}"] |
|
|
|
|
|
if context: |
|
|
prompt_parts.append(f"Context: {context}") |
|
|
|
|
|
if input_text: |
|
|
prompt_parts.append(f"Input: {input_text}") |
|
|
|
|
|
prompt_parts.append("Response:") |
|
|
return "\n".join(prompt_parts) |
|
|
|
|
|
elif data_type == 'conversation': |
|
|
if 'conversations' in sample: |
|
|
conversations = sample['conversations'] |
|
|
if isinstance(conversations, list) and len(conversations) > 0: |
|
|
dialogue = [] |
|
|
for conv in conversations[:-1]: |
|
|
role = conv.get('from', 'user') |
|
|
content = conv.get('value', '') |
|
|
dialogue.append(f"{role}: {content}") |
|
|
return "\n".join(dialogue) + "\nassistant:" |
|
|
|
|
|
elif 'messages' in sample: |
|
|
|
|
|
messages = sample['messages'] |
|
|
if isinstance(messages, list) and len(messages) > 0: |
|
|
dialogue = [] |
|
|
for msg in messages[:-1]: |
|
|
role = msg.get('role', 'user') |
|
|
content = msg.get('content', '') |
|
|
dialogue.append(f"{role}: {content}") |
|
|
return "\n".join(dialogue) + "\nassistant:" |
|
|
|
|
|
|
|
|
return sample.get('text', '') |
|
|
|
|
|
elif data_type == 'code_instruction': |
|
|
|
|
|
instruction_field = config.get('instruction_field', 'instruction') |
|
|
instruction = sample.get(instruction_field, '') |
|
|
return f"### Instruction:\n{instruction}\n### Response:" |
|
|
|
|
|
elif data_type == 'multimodal_instruction': |
|
|
|
|
|
instruction_field = config.get('instruction_field', 'conversations') |
|
|
conversations = sample.get(instruction_field, []) |
|
|
if isinstance(conversations, list) and len(conversations) > 0: |
|
|
|
|
|
dialogue = [] |
|
|
for conv in conversations[:-1]: |
|
|
role = conv.get('from', 'user') |
|
|
content = conv.get('value', '') |
|
|
dialogue.append(f"{role}: {content}") |
|
|
return "\n".join(dialogue) + "\nassistant:" |
|
|
return "" |
|
|
|
|
|
else: |
|
|
return sample.get(config.get('instruction_field', 'text'), '') |
|
|
except Exception as e: |
|
|
logger.debug(f"Error formatting instruction: {e}") |
|
|
return "" |
|
|
|
|
|
def _get_response(self, sample: Dict, config: Dict) -> str: |
|
|
try: |
|
|
data_type = config.get('type', 'instruction') |
|
|
|
|
|
if data_type == 'instruction' or data_type == 'code_instruction': |
|
|
response_field = config.get('response_field', 'output') |
|
|
return sample.get(response_field, '') |
|
|
|
|
|
elif data_type == 'conversation': |
|
|
|
|
|
if 'conversations' in sample: |
|
|
conversations = sample['conversations'] |
|
|
if isinstance(conversations, list) and len(conversations) > 0: |
|
|
return conversations[-1].get('value', '') |
|
|
|
|
|
elif 'messages' in sample: |
|
|
messages = sample['messages'] |
|
|
if isinstance(messages, list) and len(messages) > 0: |
|
|
return messages[-1].get('content', '') |
|
|
|
|
|
return "" |
|
|
|
|
|
elif data_type == 'multimodal_instruction': |
|
|
instruction_field = config.get('instruction_field', 'conversations') |
|
|
conversations = sample.get(instruction_field, []) |
|
|
if isinstance(conversations, list) and len(conversations) > 0: |
|
|
return conversations[-1].get('value', '') |
|
|
return "" |
|
|
|
|
|
else: |
|
|
response_field = config.get('response_field', 'output') |
|
|
return sample.get(response_field, '') |
|
|
except Exception as e: |
|
|
logger.debug(f"Error getting response: {e}") |
|
|
return "" |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) if hasattr(self.dataset, '__len__') else 0 |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
try: |
|
|
sample = self.dataset[idx] |
|
|
|
|
|
|
|
|
if '_config' not in sample: |
|
|
logger.warning(f"Sample at index {idx} missing _config") |
|
|
return None |
|
|
|
|
|
config = sample['_config'] |
|
|
|
|
|
|
|
|
instruction_text = self._format_instruction(sample, config) |
|
|
response_text = self._get_response(sample, config) |
|
|
|
|
|
if not instruction_text or not response_text: |
|
|
return None |
|
|
|
|
|
pad_token_id = self.tokenizer.pad_token_id |
|
|
if pad_token_id is None: |
|
|
pad_token_id = self.tokenizer.eos_token_id |
|
|
instruction_max_len = self.max_length // 2 |
|
|
|
|
|
|
|
|
instruction_enc = self.tokenizer( |
|
|
instruction_text, |
|
|
truncation=True, |
|
|
max_length=instruction_max_len, |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
instr_ids = instruction_enc['input_ids'].squeeze(0) |
|
|
|
|
|
|
|
|
instr_len = instr_ids.size(0) |
|
|
if instr_len < instruction_max_len: |
|
|
padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long) |
|
|
instr_ids = torch.cat([instr_ids, padding]) |
|
|
|
|
|
instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)]) |
|
|
else: |
|
|
instr_mask = torch.ones(instruction_max_len, dtype=torch.long) |
|
|
|
|
|
response_max_len = self.max_length // 2 |
|
|
|
|
|
|
|
|
response_enc = self.tokenizer( |
|
|
response_text, |
|
|
truncation=True, |
|
|
max_length=response_max_len - 1, |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
resp_ids = response_enc['input_ids'].squeeze(0) |
|
|
|
|
|
eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long) |
|
|
resp_ids = torch.cat([resp_ids, eos_token]) |
|
|
|
|
|
|
|
|
curr_resp_len = resp_ids.size(0) |
|
|
if curr_resp_len < response_max_len: |
|
|
padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long) |
|
|
resp_ids = torch.cat([resp_ids, padding]) |
|
|
resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)]) |
|
|
else: |
|
|
resp_mask = torch.ones(response_max_len, dtype=torch.long) |
|
|
|
|
|
result = { |
|
|
'instruction': instr_ids, |
|
|
'response': resp_ids, |
|
|
'instruction_mask': instr_mask, |
|
|
'response_mask': resp_mask, |
|
|
'task': sample.get('_source', 'unknown'), |
|
|
'modality_data': None |
|
|
} |
|
|
|
|
|
if config.get('type') == 'multimodal_instruction' and 'image' in sample: |
|
|
try: |
|
|
image = sample['image'] |
|
|
if isinstance(image, Image.Image): |
|
|
image = image.convert('RGB') |
|
|
image_tensor = image_transform(image) |
|
|
result['modality_data'] = {'image': image_tensor} |
|
|
except Exception as e: |
|
|
logger.debug(f"Error processing image: {e}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Error getting item at index {idx}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
|
|
|
class PreferenceDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
dataset_name: str = 'hh_rlhf', |
|
|
tokenizer=None, |
|
|
max_length: int = 1024, |
|
|
max_samples: Optional[int] = None, |
|
|
split: str = 'train' |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if tokenizer is None: |
|
|
raise ValueError("tokenizer cannot be None") |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
if dataset_name not in POSTTRAIN_DATASETS: |
|
|
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}") |
|
|
|
|
|
config = POSTTRAIN_DATASETS[dataset_name] |
|
|
if config.get('type') != 'preference': |
|
|
raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})") |
|
|
|
|
|
logger.info(f"Loading preference dataset: {dataset_name}") |
|
|
|
|
|
load_kwargs = { |
|
|
'path': config['hf_path'], |
|
|
'split': split, |
|
|
'cache_dir': HF_CACHE_DIR, |
|
|
} |
|
|
|
|
|
if 'config' in config: |
|
|
load_kwargs['name'] = config['config'] |
|
|
|
|
|
self.dataset = load_dataset(**load_kwargs) |
|
|
|
|
|
self.chosen_field = config.get('chosen_field', 'chosen') |
|
|
self.rejected_field = config.get('rejected_field', 'rejected') |
|
|
|
|
|
if max_samples and len(self.dataset) > max_samples: |
|
|
self.dataset = self.dataset.select(range(max_samples)) |
|
|
|
|
|
logger.info(f"Loaded {len(self.dataset)} preference pairs") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
try: |
|
|
sample = self.dataset[idx] |
|
|
|
|
|
chosen_text = sample.get(self.chosen_field, '') |
|
|
rejected_text = sample.get(self.rejected_field, '') |
|
|
|
|
|
if not chosen_text or not rejected_text: |
|
|
return None |
|
|
|
|
|
|
|
|
chosen_enc = self.tokenizer( |
|
|
chosen_text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
rejected_enc = self.tokenizer( |
|
|
rejected_text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return ( |
|
|
chosen_enc['input_ids'].squeeze(0), |
|
|
rejected_enc['input_ids'].squeeze(0), |
|
|
chosen_enc['attention_mask'].squeeze(0), |
|
|
rejected_enc['attention_mask'].squeeze(0) |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"Error getting preference item at index {idx}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def collate_fn_v2(batch): |
|
|
batch = [item for item in batch if item is not None] |
|
|
|
|
|
if not batch: |
|
|
logger.warning("Empty batch after filtering None values") |
|
|
|
|
|
return { |
|
|
'input_ids': torch.empty(0), |
|
|
'attention_mask': torch.empty(0) |
|
|
} |
|
|
|
|
|
|
|
|
if isinstance(batch[0], tuple): |
|
|
if len(batch[0]) == 4: |
|
|
chosen = torch.stack([item[0] for item in batch]) |
|
|
rejected = torch.stack([item[1] for item in batch]) |
|
|
chosen_mask = torch.stack([item[2] for item in batch]) |
|
|
rejected_mask = torch.stack([item[3] for item in batch]) |
|
|
return { |
|
|
'chosen': chosen, |
|
|
'rejected': rejected, |
|
|
'chosen_mask': chosen_mask, |
|
|
'rejected_mask': rejected_mask |
|
|
} |
|
|
else: |
|
|
chosen = torch.stack([item[0] for item in batch]) |
|
|
rejected = torch.stack([item[1] for item in batch]) |
|
|
return {'chosen': chosen, 'rejected': rejected} |
|
|
|
|
|
keys = batch[0].keys() |
|
|
collated = {} |
|
|
|
|
|
for key in keys: |
|
|
if key in ['instruction', 'response', 'instruction_mask', |
|
|
'response_mask', 'input_ids', 'attention_mask']: |
|
|
tensors = [item[key] for item in batch if item.get(key) is not None] |
|
|
if tensors: |
|
|
collated[key] = torch.stack(tensors) |
|
|
else: |
|
|
collated[key] = None |
|
|
elif key == 'modality_data': |
|
|
|
|
|
modality_list = [item[key] for item in batch if item.get(key) is not None] |
|
|
if modality_list and any(m is not None for m in modality_list): |
|
|
|
|
|
images = [m.get('image') for m in modality_list if m and 'image' in m] |
|
|
if images: |
|
|
collated[key] = {'image': torch.stack(images)} |
|
|
else: |
|
|
collated[key] = None |
|
|
else: |
|
|
collated[key] = None |
|
|
else: |
|
|
collated[key] = [item[key] for item in batch] |
|
|
|
|
|
return collated |
|
|
|
|
|
|
|
|
def create_pretrain_dataloader( |
|
|
mix_name: str = 'default', |
|
|
tokenizer=None, |
|
|
batch_size: int = 8, |
|
|
num_workers: int = 4, |
|
|
max_length: int = 2048, |
|
|
max_samples: Optional[int] = None |
|
|
): |
|
|
dataset = PreTrainDataset( |
|
|
mix_name=mix_name, |
|
|
tokenizer=tokenizer, |
|
|
max_length=max_length, |
|
|
streaming=True, |
|
|
max_samples=max_samples |
|
|
) |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn_v2 |
|
|
) |
|
|
|
|
|
|
|
|
def create_posttrain_dataloader( |
|
|
mix_name: str = 'default', |
|
|
tokenizer=None, |
|
|
batch_size: int = 8, |
|
|
num_workers: int = 4, |
|
|
max_length: int = 2048, |
|
|
max_samples: Optional[int] = None, |
|
|
split: str = 'train', |
|
|
shuffle: bool = True |
|
|
): |
|
|
dataset = PostTrainDataset( |
|
|
mix_name=mix_name, |
|
|
tokenizer=tokenizer, |
|
|
max_length=max_length, |
|
|
max_samples=max_samples, |
|
|
split=split |
|
|
) |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn_v2, |
|
|
pin_memory=True, |
|
|
drop_last=False |
|
|
) |
|
|
|
|
|
|
|
|
def create_preference_dataloader( |
|
|
dataset_name: str = 'hh_rlhf', |
|
|
tokenizer=None, |
|
|
batch_size: int = 8, |
|
|
num_workers: int = 4, |
|
|
max_length: int = 1024, |
|
|
max_samples: Optional[int] = None, |
|
|
split: str = 'train', |
|
|
shuffle: bool = True |
|
|
): |
|
|
dataset = PreferenceDataset( |
|
|
dataset_name=dataset_name, |
|
|
tokenizer=tokenizer, |
|
|
max_length=max_length, |
|
|
max_samples=max_samples, |
|
|
split=split |
|
|
) |
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn_v2, |
|
|
pin_memory=True |
|
|
) |