File size: 1,536 Bytes
fd0b01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# format each (raw, clean) pair as a chatml conversation and tokenize.
# returns dicts the hf datasets api can collate. labels are built by trl's
# DataCollatorForCompletionOnlyLM via the response_template, so here we only
# produce input_ids + attention_mask + the raw text fields trl needs.

from datasets import Dataset

from cleanup.prompts import SYSTEM_PROMPT

# this string MUST exactly match what apply_chat_template emits before the
# assistant turn begins, including the trailing newline. DataCollatorForCompletionOnlyLM
# searches for this template inside input_ids and masks everything before its
# end position with -100, so cross entropy only counts assistant tokens.
RESPONSE_TEMPLATE = "<|im_start|>assistant\n"


def format_chat(pair: dict) -> str:
    # build a single string in qwen's chatml format.
    user = pair["raw"]
    assistant = pair["clean"]
    return (
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n{user}<|im_end|>\n"
        f"<|im_start|>assistant\n{assistant}<|im_end|>"
    )


def to_dataset(pairs: list[dict]) -> Dataset:
    rows = [{"text": format_chat(p)} for p in pairs]
    return Dataset.from_list(rows)


def formatting_func(example: dict) -> list[str]:
    # trl sftrainer expects a callable that takes a row (or batch) and returns
    # a list of strings to tokenize. supports both single example dicts and
    # batched dicts with list values.
    if isinstance(example["text"], list):
        return example["text"]
    return [example["text"]]