File size: 2,554 Bytes
3270dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""RL dataset for HuggingFace datasets."""

from typing import Dict
import torch
from taoTrain.config import TrainingConfig
from taoTrain.data.hf_base import BaseHFDataset


class RLDataset(BaseHFDataset):
    """Dataset for RL training with prompts."""
    
    def _preprocess(self):
        """Prepare prompts for RL."""
        dataset_config = self.config.dataset
        
        # For RL, we typically just need prompts (no responses)
        # The responses will be generated by the model during training
        
        if dataset_config.prompt_column:
            # Use existing prompt column
            def extract_prompt(example):
                return {"prompt": example[dataset_config.prompt_column]}
            
            self.data = self.data.map(
                extract_prompt,
                remove_columns=self.data.column_names,
                desc="Extracting prompts...",
            )
        else:
            # For general datasets, just use the text column as prompt
            def identity(example):
                return {"prompt": example.get(dataset_config.text_column, "")}
            
            self.data = self.data.map(
                identity,
                remove_columns=self.data.column_names,
                desc="Preparing prompts...",
            )
        
        # Tokenize prompts
        def tokenize_function(examples):
            tokenized = self.tokenizer(
                examples["prompt"],
                truncation=True,
                max_length=self.config.model.max_seq_length,
                padding="max_length",
                return_attention_mask=True,
            )
            return tokenized
        
        self.data = self.data.map(
            tokenize_function,
            batched=True,
            batch_size=100,
            remove_columns=self.data.column_names,
            desc="Tokenizing prompts...",
        )
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get preprocessed prompt."""
        item = self.data[idx]
        
        input_ids = torch.tensor(item["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(item["attention_mask"], dtype=torch.long)
        
        # For RL, we don't have labels yet
        # They're generated during training
        
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            # "labels" will be None or set by the trainer
        }