StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""RL dataset for HuggingFace datasets."""
from typing import Dict
import torch
from taoTrain.config import TrainingConfig
from taoTrain.data.hf_base import BaseHFDataset
class RLDataset(BaseHFDataset):
"""Dataset for RL training with prompts."""
def _preprocess(self):
"""Prepare prompts for RL."""
dataset_config = self.config.dataset
# For RL, we typically just need prompts (no responses)
# The responses will be generated by the model during training
if dataset_config.prompt_column:
# Use existing prompt column
def extract_prompt(example):
return {"prompt": example[dataset_config.prompt_column]}
self.data = self.data.map(
extract_prompt,
remove_columns=self.data.column_names,
desc="Extracting prompts...",
)
else:
# For general datasets, just use the text column as prompt
def identity(example):
return {"prompt": example.get(dataset_config.text_column, "")}
self.data = self.data.map(
identity,
remove_columns=self.data.column_names,
desc="Preparing prompts...",
)
# Tokenize prompts
def tokenize_function(examples):
tokenized = self.tokenizer(
examples["prompt"],
truncation=True,
max_length=self.config.model.max_seq_length,
padding="max_length",
return_attention_mask=True,
)
return tokenized
self.data = self.data.map(
tokenize_function,
batched=True,
batch_size=100,
remove_columns=self.data.column_names,
desc="Tokenizing prompts...",
)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get preprocessed prompt."""
item = self.data[idx]
input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
# For RL, we don't have labels yet
# They're generated during training
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
# "labels" will be None or set by the trainer
}