Spaces:
Runtime error
Runtime error
| from typing import Any, Dict, Iterator, List | |
| import requests | |
| from huggingface_hub import add_collection_item, create_collection | |
| from tqdm.auto import tqdm | |
| class DatasetSearchClient: | |
| def __init__( | |
| self, | |
| base_url: str = "https://librarian-bots-dataset-column-search-api.hf.space", | |
| ): | |
| self.base_url = base_url | |
| def search( | |
| self, columns: List[str], match_all: bool = False, page_size: int = 100 | |
| ) -> Iterator[Dict[str, Any]]: | |
| """ | |
| Search datasets using the provided API, automatically handling pagination. | |
| Args: | |
| columns (List[str]): List of column names to search for. | |
| match_all (bool, optional): If True, match all columns. If False, match any column. Defaults to False. | |
| page_size (int, optional): Number of results per page. Defaults to 100. | |
| Yields: | |
| Dict[str, Any]: Each dataset result from all pages. | |
| Raises: | |
| requests.RequestException: If there's an error with the HTTP request. | |
| ValueError: If the API returns an unexpected response format. | |
| """ | |
| page = 1 | |
| total_results = None | |
| while total_results is None or (page - 1) * page_size < total_results: | |
| params = { | |
| "columns": columns, | |
| "match_all": str(match_all).lower(), | |
| "page": page, | |
| "page_size": page_size, | |
| } | |
| try: | |
| response = requests.get(f"{self.base_url}/search", params=params) | |
| response.raise_for_status() | |
| data = response.json() | |
| if not {"total", "page", "page_size", "results"}.issubset(data.keys()): | |
| raise ValueError("Unexpected response format from the API") | |
| if total_results is None: | |
| total_results = data["total"] | |
| yield from data["results"] | |
| page += 1 | |
| except requests.RequestException as e: | |
| raise requests.RequestException( | |
| f"Error connecting to the API: {str(e)}" | |
| ) from e | |
| except ValueError as e: | |
| raise ValueError(f"Error processing API response: {str(e)}") from e | |
| # Create an instance of the client | |
| client = DatasetSearchClient() | |
| def update_collection_for_dataset( | |
| collection_name: str = None, | |
| dataset_columns: List[str] = None, | |
| collection_description: str = None, | |
| collection_namespace: str = None, | |
| ): | |
| if not collection_name: | |
| collection = create_collection( | |
| collection_name, exists_ok=True, description=collection_description | |
| ) | |
| else: | |
| collection = create_collection( | |
| collection_name, | |
| exists_ok=True, | |
| description=collection_description, | |
| namespace=collection_namespace, | |
| ) | |
| results = list( | |
| tqdm( | |
| client.search(dataset_columns, match_all=True), | |
| desc="Searching datasets...", | |
| leave=False, | |
| ) | |
| ) | |
| for result in tqdm(results, desc="Adding datasets to collection...", leave=False): | |
| try: | |
| add_collection_item( | |
| collection.slug, result["hub_id"], item_type="dataset", exists_ok=True | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Error adding dataset {result['hub_id']} to collection {collection_name}: {str(e)}" | |
| ) | |
| return f"https://huggingface.co/collections/{collection.slug}" | |
| collections = [ | |
| { | |
| "dataset_columns": ["chosen", "rejected", "prompt"], | |
| "collection_description": "Datasets suitable for DPO based on having 'chosen', 'rejected', and 'prompt' columns. Created using librarian-bots/dataset-column-search-api", | |
| "collection_name": "Direct Preference Optimization Datasets", | |
| }, | |
| { | |
| "dataset_columns": ["image", "chosen", "rejected"], | |
| "collection_description": "Datasets suitable for Image Preference Optimization based on having 'image','chosen', and 'rejected' columns", | |
| "collection_name": "Image Preference Optimization Datasets", | |
| }, | |
| { | |
| "collection_name": "Alpaca Style Datasets", | |
| "dataset_columns": ["instruction", "input", "output"], | |
| "collection_description": "Datasets which follow the Alpaca Style format based on having 'instruction', 'input', and 'output' columns", | |
| }, | |
| ] | |
| # results = [ | |
| # update_collection_for_dataset(**collection, collection_namespace="librarian-bots") | |
| # for collection in collections | |
| # ] | |
| # print(results) | |