| import re |
| import textwrap |
| from collections import Counter |
| from itertools import groupby |
| from operator import itemgetter |
| from typing import Any, ClassVar, Optional |
|
|
| import yaml |
| from huggingface_hub import DatasetCardData |
|
|
| from ..config import METADATA_CONFIGS_FIELD |
| from ..features import Features |
| from ..info import DatasetInfo, DatasetInfosDict |
| from ..naming import _split_re |
| from ..utils.logging import get_logger |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class _NoDuplicateSafeLoader(yaml.SafeLoader): |
| def _check_no_duplicates_on_constructed_node(self, node): |
| keys = [self.constructed_objects[key_node] for key_node, _ in node.value] |
| keys = [tuple(key) if isinstance(key, list) else key for key in keys] |
| counter = Counter(keys) |
| duplicate_keys = [key for key in counter if counter[key] > 1] |
| if duplicate_keys: |
| raise TypeError(f"Got duplicate yaml keys: {duplicate_keys}") |
|
|
| def construct_mapping(self, node, deep=False): |
| mapping = super().construct_mapping(node, deep=deep) |
| self._check_no_duplicates_on_constructed_node(node) |
| return mapping |
|
|
|
|
| def _split_yaml_from_readme(readme_content: str) -> tuple[Optional[str], str]: |
| full_content = list(readme_content.splitlines()) |
| if full_content and full_content[0] == "---" and "---" in full_content[1:]: |
| sep_idx = full_content[1:].index("---") + 1 |
| yamlblock = "\n".join(full_content[1:sep_idx]) |
| return yamlblock, "\n".join(full_content[sep_idx + 1 :]) |
|
|
| return None, "\n".join(full_content) |
|
|
|
|
| class MetadataConfigs(dict[str, dict[str, Any]]): |
| """Should be in format {config_name: {**config_params}}.""" |
|
|
| FIELD_NAME: ClassVar[str] = METADATA_CONFIGS_FIELD |
|
|
| @staticmethod |
| def _raise_if_data_files_field_not_valid(metadata_config: dict): |
| yaml_data_files = metadata_config.get("data_files") |
| if yaml_data_files is not None: |
| yaml_error_message = textwrap.dedent( |
| f""" |
| Expected data_files in YAML to be either a string or a list of strings |
| or a list of dicts with two keys: 'split' and 'path', but got {yaml_data_files} |
| Examples of data_files in YAML: |
| |
| data_files: data.csv |
| |
| data_files: data/*.png |
| |
| data_files: |
| - part0/* |
| - part1/* |
| |
| data_files: |
| - split: train |
| path: train/* |
| - split: test |
| path: test/* |
| |
| data_files: |
| - split: train |
| path: |
| - train/part1/* |
| - train/part2/* |
| - split: test |
| path: test/* |
| |
| PS: some symbols like dashes '-' are not allowed in split names |
| """ |
| ) |
| if not isinstance(yaml_data_files, (list, str)): |
| raise ValueError(yaml_error_message) |
| if isinstance(yaml_data_files, list): |
| for yaml_data_files_item in yaml_data_files: |
| if ( |
| not isinstance(yaml_data_files_item, (str, dict)) |
| or isinstance(yaml_data_files_item, dict) |
| and not ( |
| len(yaml_data_files_item) == 2 |
| and "split" in yaml_data_files_item |
| and re.match(_split_re, yaml_data_files_item["split"]) |
| and isinstance(yaml_data_files_item.get("path"), (str, list)) |
| ) |
| ): |
| raise ValueError(yaml_error_message) |
|
|
| @classmethod |
| def _from_exported_parquet_files_and_dataset_infos( |
| cls, |
| parquet_commit_hash: str, |
| exported_parquet_files: list[dict[str, Any]], |
| dataset_infos: DatasetInfosDict, |
| ) -> "MetadataConfigs": |
| metadata_configs = { |
| config_name: { |
| "data_files": [ |
| { |
| "split": split_name, |
| "path": [ |
| parquet_file["url"].replace("refs%2Fconvert%2Fparquet", parquet_commit_hash) |
| for parquet_file in parquet_files_for_split |
| ], |
| } |
| for split_name, parquet_files_for_split in groupby(parquet_files_for_config, itemgetter("split")) |
| ], |
| "version": str(dataset_infos.get(config_name, DatasetInfo()).version or "0.0.0"), |
| } |
| for config_name, parquet_files_for_config in groupby(exported_parquet_files, itemgetter("config")) |
| } |
| if dataset_infos: |
| |
| metadata_configs = { |
| config_name: { |
| "data_files": [ |
| data_file |
| for split_name in dataset_info.splits |
| for data_file in metadata_configs[config_name]["data_files"] |
| if data_file["split"] == split_name |
| ], |
| "version": metadata_configs[config_name]["version"], |
| } |
| for config_name, dataset_info in dataset_infos.items() |
| } |
| return cls(metadata_configs) |
|
|
| @classmethod |
| def from_dataset_card_data(cls, dataset_card_data: DatasetCardData) -> "MetadataConfigs": |
| if dataset_card_data.get(cls.FIELD_NAME): |
| metadata_configs = dataset_card_data[cls.FIELD_NAME] |
| if not isinstance(metadata_configs, list): |
| raise ValueError(f"Expected {cls.FIELD_NAME} to be a list, but got '{metadata_configs}'") |
| for metadata_config in metadata_configs: |
| if "config_name" not in metadata_config: |
| raise ValueError( |
| f"Each config must include `config_name` field with a string name of a config, " |
| f"but got {metadata_config}. " |
| ) |
| cls._raise_if_data_files_field_not_valid(metadata_config) |
| return cls( |
| { |
| config.pop("config_name"): { |
| param: value if param != "features" else Features._from_yaml_list(value) |
| for param, value in config.items() |
| } |
| for metadata_config in metadata_configs |
| if (config := metadata_config.copy()) |
| } |
| ) |
| return cls() |
|
|
| def to_dataset_card_data(self, dataset_card_data: DatasetCardData) -> None: |
| if self: |
| for metadata_config in self.values(): |
| self._raise_if_data_files_field_not_valid(metadata_config) |
| current_metadata_configs = self.from_dataset_card_data(dataset_card_data) |
| total_metadata_configs = dict(sorted({**current_metadata_configs, **self}.items())) |
| for config_name, config_metadata in total_metadata_configs.items(): |
| config_metadata.pop("config_name", None) |
| dataset_card_data[self.FIELD_NAME] = [ |
| {"config_name": config_name, **config_metadata} |
| for config_name, config_metadata in total_metadata_configs.items() |
| ] |
|
|
| def get_default_config_name(self) -> Optional[str]: |
| default_config_name = None |
| for config_name, metadata_config in self.items(): |
| if len(self) == 1 or config_name == "default" or metadata_config.get("default"): |
| if default_config_name is None: |
| default_config_name = config_name |
| else: |
| raise ValueError( |
| f"Dataset has several default configs: '{default_config_name}' and '{config_name}'." |
| ) |
| return default_config_name |
|
|
|
|
| |
| |
| |
| known_task_ids = { |
| "image-classification": [], |
| "translation": [], |
| "image-segmentation": [], |
| "fill-mask": [], |
| "automatic-speech-recognition": [], |
| "token-classification": [], |
| "sentence-similarity": [], |
| "audio-classification": [], |
| "question-answering": [], |
| "summarization": [], |
| "zero-shot-classification": [], |
| "table-to-text": [], |
| "feature-extraction": [], |
| "other": [], |
| "multiple-choice": [], |
| "text-classification": [], |
| "text-to-image": [], |
| "text2text-generation": [], |
| "zero-shot-image-classification": [], |
| "tabular-classification": [], |
| "tabular-regression": [], |
| "image-to-image": [], |
| "tabular-to-text": [], |
| "unconditional-image-generation": [], |
| "text-retrieval": [], |
| "text-to-speech": [], |
| "object-detection": [], |
| "audio-to-audio": [], |
| "text-generation": [], |
| "conversational": [], |
| "table-question-answering": [], |
| "visual-question-answering": [], |
| "image-to-text": [], |
| "reinforcement-learning": [], |
| "voice-activity-detection": [], |
| "time-series-forecasting": [], |
| "document-question-answering": [], |
| } |
|
|