szxllm commited on
Commit
6419b37
·
verified ·
1 Parent(s): b17ba29

Update data_loader.py

Browse files
Files changed (1) hide show
  1. data_loader.py +787 -831
data_loader.py CHANGED
@@ -1,832 +1,788 @@
1
- # data_loader.py
2
- """
3
- 改进的数据加载器 - 支持预训练和后训练数据集
4
- """
5
- import torch
6
- import torch.nn.functional as F
7
- from torch.utils.data import Dataset, DataLoader, IterableDataset
8
- from datasets import load_dataset, concatenate_datasets, interleave_datasets
9
- from typing import Dict, List, Optional, Any, Union
10
- import random
11
- import numpy as np
12
- from tqdm import tqdm
13
- import warnings
14
- from PIL import Image
15
- import requests
16
- from io import BytesIO
17
- from torchvision import transforms
18
- import logging
19
-
20
- # 设置日志
21
- logging.basicConfig(level=logging.INFO)
22
- logger = logging.getLogger(__name__)
23
-
24
- warnings.filterwarnings("ignore", category=UserWarning)
25
-
26
- from data_config import (
27
- PRETRAIN_DATASETS,
28
- POSTTRAIN_DATASETS,
29
- TEST_DATASETS,
30
- PRETRAIN_MIX,
31
- POSTTRAIN_MIX,
32
- PREPROCESSING_CONFIG,
33
- DATASET_CACHE_DIR,
34
- HF_CACHE_DIR
35
- )
36
-
37
- # 图像变换
38
- image_transform = transforms.Compose([
39
- transforms.Resize((224, 224)),
40
- transforms.ToTensor(),
41
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
- ])
43
-
44
- class PreTrainDataset(IterableDataset):
45
- """预训练数据集 - 支持流式和混合采样"""
46
- def __init__(
47
- self,
48
- mix_name: str = 'default',
49
- tokenizer=None,
50
- max_length: int = 2048,
51
- streaming: bool = True,
52
- seed: int = 42,
53
- max_samples: Optional[int] = None
54
- ):
55
- super().__init__()
56
-
57
- if tokenizer is None:
58
- raise ValueError("tokenizer cannot be None")
59
-
60
- self.tokenizer = tokenizer
61
- self.max_length = max_length
62
- self.streaming = streaming
63
- self.seed = seed
64
- self.max_samples = max_samples
65
- self.samples_generated = 0
66
-
67
- # 获取混合配置
68
- if mix_name not in PRETRAIN_MIX:
69
- raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
70
-
71
- mix_config = PRETRAIN_MIX[mix_name]
72
- dataset_names = mix_config.get('datasets', [])
73
- weights = mix_config.get('weights', [])
74
-
75
- if not dataset_names:
76
- raise ValueError(f"No datasets found in mix: {mix_name}")
77
-
78
- logger.info(f"Loading pretrain mix: {mix_name}")
79
- logger.info(f" Datasets: {dataset_names}")
80
- logger.info(f" Weights: {weights}")
81
-
82
- # 加载数据集
83
- self.datasets = []
84
- self.probabilities = []
85
-
86
- for name, weight in zip(dataset_names, weights):
87
- if name not in PRETRAIN_DATASETS:
88
- logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping")
89
- continue
90
-
91
- config = PRETRAIN_DATASETS[name]
92
- try:
93
- ds = self._load_dataset(config)
94
- if ds is not None:
95
- self.datasets.append((name, ds, config))
96
- self.probabilities.append(weight)
97
- logger.info(f" Successfully loaded {name}")
98
- except Exception as e:
99
- logger.error(f"Error loading {name}: {e}")
100
- continue
101
-
102
- if not self.datasets:
103
- raise ValueError("No datasets loaded successfully")
104
-
105
- # 归一化概率
106
- total = sum(self.probabilities)
107
- self.probabilities = [p / total for p in self.probabilities]
108
-
109
- logger.info(f"Successfully loaded {len(self.datasets)} datasets")
110
-
111
- def _load_dataset(self, config: Dict):
112
- """加载单个数据集"""
113
- try:
114
- load_kwargs = {
115
- 'path': config['hf_path'],
116
- 'split': config.get('split', 'train'),
117
- 'streaming': config.get('streaming', self.streaming),
118
- 'cache_dir': HF_CACHE_DIR,
119
- }
120
-
121
- # 添加config参数(如果存在)
122
- if 'config' in config:
123
- load_kwargs['name'] = config['config']
124
-
125
- ds = load_dataset(**load_kwargs)
126
- return ds
127
- except Exception as e:
128
- logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
129
- return None
130
-
131
- def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
132
- """处理文本样本"""
133
- try:
134
- text_field = config.get('text_field', 'text')
135
- text = sample.get(text_field, '')
136
-
137
- if not text or not isinstance(text, str):
138
- return None
139
-
140
- text = text.strip()
141
- if len(text) < 10:
142
- return None
143
-
144
- # Tokenize
145
- encoding = self.tokenizer(
146
- text,
147
- max_length=self.max_length,
148
- truncation=True,
149
- padding='max_length',
150
- return_tensors='pt'
151
- )
152
-
153
- return {
154
- 'input_ids': encoding['input_ids'].squeeze(0),
155
- 'attention_mask': encoding['attention_mask'].squeeze(0),
156
- 'type': 'text'
157
- }
158
- except Exception as e:
159
- logger.debug(f"Error processing text sample: {e}")
160
- return None
161
-
162
- def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
163
- """处理图像-文本样本"""
164
- try:
165
- text_field = config.get('text_field', 'caption')
166
- image_field = config.get('image_field', 'image')
167
-
168
- text = sample.get(text_field, '')
169
- image = sample.get(image_field)
170
-
171
- if not text or image is None:
172
- return None
173
-
174
- # 处理图像
175
- if isinstance(image, str):
176
- # URL - 添加超时和错误处理
177
- try:
178
- response = requests.get(image, timeout=5)
179
- image = Image.open(BytesIO(response.content)).convert('RGB')
180
- except Exception as img_error:
181
- logger.debug(f"Failed to load image from URL: {img_error}")
182
- return None
183
- elif isinstance(image, Image.Image):
184
- image = image.convert('RGB')
185
- else:
186
- return None
187
-
188
- # 转换图像
189
- image_tensor = image_transform(image)
190
-
191
- # Tokenize文本
192
- encoding = self.tokenizer(
193
- text,
194
- max_length=self.max_length,
195
- truncation=True,
196
- padding='max_length',
197
- return_tensors='pt'
198
- )
199
-
200
- return {
201
- 'input_ids': encoding['input_ids'].squeeze(0),
202
- 'attention_mask': encoding['attention_mask'].squeeze(0),
203
- 'image': image_tensor,
204
- 'type': 'image_text'
205
- }
206
- except Exception as e:
207
- logger.debug(f"Error processing image-text sample: {e}")
208
- return None
209
-
210
- def __iter__(self):
211
- """迭代器"""
212
- worker_info = torch.utils.data.get_worker_info()
213
- if worker_info is not None:
214
- # 多worker时设置不同的随机种子
215
- random.seed(self.seed + worker_info.id)
216
- np.random.seed(self.seed + worker_info.id)
217
- else:
218
- random.seed(self.seed)
219
- np.random.seed(self.seed)
220
-
221
- # 创建数据集迭代器
222
- iterators = [iter(ds) for _, ds, _ in self.datasets]
223
- self.samples_generated = 0
224
-
225
- while True:
226
- # 检查是否达到最大样本数
227
- if self.max_samples and self.samples_generated >= self.max_samples:
228
- break
229
-
230
- try:
231
- # 根据概率选择数据集
232
- idx = np.random.choice(len(self.datasets), p=self.probabilities)
233
- name, _, config = self.datasets[idx]
234
-
235
- # 从选中的数据集获取样本
236
- sample = next(iterators[idx])
237
-
238
- # 处理样本
239
- processed = None
240
- if config.get('type') in ['text', 'code']:
241
- processed = self._process_text_sample(sample, config)
242
- elif config.get('type') == 'image_text':
243
- processed = self._process_image_text_sample(sample, config)
244
- else:
245
- logger.debug(f"Unknown type: {config.get('type')}")
246
- continue
247
-
248
- if processed is not None:
249
- self.samples_generated += 1
250
- yield processed
251
-
252
- except StopIteration:
253
- # 重新创建迭代器
254
- try:
255
- iterators[idx] = iter(self.datasets[idx][1])
256
- except Exception as e:
257
- logger.error(f"Failed to recreate iterator for dataset {idx}: {e}")
258
- break
259
- except Exception as e:
260
- logger.debug(f"Error in iterator: {e}")
261
- continue
262
-
263
-
264
- class PostTrainDataset(Dataset):
265
- """后训练数据集 - Instruction tuning和对话"""
266
- def __init__(
267
- self,
268
- mix_name: str = 'default',
269
- tokenizer=None,
270
- max_length: int = 2048,
271
- max_samples: Optional[int] = None,
272
- split: str = 'train'
273
- ):
274
- super().__init__()
275
-
276
- if tokenizer is None:
277
- raise ValueError("tokenizer cannot be None")
278
-
279
- self.tokenizer = tokenizer
280
- self.max_length = max_length
281
- self.split = split
282
-
283
- # 获取混合配置
284
- if mix_name not in POSTTRAIN_MIX:
285
- raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
286
-
287
- mix_config = POSTTRAIN_MIX[mix_name]
288
- dataset_names = mix_config.get('datasets', [])
289
- weights = mix_config.get('weights', [])
290
-
291
- if not dataset_names:
292
- raise ValueError(f"No datasets found in mix: {mix_name}")
293
-
294
- logger.info(f"Loading posttrain mix: {mix_name}")
295
- logger.info(f" Datasets: {dataset_names}")
296
-
297
- # 加载和合并数据集
298
- all_datasets = []
299
-
300
- for name in dataset_names:
301
- if name not in POSTTRAIN_DATASETS:
302
- logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS")
303
- continue
304
-
305
- config = POSTTRAIN_DATASETS[name]
306
- try:
307
- load_kwargs = {
308
- 'path': config['hf_path'],
309
- 'split': split,
310
- 'streaming': config.get('streaming', False),
311
- 'cache_dir': HF_CACHE_DIR,
312
- }
313
- # [新增] 如果配置里有 data_files,就加进去
314
- if 'data_files' in config:
315
- load_kwargs['data_files'] = config['data_files']
316
- # 添加config参数(如果存在)
317
- if 'config' in config:
318
- load_kwargs['name'] = config['config']
319
-
320
- ds = load_dataset(**load_kwargs)
321
-
322
- # 限制样本数
323
- if config.get('max_samples'):
324
- if hasattr(ds, 'take'):
325
- ds = ds.take(config['max_samples'])
326
- elif hasattr(ds, 'select'):
327
- ds = ds.select(range(min(len(ds), config['max_samples'])))
328
-
329
- # 添加数据集标识
330
- def add_source(example):
331
- example['_source'] = name
332
- example['_config'] = config
333
- return example
334
-
335
- ds = ds.map(add_source)
336
- all_datasets.append(ds)
337
-
338
- ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming'
339
- logger.info(f" Loaded {name}: {ds_len} samples")
340
-
341
- except Exception as e:
342
- logger.error(f"Error loading {name}: {e}")
343
- continue
344
-
345
- # 合并数据集
346
- if not all_datasets:
347
- raise ValueError("No datasets loaded successfully")
348
-
349
- if len(all_datasets) == 1:
350
- self.dataset = all_datasets[0]
351
- else:
352
- # 交织数据集
353
- probabilities = [w / sum(weights[:len(all_datasets)])
354
- for w in weights[:len(all_datasets)]]
355
- self.dataset = interleave_datasets(
356
- all_datasets,
357
- probabilities=probabilities,
358
- seed=42,
359
- stopping_strategy='all_exhausted'
360
- )
361
-
362
- # 限制总样本数
363
- if max_samples and hasattr(self.dataset, '__len__'):
364
- actual_len = min(len(self.dataset), max_samples)
365
- self.dataset = self.dataset.select(range(actual_len))
366
-
367
- dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming'
368
- logger.info(f"Total samples: {dataset_len}")
369
-
370
- def _format_instruction(self, sample: Dict, config: Dict) -> str:
371
- """格式化instruction"""
372
- try:
373
- data_type = config.get('type', 'instruction')
374
-
375
- if data_type == 'instruction':
376
- instruction_field = config.get('instruction_field', 'instruction')
377
- input_field = config.get('input_field', 'input')
378
- context_field = config.get('context_field', None)
379
-
380
- instruction = sample.get(instruction_field, '')
381
- input_text = sample.get(input_field, '')
382
- context = sample.get(context_field, '') if context_field else ''
383
-
384
- # 构建prompt
385
- prompt_parts = [f"Instruction: {instruction}"]
386
-
387
- if context:
388
- prompt_parts.append(f"Context: {context}")
389
-
390
- if input_text:
391
- prompt_parts.append(f"Input: {input_text}")
392
-
393
- prompt_parts.append("Response:")
394
- return "\n".join(prompt_parts)
395
-
396
- elif data_type == 'conversation':
397
- # 处理对话格式 - 支持不同的对话格式
398
- if 'conversations' in sample:
399
- # LLaVA格式
400
- conversations = sample['conversations']
401
- if isinstance(conversations, list) and len(conversations) > 0:
402
- dialogue = []
403
- for conv in conversations[:-1]:
404
- role = conv.get('from', 'user')
405
- content = conv.get('value', '')
406
- dialogue.append(f"{role}: {content}")
407
- return "\n".join(dialogue) + "\nassistant:"
408
-
409
- elif 'messages' in sample:
410
- # 标准消息格式
411
- messages = sample['messages']
412
- if isinstance(messages, list) and len(messages) > 0:
413
- dialogue = []
414
- for msg in messages[:-1]:
415
- role = msg.get('role', 'user')
416
- content = msg.get('content', '')
417
- dialogue.append(f"{role}: {content}")
418
- return "\n".join(dialogue) + "\nassistant:"
419
-
420
- # 如果没有标准格式,尝试使用text字段
421
- return sample.get('text', '')
422
-
423
- elif data_type == 'code_instruction':
424
- # 代码instruction格式
425
- instruction_field = config.get('instruction_field', 'instruction')
426
- instruction = sample.get(instruction_field, '')
427
- return f"### Instruction:\n{instruction}\n### Response:"
428
-
429
- elif data_type == 'multimodal_instruction':
430
- # 多模态instruction
431
- instruction_field = config.get('instruction_field', 'conversations')
432
- conversations = sample.get(instruction_field, [])
433
- if isinstance(conversations, list) and len(conversations) > 0:
434
- # 提取对话历史(除了最后一条回复)
435
- dialogue = []
436
- for conv in conversations[:-1]:
437
- role = conv.get('from', 'user')
438
- content = conv.get('value', '')
439
- dialogue.append(f"{role}: {content}")
440
- return "\n".join(dialogue) + "\nassistant:"
441
- return ""
442
-
443
- else:
444
- return sample.get(config.get('instruction_field', 'text'), '')
445
- except Exception as e:
446
- logger.debug(f"Error formatting instruction: {e}")
447
- return ""
448
-
449
- def _get_response(self, sample: Dict, config: Dict) -> str:
450
- """获取响应"""
451
- try:
452
- data_type = config.get('type', 'instruction')
453
-
454
- if data_type == 'instruction' or data_type == 'code_instruction':
455
- response_field = config.get('response_field', 'output')
456
- return sample.get(response_field, '')
457
-
458
- elif data_type == 'conversation':
459
- # 从对话中提取最后一条assistant的回复
460
- if 'conversations' in sample:
461
- conversations = sample['conversations']
462
- if isinstance(conversations, list) and len(conversations) > 0:
463
- return conversations[-1].get('value', '')
464
-
465
- elif 'messages' in sample:
466
- messages = sample['messages']
467
- if isinstance(messages, list) and len(messages) > 0:
468
- return messages[-1].get('content', '')
469
-
470
- return ""
471
-
472
- elif data_type == 'multimodal_instruction':
473
- instruction_field = config.get('instruction_field', 'conversations')
474
- conversations = sample.get(instruction_field, [])
475
- if isinstance(conversations, list) and len(conversations) > 0:
476
- return conversations[-1].get('value', '')
477
- return ""
478
-
479
- else:
480
- response_field = config.get('response_field', 'output')
481
- return sample.get(response_field, '')
482
- except Exception as e:
483
- logger.debug(f"Error getting response: {e}")
484
- return ""
485
-
486
- def __len__(self):
487
- return len(self.dataset) if hasattr(self.dataset, '__len__') else 0
488
-
489
- def __getitem__(self, idx):
490
- try:
491
- sample = self.dataset[idx]
492
-
493
- # 获取配置
494
- if '_config' not in sample:
495
- logger.warning(f"Sample at index {idx} missing _config")
496
- return None
497
-
498
- config = sample['_config']
499
-
500
- # 格式化 instruction 和 response
501
- instruction_text = self._format_instruction(sample, config)
502
- response_text = self._get_response(sample, config)
503
-
504
- if not instruction_text or not response_text:
505
- return None
506
-
507
- # 确保 pad_token_id 存在
508
- pad_token_id = self.tokenizer.pad_token_id
509
- if pad_token_id is None:
510
- pad_token_id = self.tokenizer.eos_token_id
511
-
512
- # =======================================================
513
- # 1. 处理 Instruction (不需要 EOS,因为后面紧接 Response)
514
- # =======================================================
515
- instruction_max_len = self.max_length // 2
516
-
517
- # Tokenize 不做 padding,手动处理
518
- instruction_enc = self.tokenizer(
519
- instruction_text,
520
- truncation=True,
521
- max_length=instruction_max_len,
522
- add_special_tokens=False, # 手动控制特殊token
523
- return_tensors='pt'
524
- )
525
- instr_ids = instruction_enc['input_ids'].squeeze(0)
526
-
527
- # Instruction 手动 Padding
528
- instr_len = instr_ids.size(0)
529
- if instr_len < instruction_max_len:
530
- # 左填充或者右填充皆可,通常 SFT 这里的 Instruction 是右填充
531
- # padding_tensor = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
532
- # instr_ids = torch.cat([instr_ids, padding_tensor])
533
- # 为了保持代码与原逻辑一致,这里使用右填充至固定长度
534
- padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
535
- instr_ids = torch.cat([instr_ids, padding])
536
-
537
- # Mask: 真实token为1,pad为0
538
- instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
539
- else:
540
- instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
541
-
542
- # =======================================================
543
- # 2. 处理 Response (【核心修复】:必须加 EOS)
544
- # =======================================================
545
- response_max_len = self.max_length // 2
546
-
547
- # Tokenize: 预留1个位置给EOS
548
- response_enc = self.tokenizer(
549
- response_text,
550
- truncation=True,
551
- max_length=response_max_len - 1, # 关键:留一个位置给 EOS
552
- add_special_tokens=False,
553
- return_tensors='pt'
554
- )
555
- resp_ids = response_enc['input_ids'].squeeze(0)
556
-
557
- # 【强制添加 EOS Token】
558
- eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
559
- resp_ids = torch.cat([resp_ids, eos_token])
560
-
561
- # Response 手动 Padding
562
- curr_resp_len = resp_ids.size(0)
563
- if curr_resp_len < response_max_len:
564
- padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
565
- resp_ids = torch.cat([resp_ids, padding])
566
-
567
- # Mask: 真实内容+EOS 为1,Pad 为0
568
- resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
569
- else:
570
- resp_mask = torch.ones(response_max_len, dtype=torch.long)
571
-
572
- # =======================================================
573
- # 3. 组装结果
574
- # =======================================================
575
- result = {
576
- 'instruction': instr_ids,
577
- 'response': resp_ids,
578
- 'instruction_mask': instr_mask,
579
- 'response_mask': resp_mask,
580
- 'task': sample.get('_source', 'unknown'),
581
- 'modality_data': None
582
- }
583
-
584
- # 如果是多模态数据,添加图像
585
- if config.get('type') == 'multimodal_instruction' and 'image' in sample:
586
- try:
587
- image = sample['image']
588
- if isinstance(image, Image.Image):
589
- image = image.convert('RGB')
590
- image_tensor = image_transform(image)
591
- result['modality_data'] = {'image': image_tensor}
592
- except Exception as e:
593
- logger.debug(f"Error processing image: {e}")
594
-
595
- return result
596
-
597
- except Exception as e:
598
- logger.debug(f"Error getting item at index {idx}: {e}")
599
- import traceback
600
- traceback.print_exc()
601
- return None
602
-
603
-
604
- class PreferenceDataset(Dataset):
605
- """偏好数据集 - 用于RLHF"""
606
- def __init__(
607
- self,
608
- dataset_name: str = 'hh_rlhf',
609
- tokenizer=None,
610
- max_length: int = 1024,
611
- max_samples: Optional[int] = None,
612
- split: str = 'train'
613
- ):
614
- super().__init__()
615
-
616
- if tokenizer is None:
617
- raise ValueError("tokenizer cannot be None")
618
-
619
- self.tokenizer = tokenizer
620
- self.max_length = max_length
621
-
622
- if dataset_name not in POSTTRAIN_DATASETS:
623
- raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
624
-
625
- config = POSTTRAIN_DATASETS[dataset_name]
626
- if config.get('type') != 'preference':
627
- raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
628
-
629
- logger.info(f"Loading preference dataset: {dataset_name}")
630
-
631
- load_kwargs = {
632
- 'path': config['hf_path'],
633
- 'split': split,
634
- 'cache_dir': HF_CACHE_DIR,
635
- }
636
-
637
- # 添加config参数(如果存在)
638
- if 'config' in config:
639
- load_kwargs['name'] = config['config']
640
-
641
- self.dataset = load_dataset(**load_kwargs)
642
-
643
- self.chosen_field = config.get('chosen_field', 'chosen')
644
- self.rejected_field = config.get('rejected_field', 'rejected')
645
-
646
- if max_samples and len(self.dataset) > max_samples:
647
- self.dataset = self.dataset.select(range(max_samples))
648
-
649
- logger.info(f"Loaded {len(self.dataset)} preference pairs")
650
-
651
- def __len__(self):
652
- return len(self.dataset)
653
-
654
- def __getitem__(self, idx):
655
- try:
656
- sample = self.dataset[idx]
657
-
658
- chosen_text = sample.get(self.chosen_field, '')
659
- rejected_text = sample.get(self.rejected_field, '')
660
-
661
- if not chosen_text or not rejected_text:
662
- return None
663
-
664
- # Tokenize
665
- chosen_enc = self.tokenizer(
666
- chosen_text,
667
- max_length=self.max_length,
668
- truncation=True,
669
- padding='max_length',
670
- return_tensors='pt'
671
- )
672
-
673
- rejected_enc = self.tokenizer(
674
- rejected_text,
675
- max_length=self.max_length,
676
- truncation=True,
677
- padding='max_length',
678
- return_tensors='pt'
679
- )
680
-
681
- return (
682
- chosen_enc['input_ids'].squeeze(0),
683
- rejected_enc['input_ids'].squeeze(0),
684
- chosen_enc['attention_mask'].squeeze(0),
685
- rejected_enc['attention_mask'].squeeze(0)
686
- )
687
-
688
- except Exception as e:
689
- logger.debug(f"Error getting preference item at index {idx}: {e}")
690
- return None
691
-
692
-
693
- def collate_fn_v2(batch):
694
- """改进的collate函数"""
695
- # 过滤None
696
- batch = [item for item in batch if item is not None]
697
-
698
- if not batch:
699
- logger.warning("Empty batch after filtering None values")
700
- # 返回一个空的占位batch而不是None
701
- return {
702
- 'input_ids': torch.empty(0),
703
- 'attention_mask': torch.empty(0)
704
- }
705
-
706
- # 检查是否是preference数据
707
- if isinstance(batch[0], tuple):
708
- if len(batch[0]) == 4: # 包含attention_mask
709
- chosen = torch.stack([item[0] for item in batch])
710
- rejected = torch.stack([item[1] for item in batch])
711
- chosen_mask = torch.stack([item[2] for item in batch])
712
- rejected_mask = torch.stack([item[3] for item in batch])
713
- return {
714
- 'chosen': chosen,
715
- 'rejected': rejected,
716
- 'chosen_mask': chosen_mask,
717
- 'rejected_mask': rejected_mask
718
- }
719
- else: # 旧格式兼容
720
- chosen = torch.stack([item[0] for item in batch])
721
- rejected = torch.stack([item[1] for item in batch])
722
- return {'chosen': chosen, 'rejected': rejected}
723
-
724
- # 普通数据
725
- keys = batch[0].keys()
726
- collated = {}
727
-
728
- for key in keys:
729
- if key in ['instruction', 'response', 'instruction_mask',
730
- 'response_mask', 'input_ids', 'attention_mask']:
731
- tensors = [item[key] for item in batch if item.get(key) is not None]
732
- if tensors:
733
- collated[key] = torch.stack(tensors)
734
- else:
735
- collated[key] = None
736
- elif key == 'modality_data':
737
- # 处理多模态数据
738
- modality_list = [item[key] for item in batch if item.get(key) is not None]
739
- if modality_list and any(m is not None for m in modality_list):
740
- # 收集图像
741
- images = [m.get('image') for m in modality_list if m and 'image' in m]
742
- if images:
743
- collated[key] = {'image': torch.stack(images)}
744
- else:
745
- collated[key] = None
746
- else:
747
- collated[key] = None
748
- else:
749
- collated[key] = [item[key] for item in batch]
750
-
751
- return collated
752
-
753
-
754
- def create_pretrain_dataloader(
755
- mix_name: str = 'default',
756
- tokenizer=None,
757
- batch_size: int = 8,
758
- num_workers: int = 4,
759
- max_length: int = 2048,
760
- max_samples: Optional[int] = None
761
- ):
762
- """创建预训练数据加载器"""
763
- dataset = PreTrainDataset(
764
- mix_name=mix_name,
765
- tokenizer=tokenizer,
766
- max_length=max_length,
767
- streaming=True,
768
- max_samples=max_samples
769
- )
770
- return DataLoader(
771
- dataset,
772
- batch_size=batch_size,
773
- num_workers=num_workers,
774
- collate_fn=collate_fn_v2
775
- )
776
-
777
-
778
- def create_posttrain_dataloader(
779
- mix_name: str = 'default',
780
- tokenizer=None,
781
- batch_size: int = 8,
782
- num_workers: int = 4,
783
- max_length: int = 2048,
784
- max_samples: Optional[int] = None,
785
- split: str = 'train',
786
- shuffle: bool = True
787
- ):
788
- """创建后训练数据加载器"""
789
- dataset = PostTrainDataset(
790
- mix_name=mix_name,
791
- tokenizer=tokenizer,
792
- max_length=max_length,
793
- max_samples=max_samples,
794
- split=split
795
- )
796
- return DataLoader(
797
- dataset,
798
- batch_size=batch_size,
799
- shuffle=shuffle,
800
- num_workers=num_workers,
801
- collate_fn=collate_fn_v2,
802
- pin_memory=True,
803
- drop_last=False # 保留最后一个batch
804
- )
805
-
806
-
807
- def create_preference_dataloader(
808
- dataset_name: str = 'hh_rlhf',
809
- tokenizer=None,
810
- batch_size: int = 8,
811
- num_workers: int = 4,
812
- max_length: int = 1024,
813
- max_samples: Optional[int] = None,
814
- split: str = 'train',
815
- shuffle: bool = True
816
- ):
817
- """创建偏好数据加载器"""
818
- dataset = PreferenceDataset(
819
- dataset_name=dataset_name,
820
- tokenizer=tokenizer,
821
- max_length=max_length,
822
- max_samples=max_samples,
823
- split=split
824
- )
825
- return DataLoader(
826
- dataset,
827
- batch_size=batch_size,
828
- shuffle=shuffle,
829
- num_workers=num_workers,
830
- collate_fn=collate_fn_v2,
831
- pin_memory=True
832
  )
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import Dataset, DataLoader, IterableDataset
4
+ from datasets import load_dataset, concatenate_datasets, interleave_datasets
5
+ from typing import Dict, List, Optional, Any, Union
6
+ import random
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import warnings
10
+ from PIL import Image
11
+ import requests
12
+ from io import BytesIO
13
+ from torchvision import transforms
14
+ import logging
15
+
16
+ # 设置日志
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ warnings.filterwarnings("ignore", category=UserWarning)
21
+
22
+ from data_config import (
23
+ PRETRAIN_DATASETS,
24
+ POSTTRAIN_DATASETS,
25
+ TEST_DATASETS,
26
+ PRETRAIN_MIX,
27
+ POSTTRAIN_MIX,
28
+ PREPROCESSING_CONFIG,
29
+ DATASET_CACHE_DIR,
30
+ HF_CACHE_DIR
31
+ )
32
+
33
+ # 图像变换
34
+ image_transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+
40
+ class PreTrainDataset(IterableDataset):
41
+ def __init__(
42
+ self,
43
+ mix_name: str = 'default',
44
+ tokenizer=None,
45
+ max_length: int = 2048,
46
+ streaming: bool = True,
47
+ seed: int = 42,
48
+ max_samples: Optional[int] = None
49
+ ):
50
+ super().__init__()
51
+
52
+ if tokenizer is None:
53
+ raise ValueError("tokenizer cannot be None")
54
+
55
+ self.tokenizer = tokenizer
56
+ self.max_length = max_length
57
+ self.streaming = streaming
58
+ self.seed = seed
59
+ self.max_samples = max_samples
60
+ self.samples_generated = 0
61
+
62
+ # 获取混合配置
63
+ if mix_name not in PRETRAIN_MIX:
64
+ raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
65
+
66
+ mix_config = PRETRAIN_MIX[mix_name]
67
+ dataset_names = mix_config.get('datasets', [])
68
+ weights = mix_config.get('weights', [])
69
+
70
+ if not dataset_names:
71
+ raise ValueError(f"No datasets found in mix: {mix_name}")
72
+
73
+ logger.info(f"Loading pretrain mix: {mix_name}")
74
+ logger.info(f" Datasets: {dataset_names}")
75
+ logger.info(f" Weights: {weights}")
76
+
77
+ # 加载数据集
78
+ self.datasets = []
79
+ self.probabilities = []
80
+
81
+ for name, weight in zip(dataset_names, weights):
82
+ if name not in PRETRAIN_DATASETS:
83
+ logger.warning(f"Dataset {name} not found in PRETRAIN_DATASETS, skipping")
84
+ continue
85
+
86
+ config = PRETRAIN_DATASETS[name]
87
+ try:
88
+ ds = self._load_dataset(config)
89
+ if ds is not None:
90
+ self.datasets.append((name, ds, config))
91
+ self.probabilities.append(weight)
92
+ logger.info(f" Successfully loaded {name}")
93
+ except Exception as e:
94
+ logger.error(f"Error loading {name}: {e}")
95
+ continue
96
+
97
+ if not self.datasets:
98
+ raise ValueError("No datasets loaded successfully")
99
+
100
+ # 归一化概率
101
+ total = sum(self.probabilities)
102
+ self.probabilities = [p / total for p in self.probabilities]
103
+
104
+ logger.info(f"Successfully loaded {len(self.datasets)} datasets")
105
+
106
+ def _load_dataset(self, config: Dict):
107
+ try:
108
+ load_kwargs = {
109
+ 'path': config['hf_path'],
110
+ 'split': config.get('split', 'train'),
111
+ 'streaming': config.get('streaming', self.streaming),
112
+ 'cache_dir': HF_CACHE_DIR,
113
+ }
114
+
115
+ if 'config' in config:
116
+ load_kwargs['name'] = config['config']
117
+
118
+ ds = load_dataset(**load_kwargs)
119
+ return ds
120
+ except Exception as e:
121
+ logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
122
+ return None
123
+
124
+ def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
125
+ try:
126
+ text_field = config.get('text_field', 'text')
127
+ text = sample.get(text_field, '')
128
+
129
+ if not text or not isinstance(text, str):
130
+ return None
131
+
132
+ text = text.strip()
133
+ if len(text) < 10:
134
+ return None
135
+
136
+ # Tokenize
137
+ encoding = self.tokenizer(
138
+ text,
139
+ max_length=self.max_length,
140
+ truncation=True,
141
+ padding='max_length',
142
+ return_tensors='pt'
143
+ )
144
+
145
+ return {
146
+ 'input_ids': encoding['input_ids'].squeeze(0),
147
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
148
+ 'type': 'text'
149
+ }
150
+ except Exception as e:
151
+ logger.debug(f"Error processing text sample: {e}")
152
+ return None
153
+
154
+ def _process_image_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
155
+ try:
156
+ text_field = config.get('text_field', 'caption')
157
+ image_field = config.get('image_field', 'image')
158
+
159
+ text = sample.get(text_field, '')
160
+ image = sample.get(image_field)
161
+
162
+ if not text or image is None:
163
+ return None
164
+
165
+ # 处理图像
166
+ if isinstance(image, str):
167
+ try:
168
+ response = requests.get(image, timeout=5)
169
+ image = Image.open(BytesIO(response.content)).convert('RGB')
170
+ except Exception as img_error:
171
+ logger.debug(f"Failed to load image from URL: {img_error}")
172
+ return None
173
+ elif isinstance(image, Image.Image):
174
+ image = image.convert('RGB')
175
+ else:
176
+ return None
177
+
178
+ # 转换图像
179
+ image_tensor = image_transform(image)
180
+
181
+ # Tokenize文本
182
+ encoding = self.tokenizer(
183
+ text,
184
+ max_length=self.max_length,
185
+ truncation=True,
186
+ padding='max_length',
187
+ return_tensors='pt'
188
+ )
189
+
190
+ return {
191
+ 'input_ids': encoding['input_ids'].squeeze(0),
192
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
193
+ 'image': image_tensor,
194
+ 'type': 'image_text'
195
+ }
196
+ except Exception as e:
197
+ logger.debug(f"Error processing image-text sample: {e}")
198
+ return None
199
+
200
+ def __iter__(self):
201
+ """迭代器"""
202
+ worker_info = torch.utils.data.get_worker_info()
203
+ if worker_info is not None:
204
+ # 多worker时设置不同的随机种子
205
+ random.seed(self.seed + worker_info.id)
206
+ np.random.seed(self.seed + worker_info.id)
207
+ else:
208
+ random.seed(self.seed)
209
+ np.random.seed(self.seed)
210
+
211
+ # 创建数据集迭代器
212
+ iterators = [iter(ds) for _, ds, _ in self.datasets]
213
+ self.samples_generated = 0
214
+
215
+ while True:
216
+ # 检查是否达到最大样本数
217
+ if self.max_samples and self.samples_generated >= self.max_samples:
218
+ break
219
+
220
+ try:
221
+ # 根据概率选择数据集
222
+ idx = np.random.choice(len(self.datasets), p=self.probabilities)
223
+ name, _, config = self.datasets[idx]
224
+
225
+ # 从选中的数据集获取样本
226
+ sample = next(iterators[idx])
227
+
228
+ # 处理样本
229
+ processed = None
230
+ if config.get('type') in ['text', 'code']:
231
+ processed = self._process_text_sample(sample, config)
232
+ elif config.get('type') == 'image_text':
233
+ processed = self._process_image_text_sample(sample, config)
234
+ else:
235
+ logger.debug(f"Unknown type: {config.get('type')}")
236
+ continue
237
+
238
+ if processed is not None:
239
+ self.samples_generated += 1
240
+ yield processed
241
+
242
+ except StopIteration:
243
+ # 重新创建迭代器
244
+ try:
245
+ iterators[idx] = iter(self.datasets[idx][1])
246
+ except Exception as e:
247
+ logger.error(f"Failed to recreate iterator for dataset {idx}: {e}")
248
+ break
249
+ except Exception as e:
250
+ logger.debug(f"Error in iterator: {e}")
251
+ continue
252
+
253
+
254
+ class PostTrainDataset(Dataset):
255
+ def __init__(
256
+ self,
257
+ mix_name: str = 'default',
258
+ tokenizer=None,
259
+ max_length: int = 2048,
260
+ max_samples: Optional[int] = None,
261
+ split: str = 'train'
262
+ ):
263
+ super().__init__()
264
+
265
+ if tokenizer is None:
266
+ raise ValueError("tokenizer cannot be None")
267
+
268
+ self.tokenizer = tokenizer
269
+ self.max_length = max_length
270
+ self.split = split
271
+
272
+ # 获取混合配置
273
+ if mix_name not in POSTTRAIN_MIX:
274
+ raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
275
+
276
+ mix_config = POSTTRAIN_MIX[mix_name]
277
+ dataset_names = mix_config.get('datasets', [])
278
+ weights = mix_config.get('weights', [])
279
+
280
+ if not dataset_names:
281
+ raise ValueError(f"No datasets found in mix: {mix_name}")
282
+
283
+ logger.info(f"Loading posttrain mix: {mix_name}")
284
+ logger.info(f" Datasets: {dataset_names}")
285
+
286
+ # 加载和合并数据集
287
+ all_datasets = []
288
+
289
+ for name in dataset_names:
290
+ if name not in POSTTRAIN_DATASETS:
291
+ logger.warning(f"Dataset {name} not found in POSTTRAIN_DATASETS")
292
+ continue
293
+
294
+ config = POSTTRAIN_DATASETS[name]
295
+ try:
296
+ load_kwargs = {
297
+ 'path': config['hf_path'],
298
+ 'split': split,
299
+ 'streaming': config.get('streaming', False),
300
+ 'cache_dir': HF_CACHE_DIR,
301
+ }
302
+ if 'data_files' in config:
303
+ load_kwargs['data_files'] = config['data_files']
304
+ if 'config' in config:
305
+ load_kwargs['name'] = config['config']
306
+
307
+ ds = load_dataset(**load_kwargs)
308
+
309
+ # 限制样本数
310
+ if config.get('max_samples'):
311
+ if hasattr(ds, 'take'):
312
+ ds = ds.take(config['max_samples'])
313
+ elif hasattr(ds, 'select'):
314
+ ds = ds.select(range(min(len(ds), config['max_samples'])))
315
+
316
+ # 添加数据集标识
317
+ def add_source(example):
318
+ example['_source'] = name
319
+ example['_config'] = config
320
+ return example
321
+
322
+ ds = ds.map(add_source)
323
+ all_datasets.append(ds)
324
+
325
+ ds_len = len(ds) if hasattr(ds, '__len__') else 'streaming'
326
+ logger.info(f" Loaded {name}: {ds_len} samples")
327
+
328
+ except Exception as e:
329
+ logger.error(f"Error loading {name}: {e}")
330
+ continue
331
+
332
+ # 合并数据集
333
+ if not all_datasets:
334
+ raise ValueError("No datasets loaded successfully")
335
+
336
+ if len(all_datasets) == 1:
337
+ self.dataset = all_datasets[0]
338
+ else:
339
+ # 交织数据集
340
+ probabilities = [w / sum(weights[:len(all_datasets)])
341
+ for w in weights[:len(all_datasets)]]
342
+ self.dataset = interleave_datasets(
343
+ all_datasets,
344
+ probabilities=probabilities,
345
+ seed=42,
346
+ stopping_strategy='all_exhausted'
347
+ )
348
+
349
+ # 限制总样本数
350
+ if max_samples and hasattr(self.dataset, '__len__'):
351
+ actual_len = min(len(self.dataset), max_samples)
352
+ self.dataset = self.dataset.select(range(actual_len))
353
+
354
+ dataset_len = len(self.dataset) if hasattr(self.dataset, '__len__') else 'streaming'
355
+ logger.info(f"Total samples: {dataset_len}")
356
+
357
+ def _format_instruction(self, sample: Dict, config: Dict) -> str:
358
+ """格式化instruction"""
359
+ try:
360
+ data_type = config.get('type', 'instruction')
361
+
362
+ if data_type == 'instruction':
363
+ instruction_field = config.get('instruction_field', 'instruction')
364
+ input_field = config.get('input_field', 'input')
365
+ context_field = config.get('context_field', None)
366
+
367
+ instruction = sample.get(instruction_field, '')
368
+ input_text = sample.get(input_field, '')
369
+ context = sample.get(context_field, '') if context_field else ''
370
+
371
+ # 构建prompt
372
+ prompt_parts = [f"Instruction: {instruction}"]
373
+
374
+ if context:
375
+ prompt_parts.append(f"Context: {context}")
376
+
377
+ if input_text:
378
+ prompt_parts.append(f"Input: {input_text}")
379
+
380
+ prompt_parts.append("Response:")
381
+ return "\n".join(prompt_parts)
382
+
383
+ elif data_type == 'conversation':
384
+ if 'conversations' in sample:
385
+ conversations = sample['conversations']
386
+ if isinstance(conversations, list) and len(conversations) > 0:
387
+ dialogue = []
388
+ for conv in conversations[:-1]:
389
+ role = conv.get('from', 'user')
390
+ content = conv.get('value', '')
391
+ dialogue.append(f"{role}: {content}")
392
+ return "\n".join(dialogue) + "\nassistant:"
393
+
394
+ elif 'messages' in sample:
395
+ # 标准消息格式
396
+ messages = sample['messages']
397
+ if isinstance(messages, list) and len(messages) > 0:
398
+ dialogue = []
399
+ for msg in messages[:-1]:
400
+ role = msg.get('role', 'user')
401
+ content = msg.get('content', '')
402
+ dialogue.append(f"{role}: {content}")
403
+ return "\n".join(dialogue) + "\nassistant:"
404
+
405
+ # 如果没有标准格式,尝试使用text字段
406
+ return sample.get('text', '')
407
+
408
+ elif data_type == 'code_instruction':
409
+ # 代码instruction格式
410
+ instruction_field = config.get('instruction_field', 'instruction')
411
+ instruction = sample.get(instruction_field, '')
412
+ return f"### Instruction:\n{instruction}\n### Response:"
413
+
414
+ elif data_type == 'multimodal_instruction':
415
+ # 多模态instruction
416
+ instruction_field = config.get('instruction_field', 'conversations')
417
+ conversations = sample.get(instruction_field, [])
418
+ if isinstance(conversations, list) and len(conversations) > 0:
419
+ # 提取对话历史(除了最后一条回复)
420
+ dialogue = []
421
+ for conv in conversations[:-1]:
422
+ role = conv.get('from', 'user')
423
+ content = conv.get('value', '')
424
+ dialogue.append(f"{role}: {content}")
425
+ return "\n".join(dialogue) + "\nassistant:"
426
+ return ""
427
+
428
+ else:
429
+ return sample.get(config.get('instruction_field', 'text'), '')
430
+ except Exception as e:
431
+ logger.debug(f"Error formatting instruction: {e}")
432
+ return ""
433
+
434
+ def _get_response(self, sample: Dict, config: Dict) -> str:
435
+ try:
436
+ data_type = config.get('type', 'instruction')
437
+
438
+ if data_type == 'instruction' or data_type == 'code_instruction':
439
+ response_field = config.get('response_field', 'output')
440
+ return sample.get(response_field, '')
441
+
442
+ elif data_type == 'conversation':
443
+ # 从对话中提取最后一条assistant的回复
444
+ if 'conversations' in sample:
445
+ conversations = sample['conversations']
446
+ if isinstance(conversations, list) and len(conversations) > 0:
447
+ return conversations[-1].get('value', '')
448
+
449
+ elif 'messages' in sample:
450
+ messages = sample['messages']
451
+ if isinstance(messages, list) and len(messages) > 0:
452
+ return messages[-1].get('content', '')
453
+
454
+ return ""
455
+
456
+ elif data_type == 'multimodal_instruction':
457
+ instruction_field = config.get('instruction_field', 'conversations')
458
+ conversations = sample.get(instruction_field, [])
459
+ if isinstance(conversations, list) and len(conversations) > 0:
460
+ return conversations[-1].get('value', '')
461
+ return ""
462
+
463
+ else:
464
+ response_field = config.get('response_field', 'output')
465
+ return sample.get(response_field, '')
466
+ except Exception as e:
467
+ logger.debug(f"Error getting response: {e}")
468
+ return ""
469
+
470
+ def __len__(self):
471
+ return len(self.dataset) if hasattr(self.dataset, '__len__') else 0
472
+
473
+ def __getitem__(self, idx):
474
+ try:
475
+ sample = self.dataset[idx]
476
+
477
+ # 获取配置
478
+ if '_config' not in sample:
479
+ logger.warning(f"Sample at index {idx} missing _config")
480
+ return None
481
+
482
+ config = sample['_config']
483
+
484
+ # 格式化 instruction 和 response
485
+ instruction_text = self._format_instruction(sample, config)
486
+ response_text = self._get_response(sample, config)
487
+
488
+ if not instruction_text or not response_text:
489
+ return None
490
+
491
+ pad_token_id = self.tokenizer.pad_token_id
492
+ if pad_token_id is None:
493
+ pad_token_id = self.tokenizer.eos_token_id
494
+ instruction_max_len = self.max_length // 2
495
+
496
+ # Tokenize 不做 padding,手动处理
497
+ instruction_enc = self.tokenizer(
498
+ instruction_text,
499
+ truncation=True,
500
+ max_length=instruction_max_len,
501
+ add_special_tokens=False,
502
+ return_tensors='pt'
503
+ )
504
+ instr_ids = instruction_enc['input_ids'].squeeze(0)
505
+
506
+ # Instruction 手动 Padding
507
+ instr_len = instr_ids.size(0)
508
+ if instr_len < instruction_max_len:
509
+ padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
510
+ instr_ids = torch.cat([instr_ids, padding])
511
+
512
+ instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
513
+ else:
514
+ instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
515
+
516
+ response_max_len = self.max_length // 2
517
+
518
+ # Tokenize: 预留1个位置给EOS
519
+ response_enc = self.tokenizer(
520
+ response_text,
521
+ truncation=True,
522
+ max_length=response_max_len - 1,
523
+ add_special_tokens=False,
524
+ return_tensors='pt'
525
+ )
526
+ resp_ids = response_enc['input_ids'].squeeze(0)
527
+
528
+ eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
529
+ resp_ids = torch.cat([resp_ids, eos_token])
530
+
531
+ # Response 手动 Padding
532
+ curr_resp_len = resp_ids.size(0)
533
+ if curr_resp_len < response_max_len:
534
+ padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
535
+ resp_ids = torch.cat([resp_ids, padding])
536
+ resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
537
+ else:
538
+ resp_mask = torch.ones(response_max_len, dtype=torch.long)
539
+
540
+ result = {
541
+ 'instruction': instr_ids,
542
+ 'response': resp_ids,
543
+ 'instruction_mask': instr_mask,
544
+ 'response_mask': resp_mask,
545
+ 'task': sample.get('_source', 'unknown'),
546
+ 'modality_data': None
547
+ }
548
+
549
+ if config.get('type') == 'multimodal_instruction' and 'image' in sample:
550
+ try:
551
+ image = sample['image']
552
+ if isinstance(image, Image.Image):
553
+ image = image.convert('RGB')
554
+ image_tensor = image_transform(image)
555
+ result['modality_data'] = {'image': image_tensor}
556
+ except Exception as e:
557
+ logger.debug(f"Error processing image: {e}")
558
+
559
+ return result
560
+
561
+ except Exception as e:
562
+ logger.debug(f"Error getting item at index {idx}: {e}")
563
+ import traceback
564
+ traceback.print_exc()
565
+ return None
566
+
567
+
568
+ class PreferenceDataset(Dataset):
569
+ def __init__(
570
+ self,
571
+ dataset_name: str = 'hh_rlhf',
572
+ tokenizer=None,
573
+ max_length: int = 1024,
574
+ max_samples: Optional[int] = None,
575
+ split: str = 'train'
576
+ ):
577
+ super().__init__()
578
+
579
+ if tokenizer is None:
580
+ raise ValueError("tokenizer cannot be None")
581
+
582
+ self.tokenizer = tokenizer
583
+ self.max_length = max_length
584
+
585
+ if dataset_name not in POSTTRAIN_DATASETS:
586
+ raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
587
+
588
+ config = POSTTRAIN_DATASETS[dataset_name]
589
+ if config.get('type') != 'preference':
590
+ raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
591
+
592
+ logger.info(f"Loading preference dataset: {dataset_name}")
593
+
594
+ load_kwargs = {
595
+ 'path': config['hf_path'],
596
+ 'split': split,
597
+ 'cache_dir': HF_CACHE_DIR,
598
+ }
599
+
600
+ if 'config' in config:
601
+ load_kwargs['name'] = config['config']
602
+
603
+ self.dataset = load_dataset(**load_kwargs)
604
+
605
+ self.chosen_field = config.get('chosen_field', 'chosen')
606
+ self.rejected_field = config.get('rejected_field', 'rejected')
607
+
608
+ if max_samples and len(self.dataset) > max_samples:
609
+ self.dataset = self.dataset.select(range(max_samples))
610
+
611
+ logger.info(f"Loaded {len(self.dataset)} preference pairs")
612
+
613
+ def __len__(self):
614
+ return len(self.dataset)
615
+
616
+ def __getitem__(self, idx):
617
+ try:
618
+ sample = self.dataset[idx]
619
+
620
+ chosen_text = sample.get(self.chosen_field, '')
621
+ rejected_text = sample.get(self.rejected_field, '')
622
+
623
+ if not chosen_text or not rejected_text:
624
+ return None
625
+
626
+ # Tokenize
627
+ chosen_enc = self.tokenizer(
628
+ chosen_text,
629
+ max_length=self.max_length,
630
+ truncation=True,
631
+ padding='max_length',
632
+ return_tensors='pt'
633
+ )
634
+
635
+ rejected_enc = self.tokenizer(
636
+ rejected_text,
637
+ max_length=self.max_length,
638
+ truncation=True,
639
+ padding='max_length',
640
+ return_tensors='pt'
641
+ )
642
+
643
+ return (
644
+ chosen_enc['input_ids'].squeeze(0),
645
+ rejected_enc['input_ids'].squeeze(0),
646
+ chosen_enc['attention_mask'].squeeze(0),
647
+ rejected_enc['attention_mask'].squeeze(0)
648
+ )
649
+
650
+ except Exception as e:
651
+ logger.debug(f"Error getting preference item at index {idx}: {e}")
652
+ return None
653
+
654
+
655
+ def collate_fn_v2(batch):
656
+ batch = [item for item in batch if item is not None]
657
+
658
+ if not batch:
659
+ logger.warning("Empty batch after filtering None values")
660
+ # 返回一个空的占位batch而不是None
661
+ return {
662
+ 'input_ids': torch.empty(0),
663
+ 'attention_mask': torch.empty(0)
664
+ }
665
+
666
+ # 检查是否是preference数据
667
+ if isinstance(batch[0], tuple):
668
+ if len(batch[0]) == 4: # 包含attention_mask
669
+ chosen = torch.stack([item[0] for item in batch])
670
+ rejected = torch.stack([item[1] for item in batch])
671
+ chosen_mask = torch.stack([item[2] for item in batch])
672
+ rejected_mask = torch.stack([item[3] for item in batch])
673
+ return {
674
+ 'chosen': chosen,
675
+ 'rejected': rejected,
676
+ 'chosen_mask': chosen_mask,
677
+ 'rejected_mask': rejected_mask
678
+ }
679
+ else:
680
+ chosen = torch.stack([item[0] for item in batch])
681
+ rejected = torch.stack([item[1] for item in batch])
682
+ return {'chosen': chosen, 'rejected': rejected}
683
+
684
+ keys = batch[0].keys()
685
+ collated = {}
686
+
687
+ for key in keys:
688
+ if key in ['instruction', 'response', 'instruction_mask',
689
+ 'response_mask', 'input_ids', 'attention_mask']:
690
+ tensors = [item[key] for item in batch if item.get(key) is not None]
691
+ if tensors:
692
+ collated[key] = torch.stack(tensors)
693
+ else:
694
+ collated[key] = None
695
+ elif key == 'modality_data':
696
+ # 处理多模态数据
697
+ modality_list = [item[key] for item in batch if item.get(key) is not None]
698
+ if modality_list and any(m is not None for m in modality_list):
699
+ # 收集图像
700
+ images = [m.get('image') for m in modality_list if m and 'image' in m]
701
+ if images:
702
+ collated[key] = {'image': torch.stack(images)}
703
+ else:
704
+ collated[key] = None
705
+ else:
706
+ collated[key] = None
707
+ else:
708
+ collated[key] = [item[key] for item in batch]
709
+
710
+ return collated
711
+
712
+
713
+ def create_pretrain_dataloader(
714
+ mix_name: str = 'default',
715
+ tokenizer=None,
716
+ batch_size: int = 8,
717
+ num_workers: int = 4,
718
+ max_length: int = 2048,
719
+ max_samples: Optional[int] = None
720
+ ):
721
+ dataset = PreTrainDataset(
722
+ mix_name=mix_name,
723
+ tokenizer=tokenizer,
724
+ max_length=max_length,
725
+ streaming=True,
726
+ max_samples=max_samples
727
+ )
728
+ return DataLoader(
729
+ dataset,
730
+ batch_size=batch_size,
731
+ num_workers=num_workers,
732
+ collate_fn=collate_fn_v2
733
+ )
734
+
735
+
736
+ def create_posttrain_dataloader(
737
+ mix_name: str = 'default',
738
+ tokenizer=None,
739
+ batch_size: int = 8,
740
+ num_workers: int = 4,
741
+ max_length: int = 2048,
742
+ max_samples: Optional[int] = None,
743
+ split: str = 'train',
744
+ shuffle: bool = True
745
+ ):
746
+ dataset = PostTrainDataset(
747
+ mix_name=mix_name,
748
+ tokenizer=tokenizer,
749
+ max_length=max_length,
750
+ max_samples=max_samples,
751
+ split=split
752
+ )
753
+ return DataLoader(
754
+ dataset,
755
+ batch_size=batch_size,
756
+ shuffle=shuffle,
757
+ num_workers=num_workers,
758
+ collate_fn=collate_fn_v2,
759
+ pin_memory=True,
760
+ drop_last=False
761
+ )
762
+
763
+
764
+ def create_preference_dataloader(
765
+ dataset_name: str = 'hh_rlhf',
766
+ tokenizer=None,
767
+ batch_size: int = 8,
768
+ num_workers: int = 4,
769
+ max_length: int = 1024,
770
+ max_samples: Optional[int] = None,
771
+ split: str = 'train',
772
+ shuffle: bool = True
773
+ ):
774
+ dataset = PreferenceDataset(
775
+ dataset_name=dataset_name,
776
+ tokenizer=tokenizer,
777
+ max_length=max_length,
778
+ max_samples=max_samples,
779
+ split=split
780
+ )
781
+ return DataLoader(
782
+ dataset,
783
+ batch_size=batch_size,
784
+ shuffle=shuffle,
785
+ num_workers=num_workers,
786
+ collate_fn=collate_fn_v2,
787
+ pin_memory=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  )