File size: 1,831 Bytes
fcca8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
helper script to parse chat datasets into a usable yaml
"""

import click
import yaml
from datasets import load_dataset


@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
    ds_cfg = {}
    ds_cfg["path"] = dataset
    ds_cfg["split"] = split
    ds_cfg["type"] = "chat_template"
    ds_cfg["chat_template"] = "<<<Replace based on your model>>>"

    dataset = load_dataset(dataset, split=split)
    features = dataset.features
    feature_keys = features.keys()
    field_messages = None
    for key in ["conversation", "conversations", "messages"]:
        if key in feature_keys:
            field_messages = key
            break
    if not field_messages:
        raise ValueError(
            f'No conversation field found in dataset: {", ".join(feature_keys)}'
        )
    ds_cfg["field_messages"] = field_messages

    message_fields = features[field_messages][0].keys()

    message_property_mappings = {"role": None, "content": None}
    for key in ["from", "role"]:
        if key in message_fields:
            message_property_mappings["role"] = key
            break
    if not message_property_mappings["role"]:
        raise ValueError(
            f'No role field found in messages: {", ".join(message_fields)}'
        )

    for key in ["content", "text", "value"]:
        if key in message_fields:
            message_property_mappings["content"] = key
            break
    if not message_property_mappings["content"]:
        raise ValueError(
            f'No content field found in messages: {", ".join(message_fields)}'
        )
    ds_cfg["message_property_mappings"] = message_property_mappings

    print(yaml.dump({"datasets": [ds_cfg]}))


if __name__ == "__main__":
    parse_dataset()