import torch from torch.utils.data import Dataset, DataLoader from datasets import load_dataset, interleave_datasets from typing import Optional, List import logging import os logger = logging.getLogger(__name__) from data_config import ( GRPO_DATASETS, GRPO_PROMPT_MIX, HF_CACHE_DIR ) class GRPOPromptDataset(Dataset): def __init__( self, mix_name: str = 'default', tokenizer=None, max_length: int = 512, max_samples: Optional[int] = None ): super().__init__() if tokenizer is None: raise ValueError("tokenizer cannot be None") self.tokenizer = tokenizer self.max_length = max_length if mix_name not in GRPO_PROMPT_MIX: raise ValueError( f"Unknown mix: {mix_name}. " f"Available: {list(GRPO_PROMPT_MIX.keys())}" ) mix_config = GRPO_PROMPT_MIX[mix_name] dataset_names = mix_config.get('datasets', []) weights = mix_config.get('weights', []) all_datasets = [] for name in dataset_names: if name not in GRPO_DATASETS: logger.warning(f"Dataset {name} not found") continue config = GRPO_DATASETS[name] data_file = config.get('data_files') if data_file and not os.path.exists(data_file): logger.error(f"Data file not found: {data_file}") continue try: load_kwargs = { 'path': config['hf_path'], 'split': config.get('split', 'train'), 'cache_dir': HF_CACHE_DIR, } if 'data_files' in config: load_kwargs['data_files'] = config['data_files'] ds = load_dataset(**load_kwargs) if config.get('max_samples'): ds = ds.select(range(min(len(ds), config['max_samples']))) all_datasets.append(ds) except Exception as e: logger.error(f"Error loading {name}: {e}") continue if not all_datasets: raise ValueError("No datasets loaded successfully") if len(all_datasets) == 1: self.dataset = all_datasets[0] else: probabilities = [w / sum(weights[:len(all_datasets)]) for w in weights[:len(all_datasets)]] self.dataset = interleave_datasets( all_datasets, probabilities=probabilities, seed=42, stopping_strategy='all_exhausted' ) if max_samples and len(self.dataset) > max_samples: self.dataset = self.dataset.select(range(max_samples)) logger.info(f"Total prompts: {len(self.dataset)}") def __len__(self): return len(self.dataset) def __getitem__(self, idx): try: sample = self.dataset[idx] prompt = sample.get('prompt', '') if not prompt: logger.warning(f"Empty prompt at index {idx}") return None encoding = self.tokenizer( prompt, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt', add_special_tokens=True ) return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0), 'prompt_text': prompt } except Exception as e: logger.debug(f"Error processing sample {idx}: {e}") return None def grpo_collate_fn(batch): batch = [item for item in batch if item is not None] if not batch: return None return { 'input_ids': torch.stack([item['input_ids'] for item in batch]), 'attention_mask': torch.stack([item['attention_mask'] for item in batch]), 'prompt_texts': [item['prompt_text'] for item in batch] } def create_grpo_prompt_dataloader( mix_name: str = 'default', tokenizer=None, batch_size: int = 4, num_workers: int = 2, max_length: int = 512, max_samples: Optional[int] = None, shuffle: bool = True ): dataset = GRPOPromptDataset( mix_name=mix_name, tokenizer=tokenizer, max_length=max_length, max_samples=max_samples ) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=grpo_collate_fn, pin_memory=True, drop_last=False )