""" helper script to parse chat datasets into a usable yaml """ import click import yaml from datasets import load_dataset @click.command() @click.argument("dataset", type=str) @click.option("--split", type=str, default="train") def parse_dataset(dataset=None, split="train"): ds_cfg = {} ds_cfg["path"] = dataset ds_cfg["split"] = split ds_cfg["type"] = "chat_template" ds_cfg["chat_template"] = "<<>>" dataset = load_dataset(dataset, split=split) features = dataset.features feature_keys = features.keys() field_messages = None for key in ["conversation", "conversations", "messages"]: if key in feature_keys: field_messages = key break if not field_messages: raise ValueError( f'No conversation field found in dataset: {", ".join(feature_keys)}' ) ds_cfg["field_messages"] = field_messages message_fields = features[field_messages][0].keys() message_property_mappings = {"role": None, "content": None} for key in ["from", "role"]: if key in message_fields: message_property_mappings["role"] = key break if not message_property_mappings["role"]: raise ValueError( f'No role field found in messages: {", ".join(message_fields)}' ) for key in ["content", "text", "value"]: if key in message_fields: message_property_mappings["content"] = key break if not message_property_mappings["content"]: raise ValueError( f'No content field found in messages: {", ".join(message_fields)}' ) ds_cfg["message_property_mappings"] = message_property_mappings print(yaml.dump({"datasets": [ds_cfg]})) if __name__ == "__main__": parse_dataset()