| |
|
| | """
|
| | 改进的数据加载器 - 支持预训练和后训练数据集
|
| | """
|
| | import torch
|
| | import torch.nn.functional as F
|
| | from torch.utils.data import Dataset, DataLoader, IterableDataset
|
| | from datasets import load_dataset, concatenate_datasets, interleave_datasets
|
| | from typing import Dict, List, Optional, Any, Union
|
| | import random
|
| | import numpy as np
|
| | from tqdm import tqdm
|
| | import warnings
|
| | from PIL import Image
|
| | import requests
|
| | from io import BytesIO
|
| | from torchvision import transforms
|
| | import logging
|
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO)
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | warnings.filterwarnings("ignore", category=UserWarning)
|
| |
|
| | from data_config import (
|
| | PRETRAIN_DATASETS,
|
| | POSTTRAIN_DATASETS,
|
| | TEST_DATASETS,
|
| | PRETRAIN_MIX,
|
| | POSTTRAIN_MIX,
|
| | PREPROCESSING_CONFIG,
|
| | DATASET_CACHE_DIR,
|
| | HF_CACHE_DIR
|
| | )
|
| |
|
| |
|
| | image_transform = transforms.Compose([
|
| | transforms.Resize((224, 224)),
|
| | transforms.ToTensor(),
|
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| | ])
|
| |
|
| | class PreTrainDataset(IterableDataset):
|
| | """预训练数据集 - 支持流式和混合采样"""
|
| | def __init__(
|
| | self,
|
| | mix_name: str = 'default',
|
| | tokenizer=None,
|
| | max_length: int = 2048,
|
| | streaming: bool = True,
|
| | seed: int = 42,
|
| | max_samples: Optional[int] = None
|
| | ):
|
| | super().__init__()
|
| |
|
| | if tokenizer is None:
|
| | raise ValueError("tokenizer cannot be None")
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.max_length = max_length
|
| | self.streaming = streaming
|
| | self.seed = seed
|
| | self.max_samples = max_samples
|
| | self.samples_generated = 0
|
| |
|
| |
|
| | if mix_name not in PRETRAIN_MIX:
|
| | raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
|
| |
|
| | mix_config = PRETRAIN_MIX[mix_name]
|
| | dataset_names = mix_config.get('datasets', [])
|
| | weights = mix_config.get('weights', [])
|
| |
|
| | if not dataset_names:
|
| | raise ValueError(f"No datasets found in mix: {mix_name}")
|
| |
|
| | logger.info(f"Loading pretrain mix: {mix_name}")
|
| | logger.info(f" Datasets: {dataset_names}")
|
| | logger.info(f" Weights: {weights}")
|
| |
|
| |
|
| | self.datasets = []
|
| | self.probabilities = []
|
| |
|
| | for name, weight in zip(dataset_names, weights):
|
| | if name not in PRETRAIN_DATASETS:
|
| | logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping")
|
| | continue
|
| |
|
| | config = PRETRAIN_DATASETS[name]
|
| | try:
|
| | ds = self._load_dataset(config)
|
| | if ds is not None:
|
| | self.datasets.append((name, ds, config))
|
| | self.probabilities.append(weight)
|
| | logger.info(f" Successfully loaded {name}")
|
| | except Exception as e:
|
| | logger.error(f"Error loading {name}: {e}")
|
| | continue
|
| |
|
| | if not self.datasets:
|
| | raise ValueError("No datasets loaded successfully")
|
| |
|
| |
|
| | total = sum(self.probabilities)
|
| | self.probabilities = [p / total for p in self.probabilities]
|
| |
|
| | logger.info(f"Successfully loaded {len(self.datasets)} datasets")
|
| |
|
| | def _load_dataset(self, config: Dict):
|
| | """加载单个数据集"""
|
| | try:
|
| | load_kwargs = {
|
| | 'path': config['hf_path'],
|
| | 'split': config.get('split', 'train'),
|
| | 'streaming': config.get('streaming', self.streaming),
|
| | 'cache_dir': HF_CACHE_DIR,
|
| | }
|
| |
|
| |
|
| | if 'config' in config:
|
| | load_kwargs['name'] = config['config']
|
| |
|
| | ds = load_dataset(**load_kwargs)
|
| | return ds
|
| | except Exception as e:
|
| | logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
|
| | return None
|
| |
|
| | def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
|
| | """处理文本样本"""
|
| | try:
|
| | text_field = config.get('text_field', 'text')
|
| | text = sample.get(text_field, '')
|
| |
|
| | if not text or not isinstance(text, str):
|
| | return None
|
| |
|
| | text = text.strip()
|
| | if len(text) < 10:
|
| | return None
|
| |
|
| |
|
| | encoding = self.tokenizer(
|
| | text,
|
| | max_length=self.max_length,
|
| | truncation=True,
|
| | padding='max_length',
|
| | return_tensors='pt'
|
| | )
|
| |
|
| | return {
|
| | 'input_ids': encoding['input_ids'].squeeze(0),
|
| | 'attention_mask': encoding['attention_mask'].squeeze(0),
|
| | 'type': 'text'
|
| | }
|
| | except Exception as e:
|
| | logger.debug(f"Error processing text sample: {e}")
|
| | return None
|
| |
|
| | def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
|
| | """处理图像-文本样本"""
|
| | try:
|
| | text_field = config.get('text_field', 'caption')
|
| | image_field = config.get('image_field', 'image')
|
| |
|
| | text = sample.get(text_field, '')
|
| | image = sample.get(image_field)
|
| |
|
| | if not text or image is None:
|
| | return None
|
| |
|
| |
|
| | if isinstance(image, str):
|
| |
|
| | try:
|
| | response = requests.get(image, timeout=5)
|
| | image = Image.open(BytesIO(response.content)).convert('RGB')
|
| | except Exception as img_error:
|
| | logger.debug(f"Failed to load image from URL: {img_error}")
|
| | return None
|
| | elif isinstance(image, Image.Image):
|
| | image = image.convert('RGB')
|
| | else:
|
| | return None
|
| |
|
| |
|
| | image_tensor = image_transform(image)
|
| |
|
| |
|
| | encoding = self.tokenizer(
|
| | text,
|
| | max_length=self.max_length,
|
| | truncation=True,
|
| | padding='max_length',
|
| | return_tensors='pt'
|
| | )
|
| |
|
| | return {
|
| | 'input_ids': encoding['input_ids'].squeeze(0),
|
| | 'attention_mask': encoding['attention_mask'].squeeze(0),
|
| | 'image': image_tensor,
|
| | 'type': 'image_text'
|
| | }
|
| | except Exception as e:
|
| | logger.debug(f"Error processing image-text sample: {e}")
|
| | return None
|
| |
|
| | def __iter__(self):
|
| | """迭代器"""
|
| | worker_info = torch.utils.data.get_worker_info()
|
| | if worker_info is not None:
|
| |
|
| | random.seed(self.seed + worker_info.id)
|
| | np.random.seed(self.seed + worker_info.id)
|
| | else:
|
| | random.seed(self.seed)
|
| | np.random.seed(self.seed)
|
| |
|
| |
|
| | iterators = [iter(ds) for _, ds, _ in self.datasets]
|
| | self.samples_generated = 0
|
| |
|
| | while True:
|
| |
|
| | if self.max_samples and self.samples_generated >= self.max_samples:
|
| | break
|
| |
|
| | try:
|
| |
|
| | idx = np.random.choice(len(self.datasets), p=self.probabilities)
|
| | name, _, config = self.datasets[idx]
|
| |
|
| |
|
| | sample = next(iterators[idx])
|
| |
|
| |
|
| | processed = None
|
| | if config.get('type') in ['text', 'code']:
|
| | processed = self._process_text_sample(sample, config)
|
| | elif config.get('type') == 'image_text':
|
| | processed = self._process_image_text_sample(sample, config)
|
| | else:
|
| | logger.debug(f"Unknown type: {config.get('type')}")
|
| | continue
|
| |
|
| | if processed is not None:
|
| | self.samples_generated += 1
|
| | yield processed
|
| |
|
| | except StopIteration:
|
| |
|
| | try:
|
| | iterators[idx] = iter(self.datasets[idx][1])
|
| | except Exception as e:
|
| | logger.error(f"Failed to recreate iterator for dataset {idx}: {e}")
|
| | break
|
| | except Exception as e:
|
| | logger.debug(f"Error in iterator: {e}")
|
| | continue
|
| |
|
| |
|
| | class PostTrainDataset(Dataset):
|
| | """后训练数据集 - Instruction tuning和对话"""
|
| | def __init__(
|
| | self,
|
| | mix_name: str = 'default',
|
| | tokenizer=None,
|
| | max_length: int = 2048,
|
| | max_samples: Optional[int] = None,
|
| | split: str = 'train'
|
| | ):
|
| | super().__init__()
|
| |
|
| | if tokenizer is None:
|
| | raise ValueError("tokenizer cannot be None")
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.max_length = max_length
|
| | self.split = split
|
| |
|
| |
|
| | if mix_name not in POSTTRAIN_MIX:
|
| | raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
|
| |
|
| | mix_config = POSTTRAIN_MIX[mix_name]
|
| | dataset_names = mix_config.get('datasets', [])
|
| | weights = mix_config.get('weights', [])
|
| |
|
| | if not dataset_names:
|
| | raise ValueError(f"No datasets found in mix: {mix_name}")
|
| |
|
| | logger.info(f"Loading posttrain mix: {mix_name}")
|
| | logger.info(f" Datasets: {dataset_names}")
|
| |
|
| |
|
| | all_datasets = []
|
| |
|
| | for name in dataset_names:
|
| | if name not in POSTTRAIN_DATASETS:
|
| | logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS")
|
| | continue
|
| |
|
| | config = POSTTRAIN_DATASETS[name]
|
| | try:
|
| | load_kwargs = {
|
| | 'path': config['hf_path'],
|
| | 'split': split,
|
| | 'streaming': config.get('streaming', False),
|
| | 'cache_dir': HF_CACHE_DIR,
|
| | }
|
| |
|
| | if 'data_files' in config:
|
| | load_kwargs['data_files'] = config['data_files']
|
| |
|
| | if 'config' in config:
|
| | load_kwargs['name'] = config['config']
|
| |
|
| | ds = load_dataset(**load_kwargs)
|
| |
|
| |
|
| | if config.get('max_samples'):
|
| | if hasattr(ds, 'take'):
|
| | ds = ds.take(config['max_samples'])
|
| | elif hasattr(ds, 'select'):
|
| | ds = ds.select(range(min(len(ds), config['max_samples'])))
|
| |
|
| |
|
| | def add_source(example):
|
| | example['_source'] = name
|
| | example['_config'] = config
|
| | return example
|
| |
|
| | ds = ds.map(add_source)
|
| | all_datasets.append(ds)
|
| |
|
| | ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming'
|
| | logger.info(f" Loaded {name}: {ds_len} samples")
|
| |
|
| | except Exception as e:
|
| | logger.error(f"Error loading {name}: {e}")
|
| | continue
|
| |
|
| |
|
| | if not all_datasets:
|
| | raise ValueError("No datasets loaded successfully")
|
| |
|
| | if len(all_datasets) == 1:
|
| | self.dataset = all_datasets[0]
|
| | else:
|
| |
|
| | probabilities = [w / sum(weights[:len(all_datasets)])
|
| | for w in weights[:len(all_datasets)]]
|
| | self.dataset = interleave_datasets(
|
| | all_datasets,
|
| | probabilities=probabilities,
|
| | seed=42,
|
| | stopping_strategy='all_exhausted'
|
| | )
|
| |
|
| |
|
| | if max_samples and hasattr(self.dataset, '__len__'):
|
| | actual_len = min(len(self.dataset), max_samples)
|
| | self.dataset = self.dataset.select(range(actual_len))
|
| |
|
| | dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming'
|
| | logger.info(f"Total samples: {dataset_len}")
|
| |
|
| | def _format_instruction(self, sample: Dict, config: Dict) -> str:
|
| | """格式化instruction"""
|
| | try:
|
| | data_type = config.get('type', 'instruction')
|
| |
|
| | if data_type == 'instruction':
|
| | instruction_field = config.get('instruction_field', 'instruction')
|
| | input_field = config.get('input_field', 'input')
|
| | context_field = config.get('context_field', None)
|
| |
|
| | instruction = sample.get(instruction_field, '')
|
| | input_text = sample.get(input_field, '')
|
| | context = sample.get(context_field, '') if context_field else ''
|
| |
|
| |
|
| | prompt_parts = [f"Instruction: {instruction}"]
|
| |
|
| | if context:
|
| | prompt_parts.append(f"Context: {context}")
|
| |
|
| | if input_text:
|
| | prompt_parts.append(f"Input: {input_text}")
|
| |
|
| | prompt_parts.append("Response:")
|
| | return "\n".join(prompt_parts)
|
| |
|
| | elif data_type == 'conversation':
|
| |
|
| | if 'conversations' in sample:
|
| |
|
| | conversations = sample['conversations']
|
| | if isinstance(conversations, list) and len(conversations) > 0:
|
| | dialogue = []
|
| | for conv in conversations[:-1]:
|
| | role = conv.get('from', 'user')
|
| | content = conv.get('value', '')
|
| | dialogue.append(f"{role}: {content}")
|
| | return "\n".join(dialogue) + "\nassistant:"
|
| |
|
| | elif 'messages' in sample:
|
| |
|
| | messages = sample['messages']
|
| | if isinstance(messages, list) and len(messages) > 0:
|
| | dialogue = []
|
| | for msg in messages[:-1]:
|
| | role = msg.get('role', 'user')
|
| | content = msg.get('content', '')
|
| | dialogue.append(f"{role}: {content}")
|
| | return "\n".join(dialogue) + "\nassistant:"
|
| |
|
| |
|
| | return sample.get('text', '')
|
| |
|
| | elif data_type == 'code_instruction':
|
| |
|
| | instruction_field = config.get('instruction_field', 'instruction')
|
| | instruction = sample.get(instruction_field, '')
|
| | return f"### Instruction:\n{instruction}\n### Response:"
|
| |
|
| | elif data_type == 'multimodal_instruction':
|
| |
|
| | instruction_field = config.get('instruction_field', 'conversations')
|
| | conversations = sample.get(instruction_field, [])
|
| | if isinstance(conversations, list) and len(conversations) > 0:
|
| |
|
| | dialogue = []
|
| | for conv in conversations[:-1]:
|
| | role = conv.get('from', 'user')
|
| | content = conv.get('value', '')
|
| | dialogue.append(f"{role}: {content}")
|
| | return "\n".join(dialogue) + "\nassistant:"
|
| | return ""
|
| |
|
| | else:
|
| | return sample.get(config.get('instruction_field', 'text'), '')
|
| | except Exception as e:
|
| | logger.debug(f"Error formatting instruction: {e}")
|
| | return ""
|
| |
|
| | def _get_response(self, sample: Dict, config: Dict) -> str:
|
| | """获取响应"""
|
| | try:
|
| | data_type = config.get('type', 'instruction')
|
| |
|
| | if data_type == 'instruction' or data_type == 'code_instruction':
|
| | response_field = config.get('response_field', 'output')
|
| | return sample.get(response_field, '')
|
| |
|
| | elif data_type == 'conversation':
|
| |
|
| | if 'conversations' in sample:
|
| | conversations = sample['conversations']
|
| | if isinstance(conversations, list) and len(conversations) > 0:
|
| | return conversations[-1].get('value', '')
|
| |
|
| | elif 'messages' in sample:
|
| | messages = sample['messages']
|
| | if isinstance(messages, list) and len(messages) > 0:
|
| | return messages[-1].get('content', '')
|
| |
|
| | return ""
|
| |
|
| | elif data_type == 'multimodal_instruction':
|
| | instruction_field = config.get('instruction_field', 'conversations')
|
| | conversations = sample.get(instruction_field, [])
|
| | if isinstance(conversations, list) and len(conversations) > 0:
|
| | return conversations[-1].get('value', '')
|
| | return ""
|
| |
|
| | else:
|
| | response_field = config.get('response_field', 'output')
|
| | return sample.get(response_field, '')
|
| | except Exception as e:
|
| | logger.debug(f"Error getting response: {e}")
|
| | return ""
|
| |
|
| | def __len__(self):
|
| | return len(self.dataset) if hasattr(self.dataset, '__len__') else 0
|
| |
|
| | def __getitem__(self, idx):
|
| | try:
|
| | sample = self.dataset[idx]
|
| |
|
| |
|
| | if '_config' not in sample:
|
| | logger.warning(f"Sample at index {idx} missing _config")
|
| | return None
|
| |
|
| | config = sample['_config']
|
| |
|
| |
|
| | instruction_text = self._format_instruction(sample, config)
|
| | response_text = self._get_response(sample, config)
|
| |
|
| | if not instruction_text or not response_text:
|
| | return None
|
| |
|
| |
|
| | pad_token_id = self.tokenizer.pad_token_id
|
| | if pad_token_id is None:
|
| | pad_token_id = self.tokenizer.eos_token_id
|
| |
|
| |
|
| |
|
| |
|
| | instruction_max_len = self.max_length // 2
|
| |
|
| |
|
| | instruction_enc = self.tokenizer(
|
| | instruction_text,
|
| | truncation=True,
|
| | max_length=instruction_max_len,
|
| | add_special_tokens=False,
|
| | return_tensors='pt'
|
| | )
|
| | instr_ids = instruction_enc['input_ids'].squeeze(0)
|
| |
|
| |
|
| | instr_len = instr_ids.size(0)
|
| | if instr_len < instruction_max_len:
|
| |
|
| |
|
| |
|
| |
|
| | padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
|
| | instr_ids = torch.cat([instr_ids, padding])
|
| |
|
| |
|
| | instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
|
| | else:
|
| | instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
|
| |
|
| |
|
| |
|
| |
|
| | response_max_len = self.max_length // 2
|
| |
|
| |
|
| | response_enc = self.tokenizer(
|
| | response_text,
|
| | truncation=True,
|
| | max_length=response_max_len - 1,
|
| | add_special_tokens=False,
|
| | return_tensors='pt'
|
| | )
|
| | resp_ids = response_enc['input_ids'].squeeze(0)
|
| |
|
| |
|
| | eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
|
| | resp_ids = torch.cat([resp_ids, eos_token])
|
| |
|
| |
|
| | curr_resp_len = resp_ids.size(0)
|
| | if curr_resp_len < response_max_len:
|
| | padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
|
| | resp_ids = torch.cat([resp_ids, padding])
|
| |
|
| |
|
| | resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
|
| | else:
|
| | resp_mask = torch.ones(response_max_len, dtype=torch.long)
|
| |
|
| |
|
| |
|
| |
|
| | result = {
|
| | 'instruction': instr_ids,
|
| | 'response': resp_ids,
|
| | 'instruction_mask': instr_mask,
|
| | 'response_mask': resp_mask,
|
| | 'task': sample.get('_source', 'unknown'),
|
| | 'modality_data': None
|
| | }
|
| |
|
| |
|
| | if config.get('type') == 'multimodal_instruction' and 'image' in sample:
|
| | try:
|
| | image = sample['image']
|
| | if isinstance(image, Image.Image):
|
| | image = image.convert('RGB')
|
| | image_tensor = image_transform(image)
|
| | result['modality_data'] = {'image': image_tensor}
|
| | except Exception as e:
|
| | logger.debug(f"Error processing image: {e}")
|
| |
|
| | return result
|
| |
|
| | except Exception as e:
|
| | logger.debug(f"Error getting item at index {idx}: {e}")
|
| | import traceback
|
| | traceback.print_exc()
|
| | return None
|
| |
|
| |
|
| | class PreferenceDataset(Dataset):
|
| | """偏好数据集 - 用于RLHF"""
|
| | def __init__(
|
| | self,
|
| | dataset_name: str = 'hh_rlhf',
|
| | tokenizer=None,
|
| | max_length: int = 1024,
|
| | max_samples: Optional[int] = None,
|
| | split: str = 'train'
|
| | ):
|
| | super().__init__()
|
| |
|
| | if tokenizer is None:
|
| | raise ValueError("tokenizer cannot be None")
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.max_length = max_length
|
| |
|
| | if dataset_name not in POSTTRAIN_DATASETS:
|
| | raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
|
| |
|
| | config = POSTTRAIN_DATASETS[dataset_name]
|
| | if config.get('type') != 'preference':
|
| | raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
|
| |
|
| | logger.info(f"Loading preference dataset: {dataset_name}")
|
| |
|
| | load_kwargs = {
|
| | 'path': config['hf_path'],
|
| | 'split': split,
|
| | 'cache_dir': HF_CACHE_DIR,
|
| | }
|
| |
|
| |
|
| | if 'config' in config:
|
| | load_kwargs['name'] = config['config']
|
| |
|
| | self.dataset = load_dataset(**load_kwargs)
|
| |
|
| | self.chosen_field = config.get('chosen_field', 'chosen')
|
| | self.rejected_field = config.get('rejected_field', 'rejected')
|
| |
|
| | if max_samples and len(self.dataset) > max_samples:
|
| | self.dataset = self.dataset.select(range(max_samples))
|
| |
|
| | logger.info(f"Loaded {len(self.dataset)} preference pairs")
|
| |
|
| | def __len__(self):
|
| | return len(self.dataset)
|
| |
|
| | def __getitem__(self, idx):
|
| | try:
|
| | sample = self.dataset[idx]
|
| |
|
| | chosen_text = sample.get(self.chosen_field, '')
|
| | rejected_text = sample.get(self.rejected_field, '')
|
| |
|
| | if not chosen_text or not rejected_text:
|
| | return None
|
| |
|
| |
|
| | chosen_enc = self.tokenizer(
|
| | chosen_text,
|
| | max_length=self.max_length,
|
| | truncation=True,
|
| | padding='max_length',
|
| | return_tensors='pt'
|
| | )
|
| |
|
| | rejected_enc = self.tokenizer(
|
| | rejected_text,
|
| | max_length=self.max_length,
|
| | truncation=True,
|
| | padding='max_length',
|
| | return_tensors='pt'
|
| | )
|
| |
|
| | return (
|
| | chosen_enc['input_ids'].squeeze(0),
|
| | rejected_enc['input_ids'].squeeze(0),
|
| | chosen_enc['attention_mask'].squeeze(0),
|
| | rejected_enc['attention_mask'].squeeze(0)
|
| | )
|
| |
|
| | except Exception as e:
|
| | logger.debug(f"Error getting preference item at index {idx}: {e}")
|
| | return None
|
| |
|
| |
|
| | def collate_fn_v2(batch):
|
| | """改进的collate函数"""
|
| |
|
| | batch = [item for item in batch if item is not None]
|
| |
|
| | if not batch:
|
| | logger.warning("Empty batch after filtering None values")
|
| |
|
| | return {
|
| | 'input_ids': torch.empty(0),
|
| | 'attention_mask': torch.empty(0)
|
| | }
|
| |
|
| |
|
| | if isinstance(batch[0], tuple):
|
| | if len(batch[0]) == 4:
|
| | chosen = torch.stack([item[0] for item in batch])
|
| | rejected = torch.stack([item[1] for item in batch])
|
| | chosen_mask = torch.stack([item[2] for item in batch])
|
| | rejected_mask = torch.stack([item[3] for item in batch])
|
| | return {
|
| | 'chosen': chosen,
|
| | 'rejected': rejected,
|
| | 'chosen_mask': chosen_mask,
|
| | 'rejected_mask': rejected_mask
|
| | }
|
| | else:
|
| | chosen = torch.stack([item[0] for item in batch])
|
| | rejected = torch.stack([item[1] for item in batch])
|
| | return {'chosen': chosen, 'rejected': rejected}
|
| |
|
| |
|
| | keys = batch[0].keys()
|
| | collated = {}
|
| |
|
| | for key in keys:
|
| | if key in ['instruction', 'response', 'instruction_mask',
|
| | 'response_mask', 'input_ids', 'attention_mask']:
|
| | tensors = [item[key] for item in batch if item.get(key) is not None]
|
| | if tensors:
|
| | collated[key] = torch.stack(tensors)
|
| | else:
|
| | collated[key] = None
|
| | elif key == 'modality_data':
|
| |
|
| | modality_list = [item[key] for item in batch if item.get(key) is not None]
|
| | if modality_list and any(m is not None for m in modality_list):
|
| |
|
| | images = [m.get('image') for m in modality_list if m and 'image' in m]
|
| | if images:
|
| | collated[key] = {'image': torch.stack(images)}
|
| | else:
|
| | collated[key] = None
|
| | else:
|
| | collated[key] = None
|
| | else:
|
| | collated[key] = [item[key] for item in batch]
|
| |
|
| | return collated
|
| |
|
| |
|
| | def create_pretrain_dataloader(
|
| | mix_name: str = 'default',
|
| | tokenizer=None,
|
| | batch_size: int = 8,
|
| | num_workers: int = 4,
|
| | max_length: int = 2048,
|
| | max_samples: Optional[int] = None
|
| | ):
|
| | """创建预训练数据加载器"""
|
| | dataset = PreTrainDataset(
|
| | mix_name=mix_name,
|
| | tokenizer=tokenizer,
|
| | max_length=max_length,
|
| | streaming=True,
|
| | max_samples=max_samples
|
| | )
|
| | return DataLoader(
|
| | dataset,
|
| | batch_size=batch_size,
|
| | num_workers=num_workers,
|
| | collate_fn=collate_fn_v2
|
| | )
|
| |
|
| |
|
| | def create_posttrain_dataloader(
|
| | mix_name: str = 'default',
|
| | tokenizer=None,
|
| | batch_size: int = 8,
|
| | num_workers: int = 4,
|
| | max_length: int = 2048,
|
| | max_samples: Optional[int] = None,
|
| | split: str = 'train',
|
| | shuffle: bool = True
|
| | ):
|
| | """创建后训练数据加载器"""
|
| | dataset = PostTrainDataset(
|
| | mix_name=mix_name,
|
| | tokenizer=tokenizer,
|
| | max_length=max_length,
|
| | max_samples=max_samples,
|
| | split=split
|
| | )
|
| | return DataLoader(
|
| | dataset,
|
| | batch_size=batch_size,
|
| | shuffle=shuffle,
|
| | num_workers=num_workers,
|
| | collate_fn=collate_fn_v2,
|
| | pin_memory=True,
|
| | drop_last=False
|
| | )
|
| |
|
| |
|
| | def create_preference_dataloader(
|
| | dataset_name: str = 'hh_rlhf',
|
| | tokenizer=None,
|
| | batch_size: int = 8,
|
| | num_workers: int = 4,
|
| | max_length: int = 1024,
|
| | max_samples: Optional[int] = None,
|
| | split: str = 'train',
|
| | shuffle: bool = True
|
| | ):
|
| | """创建偏好数据加载器"""
|
| | dataset = PreferenceDataset(
|
| | dataset_name=dataset_name,
|
| | tokenizer=tokenizer,
|
| | max_length=max_length,
|
| | max_samples=max_samples,
|
| | split=split
|
| | )
|
| | return DataLoader(
|
| | dataset,
|
| | batch_size=batch_size,
|
| | shuffle=shuffle,
|
| | num_workers=num_workers,
|
| | collate_fn=collate_fn_v2,
|
| | pin_memory=True
|
| | ) |