"""Search tools for finding datasets on Hugging Face Hub.""" from typing import Optional, List from utils.hf_client import get_client from utils.formatting import format_dataset_list from huggingface_hub import list_datasets def search_datasets( query: str, limit: int = 10, filter_task: Optional[str] = None, sort_by: str = "downloads" ) -> str: """ Search for datasets on Hugging Face Hub by keyword, task, or domain. Use this tool to find datasets matching specific criteria. You can search by name, description, or filter by ML task category. Args: query: Search query string (e.g., "sentiment analysis", "image classification", "medical") limit: Maximum number of results to return (1-50, default: 10) filter_task: Optional ML task filter (e.g., "text-classification", "image-classification", "question-answering", "summarization", "translation") sort_by: Sort results by "downloads", "likes", or "created" (default: "downloads") Returns: Formatted list of matching datasets with their IDs, download counts, and tags. Example queries: - "sentiment" - Find sentiment analysis datasets - "medical imaging" - Find medical image datasets - "french translation" - Find French translation datasets """ limit = max(1, min(50, limit)) # Clamp between 1 and 50 client = get_client() datasets = client.search_datasets( query=query, limit=limit, filter_task=filter_task, sort=sort_by ) return format_dataset_list(datasets) def search_by_columns( column_names: List[str], data_types: Optional[List[str]] = None, limit: int = 10 ) -> str: """ Find datasets that contain specific column names or data types. Use this tool when you need datasets with particular features or columns, such as finding all datasets with a "label" column or "image" type. Args: column_names: List of column names to search for (e.g., ["text", "label"], ["image", "caption"]) data_types: Optional list of data types to filter by (e.g., ["Image", "Audio", "ClassLabel"]) limit: Maximum number of results to return (1-30, default: 10) Returns: Formatted list of datasets containing the specified columns/types. Common column patterns: - Text classification: ["text", "label"] - Image classification: ["image", "label"] - Question answering: ["question", "answer", "context"] - Translation: ["source", "target"] or ["en", "fr"] """ limit = max(1, min(30, limit)) # Build search query from column names search_query = " ".join(column_names) # Search for datasets client = get_client() all_datasets = client.search_datasets(query=search_query, limit=limit * 3) # Filter by actually checking schemas (best effort) matching_datasets = [] for ds in all_datasets: if len(matching_datasets) >= limit: break try: schema = client.get_schema(ds['id']) if "error" not in schema: columns = schema.get('columns', []) columns_lower = [c.lower() for c in columns] # Check if any requested columns match matches = sum(1 for col in column_names if col.lower() in columns_lower) if matches > 0: ds['matched_columns'] = matches ds['total_columns'] = len(columns) matching_datasets.append(ds) except Exception: continue if not matching_datasets: return f"No datasets found with columns matching: {', '.join(column_names)}\n\nTry broader search terms or check column naming conventions." # Format results lines = [f"## Datasets with columns: {', '.join(column_names)}\n"] for i, ds in enumerate(matching_datasets, 1): lines.append(f"### {i}. {ds['id']}") lines.append(f"- Matched columns: {ds.get('matched_columns', 'N/A')}/{len(column_names)}") lines.append(f"- Total columns: {ds.get('total_columns', 'N/A')}") lines.append(f"- Downloads: {ds.get('downloads', 'N/A'):,}") lines.append("") return "\n".join(lines)