File size: 4,291 Bytes
b67578f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Search tools for finding datasets on Hugging Face Hub."""

from typing import Optional, List
from utils.hf_client import get_client
from utils.formatting import format_dataset_list
from huggingface_hub import list_datasets


def search_datasets(
    query: str,
    limit: int = 10,
    filter_task: Optional[str] = None,
    sort_by: str = "downloads"
) -> str:
    """
    Search for datasets on Hugging Face Hub by keyword, task, or domain.

    Use this tool to find datasets matching specific criteria. You can search by
    name, description, or filter by ML task category.

    Args:
        query: Search query string (e.g., "sentiment analysis", "image classification", "medical")
        limit: Maximum number of results to return (1-50, default: 10)
        filter_task: Optional ML task filter (e.g., "text-classification", "image-classification",
                    "question-answering", "summarization", "translation")
        sort_by: Sort results by "downloads", "likes", or "created" (default: "downloads")

    Returns:
        Formatted list of matching datasets with their IDs, download counts, and tags.

    Example queries:
        - "sentiment" - Find sentiment analysis datasets
        - "medical imaging" - Find medical image datasets
        - "french translation" - Find French translation datasets
    """
    limit = max(1, min(50, limit))  # Clamp between 1 and 50

    client = get_client()
    datasets = client.search_datasets(
        query=query,
        limit=limit,
        filter_task=filter_task,
        sort=sort_by
    )

    return format_dataset_list(datasets)


def search_by_columns(
    column_names: List[str],
    data_types: Optional[List[str]] = None,
    limit: int = 10
) -> str:
    """
    Find datasets that contain specific column names or data types.

    Use this tool when you need datasets with particular features or columns,
    such as finding all datasets with a "label" column or "image" type.

    Args:
        column_names: List of column names to search for (e.g., ["text", "label"], ["image", "caption"])
        data_types: Optional list of data types to filter by (e.g., ["Image", "Audio", "ClassLabel"])
        limit: Maximum number of results to return (1-30, default: 10)

    Returns:
        Formatted list of datasets containing the specified columns/types.

    Common column patterns:
        - Text classification: ["text", "label"]
        - Image classification: ["image", "label"]
        - Question answering: ["question", "answer", "context"]
        - Translation: ["source", "target"] or ["en", "fr"]
    """
    limit = max(1, min(30, limit))

    # Build search query from column names
    search_query = " ".join(column_names)

    # Search for datasets
    client = get_client()
    all_datasets = client.search_datasets(query=search_query, limit=limit * 3)

    # Filter by actually checking schemas (best effort)
    matching_datasets = []
    for ds in all_datasets:
        if len(matching_datasets) >= limit:
            break

        try:
            schema = client.get_schema(ds['id'])
            if "error" not in schema:
                columns = schema.get('columns', [])
                columns_lower = [c.lower() for c in columns]

                # Check if any requested columns match
                matches = sum(1 for col in column_names if col.lower() in columns_lower)
                if matches > 0:
                    ds['matched_columns'] = matches
                    ds['total_columns'] = len(columns)
                    matching_datasets.append(ds)
        except Exception:
            continue

    if not matching_datasets:
        return f"No datasets found with columns matching: {', '.join(column_names)}\n\nTry broader search terms or check column naming conventions."

    # Format results
    lines = [f"## Datasets with columns: {', '.join(column_names)}\n"]
    for i, ds in enumerate(matching_datasets, 1):
        lines.append(f"### {i}. {ds['id']}")
        lines.append(f"- Matched columns: {ds.get('matched_columns', 'N/A')}/{len(column_names)}")
        lines.append(f"- Total columns: {ds.get('total_columns', 'N/A')}")
        lines.append(f"- Downloads: {ds.get('downloads', 'N/A'):,}")
        lines.append("")

    return "\n".join(lines)