File size: 5,391 Bytes
b67578f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Hugging Face API client wrapper for dataset operations."""

import os
from typing import Optional, List, Dict, Any
from huggingface_hub import HfApi, list_datasets, DatasetCard
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
from dotenv import load_dotenv

load_dotenv()


class HFDatasetClient:
    """Client for interacting with Hugging Face datasets."""

    def __init__(self, token: Optional[str] = None):
        self.token = token or os.getenv("HF_TOKEN")
        self.api = HfApi(token=self.token)

    def search_datasets(
        self,
        query: str,
        limit: int = 10,
        filter_task: Optional[str] = None,
        sort: str = "downloads"
    ) -> List[Dict[str, Any]]:
        """Search for datasets on Hugging Face Hub."""
        datasets = list(list_datasets(
            search=query,
            limit=limit,
            sort=sort,
            task_categories=filter_task if filter_task else None
        ))

        return [
            {
                "id": ds.id,
                "downloads": ds.downloads,
                "likes": ds.likes,
                "tags": ds.tags[:5] if ds.tags else [],
                "created_at": str(ds.created_at) if ds.created_at else None,
            }
            for ds in datasets
        ]

    def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
        """Get detailed information about a dataset."""
        info = self.api.dataset_info(dataset_id)

        # Try to get the dataset card
        card_content = None
        try:
            card = DatasetCard.load(dataset_id)
            card_content = card.text[:2000] if card.text else None  # Limit card size
        except Exception:
            pass

        return {
            "id": info.id,
            "author": info.author,
            "downloads": info.downloads,
            "likes": info.likes,
            "tags": info.tags,
            "license": getattr(info, 'license', None),
            "created_at": str(info.created_at) if info.created_at else None,
            "last_modified": str(info.last_modified) if info.last_modified else None,
            "card_summary": card_content,
        }

    def get_configs_and_splits(self, dataset_id: str) -> Dict[str, List[str]]:
        """Get available configs and splits for a dataset."""
        try:
            configs = get_dataset_config_names(dataset_id, trust_remote_code=True)
        except Exception:
            configs = ["default"]

        result = {}
        for config in configs[:3]:  # Limit to first 3 configs
            try:
                splits = get_dataset_split_names(dataset_id, config, trust_remote_code=True)
                result[config] = splits
            except Exception:
                result[config] = ["train"]

        return result

    def load_sample(
        self,
        dataset_id: str,
        config: Optional[str] = None,
        split: str = "train",
        n_rows: int = 5,
        streaming: bool = True
    ) -> List[Dict[str, Any]]:
        """Load a sample of rows from a dataset."""
        try:
            ds = load_dataset(
                dataset_id,
                config,
                split=split,
                streaming=streaming,
                trust_remote_code=True
            )

            if streaming:
                samples = []
                for i, row in enumerate(ds):
                    if i >= n_rows:
                        break
                    # Convert row to serializable format
                    samples.append(self._serialize_row(row))
                return samples
            else:
                return [self._serialize_row(row) for row in ds.select(range(min(n_rows, len(ds))))]
        except Exception as e:
            return [{"error": str(e)}]

    def get_schema(self, dataset_id: str, config: Optional[str] = None, split: str = "train") -> Dict[str, Any]:
        """Get the schema/features of a dataset."""
        try:
            ds = load_dataset(
                dataset_id,
                config,
                split=split,
                streaming=True,
                trust_remote_code=True
            )

            features = ds.features
            schema = {}
            for name, feature in features.items():
                schema[name] = str(feature)

            return {
                "columns": list(features.keys()),
                "features": schema,
                "num_columns": len(features)
            }
        except Exception as e:
            return {"error": str(e)}

    def _serialize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
        """Convert a row to JSON-serializable format."""
        result = {}
        for key, value in row.items():
            if hasattr(value, 'tolist'):  # numpy array
                result[key] = value.tolist()
            elif hasattr(value, '__dict__'):  # PIL Image or similar
                result[key] = f"<{type(value).__name__}>"
            elif isinstance(value, bytes):
                result[key] = f"<bytes: {len(value)} bytes>"
            else:
                result[key] = value
        return result


# Singleton instance
_client = None

def get_client() -> HFDatasetClient:
    """Get or create the HF client singleton."""
    global _client
    if _client is None:
        _client = HFDatasetClient()
    return _client