szxllm commited on
Commit
b7e4db5
·
verified ·
1 Parent(s): 6d0972d

Update data_loader.py

Browse files
Files changed (1) hide show
  1. data_loader.py +179 -241
data_loader.py CHANGED
@@ -12,8 +12,7 @@ import requests
12
  from io import BytesIO
13
  from torchvision import transforms
14
  import logging
15
-
16
- # 设置日志
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
@@ -29,8 +28,6 @@ from data_config import (
29
  DATASET_CACHE_DIR,
30
  HF_CACHE_DIR
31
  )
32
-
33
- # 图像变换
34
  image_transform = transforms.Compose([
35
  transforms.Resize((224, 224)),
36
  transforms.ToTensor(),
@@ -59,7 +56,6 @@ class PreTrainDataset(IterableDataset):
59
  self.max_samples = max_samples
60
  self.samples_generated = 0
61
 
62
- # 获取混合配置
63
  if mix_name not in PRETRAIN_MIX:
64
  raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
65
 
@@ -69,12 +65,6 @@ class PreTrainDataset(IterableDataset):
69
 
70
  if not dataset_names:
71
  raise ValueError(f"No datasets found in mix: {mix_name}")
72
-
73
- logger.info(f"Loading pretrain mix: {mix_name}")
74
- logger.info(f" Datasets: {dataset_names}")
75
- logger.info(f" Weights: {weights}")
76
-
77
- # 加载数据集
78
  self.datasets = []
79
  self.probabilities = []
80
 
@@ -97,7 +87,6 @@ class PreTrainDataset(IterableDataset):
97
  if not self.datasets:
98
  raise ValueError("No datasets loaded successfully")
99
 
100
- # 归一化概率
101
  total = sum(self.probabilities)
102
  self.probabilities = [p / total for p in self.probabilities]
103
 
@@ -111,14 +100,34 @@ class PreTrainDataset(IterableDataset):
111
  'streaming': config.get('streaming', self.streaming),
112
  'cache_dir': HF_CACHE_DIR,
113
  }
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  if 'config' in config:
116
  load_kwargs['name'] = config['config']
117
 
 
118
  ds = load_dataset(**load_kwargs)
119
  return ds
 
120
  except Exception as e:
121
  logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
 
 
122
  return None
123
 
124
  def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
@@ -130,21 +139,25 @@ class PreTrainDataset(IterableDataset):
130
  return None
131
 
132
  text = text.strip()
133
- if len(text) < 10:
134
  return None
135
-
136
- # Tokenize
137
  encoding = self.tokenizer(
138
  text,
139
- max_length=self.max_length,
140
  truncation=True,
141
- padding='max_length',
142
- return_tensors='pt'
 
143
  )
144
 
 
 
 
 
145
  return {
146
- 'input_ids': encoding['input_ids'].squeeze(0),
147
- 'attention_mask': encoding['attention_mask'].squeeze(0),
148
  'type': 'text'
149
  }
150
  except Exception as e:
@@ -161,8 +174,6 @@ class PreTrainDataset(IterableDataset):
161
 
162
  if not text or image is None:
163
  return None
164
-
165
- # 处理图像
166
  if isinstance(image, str):
167
  try:
168
  response = requests.get(image, timeout=5)
@@ -174,11 +185,8 @@ class PreTrainDataset(IterableDataset):
174
  image = image.convert('RGB')
175
  else:
176
  return None
177
-
178
- # 转换图像
179
  image_tensor = image_transform(image)
180
-
181
- # Tokenize文本
182
  encoding = self.tokenizer(
183
  text,
184
  max_length=self.max_length,
@@ -198,34 +206,26 @@ class PreTrainDataset(IterableDataset):
198
  return None
199
 
200
  def __iter__(self):
201
- """迭代器"""
202
  worker_info = torch.utils.data.get_worker_info()
203
  if worker_info is not None:
204
- # 多worker时设置不同的随机种子
205
  random.seed(self.seed + worker_info.id)
206
  np.random.seed(self.seed + worker_info.id)
207
  else:
208
  random.seed(self.seed)
209
  np.random.seed(self.seed)
210
 
211
- # 创建数据集迭代器
212
  iterators = [iter(ds) for _, ds, _ in self.datasets]
213
  self.samples_generated = 0
214
 
215
  while True:
216
- # 检查是否达到最大样本数
217
  if self.max_samples and self.samples_generated >= self.max_samples:
218
  break
219
 
220
  try:
221
- # 根据概率选择数据集
222
  idx = np.random.choice(len(self.datasets), p=self.probabilities)
223
  name, _, config = self.datasets[idx]
224
-
225
- # 从选中的数据集获取样本
226
  sample = next(iterators[idx])
227
 
228
- # 处理样本
229
  processed = None
230
  if config.get('type') in ['text', 'code']:
231
  processed = self._process_text_sample(sample, config)
@@ -240,7 +240,6 @@ class PreTrainDataset(IterableDataset):
240
  yield processed
241
 
242
  except StopIteration:
243
- # 重新创建迭代器
244
  try:
245
  iterators[idx] = iter(self.datasets[idx][1])
246
  except Exception as e:
@@ -269,7 +268,6 @@ class PostTrainDataset(Dataset):
269
  self.max_length = max_length
270
  self.split = split
271
 
272
- # 获取混合配置
273
  if mix_name not in POSTTRAIN_MIX:
274
  raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
275
 
@@ -283,7 +281,6 @@ class PostTrainDataset(Dataset):
283
  logger.info(f"Loading posttrain mix: {mix_name}")
284
  logger.info(f" Datasets: {dataset_names}")
285
 
286
- # 加载和合并数据集
287
  all_datasets = []
288
 
289
  for name in dataset_names:
@@ -306,14 +303,12 @@ class PostTrainDataset(Dataset):
306
 
307
  ds = load_dataset(**load_kwargs)
308
 
309
- # 限制样本数
310
  if config.get('max_samples'):
311
  if hasattr(ds, 'take'):
312
  ds = ds.take(config['max_samples'])
313
  elif hasattr(ds, 'select'):
314
  ds = ds.select(range(min(len(ds), config['max_samples'])))
315
 
316
- # 添加数据集标识
317
  def add_source(example):
318
  example['_source'] = name
319
  example['_config'] = config
@@ -329,14 +324,12 @@ class PostTrainDataset(Dataset):
329
  logger.error(f"Error loading {name}: {e}")
330
  continue
331
 
332
- # 合并数据集
333
  if not all_datasets:
334
  raise ValueError("No datasets loaded successfully")
335
 
336
  if len(all_datasets) == 1:
337
  self.dataset = all_datasets[0]
338
  else:
339
- # 交织数据集
340
  probabilities = [w / sum(weights[:len(all_datasets)])
341
  for w in weights[:len(all_datasets)]]
342
  self.dataset = interleave_datasets(
@@ -345,8 +338,6 @@ class PostTrainDataset(Dataset):
345
  seed=42,
346
  stopping_strategy='all_exhausted'
347
  )
348
-
349
- # 限制总样本数
350
  if max_samples and hasattr(self.dataset, '__len__'):
351
  actual_len = min(len(self.dataset), max_samples)
352
  self.dataset = self.dataset.select(range(actual_len))
@@ -355,7 +346,6 @@ class PostTrainDataset(Dataset):
355
  logger.info(f"Total samples: {dataset_len}")
356
 
357
  def _format_instruction(self, sample: Dict, config: Dict) -> str:
358
- """格式化instruction"""
359
  try:
360
  data_type = config.get('type', 'instruction')
361
 
@@ -367,8 +357,6 @@ class PostTrainDataset(Dataset):
367
  instruction = sample.get(instruction_field, '')
368
  input_text = sample.get(input_field, '')
369
  context = sample.get(context_field, '') if context_field else ''
370
-
371
- # 构建prompt
372
  prompt_parts = [f"Instruction: {instruction}"]
373
 
374
  if context:
@@ -385,40 +373,47 @@ class PostTrainDataset(Dataset):
385
  conversations = sample['conversations']
386
  if isinstance(conversations, list) and len(conversations) > 0:
387
  dialogue = []
388
- for conv in conversations[:-1]:
389
- role = conv.get('from', 'user')
390
- content = conv.get('value', '')
 
 
 
 
391
  dialogue.append(f"{role}: {content}")
392
  return "\n".join(dialogue) + "\nassistant:"
393
 
394
  elif 'messages' in sample:
395
- # 标准消息格式
396
  messages = sample['messages']
397
  if isinstance(messages, list) and len(messages) > 0:
398
  dialogue = []
399
- for msg in messages[:-1]:
 
 
 
 
400
  role = msg.get('role', 'user')
401
  content = msg.get('content', '')
402
  dialogue.append(f"{role}: {content}")
403
  return "\n".join(dialogue) + "\nassistant:"
404
 
405
- # 如果没有标准格式,尝试使用text字段
406
  return sample.get('text', '')
407
 
408
  elif data_type == 'code_instruction':
409
- # 代码instruction格式
410
  instruction_field = config.get('instruction_field', 'instruction')
411
  instruction = sample.get(instruction_field, '')
412
  return f"### Instruction:\n{instruction}\n### Response:"
413
 
414
  elif data_type == 'multimodal_instruction':
415
- # 多模态instruction
416
  instruction_field = config.get('instruction_field', 'conversations')
417
  conversations = sample.get(instruction_field, [])
418
  if isinstance(conversations, list) and len(conversations) > 0:
419
- # 提取对话历史(除了最后一条回复)
420
  dialogue = []
421
- for conv in conversations[:-1]:
 
 
 
 
422
  role = conv.get('from', 'user')
423
  content = conv.get('value', '')
424
  dialogue.append(f"{role}: {content}")
@@ -432,6 +427,7 @@ class PostTrainDataset(Dataset):
432
  return ""
433
 
434
  def _get_response(self, sample: Dict, config: Dict) -> str:
 
435
  try:
436
  data_type = config.get('type', 'instruction')
437
 
@@ -444,20 +440,33 @@ class PostTrainDataset(Dataset):
444
  if 'conversations' in sample:
445
  conversations = sample['conversations']
446
  if isinstance(conversations, list) and len(conversations) > 0:
447
- return conversations[-1].get('value', '')
448
-
 
 
 
 
 
 
 
 
 
449
  elif 'messages' in sample:
450
  messages = sample['messages']
451
  if isinstance(messages, list) and len(messages) > 0:
452
  return messages[-1].get('content', '')
453
-
454
  return ""
455
 
456
  elif data_type == 'multimodal_instruction':
457
  instruction_field = config.get('instruction_field', 'conversations')
458
  conversations = sample.get(instruction_field, [])
459
  if isinstance(conversations, list) and len(conversations) > 0:
460
- return conversations[-1].get('value', '')
 
 
 
 
 
461
  return ""
462
 
463
  else:
@@ -473,75 +482,47 @@ class PostTrainDataset(Dataset):
473
  def __getitem__(self, idx):
474
  try:
475
  sample = self.dataset[idx]
476
-
477
- # 获取配置
478
  if '_config' not in sample:
479
  logger.warning(f"Sample at index {idx} missing _config")
480
  return None
481
 
482
  config = sample['_config']
483
-
484
- # 格式化 instruction 和 response
485
  instruction_text = self._format_instruction(sample, config)
486
  response_text = self._get_response(sample, config)
487
 
488
  if not instruction_text or not response_text:
489
  return None
490
-
491
  pad_token_id = self.tokenizer.pad_token_id
492
  if pad_token_id is None:
493
  pad_token_id = self.tokenizer.eos_token_id
494
- instruction_max_len = self.max_length // 2
495
-
496
- # Tokenize 不做 padding,手动处理
497
  instruction_enc = self.tokenizer(
498
  instruction_text,
499
  truncation=True,
500
  max_length=instruction_max_len,
501
- add_special_tokens=False,
502
- return_tensors='pt'
503
  )
504
- instr_ids = instruction_enc['input_ids'].squeeze(0)
505
-
506
- # Instruction 手动 Padding
507
- instr_len = instr_ids.size(0)
508
- if instr_len < instruction_max_len:
509
- padding = torch.full((instruction_max_len - instr_len,), pad_token_id, dtype=torch.long)
510
- instr_ids = torch.cat([instr_ids, padding])
511
-
512
- instr_mask = torch.cat([torch.ones(instr_len, dtype=torch.long), torch.zeros(instruction_max_len - instr_len, dtype=torch.long)])
513
- else:
514
- instr_mask = torch.ones(instruction_max_len, dtype=torch.long)
515
 
516
- response_max_len = self.max_length // 2
517
 
518
- # Tokenize: 预留1个位置给EOS
519
  response_enc = self.tokenizer(
520
  response_text,
521
  truncation=True,
522
  max_length=response_max_len - 1,
523
  add_special_tokens=False,
524
- return_tensors='pt'
525
  )
526
- resp_ids = response_enc['input_ids'].squeeze(0)
527
-
528
- eos_token = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
529
- resp_ids = torch.cat([resp_ids, eos_token])
530
-
531
- # Response 手动 Padding
532
- curr_resp_len = resp_ids.size(0)
533
- if curr_resp_len < response_max_len:
534
- padding = torch.full((response_max_len - curr_resp_len,), pad_token_id, dtype=torch.long)
535
- resp_ids = torch.cat([resp_ids, padding])
536
- resp_mask = torch.cat([torch.ones(curr_resp_len, dtype=torch.long), torch.zeros(response_max_len - curr_resp_len, dtype=torch.long)])
537
- else:
538
- resp_mask = torch.ones(response_max_len, dtype=torch.long)
539
-
540
  result = {
541
  'instruction': instr_ids,
542
  'response': resp_ids,
543
- 'instruction_mask': instr_mask,
544
- 'response_mask': resp_mask,
545
  'task': sample.get('_source', 'unknown'),
546
  'modality_data': None
547
  }
@@ -564,150 +545,93 @@ class PostTrainDataset(Dataset):
564
  traceback.print_exc()
565
  return None
566
 
 
567
 
568
- class PreferenceDataset(Dataset):
569
- def __init__(
570
- self,
571
- dataset_name: str = 'hh_rlhf',
572
- tokenizer=None,
573
- max_length: int = 1024,
574
- max_samples: Optional[int] = None,
575
- split: str = 'train'
576
- ):
577
- super().__init__()
578
-
579
- if tokenizer is None:
580
- raise ValueError("tokenizer cannot be None")
581
-
582
- self.tokenizer = tokenizer
583
- self.max_length = max_length
584
-
585
- if dataset_name not in POSTTRAIN_DATASETS:
586
- raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(POSTTRAIN_DATASETS.keys())}")
587
-
588
- config = POSTTRAIN_DATASETS[dataset_name]
589
- if config.get('type') != 'preference':
590
- raise ValueError(f"{dataset_name} is not a preference dataset (type: {config.get('type')})")
591
-
592
- logger.info(f"Loading preference dataset: {dataset_name}")
593
 
594
- load_kwargs = {
595
- 'path': config['hf_path'],
596
- 'split': split,
597
- 'cache_dir': HF_CACHE_DIR,
598
  }
599
-
600
- if 'config' in config:
601
- load_kwargs['name'] = config['config']
602
-
603
- self.dataset = load_dataset(**load_kwargs)
604
-
605
- self.chosen_field = config.get('chosen_field', 'chosen')
606
- self.rejected_field = config.get('rejected_field', 'rejected')
607
-
608
- if max_samples and len(self.dataset) > max_samples:
609
- self.dataset = self.dataset.select(range(max_samples))
610
-
611
- logger.info(f"Loaded {len(self.dataset)} preference pairs")
612
 
613
- def __len__(self):
614
- return len(self.dataset)
615
-
616
- def __getitem__(self, idx):
617
- try:
618
- sample = self.dataset[idx]
619
-
620
- chosen_text = sample.get(self.chosen_field, '')
621
- rejected_text = sample.get(self.rejected_field, '')
622
-
623
- if not chosen_text or not rejected_text:
624
- return None
625
-
626
- # Tokenize
627
- chosen_enc = self.tokenizer(
628
- chosen_text,
629
- max_length=self.max_length,
630
- truncation=True,
631
- padding='max_length',
632
- return_tensors='pt'
633
- )
634
-
635
- rejected_enc = self.tokenizer(
636
- rejected_text,
637
- max_length=self.max_length,
638
- truncation=True,
639
- padding='max_length',
640
- return_tensors='pt'
641
- )
642
-
643
- return (
644
- chosen_enc['input_ids'].squeeze(0),
645
- rejected_enc['input_ids'].squeeze(0),
646
- chosen_enc['attention_mask'].squeeze(0),
647
- rejected_enc['attention_mask'].squeeze(0)
648
- )
649
 
650
- except Exception as e:
651
- logger.debug(f"Error getting preference item at index {idx}: {e}")
652
  return None
653
-
654
-
655
- def collate_fn_v2(batch):
656
- batch = [item for item in batch if item is not None]
657
-
658
- if not batch:
659
- logger.warning("Empty batch after filtering None values")
660
- # 返回一个空的占位batch而不是None
661
- return {
662
- 'input_ids': torch.empty(0),
663
- 'attention_mask': torch.empty(0)
664
- }
665
-
666
- # 检查是否是preference数据
667
- if isinstance(batch[0], tuple):
668
- if len(batch[0]) == 4: # 包含attention_mask
669
- chosen = torch.stack([item[0] for item in batch])
670
- rejected = torch.stack([item[1] for item in batch])
671
- chosen_mask = torch.stack([item[2] for item in batch])
672
- rejected_mask = torch.stack([item[3] for item in batch])
673
- return {
674
- 'chosen': chosen,
675
- 'rejected': rejected,
676
- 'chosen_mask': chosen_mask,
677
- 'rejected_mask': rejected_mask
678
- }
679
- else:
680
- chosen = torch.stack([item[0] for item in batch])
681
- rejected = torch.stack([item[1] for item in batch])
682
- return {'chosen': chosen, 'rejected': rejected}
683
-
684
- keys = batch[0].keys()
685
- collated = {}
686
-
687
- for key in keys:
688
- if key in ['instruction', 'response', 'instruction_mask',
689
- 'response_mask', 'input_ids', 'attention_mask']:
690
- tensors = [item[key] for item in batch if item.get(key) is not None]
691
- if tensors:
692
- collated[key] = torch.stack(tensors)
693
- else:
694
- collated[key] = None
695
- elif key == 'modality_data':
696
- # 处理多模态数据
697
- modality_list = [item[key] for item in batch if item.get(key) is not None]
698
- if modality_list and any(m is not None for m in modality_list):
699
- # 收集图像
700
- images = [m.get('image') for m in modality_list if m and 'image' in m]
701
- if images:
702
- collated[key] = {'image': torch.stack(images)}
703
- else:
704
- collated[key] = None
705
  else:
706
- collated[key] = None
707
  else:
708
- collated[key] = [item[key] for item in batch]
709
 
710
- return collated
 
 
 
 
711
 
712
 
713
  def create_pretrain_dataloader(
@@ -718,18 +642,26 @@ def create_pretrain_dataloader(
718
  max_length: int = 2048,
719
  max_samples: Optional[int] = None
720
  ):
 
 
 
 
721
  dataset = PreTrainDataset(
722
  mix_name=mix_name,
723
  tokenizer=tokenizer,
724
  max_length=max_length,
725
- streaming=True,
726
  max_samples=max_samples
727
  )
 
 
 
728
  return DataLoader(
729
  dataset,
730
  batch_size=batch_size,
731
  num_workers=num_workers,
732
- collate_fn=collate_fn_v2
 
733
  )
734
 
735
 
@@ -743,6 +675,10 @@ def create_posttrain_dataloader(
743
  split: str = 'train',
744
  shuffle: bool = True
745
  ):
 
 
 
 
746
  dataset = PostTrainDataset(
747
  mix_name=mix_name,
748
  tokenizer=tokenizer,
@@ -750,14 +686,16 @@ def create_posttrain_dataloader(
750
  max_samples=max_samples,
751
  split=split
752
  )
 
 
753
  return DataLoader(
754
  dataset,
755
  batch_size=batch_size,
756
  shuffle=shuffle,
757
  num_workers=num_workers,
758
- collate_fn=collate_fn_v2,
759
  pin_memory=True,
760
- drop_last=False
761
  )
762
 
763
 
@@ -783,6 +721,6 @@ def create_preference_dataloader(
783
  batch_size=batch_size,
784
  shuffle=shuffle,
785
  num_workers=num_workers,
786
- collate_fn=collate_fn_v2,
787
  pin_memory=True
788
  )
 
12
  from io import BytesIO
13
  from torchvision import transforms
14
  import logging
15
+ import os
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
28
  DATASET_CACHE_DIR,
29
  HF_CACHE_DIR
30
  )
 
 
31
  image_transform = transforms.Compose([
32
  transforms.Resize((224, 224)),
33
  transforms.ToTensor(),
 
56
  self.max_samples = max_samples
57
  self.samples_generated = 0
58
 
 
59
  if mix_name not in PRETRAIN_MIX:
60
  raise ValueError(f"Unknown mix: {mix_name}. Available: {list(PRETRAIN_MIX.keys())}")
61
 
 
65
 
66
  if not dataset_names:
67
  raise ValueError(f"No datasets found in mix: {mix_name}")
 
 
 
 
 
 
68
  self.datasets = []
69
  self.probabilities = []
70
 
 
87
  if not self.datasets:
88
  raise ValueError("No datasets loaded successfully")
89
 
 
90
  total = sum(self.probabilities)
91
  self.probabilities = [p / total for p in self.probabilities]
92
 
 
100
  'streaming': config.get('streaming', self.streaming),
101
  'cache_dir': HF_CACHE_DIR,
102
  }
103
+ if 'data_files' in config:
104
+ files = config['data_files']
105
+ if isinstance(files, list):
106
+ for f in files:
107
+ if not os.path.exists(f):
108
+ logger.error(f" Data file not found in list: {f}")
109
+ return None
110
+ logger.info(f" Verified {len(files)} local files.")
111
+
112
+ elif isinstance(files, str):
113
+ if not os.path.exists(files):
114
+ logger.error(f" Data file not found: {files}")
115
+ return None
116
+ logger.info(f" Verified local file: {files}")
117
+
118
+ load_kwargs['data_files'] = files
119
+
120
  if 'config' in config:
121
  load_kwargs['name'] = config['config']
122
 
123
+ logger.info(f" Loading HF dataset: {config['hf_path']}...")
124
  ds = load_dataset(**load_kwargs)
125
  return ds
126
+
127
  except Exception as e:
128
  logger.error(f"Failed to load {config.get('hf_path', 'unknown')}: {e}")
129
+ import traceback
130
+ traceback.print_exc()
131
  return None
132
 
133
  def _process_text_sample(self, sample: Dict, config: Dict) -> Optional[Dict]:
 
139
  return None
140
 
141
  text = text.strip()
142
+ if len(text) < 10:
143
  return None
144
+ max_input_len = self.max_length - 1
145
+
146
  encoding = self.tokenizer(
147
  text,
148
+ max_length=max_input_len,
149
  truncation=True,
150
+ padding=False,
151
+ add_special_tokens=False,
152
+ return_tensors=None
153
  )
154
 
155
+ input_ids = encoding['input_ids']
156
+ input_ids.append(self.tokenizer.eos_token_id)
157
+ input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
158
+
159
  return {
160
+ 'input_ids': input_ids_tensor,
 
161
  'type': 'text'
162
  }
163
  except Exception as e:
 
174
 
175
  if not text or image is None:
176
  return None
 
 
177
  if isinstance(image, str):
178
  try:
179
  response = requests.get(image, timeout=5)
 
185
  image = image.convert('RGB')
186
  else:
187
  return None
 
 
188
  image_tensor = image_transform(image)
189
+
 
190
  encoding = self.tokenizer(
191
  text,
192
  max_length=self.max_length,
 
206
  return None
207
 
208
  def __iter__(self):
 
209
  worker_info = torch.utils.data.get_worker_info()
210
  if worker_info is not None:
 
211
  random.seed(self.seed + worker_info.id)
212
  np.random.seed(self.seed + worker_info.id)
213
  else:
214
  random.seed(self.seed)
215
  np.random.seed(self.seed)
216
 
 
217
  iterators = [iter(ds) for _, ds, _ in self.datasets]
218
  self.samples_generated = 0
219
 
220
  while True:
 
221
  if self.max_samples and self.samples_generated >= self.max_samples:
222
  break
223
 
224
  try:
 
225
  idx = np.random.choice(len(self.datasets), p=self.probabilities)
226
  name, _, config = self.datasets[idx]
 
 
227
  sample = next(iterators[idx])
228
 
 
229
  processed = None
230
  if config.get('type') in ['text', 'code']:
231
  processed = self._process_text_sample(sample, config)
 
240
  yield processed
241
 
242
  except StopIteration:
 
243
  try:
244
  iterators[idx] = iter(self.datasets[idx][1])
245
  except Exception as e:
 
268
  self.max_length = max_length
269
  self.split = split
270
 
 
271
  if mix_name not in POSTTRAIN_MIX:
272
  raise ValueError(f"Unknown mix: {mix_name}. Available: {list(POSTTRAIN_MIX.keys())}")
273
 
 
281
  logger.info(f"Loading posttrain mix: {mix_name}")
282
  logger.info(f" Datasets: {dataset_names}")
283
 
 
284
  all_datasets = []
285
 
286
  for name in dataset_names:
 
303
 
304
  ds = load_dataset(**load_kwargs)
305
 
 
306
  if config.get('max_samples'):
307
  if hasattr(ds, 'take'):
308
  ds = ds.take(config['max_samples'])
309
  elif hasattr(ds, 'select'):
310
  ds = ds.select(range(min(len(ds), config['max_samples'])))
311
 
 
312
  def add_source(example):
313
  example['_source'] = name
314
  example['_config'] = config
 
324
  logger.error(f"Error loading {name}: {e}")
325
  continue
326
 
 
327
  if not all_datasets:
328
  raise ValueError("No datasets loaded successfully")
329
 
330
  if len(all_datasets) == 1:
331
  self.dataset = all_datasets[0]
332
  else:
 
333
  probabilities = [w / sum(weights[:len(all_datasets)])
334
  for w in weights[:len(all_datasets)]]
335
  self.dataset = interleave_datasets(
 
338
  seed=42,
339
  stopping_strategy='all_exhausted'
340
  )
 
 
341
  if max_samples and hasattr(self.dataset, '__len__'):
342
  actual_len = min(len(self.dataset), max_samples)
343
  self.dataset = self.dataset.select(range(actual_len))
 
346
  logger.info(f"Total samples: {dataset_len}")
347
 
348
  def _format_instruction(self, sample: Dict, config: Dict) -> str:
 
349
  try:
350
  data_type = config.get('type', 'instruction')
351
 
 
357
  instruction = sample.get(instruction_field, '')
358
  input_text = sample.get(input_field, '')
359
  context = sample.get(context_field, '') if context_field else ''
 
 
360
  prompt_parts = [f"Instruction: {instruction}"]
361
 
362
  if context:
 
373
  conversations = sample['conversations']
374
  if isinstance(conversations, list) and len(conversations) > 0:
375
  dialogue = []
376
+ last_role = conversations[-1].get('role', conversations[-1].get('from', 'user')).lower()
377
+ upto = len(conversations)
378
+ if last_role == 'assistant':
379
+ upto = len(conversations) - 1
380
+ for conv in conversations[:upto]:
381
+ role = conv.get('role', conv.get('from', 'user'))
382
+ content = conv.get('content', conv.get('value', ''))
383
  dialogue.append(f"{role}: {content}")
384
  return "\n".join(dialogue) + "\nassistant:"
385
 
386
  elif 'messages' in sample:
 
387
  messages = sample['messages']
388
  if isinstance(messages, list) and len(messages) > 0:
389
  dialogue = []
390
+ last_role = messages[-1].get('role', 'user').lower()
391
+ upto = len(messages)
392
+ if last_role == 'assistant':
393
+ upto = len(messages) - 1
394
+ for msg in messages[:upto]:
395
  role = msg.get('role', 'user')
396
  content = msg.get('content', '')
397
  dialogue.append(f"{role}: {content}")
398
  return "\n".join(dialogue) + "\nassistant:"
399
 
 
400
  return sample.get('text', '')
401
 
402
  elif data_type == 'code_instruction':
 
403
  instruction_field = config.get('instruction_field', 'instruction')
404
  instruction = sample.get(instruction_field, '')
405
  return f"### Instruction:\n{instruction}\n### Response:"
406
 
407
  elif data_type == 'multimodal_instruction':
 
408
  instruction_field = config.get('instruction_field', 'conversations')
409
  conversations = sample.get(instruction_field, [])
410
  if isinstance(conversations, list) and len(conversations) > 0:
 
411
  dialogue = []
412
+ last_role = conversations[-1].get('from', 'user').lower() if isinstance(conversations[-1].get('from', 'user'), str) else 'user'
413
+ upto = len(conversations)
414
+ if last_role == 'assistant':
415
+ upto = len(conversations) - 1
416
+ for conv in conversations[:upto]:
417
  role = conv.get('from', 'user')
418
  content = conv.get('value', '')
419
  dialogue.append(f"{role}: {content}")
 
427
  return ""
428
 
429
  def _get_response(self, sample: Dict, config: Dict) -> str:
430
+ """获取响应(兼容 <think>/<answer> 标签)"""
431
  try:
432
  data_type = config.get('type', 'instruction')
433
 
 
440
  if 'conversations' in sample:
441
  conversations = sample['conversations']
442
  if isinstance(conversations, list) and len(conversations) > 0:
443
+ last_turn = conversations[-1]
444
+ content = last_turn.get('content', last_turn.get('value', ''))
445
+ if not isinstance(content, str):
446
+ return ''
447
+
448
+
449
+ # 仅当最后一条 role 为 assistant 时返回
450
+ role = last_turn.get('role', last_turn.get('from', '')).lower()
451
+ if role != 'assistant':
452
+ return ''
453
+ return str(content).strip() if content else ""
454
  elif 'messages' in sample:
455
  messages = sample['messages']
456
  if isinstance(messages, list) and len(messages) > 0:
457
  return messages[-1].get('content', '')
 
458
  return ""
459
 
460
  elif data_type == 'multimodal_instruction':
461
  instruction_field = config.get('instruction_field', 'conversations')
462
  conversations = sample.get(instruction_field, [])
463
  if isinstance(conversations, list) and len(conversations) > 0:
464
+ last = conversations[-1].get('value', '')
465
+ import re
466
+ m = re.search(r'<answer>([\\s\\S]*?)</answer>', last, re.IGNORECASE)
467
+ if m:
468
+ return m.group(1).strip()
469
+ return re.sub(r'<think>[\\s\\S]*?</think>', '', last, flags=re.IGNORECASE).strip()
470
  return ""
471
 
472
  else:
 
482
  def __getitem__(self, idx):
483
  try:
484
  sample = self.dataset[idx]
 
 
485
  if '_config' not in sample:
486
  logger.warning(f"Sample at index {idx} missing _config")
487
  return None
488
 
489
  config = sample['_config']
 
 
490
  instruction_text = self._format_instruction(sample, config)
491
  response_text = self._get_response(sample, config)
492
 
493
  if not instruction_text or not response_text:
494
  return None
 
495
  pad_token_id = self.tokenizer.pad_token_id
496
  if pad_token_id is None:
497
  pad_token_id = self.tokenizer.eos_token_id
498
+
499
+ instruction_max_len = 256
500
+
501
  instruction_enc = self.tokenizer(
502
  instruction_text,
503
  truncation=True,
504
  max_length=instruction_max_len,
505
+ add_special_tokens=False,
506
+ return_tensors=None
507
  )
508
+ instr_ids_list = instruction_enc['input_ids']
509
+ instr_ids = torch.tensor(instr_ids_list, dtype=torch.long)
 
 
 
 
 
 
 
 
 
510
 
511
+ response_max_len = self.max_length - len(instr_ids)
512
 
 
513
  response_enc = self.tokenizer(
514
  response_text,
515
  truncation=True,
516
  max_length=response_max_len - 1,
517
  add_special_tokens=False,
518
+ return_tensors=None
519
  )
520
+ resp_ids_list = response_enc['input_ids']
521
+ resp_ids_list = resp_ids_list + [self.tokenizer.eos_token_id]
522
+ resp_ids = torch.tensor(resp_ids_list, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
523
  result = {
524
  'instruction': instr_ids,
525
  'response': resp_ids,
 
 
526
  'task': sample.get('_source', 'unknown'),
527
  'modality_data': None
528
  }
 
545
  traceback.print_exc()
546
  return None
547
 
548
+ from torch.nn.utils.rnn import pad_sequence
549
 
550
+ class DynamicCollate:
551
+ def __init__(self, pad_token_id: int):
552
+ self.pad_token_id = pad_token_id
553
+
554
+ def __call__(self, batch):
555
+ batch = [item for item in batch if item is not None]
556
+ if not batch:
557
+ return {
558
+ 'input_ids': torch.empty(0),
559
+ 'attention_mask': torch.empty(0)
560
+ }
561
+ input_ids_list = [item['input_ids'] for item in batch]
562
+ padded_input_ids = pad_sequence(
563
+ input_ids_list,
564
+ batch_first=True,
565
+ padding_value=self.pad_token_id
566
+ )
567
+ attention_mask = (padded_input_ids != self.pad_token_id).long()
 
 
 
 
 
 
 
568
 
569
+ return {
570
+ 'input_ids': padded_input_ids,
571
+ 'attention_mask': attention_mask
 
572
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
+ def collate_fn_v2_factory(pad_token_id: int):
575
+ def collate_fn_v2(batch):
576
+ batch = [item for item in batch if item is not None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
+ if not batch:
579
+ logger.warning("Empty batch after filtering None values")
580
  return None
581
+
582
+ if isinstance(batch[0], tuple):
583
+ if len(batch[0]) == 4:
584
+ chosen = torch.stack([item[0] for item in batch])
585
+ rejected = torch.stack([item[1] for item in batch])
586
+ chosen_mask = torch.stack([item[2] for item in batch])
587
+ rejected_mask = torch.stack([item[3] for item in batch])
588
+ return {
589
+ 'chosen': chosen,
590
+ 'rejected': rejected,
591
+ 'chosen_mask': chosen_mask,
592
+ 'rejected_mask': rejected_mask
593
+ }
594
+ else:
595
+ chosen = torch.stack([item[0] for item in batch])
596
+ rejected = torch.stack([item[1] for item in batch])
597
+ return {'chosen': chosen, 'rejected': rejected}
598
+
599
+ collated = {}
600
+ instr_list = [item['instruction'] for item in batch if item.get('instruction') is not None]
601
+ if instr_list:
602
+ padded_instr = pad_sequence(instr_list, batch_first=True, padding_value=pad_token_id)
603
+ instr_mask = (padded_instr != pad_token_id).long()
604
+ collated['instruction'] = padded_instr
605
+ collated['instruction_mask'] = instr_mask
606
+ else:
607
+ collated['instruction'] = None
608
+ collated['instruction_mask'] = None
609
+
610
+ resp_list = [item['response'] for item in batch if item.get('response') is not None]
611
+ if resp_list:
612
+ padded_resp = pad_sequence(resp_list, batch_first=True, padding_value=pad_token_id)
613
+ resp_mask = (padded_resp != pad_token_id).long()
614
+ collated['response'] = padded_resp
615
+ collated['response_mask'] = resp_mask
616
+ else:
617
+ collated['response'] = None
618
+ collated['response_mask'] = None
619
+
620
+ modality_list = [item.get('modality_data') for item in batch if item.get('modality_data') is not None]
621
+ if modality_list and any(m is not None for m in modality_list):
622
+ images = [m.get('image') for m in modality_list if m and 'image' in m]
623
+ if images:
624
+ collated['modality_data'] = {'image': torch.stack(images)}
 
 
 
 
 
 
 
 
625
  else:
626
+ collated['modality_data'] = None
627
  else:
628
+ collated['modality_data'] = None
629
 
630
+ collated['task'] = [item.get('task', 'unknown') for item in batch]
631
+
632
+ return collated
633
+
634
+ return collate_fn_v2
635
 
636
 
637
  def create_pretrain_dataloader(
 
642
  max_length: int = 2048,
643
  max_samples: Optional[int] = None
644
  ):
645
+
646
+ if tokenizer.pad_token_id is None:
647
+ tokenizer.pad_token_id = tokenizer.eos_token_id
648
+
649
  dataset = PreTrainDataset(
650
  mix_name=mix_name,
651
  tokenizer=tokenizer,
652
  max_length=max_length,
653
+ streaming=True,
654
  max_samples=max_samples
655
  )
656
+
657
+ collate_fn = DynamicCollate(pad_token_id=tokenizer.pad_token_id)
658
+
659
  return DataLoader(
660
  dataset,
661
  batch_size=batch_size,
662
  num_workers=num_workers,
663
+ collate_fn=collate_fn,
664
+ pin_memory=True
665
  )
666
 
667
 
 
675
  split: str = 'train',
676
  shuffle: bool = True
677
  ):
678
+
679
+ if tokenizer.pad_token_id is None:
680
+ tokenizer.pad_token_id = tokenizer.eos_token_id
681
+
682
  dataset = PostTrainDataset(
683
  mix_name=mix_name,
684
  tokenizer=tokenizer,
 
686
  max_samples=max_samples,
687
  split=split
688
  )
689
+ collate_fn = collate_fn_v2_factory(pad_token_id=tokenizer.pad_token_id)
690
+
691
  return DataLoader(
692
  dataset,
693
  batch_size=batch_size,
694
  shuffle=shuffle,
695
  num_workers=num_workers,
696
+ collate_fn=collate_fn,
697
  pin_memory=True,
698
+ drop_last=False
699
  )
700
 
701
 
 
721
  batch_size=batch_size,
722
  shuffle=shuffle,
723
  num_workers=num_workers,
724
+ collate_fn=collate_fn_v2_factory(pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id),
725
  pin_memory=True
726
  )