"""RL dataset for HuggingFace datasets.""" from typing import Dict import torch from taoTrain.config import TrainingConfig from taoTrain.data.hf_base import BaseHFDataset class RLDataset(BaseHFDataset): """Dataset for RL training with prompts.""" def _preprocess(self): """Prepare prompts for RL.""" dataset_config = self.config.dataset # For RL, we typically just need prompts (no responses) # The responses will be generated by the model during training if dataset_config.prompt_column: # Use existing prompt column def extract_prompt(example): return {"prompt": example[dataset_config.prompt_column]} self.data = self.data.map( extract_prompt, remove_columns=self.data.column_names, desc="Extracting prompts...", ) else: # For general datasets, just use the text column as prompt def identity(example): return {"prompt": example.get(dataset_config.text_column, "")} self.data = self.data.map( identity, remove_columns=self.data.column_names, desc="Preparing prompts...", ) # Tokenize prompts def tokenize_function(examples): tokenized = self.tokenizer( examples["prompt"], truncation=True, max_length=self.config.model.max_seq_length, padding="max_length", return_attention_mask=True, ) return tokenized self.data = self.data.map( tokenize_function, batched=True, batch_size=100, remove_columns=self.data.column_names, desc="Tokenizing prompts...", ) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get preprocessed prompt.""" item = self.data[idx] input_ids = torch.tensor(item["input_ids"], dtype=torch.long) attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long) # For RL, we don't have labels yet # They're generated during training return { "input_ids": input_ids, "attention_mask": attention_mask, # "labels" will be None or set by the trainer }