| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import load_dataset, interleave_datasets |
| from typing import Optional, List |
| import logging |
| import os |
|
|
| logger = logging.getLogger(__name__) |
|
|
| from data_config import ( |
| GRPO_DATASETS, |
| GRPO_PROMPT_MIX, |
| HF_CACHE_DIR |
| ) |
|
|
|
|
| class GRPOPromptDataset(Dataset): |
| def __init__( |
| self, |
| mix_name: str = 'default', |
| tokenizer=None, |
| max_length: int = 512, |
| max_samples: Optional[int] = None |
| ): |
| super().__init__() |
| |
| if tokenizer is None: |
| raise ValueError("tokenizer cannot be None") |
| |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| if mix_name not in GRPO_PROMPT_MIX: |
| raise ValueError( |
| f"Unknown mix: {mix_name}. " |
| f"Available: {list(GRPO_PROMPT_MIX.keys())}" |
| ) |
| |
| mix_config = GRPO_PROMPT_MIX[mix_name] |
| dataset_names = mix_config.get('datasets', []) |
| weights = mix_config.get('weights', []) |
|
|
| all_datasets = [] |
| |
| for name in dataset_names: |
| if name not in GRPO_DATASETS: |
| logger.warning(f"Dataset {name} not found") |
| continue |
| |
| config = GRPO_DATASETS[name] |
|
|
| data_file = config.get('data_files') |
| if data_file and not os.path.exists(data_file): |
| logger.error(f"Data file not found: {data_file}") |
| continue |
| |
| try: |
| load_kwargs = { |
| 'path': config['hf_path'], |
| 'split': config.get('split', 'train'), |
| 'cache_dir': HF_CACHE_DIR, |
| } |
| |
| if 'data_files' in config: |
| load_kwargs['data_files'] = config['data_files'] |
| |
| ds = load_dataset(**load_kwargs) |
| if config.get('max_samples'): |
| ds = ds.select(range(min(len(ds), config['max_samples']))) |
| |
| all_datasets.append(ds) |
| |
| except Exception as e: |
| logger.error(f"Error loading {name}: {e}") |
| continue |
| |
| if not all_datasets: |
| raise ValueError("No datasets loaded successfully") |
|
|
| if len(all_datasets) == 1: |
| self.dataset = all_datasets[0] |
| else: |
| probabilities = [w / sum(weights[:len(all_datasets)]) |
| for w in weights[:len(all_datasets)]] |
| self.dataset = interleave_datasets( |
| all_datasets, |
| probabilities=probabilities, |
| seed=42, |
| stopping_strategy='all_exhausted' |
| ) |
| |
| if max_samples and len(self.dataset) > max_samples: |
| self.dataset = self.dataset.select(range(max_samples)) |
| |
| logger.info(f"Total prompts: {len(self.dataset)}") |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| try: |
| sample = self.dataset[idx] |
|
|
| prompt = sample.get('prompt', '') |
| |
| if not prompt: |
| logger.warning(f"Empty prompt at index {idx}") |
| return None |
| encoding = self.tokenizer( |
| prompt, |
| max_length=self.max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt', |
| add_special_tokens=True |
| ) |
| |
| return { |
| 'input_ids': encoding['input_ids'].squeeze(0), |
| 'attention_mask': encoding['attention_mask'].squeeze(0), |
| 'prompt_text': prompt |
| } |
| |
| except Exception as e: |
| logger.debug(f"Error processing sample {idx}: {e}") |
| return None |
|
|
|
|
| def grpo_collate_fn(batch): |
| batch = [item for item in batch if item is not None] |
| |
| if not batch: |
| return None |
| |
| return { |
| 'input_ids': torch.stack([item['input_ids'] for item in batch]), |
| 'attention_mask': torch.stack([item['attention_mask'] for item in batch]), |
| 'prompt_texts': [item['prompt_text'] for item in batch] |
| } |
|
|
|
|
| def create_grpo_prompt_dataloader( |
| mix_name: str = 'default', |
| tokenizer=None, |
| batch_size: int = 4, |
| num_workers: int = 2, |
| max_length: int = 512, |
| max_samples: Optional[int] = None, |
| shuffle: bool = True |
| ): |
| dataset = GRPOPromptDataset( |
| mix_name=mix_name, |
| tokenizer=tokenizer, |
| max_length=max_length, |
| max_samples=max_samples |
| ) |
| |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| collate_fn=grpo_collate_fn, |
| pin_memory=True, |
| drop_last=False |
| ) |