File size: 5,911 Bytes
f6b92b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
SFT data pipeline: loads UltraChat 200K and formats into chat template.

Chat template:
  <|user|>
  What is gravity?
  <|end|>
  <|assistant|>
  Gravity is a fundamental force...
  <|end|>

Labels are shifted left by 1 (standard causal LM), with user turns masked.
"""

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset


CHAT_TEMPLATE = {
    "user_start": "<|user|>\n",
    "assistant_start": "<|assistant|>\n",
    "turn_end": "\n<|end|>\n",
}


def format_conversation(messages):
    """Convert a list of {role, content} messages into our chat template string."""
    text = ""
    for msg in messages:
        role = msg["role"]
        content = msg["content"].strip()
        if role == "user":
            text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
        elif role == "assistant":
            text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
    return text


class SFTDataset(Dataset):
    """
    Loads UltraChat 200K conversations, tokenizes them, builds shifted labels
    with user turns masked so the model only learns to generate assistant responses.
    """

    def __init__(self, tokenizer, max_seq_len=2048, split="train_sft", cache_dir=None, max_samples=None):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

        special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
        vocab = tokenizer.get_vocab()
        new_tokens = [t for t in special_tokens if t not in vocab]
        if new_tokens:
            tokenizer.add_tokens(new_tokens, special_tokens=True)

        self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
        self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
        self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]

        print(f"[SFT Data] Loading UltraChat 200K ({split})...")
        ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, cache_dir=cache_dir)
        if max_samples:
            ds = ds.select(range(min(max_samples, len(ds))))
        print(f"[SFT Data] {len(ds)} conversations loaded")

        self.examples = []
        skipped = 0
        for i, row in enumerate(ds):
            messages = row["messages"]
            if len(messages) < 2:
                skipped += 1
                continue

            text = format_conversation(messages)
            all_ids = tokenizer.encode(text, add_special_tokens=False)

            # Need at least max_seq_len+1 for shift, but truncate if longer
            if len(all_ids) > max_seq_len + 1:
                all_ids = all_ids[:max_seq_len + 1]

            if len(all_ids) < 10:
                skipped += 1
                continue

            # Shifted: input = all_ids[:-1], target = all_ids[1:]
            input_ids = all_ids[:-1]
            target_ids = all_ids[1:]

            # Build mask: -100 for user turns, real token id for assistant turns
            labels = self._build_shifted_labels(input_ids, target_ids)
            self.examples.append((input_ids, labels))

            if (i + 1) % 50000 == 0:
                print(f"  Processed {i+1} conversations...")

        print(f"[SFT Data] {len(self.examples)} examples ready, {skipped} skipped")

    def _build_shifted_labels(self, input_ids, target_ids):
        """
        Walk through the token sequence and track whether we're in a user turn
        or assistant turn. Only keep labels for assistant response content.

        Masking strategy (applied to the SHIFTED target):
        - Everything before and including <|assistant|>\\n: masked
        - Assistant response content and <|end|>: TRAIN
        - <|user|> and user content until next <|assistant|>: masked
        """
        labels = [-100] * len(target_ids)
        in_assistant = False

        for i, tid in enumerate(input_ids):
            if tid == self.assistant_token_id:
                # Next token after <|assistant|> is \n, then content starts
                in_assistant = True
                continue

            if tid == self.user_token_id:
                in_assistant = False
                continue

            if in_assistant:
                labels[i] = target_ids[i]

            # When we hit <|end|> in assistant mode, include it then switch off
            if tid == self.end_token_id and in_assistant:
                in_assistant = False

        return labels

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        input_ids, labels = self.examples[idx]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)


def sft_collate_fn(batch, pad_id=0):
    """Pad sequences to the same length within a batch."""
    input_ids_list, labels_list = zip(*batch)
    max_len = max(ids.size(0) for ids in input_ids_list)

    padded_inputs = []
    padded_labels = []
    for ids, lbl in zip(input_ids_list, labels_list):
        pad_len = max_len - ids.size(0)
        padded_inputs.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
        padded_labels.append(torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]))

    return torch.stack(padded_inputs), torch.stack(padded_labels)


def create_sft_dataloader(tokenizer, batch_size=4, max_seq_len=2048,
                          cache_dir=None, max_samples=None, num_workers=4):
    dataset = SFTDataset(
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        split="train_sft",
        cache_dir=cache_dir,
        max_samples=max_samples,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
    ), dataset