File size: 5,308 Bytes
493809a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
DPO data pipeline: loads UltraFeedback preference pairs.
Each example has a prompt + chosen response + rejected response.
We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
chat template, and return them as pairs for DPO training.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
CHAT_TEMPLATE = {
"user_start": "<|user|>\n",
"assistant_start": "<|assistant|>\n",
"turn_end": "\n<|end|>\n",
}
def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
"""Build chat-templated strings for chosen and rejected."""
def build(messages):
text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
for msg in messages:
role = msg.get("role", "assistant")
content = msg.get("content", "").strip()
if role == "assistant":
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
elif role == "user":
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
return text
return build(chosen_msgs), build(rejected_msgs)
class DPODataset(Dataset):
"""
Loads UltraFeedback preference pairs and tokenizes them.
Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
"""
def __init__(self, tokenizer, max_seq_len=2048, split="train",
cache_dir=None, max_samples=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
ds = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned",
split=split,
cache_dir=cache_dir,
)
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"[DPO Data] {len(ds)} preference pairs loaded")
self.examples = []
skipped = 0
for i, row in enumerate(ds):
prompt = row.get("prompt", "")
chosen = row.get("chosen", [])
rejected = row.get("rejected", [])
if not prompt or not chosen or not rejected:
skipped += 1
continue
chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)
chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)
# Truncate if needed
if len(chosen_ids) > max_seq_len + 1:
chosen_ids = chosen_ids[:max_seq_len + 1]
if len(rejected_ids) > max_seq_len + 1:
rejected_ids = rejected_ids[:max_seq_len + 1]
if len(chosen_ids) < 10 or len(rejected_ids) < 10:
skipped += 1
continue
# Find where the prompt ends (first <|assistant|> token)
prompt_end = 0
for j, tid in enumerate(chosen_ids):
if tid == self.assistant_token_id:
prompt_end = j + 2 # skip <|assistant|> and \n
break
self.examples.append({
"chosen_ids": chosen_ids,
"rejected_ids": rejected_ids,
"prompt_len": prompt_end,
})
if (i + 1) % 20000 == 0:
print(f" Processed {i+1} pairs...")
print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
return {
"chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
"rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
"prompt_len": ex["prompt_len"],
}
def dpo_collate_fn(batch, pad_id=0):
"""Pad chosen and rejected sequences separately."""
max_chosen = max(b["chosen_ids"].size(0) for b in batch)
max_rejected = max(b["rejected_ids"].size(0) for b in batch)
chosen_padded = []
rejected_padded = []
prompt_lens = []
for b in batch:
c_pad = max_chosen - b["chosen_ids"].size(0)
r_pad = max_rejected - b["rejected_ids"].size(0)
chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
prompt_lens.append(b["prompt_len"])
return {
"chosen_ids": torch.stack(chosen_padded),
"rejected_ids": torch.stack(rejected_padded),
"prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
}
|