File size: 5,911 Bytes
f6b92b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
SFT data pipeline: loads UltraChat 200K and formats into chat template.
Chat template:
<|user|>
What is gravity?
<|end|>
<|assistant|>
Gravity is a fundamental force...
<|end|>
Labels are shifted left by 1 (standard causal LM), with user turns masked.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
CHAT_TEMPLATE = {
"user_start": "<|user|>\n",
"assistant_start": "<|assistant|>\n",
"turn_end": "\n<|end|>\n",
}
def format_conversation(messages):
"""Convert a list of {role, content} messages into our chat template string."""
text = ""
for msg in messages:
role = msg["role"]
content = msg["content"].strip()
if role == "user":
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
elif role == "assistant":
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
return text
class SFTDataset(Dataset):
"""
Loads UltraChat 200K conversations, tokenizes them, builds shifted labels
with user turns masked so the model only learns to generate assistant responses.
"""
def __init__(self, tokenizer, max_seq_len=2048, split="train_sft", cache_dir=None, max_samples=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
print(f"[SFT Data] Loading UltraChat 200K ({split})...")
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, cache_dir=cache_dir)
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"[SFT Data] {len(ds)} conversations loaded")
self.examples = []
skipped = 0
for i, row in enumerate(ds):
messages = row["messages"]
if len(messages) < 2:
skipped += 1
continue
text = format_conversation(messages)
all_ids = tokenizer.encode(text, add_special_tokens=False)
# Need at least max_seq_len+1 for shift, but truncate if longer
if len(all_ids) > max_seq_len + 1:
all_ids = all_ids[:max_seq_len + 1]
if len(all_ids) < 10:
skipped += 1
continue
# Shifted: input = all_ids[:-1], target = all_ids[1:]
input_ids = all_ids[:-1]
target_ids = all_ids[1:]
# Build mask: -100 for user turns, real token id for assistant turns
labels = self._build_shifted_labels(input_ids, target_ids)
self.examples.append((input_ids, labels))
if (i + 1) % 50000 == 0:
print(f" Processed {i+1} conversations...")
print(f"[SFT Data] {len(self.examples)} examples ready, {skipped} skipped")
def _build_shifted_labels(self, input_ids, target_ids):
"""
Walk through the token sequence and track whether we're in a user turn
or assistant turn. Only keep labels for assistant response content.
Masking strategy (applied to the SHIFTED target):
- Everything before and including <|assistant|>\\n: masked
- Assistant response content and <|end|>: TRAIN
- <|user|> and user content until next <|assistant|>: masked
"""
labels = [-100] * len(target_ids)
in_assistant = False
for i, tid in enumerate(input_ids):
if tid == self.assistant_token_id:
# Next token after <|assistant|> is \n, then content starts
in_assistant = True
continue
if tid == self.user_token_id:
in_assistant = False
continue
if in_assistant:
labels[i] = target_ids[i]
# When we hit <|end|> in assistant mode, include it then switch off
if tid == self.end_token_id and in_assistant:
in_assistant = False
return labels
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids, labels = self.examples[idx]
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
def sft_collate_fn(batch, pad_id=0):
"""Pad sequences to the same length within a batch."""
input_ids_list, labels_list = zip(*batch)
max_len = max(ids.size(0) for ids in input_ids_list)
padded_inputs = []
padded_labels = []
for ids, lbl in zip(input_ids_list, labels_list):
pad_len = max_len - ids.size(0)
padded_inputs.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
padded_labels.append(torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]))
return torch.stack(padded_inputs), torch.stack(padded_labels)
def create_sft_dataloader(tokenizer, batch_size=4, max_seq_len=2048,
cache_dir=None, max_samples=None, num_workers=4):
dataset = SFTDataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
split="train_sft",
cache_dir=cache_dir,
max_samples=max_samples,
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
), dataset
|