File size: 1,831 Bytes
fcca8c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys()
message_property_mappings = {"role": None, "content": None}
for key in ["from", "role"]:
if key in message_fields:
message_property_mappings["role"] = key
break
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
for key in ["content", "text", "value"]:
if key in message_fields:
message_property_mappings["content"] = key
break
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_property_mappings"] = message_property_mappings
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()
|