Spaces:
Running
Running
| """Search tools for finding datasets on Hugging Face Hub.""" | |
| from typing import Optional, List | |
| from utils.hf_client import get_client | |
| from utils.formatting import format_dataset_list | |
| from huggingface_hub import list_datasets | |
| def search_datasets( | |
| query: str, | |
| limit: int = 10, | |
| filter_task: Optional[str] = None, | |
| sort_by: str = "downloads" | |
| ) -> str: | |
| """ | |
| Search for datasets on Hugging Face Hub by keyword, task, or domain. | |
| Use this tool to find datasets matching specific criteria. You can search by | |
| name, description, or filter by ML task category. | |
| Args: | |
| query: Search query string (e.g., "sentiment analysis", "image classification", "medical") | |
| limit: Maximum number of results to return (1-50, default: 10) | |
| filter_task: Optional ML task filter (e.g., "text-classification", "image-classification", | |
| "question-answering", "summarization", "translation") | |
| sort_by: Sort results by "downloads", "likes", or "created" (default: "downloads") | |
| Returns: | |
| Formatted list of matching datasets with their IDs, download counts, and tags. | |
| Example queries: | |
| - "sentiment" - Find sentiment analysis datasets | |
| - "medical imaging" - Find medical image datasets | |
| - "french translation" - Find French translation datasets | |
| """ | |
| limit = max(1, min(50, limit)) # Clamp between 1 and 50 | |
| client = get_client() | |
| datasets = client.search_datasets( | |
| query=query, | |
| limit=limit, | |
| filter_task=filter_task, | |
| sort=sort_by | |
| ) | |
| return format_dataset_list(datasets) | |
| def search_by_columns( | |
| column_names: List[str], | |
| data_types: Optional[List[str]] = None, | |
| limit: int = 10 | |
| ) -> str: | |
| """ | |
| Find datasets that contain specific column names or data types. | |
| Use this tool when you need datasets with particular features or columns, | |
| such as finding all datasets with a "label" column or "image" type. | |
| Args: | |
| column_names: List of column names to search for (e.g., ["text", "label"], ["image", "caption"]) | |
| data_types: Optional list of data types to filter by (e.g., ["Image", "Audio", "ClassLabel"]) | |
| limit: Maximum number of results to return (1-30, default: 10) | |
| Returns: | |
| Formatted list of datasets containing the specified columns/types. | |
| Common column patterns: | |
| - Text classification: ["text", "label"] | |
| - Image classification: ["image", "label"] | |
| - Question answering: ["question", "answer", "context"] | |
| - Translation: ["source", "target"] or ["en", "fr"] | |
| """ | |
| limit = max(1, min(30, limit)) | |
| # Build search query from column names | |
| search_query = " ".join(column_names) | |
| # Search for datasets | |
| client = get_client() | |
| all_datasets = client.search_datasets(query=search_query, limit=limit * 3) | |
| # Filter by actually checking schemas (best effort) | |
| matching_datasets = [] | |
| for ds in all_datasets: | |
| if len(matching_datasets) >= limit: | |
| break | |
| try: | |
| schema = client.get_schema(ds['id']) | |
| if "error" not in schema: | |
| columns = schema.get('columns', []) | |
| columns_lower = [c.lower() for c in columns] | |
| # Check if any requested columns match | |
| matches = sum(1 for col in column_names if col.lower() in columns_lower) | |
| if matches > 0: | |
| ds['matched_columns'] = matches | |
| ds['total_columns'] = len(columns) | |
| matching_datasets.append(ds) | |
| except Exception: | |
| continue | |
| if not matching_datasets: | |
| return f"No datasets found with columns matching: {', '.join(column_names)}\n\nTry broader search terms or check column naming conventions." | |
| # Format results | |
| lines = [f"## Datasets with columns: {', '.join(column_names)}\n"] | |
| for i, ds in enumerate(matching_datasets, 1): | |
| lines.append(f"### {i}. {ds['id']}") | |
| lines.append(f"- Matched columns: {ds.get('matched_columns', 'N/A')}/{len(column_names)}") | |
| lines.append(f"- Total columns: {ds.get('total_columns', 'N/A')}") | |
| lines.append(f"- Downloads: {ds.get('downloads', 'N/A'):,}") | |
| lines.append("") | |
| return "\n".join(lines) | |