MultiModal / grpo_dataloader.py
szxllm's picture
Update grpo_dataloader.py
ddb2b53 verified
Raw
History Blame Contribute Delete
4.94 kB
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, interleave_datasets
from typing import Optional, List
import logging
import os
logger = logging.getLogger(__name__)
from data_config import (
GRPO_DATASETS,
GRPO_PROMPT_MIX,
HF_CACHE_DIR
)
class GRPOPromptDataset(Dataset):
def __init__(
self,
mix_name: str = 'default',
tokenizer=None,
max_length: int = 512,
max_samples: Optional[int] = None
):
super().__init__()
if tokenizer is None:
raise ValueError("tokenizer cannot be None")
self.tokenizer = tokenizer
self.max_length = max_length
if mix_name not in GRPO_PROMPT_MIX:
raise ValueError(
f"Unknown mix: {mix_name}. "
f"Available: {list(GRPO_PROMPT_MIX.keys())}"
)
mix_config = GRPO_PROMPT_MIX[mix_name]
dataset_names = mix_config.get('datasets', [])
weights = mix_config.get('weights', [])
all_datasets = []
for name in dataset_names:
if name not in GRPO_DATASETS:
logger.warning(f"Dataset {name} not found")
continue
config = GRPO_DATASETS[name]
data_file = config.get('data_files')
if data_file and not os.path.exists(data_file):
logger.error(f"Data file not found: {data_file}")
continue
try:
load_kwargs = {
'path': config['hf_path'],
'split': config.get('split', 'train'),
'cache_dir': HF_CACHE_DIR,
}
if 'data_files' in config:
load_kwargs['data_files'] = config['data_files']
ds = load_dataset(**load_kwargs)
if config.get('max_samples'):
ds = ds.select(range(min(len(ds), config['max_samples'])))
all_datasets.append(ds)
except Exception as e:
logger.error(f"Error loading {name}: {e}")
continue
if not all_datasets:
raise ValueError("No datasets loaded successfully")
if len(all_datasets) == 1:
self.dataset = all_datasets[0]
else:
probabilities = [w / sum(weights[:len(all_datasets)])
for w in weights[:len(all_datasets)]]
self.dataset = interleave_datasets(
all_datasets,
probabilities=probabilities,
seed=42,
stopping_strategy='all_exhausted'
)
if max_samples and len(self.dataset) > max_samples:
self.dataset = self.dataset.select(range(max_samples))
logger.info(f"Total prompts: {len(self.dataset)}")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
try:
sample = self.dataset[idx]
prompt = sample.get('prompt', '')
if not prompt:
logger.warning(f"Empty prompt at index {idx}")
return None
encoding = self.tokenizer(
prompt,
max_length=self.max_length,
truncation=True,
padding='max_length',
return_tensors='pt',
add_special_tokens=True
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'prompt_text': prompt
}
except Exception as e:
logger.debug(f"Error processing sample {idx}: {e}")
return None
def grpo_collate_fn(batch):
batch = [item for item in batch if item is not None]
if not batch:
return None
return {
'input_ids': torch.stack([item['input_ids'] for item in batch]),
'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
'prompt_texts': [item['prompt_text'] for item in batch]
}
def create_grpo_prompt_dataloader(
mix_name: str = 'default',
tokenizer=None,
batch_size: int = 4,
num_workers: int = 2,
max_length: int = 512,
max_samples: Optional[int] = None,
shuffle: bool = True
):
dataset = GRPOPromptDataset(
mix_name=mix_name,
tokenizer=tokenizer,
max_length=max_length,
max_samples=max_samples
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=grpo_collate_fn,
pin_memory=True,
drop_last=False
)