File size: 11,231 Bytes
11df203
2762e2a
11df203
2762e2a
 
11df203
 
2762e2a
2b910cc
 
ab96cfe
2b910cc
2a623ac
 
 
 
 
 
 
 
 
 
2762e2a
 
 
ab96cfe
2762e2a
ab96cfe
2762e2a
 
 
 
 
 
 
ab96cfe
c2830c1
2762e2a
 
 
ab96cfe
2762e2a
 
 
ab96cfe
2762e2a
 
 
 
 
 
c2830c1
ab96cfe
2762e2a
 
 
 
 
 
 
 
 
 
ab96cfe
2762e2a
 
 
 
 
ab96cfe
2762e2a
 
 
 
 
 
ab96cfe
2762e2a
ab96cfe
2762e2a
 
 
aefe0b6
 
 
 
2a623ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab96cfe
 
 
 
 
2762e2a
 
2b910cc
2762e2a
 
 
 
 
ab96cfe
2762e2a
ab96cfe
2762e2a
ab96cfe
 
 
 
 
 
2762e2a
ab96cfe
2762e2a
ab96cfe
 
 
 
 
 
 
2762e2a
ab96cfe
 
 
2762e2a
ab96cfe
2762e2a
ab96cfe
 
 
 
 
2762e2a
ab96cfe
2a623ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2762e2a
2a623ac
2762e2a
2a623ac
 
2762e2a
 
2b910cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aefe0b6
2b910cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2762e2a
 
 
 
 
ab96cfe
2762e2a
 
 
ab96cfe
 
2762e2a
ab96cfe
2762e2a
ab96cfe
 
 
2762e2a
 
 
 
ab96cfe
2762e2a
ab96cfe
2762e2a
 
ab96cfe
2762e2a
ab96cfe
 
 
2762e2a
ab96cfe
2762e2a
ab96cfe
2762e2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
Dataset sampling tool for retrieving samples from HuggingFace datasets.

This module provides tools for efficiently sampling data from HuggingFace datasets
with support for different splits, configurable sample sizes, and streaming for large datasets.
"""

import logging
import gradio as gr
from typing import Optional, Dict, Any
from hf_eda_mcp.config import get_config
from hf_eda_mcp.services.dataset_service import get_dataset_service, DatasetServiceError
from hf_eda_mcp.integrations.hf_client import DatasetNotFoundError, AuthenticationError, NetworkError
from hf_eda_mcp.validation import (
    validate_dataset_id,
    validate_config_name,
    validate_split_name,
    validate_sample_size,
    ValidationError,
    format_validation_error,
)
from hf_eda_mcp.error_handling import format_error_response, log_error_with_context

logger = logging.getLogger(__name__)

# Default constants (can be overridden by config)
DEFAULT_SAMPLE_SIZE = 10
VALID_SPLITS = {"train", "validation", "test", "dev", "val"}


def get_dataset_sample(
    dataset_id: str,
    split: str = "train",
    num_samples: int = DEFAULT_SAMPLE_SIZE,
    config_name: Optional[str] = None,
    streaming: bool = True,
    hf_api_token: gr.Header = "",
) -> Dict[str, Any]:
    """
    Retrieve a sample of rows from a HuggingFace dataset.

    This function efficiently samples data from datasets with support for different
    splits and configurable sample sizes. It uses streaming by default for large
    datasets to minimize memory usage and loading time.

    Args:
        dataset_id: HuggingFace dataset identifier (e.g., 'imdb', 'squad', 'glue')
        split: Dataset split to sample from (default: 'train')
        num_samples: Number of samples to retrieve (default: 10, max: 10000)
        config_name: Optional configuration name for multi-config datasets
        streaming: Whether to use streaming mode for efficient loading (default: True)
        hf_api_token: Header parsed by Gradio when hf_api_token is provided in MCP configuration headers

    Returns:
        Dictionary containing sampled data and metadata:
        - dataset_id: Original dataset identifier
        - config_name: Configuration name used (if any)
        - split: Split name sampled from
        - num_samples: Actual number of samples returned
        - requested_samples: Number of samples originally requested
        - data: List of sample dictionaries
        - schema: Dictionary describing the dataset features/columns
        - sample_info: Additional information about the sampling process

    Raises:
        ValueError: If inputs are invalid (empty dataset_id, invalid split, etc.)
        DatasetNotFoundError: If dataset or split doesn't exist
        AuthenticationError: If dataset is private and authentication fails
        DatasetServiceError: If sampling fails for other reasons

    Example:
        >>> # Basic sampling
        >>> sample = get_dataset_sample("imdb", split="train", num_samples=5)
        >>> print(f"Got {sample['num_samples']} samples from {sample['dataset_id']}")
        >>> for i, row in enumerate(sample['data']):
        ...     print(f"Sample {i+1}: {list(row.keys())}")

        >>> # Multi-config dataset sampling
        >>> sample = get_dataset_sample("glue", split="validation",
        ...                           num_samples=3, config_name="cola")
        >>> print(f"Schema: {sample['schema']}")
    """
    # Handle empty strings from Gradio (convert to None)
    if config_name == "":
        config_name = None
    
    # Input validation using centralized validation
    try:
        dataset_id = validate_dataset_id(dataset_id)
        config_name = validate_config_name(config_name)
        split = validate_split_name(split)
        num_samples = validate_sample_size(num_samples, "num_samples")
    except ValidationError as e:
        logger.error(f"Validation error: {format_validation_error(e)}")
        raise ValueError(format_validation_error(e))

    context = {
        "dataset_id": dataset_id,
        "split": split,
        "num_samples": num_samples,
        "config_name": config_name,
        "operation": "get_dataset_sample"
    }
    
    logger.info(
        f"Sampling {num_samples} rows from dataset: {dataset_id}, "
        f"split: {split}" + (f", config: {config_name}" if config_name else "")
    )

    try:
        # Get dataset service and load sample
        service = get_dataset_service(hf_api_token=hf_api_token)
        sample_data = service.load_dataset_sample(
            dataset_id=dataset_id,
            split=split,
            num_samples=num_samples,
            config_name=config_name,
            streaming=streaming,
        )

        # Enhance the response with additional metadata
        config = get_config()
        sample_data["sample_info"] = {
            "streaming_used": streaming,
            "sampling_strategy": "sequential_head",  # We take first N samples
            "max_sample_size": config.max_sample_size,
            "truncated": sample_data["num_samples"] < sample_data["requested_samples"],
        }

        # Add data preview information
        if sample_data["data"]:
            first_sample = sample_data["data"][0]
            sample_data["sample_info"]["preview"] = {
                "columns": list(first_sample.keys())
                if isinstance(first_sample, dict)
                else [],
                "first_sample_types": {
                    k: type(v).__name__ for k, v in first_sample.items()
                }
                if isinstance(first_sample, dict)
                else {},
            }

        # Add summary
        sample_data["summary"] = _generate_sample_summary(sample_data)

        logger.info(
            f"Successfully sampled {sample_data['num_samples']} rows from {dataset_id}"
        )
        return sample_data

    except DatasetNotFoundError as e:
        log_error_with_context(e, context, level=logging.WARNING)
        error_response = format_error_response(e, context)
        logger.info(f"Dataset/split not found suggestions: {error_response.get('suggestions', [])}")
        raise
        
    except AuthenticationError as e:
        log_error_with_context(e, context, level=logging.WARNING)
        error_response = format_error_response(e, context)
        logger.info(f"Authentication error guidance: {error_response.get('suggestions', [])}")
        raise
        
    except NetworkError as e:
        log_error_with_context(e, context)
        error_response = format_error_response(e, context)
        logger.info(f"Network error guidance: {error_response.get('suggestions', [])}")
        raise
        
    except Exception as e:
        log_error_with_context(e, context)
        raise DatasetServiceError(f"Failed to sample dataset: {str(e)}") from e


# def get_dataset_sample_with_indices(
#     dataset_id: str,
#     indices: List[int],
#     split: str = "train",
#     config_name: Optional[str] = None,
# ) -> Dict[str, Any]:
#     """
#     Retrieve specific samples by their indices from a HuggingFace dataset.

#     This function allows for targeted sampling by specifying exact row indices.
#     Note: This requires loading the dataset in non-streaming mode.

#     Args:
#         dataset_id: HuggingFace dataset identifier
#         indices: List of row indices to retrieve
#         split: Dataset split to sample from (default: 'train')
#         config_name: Optional configuration name for multi-config datasets

#     Returns:
#         Dictionary containing the requested samples and metadata

#     Raises:
#         ValueError: If inputs are invalid
#         DatasetServiceError: If sampling fails
#     """
#     # Handle empty strings from Gradio (convert to None)
#     if config_name == "":
#         config_name = None
    
#     # Input validation using centralized validation
#     try:
#         dataset_id = validate_dataset_id(dataset_id)
#         config_name = validate_config_name(config_name)
#         split = validate_split_name(split)
#         indices = validate_indices(indices)
#     except ValidationError as e:
#         logger.error(f"Validation error: {format_validation_error(e)}")
#         raise ValueError(format_validation_error(e))

#     logger.info(f"Sampling {len(indices)} specific indices from dataset: {dataset_id}")

#     try:
#         from datasets import load_dataset

#         # Load dataset without streaming to access by index
#         dataset = load_dataset(
#             dataset_id, name=config_name, split=split, streaming=False
#         )

#         # Validate indices are within bounds
#         max_index = max(indices)
#         if max_index >= len(dataset):
#             raise ValueError(
#                 f"Index {max_index} is out of bounds for dataset with {len(dataset)} rows"
#             )

#         # Get samples by indices
#         samples = [dataset[i] for i in indices]

#         # Get dataset info for schema
#         service = get_dataset_service(hf_api_token=hf_api_token)
#         dataset_info = service.load_dataset_info(dataset_id, config_name)

#         # Prepare response
#         sample_data = {
#             "dataset_id": dataset_id,
#             "config_name": config_name,
#             "split": split,
#             "num_samples": len(samples),
#             "requested_indices": indices,
#             "data": samples,
#             "schema": dataset_info.get("features", {}),
#             "sample_info": {
#                 "sampling_strategy": "by_indices",
#                 "streaming_used": False,
#                 "indices_requested": len(indices),
#             },
#         }

#         sample_data["summary"] = _generate_sample_summary(sample_data)

#         return sample_data

#     except Exception as e:
#         logger.error(f"Failed to sample by indices from {dataset_id}: {str(e)}")
#         raise DatasetServiceError(f"Failed to sample by indices: {str(e)}")


def _generate_sample_summary(sample_data: Dict[str, Any]) -> str:
    """Generate a human-readable summary of the sample data."""
    summary_parts = []

    # Basic info
    summary_parts.append(f"Dataset: {sample_data.get('dataset_id', 'Unknown')}")
    summary_parts.append(f"Split: {sample_data.get('split', 'Unknown')}")

    if sample_data.get("config_name"):
        summary_parts.append(f"Config: {sample_data['config_name']}")

    # Sample info
    num_samples = sample_data.get("num_samples", 0)
    requested = sample_data.get("requested_samples", num_samples)

    if num_samples == requested:
        summary_parts.append(f"Samples: {num_samples}")
    else:
        summary_parts.append(f"Samples: {num_samples}/{requested} (truncated)")

    # Schema info
    schema = sample_data.get("schema", {})
    if schema:
        summary_parts.append(f"Columns: {len(schema)}")

    # Sampling strategy
    sample_info = sample_data.get("sample_info", {})
    strategy = sample_info.get("sampling_strategy", "unknown")
    if strategy == "by_indices":
        summary_parts.append("Strategy: by indices")
    elif strategy == "sequential_head":
        summary_parts.append("Strategy: first N rows")

    return " | ".join(summary_parts)