dkumar15 commited on
Commit
493809a
·
verified ·
1 Parent(s): f6b92b7

Upload training_code/model/dpo_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/model/dpo_data.py +144 -0
training_code/model/dpo_data.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO data pipeline: loads UltraFeedback preference pairs.
3
+
4
+ Each example has a prompt + chosen response + rejected response.
5
+ We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
6
+ chat template, and return them as pairs for DPO training.
7
+ """
8
+
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from datasets import load_dataset
12
+
13
+
14
+ CHAT_TEMPLATE = {
15
+ "user_start": "<|user|>\n",
16
+ "assistant_start": "<|assistant|>\n",
17
+ "turn_end": "\n<|end|>\n",
18
+ }
19
+
20
+
21
+ def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
22
+ """Build chat-templated strings for chosen and rejected."""
23
+ def build(messages):
24
+ text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
25
+ for msg in messages:
26
+ role = msg.get("role", "assistant")
27
+ content = msg.get("content", "").strip()
28
+ if role == "assistant":
29
+ text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
30
+ elif role == "user":
31
+ text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
32
+ return text
33
+
34
+ return build(chosen_msgs), build(rejected_msgs)
35
+
36
+
37
+ class DPODataset(Dataset):
38
+ """
39
+ Loads UltraFeedback preference pairs and tokenizes them.
40
+ Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
41
+ """
42
+
43
+ def __init__(self, tokenizer, max_seq_len=2048, split="train",
44
+ cache_dir=None, max_samples=None):
45
+ self.tokenizer = tokenizer
46
+ self.max_seq_len = max_seq_len
47
+
48
+ special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
49
+ vocab = tokenizer.get_vocab()
50
+ new_tokens = [t for t in special_tokens if t not in vocab]
51
+ if new_tokens:
52
+ tokenizer.add_tokens(new_tokens, special_tokens=True)
53
+
54
+ self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
55
+ self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
56
+ self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
57
+
58
+ print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
59
+ ds = load_dataset(
60
+ "argilla/ultrafeedback-binarized-preferences-cleaned",
61
+ split=split,
62
+ cache_dir=cache_dir,
63
+ )
64
+ if max_samples:
65
+ ds = ds.select(range(min(max_samples, len(ds))))
66
+ print(f"[DPO Data] {len(ds)} preference pairs loaded")
67
+
68
+ self.examples = []
69
+ skipped = 0
70
+ for i, row in enumerate(ds):
71
+ prompt = row.get("prompt", "")
72
+ chosen = row.get("chosen", [])
73
+ rejected = row.get("rejected", [])
74
+
75
+ if not prompt or not chosen or not rejected:
76
+ skipped += 1
77
+ continue
78
+
79
+ chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)
80
+
81
+ chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
82
+ rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)
83
+
84
+ # Truncate if needed
85
+ if len(chosen_ids) > max_seq_len + 1:
86
+ chosen_ids = chosen_ids[:max_seq_len + 1]
87
+ if len(rejected_ids) > max_seq_len + 1:
88
+ rejected_ids = rejected_ids[:max_seq_len + 1]
89
+
90
+ if len(chosen_ids) < 10 or len(rejected_ids) < 10:
91
+ skipped += 1
92
+ continue
93
+
94
+ # Find where the prompt ends (first <|assistant|> token)
95
+ prompt_end = 0
96
+ for j, tid in enumerate(chosen_ids):
97
+ if tid == self.assistant_token_id:
98
+ prompt_end = j + 2 # skip <|assistant|> and \n
99
+ break
100
+
101
+ self.examples.append({
102
+ "chosen_ids": chosen_ids,
103
+ "rejected_ids": rejected_ids,
104
+ "prompt_len": prompt_end,
105
+ })
106
+
107
+ if (i + 1) % 20000 == 0:
108
+ print(f" Processed {i+1} pairs...")
109
+
110
+ print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")
111
+
112
+ def __len__(self):
113
+ return len(self.examples)
114
+
115
+ def __getitem__(self, idx):
116
+ ex = self.examples[idx]
117
+ return {
118
+ "chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
119
+ "rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
120
+ "prompt_len": ex["prompt_len"],
121
+ }
122
+
123
+
124
+ def dpo_collate_fn(batch, pad_id=0):
125
+ """Pad chosen and rejected sequences separately."""
126
+ max_chosen = max(b["chosen_ids"].size(0) for b in batch)
127
+ max_rejected = max(b["rejected_ids"].size(0) for b in batch)
128
+
129
+ chosen_padded = []
130
+ rejected_padded = []
131
+ prompt_lens = []
132
+
133
+ for b in batch:
134
+ c_pad = max_chosen - b["chosen_ids"].size(0)
135
+ r_pad = max_rejected - b["rejected_ids"].size(0)
136
+ chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
137
+ rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
138
+ prompt_lens.append(b["prompt_len"])
139
+
140
+ return {
141
+ "chosen_ids": torch.stack(chosen_padded),
142
+ "rejected_ids": torch.stack(rejected_padded),
143
+ "prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
144
+ }