File size: 5,308 Bytes
493809a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
DPO data pipeline: loads UltraFeedback preference pairs.

Each example has a prompt + chosen response + rejected response.
We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
chat template, and return them as pairs for DPO training.
"""

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset


CHAT_TEMPLATE = {
    "user_start": "<|user|>\n",
    "assistant_start": "<|assistant|>\n",
    "turn_end": "\n<|end|>\n",
}


def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
    """Build chat-templated strings for chosen and rejected."""
    def build(messages):
        text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
        for msg in messages:
            role = msg.get("role", "assistant")
            content = msg.get("content", "").strip()
            if role == "assistant":
                text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
            elif role == "user":
                text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
        return text

    return build(chosen_msgs), build(rejected_msgs)


class DPODataset(Dataset):
    """
    Loads UltraFeedback preference pairs and tokenizes them.
    Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
    """

    def __init__(self, tokenizer, max_seq_len=2048, split="train",
                 cache_dir=None, max_samples=None):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

        special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
        vocab = tokenizer.get_vocab()
        new_tokens = [t for t in special_tokens if t not in vocab]
        if new_tokens:
            tokenizer.add_tokens(new_tokens, special_tokens=True)

        self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
        self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
        self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]

        print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
        ds = load_dataset(
            "argilla/ultrafeedback-binarized-preferences-cleaned",
            split=split,
            cache_dir=cache_dir,
        )
        if max_samples:
            ds = ds.select(range(min(max_samples, len(ds))))
        print(f"[DPO Data] {len(ds)} preference pairs loaded")

        self.examples = []
        skipped = 0
        for i, row in enumerate(ds):
            prompt = row.get("prompt", "")
            chosen = row.get("chosen", [])
            rejected = row.get("rejected", [])

            if not prompt or not chosen or not rejected:
                skipped += 1
                continue

            chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)

            chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
            rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)

            # Truncate if needed
            if len(chosen_ids) > max_seq_len + 1:
                chosen_ids = chosen_ids[:max_seq_len + 1]
            if len(rejected_ids) > max_seq_len + 1:
                rejected_ids = rejected_ids[:max_seq_len + 1]

            if len(chosen_ids) < 10 or len(rejected_ids) < 10:
                skipped += 1
                continue

            # Find where the prompt ends (first <|assistant|> token)
            prompt_end = 0
            for j, tid in enumerate(chosen_ids):
                if tid == self.assistant_token_id:
                    prompt_end = j + 2  # skip <|assistant|> and \n
                    break

            self.examples.append({
                "chosen_ids": chosen_ids,
                "rejected_ids": rejected_ids,
                "prompt_len": prompt_end,
            })

            if (i + 1) % 20000 == 0:
                print(f"  Processed {i+1} pairs...")

        print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        return {
            "chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
            "rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
            "prompt_len": ex["prompt_len"],
        }


def dpo_collate_fn(batch, pad_id=0):
    """Pad chosen and rejected sequences separately."""
    max_chosen = max(b["chosen_ids"].size(0) for b in batch)
    max_rejected = max(b["rejected_ids"].size(0) for b in batch)

    chosen_padded = []
    rejected_padded = []
    prompt_lens = []

    for b in batch:
        c_pad = max_chosen - b["chosen_ids"].size(0)
        r_pad = max_rejected - b["rejected_ids"].size(0)
        chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
        rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
        prompt_lens.append(b["prompt_len"])

    return {
        "chosen_ids": torch.stack(chosen_padded),
        "rejected_ids": torch.stack(rejected_padded),
        "prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
    }