| | import json |
| | import os |
| | from dataclasses import dataclass |
| | from typing import TYPE_CHECKING, List, Literal, Optional |
| |
|
| | from ..extras.constants import DATA_CONFIG |
| | from ..extras.misc import use_modelscope |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from ..hparams import DataArguments |
| |
|
| |
|
| | @dataclass |
| | class DatasetAttr: |
| | load_from: Literal["hf_hub", "ms_hub", "script", "file"] |
| | dataset_name: Optional[str] = None |
| | dataset_sha1: Optional[str] = None |
| | subset: Optional[str] = None |
| | folder: Optional[str] = None |
| | ranking: Optional[bool] = False |
| | formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" |
| |
|
| | system: Optional[str] = None |
| |
|
| | prompt: Optional[str] = "instruction" |
| | query: Optional[str] = "input" |
| | response: Optional[str] = "output" |
| | history: Optional[str] = None |
| |
|
| | messages: Optional[str] = "conversations" |
| | tools: Optional[str] = None |
| |
|
| | role_tag: Optional[str] = "from" |
| | content_tag: Optional[str] = "value" |
| | user_tag: Optional[str] = "human" |
| | assistant_tag: Optional[str] = "gpt" |
| | observation_tag: Optional[str] = "observation" |
| | function_tag: Optional[str] = "function_call" |
| |
|
| | def __repr__(self) -> str: |
| | return self.dataset_name |
| |
|
| |
|
| | def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: |
| | dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else [] |
| | try: |
| | with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: |
| | dataset_info = json.load(f) |
| | except Exception as err: |
| | if data_args.dataset is not None: |
| | raise ValueError( |
| | "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) |
| | ) |
| | dataset_info = None |
| |
|
| | if data_args.interleave_probs is not None: |
| | data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")] |
| |
|
| | dataset_list: List[DatasetAttr] = [] |
| | for name in dataset_names: |
| | if name not in dataset_info: |
| | raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) |
| |
|
| | has_hf_url = "hf_hub_url" in dataset_info[name] |
| | has_ms_url = "ms_hub_url" in dataset_info[name] |
| |
|
| | if has_hf_url or has_ms_url: |
| | if (use_modelscope() and has_ms_url) or (not has_hf_url): |
| | dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]) |
| | else: |
| | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) |
| | elif "script_url" in dataset_info[name]: |
| | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) |
| | else: |
| | dataset_attr = DatasetAttr( |
| | "file", |
| | dataset_name=dataset_info[name]["file_name"], |
| | dataset_sha1=dataset_info[name].get("file_sha1", None), |
| | ) |
| |
|
| | dataset_attr.subset = dataset_info[name].get("subset", None) |
| | dataset_attr.folder = dataset_info[name].get("folder", None) |
| | dataset_attr.ranking = dataset_info[name].get("ranking", False) |
| | dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") |
| |
|
| | if "columns" in dataset_info[name]: |
| | if dataset_attr.formatting == "alpaca": |
| | column_names = ["prompt", "query", "response", "history"] |
| | else: |
| | column_names = ["messages", "tools"] |
| |
|
| | column_names += ["system"] |
| | for column_name in column_names: |
| | setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None)) |
| |
|
| | if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: |
| | for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]: |
| | setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None)) |
| |
|
| | dataset_list.append(dataset_attr) |
| |
|
| | return dataset_list |
| |
|