KhalilGuetari's picture
change order of args for gr headers
c2830c1
"""
Dataset sampling tool for retrieving samples from HuggingFace datasets.
This module provides tools for efficiently sampling data from HuggingFace datasets
with support for different splits, configurable sample sizes, and streaming for large datasets.
"""
import logging
import gradio as gr
from typing import Optional, Dict, Any
from hf_eda_mcp.config import get_config
from hf_eda_mcp.services.dataset_service import get_dataset_service, DatasetServiceError
from hf_eda_mcp.integrations.hf_client import DatasetNotFoundError, AuthenticationError, NetworkError
from hf_eda_mcp.validation import (
validate_dataset_id,
validate_config_name,
validate_split_name,
validate_sample_size,
ValidationError,
format_validation_error,
)
from hf_eda_mcp.error_handling import format_error_response, log_error_with_context
logger = logging.getLogger(__name__)
# Default constants (can be overridden by config)
DEFAULT_SAMPLE_SIZE = 10
VALID_SPLITS = {"train", "validation", "test", "dev", "val"}
def get_dataset_sample(
dataset_id: str,
split: str = "train",
num_samples: int = DEFAULT_SAMPLE_SIZE,
config_name: Optional[str] = None,
streaming: bool = True,
hf_api_token: gr.Header = "",
) -> Dict[str, Any]:
"""
Retrieve a sample of rows from a HuggingFace dataset.
This function efficiently samples data from datasets with support for different
splits and configurable sample sizes. It uses streaming by default for large
datasets to minimize memory usage and loading time.
Args:
dataset_id: HuggingFace dataset identifier (e.g., 'imdb', 'squad', 'glue')
split: Dataset split to sample from (default: 'train')
num_samples: Number of samples to retrieve (default: 10, max: 10000)
config_name: Optional configuration name for multi-config datasets
streaming: Whether to use streaming mode for efficient loading (default: True)
hf_api_token: Header parsed by Gradio when hf_api_token is provided in MCP configuration headers
Returns:
Dictionary containing sampled data and metadata:
- dataset_id: Original dataset identifier
- config_name: Configuration name used (if any)
- split: Split name sampled from
- num_samples: Actual number of samples returned
- requested_samples: Number of samples originally requested
- data: List of sample dictionaries
- schema: Dictionary describing the dataset features/columns
- sample_info: Additional information about the sampling process
Raises:
ValueError: If inputs are invalid (empty dataset_id, invalid split, etc.)
DatasetNotFoundError: If dataset or split doesn't exist
AuthenticationError: If dataset is private and authentication fails
DatasetServiceError: If sampling fails for other reasons
Example:
>>> # Basic sampling
>>> sample = get_dataset_sample("imdb", split="train", num_samples=5)
>>> print(f"Got {sample['num_samples']} samples from {sample['dataset_id']}")
>>> for i, row in enumerate(sample['data']):
... print(f"Sample {i+1}: {list(row.keys())}")
>>> # Multi-config dataset sampling
>>> sample = get_dataset_sample("glue", split="validation",
... num_samples=3, config_name="cola")
>>> print(f"Schema: {sample['schema']}")
"""
# Handle empty strings from Gradio (convert to None)
if config_name == "":
config_name = None
# Input validation using centralized validation
try:
dataset_id = validate_dataset_id(dataset_id)
config_name = validate_config_name(config_name)
split = validate_split_name(split)
num_samples = validate_sample_size(num_samples, "num_samples")
except ValidationError as e:
logger.error(f"Validation error: {format_validation_error(e)}")
raise ValueError(format_validation_error(e))
context = {
"dataset_id": dataset_id,
"split": split,
"num_samples": num_samples,
"config_name": config_name,
"operation": "get_dataset_sample"
}
logger.info(
f"Sampling {num_samples} rows from dataset: {dataset_id}, "
f"split: {split}" + (f", config: {config_name}" if config_name else "")
)
try:
# Get dataset service and load sample
service = get_dataset_service(hf_api_token=hf_api_token)
sample_data = service.load_dataset_sample(
dataset_id=dataset_id,
split=split,
num_samples=num_samples,
config_name=config_name,
streaming=streaming,
)
# Enhance the response with additional metadata
config = get_config()
sample_data["sample_info"] = {
"streaming_used": streaming,
"sampling_strategy": "sequential_head", # We take first N samples
"max_sample_size": config.max_sample_size,
"truncated": sample_data["num_samples"] < sample_data["requested_samples"],
}
# Add data preview information
if sample_data["data"]:
first_sample = sample_data["data"][0]
sample_data["sample_info"]["preview"] = {
"columns": list(first_sample.keys())
if isinstance(first_sample, dict)
else [],
"first_sample_types": {
k: type(v).__name__ for k, v in first_sample.items()
}
if isinstance(first_sample, dict)
else {},
}
# Add summary
sample_data["summary"] = _generate_sample_summary(sample_data)
logger.info(
f"Successfully sampled {sample_data['num_samples']} rows from {dataset_id}"
)
return sample_data
except DatasetNotFoundError as e:
log_error_with_context(e, context, level=logging.WARNING)
error_response = format_error_response(e, context)
logger.info(f"Dataset/split not found suggestions: {error_response.get('suggestions', [])}")
raise
except AuthenticationError as e:
log_error_with_context(e, context, level=logging.WARNING)
error_response = format_error_response(e, context)
logger.info(f"Authentication error guidance: {error_response.get('suggestions', [])}")
raise
except NetworkError as e:
log_error_with_context(e, context)
error_response = format_error_response(e, context)
logger.info(f"Network error guidance: {error_response.get('suggestions', [])}")
raise
except Exception as e:
log_error_with_context(e, context)
raise DatasetServiceError(f"Failed to sample dataset: {str(e)}") from e
# def get_dataset_sample_with_indices(
# dataset_id: str,
# indices: List[int],
# split: str = "train",
# config_name: Optional[str] = None,
# ) -> Dict[str, Any]:
# """
# Retrieve specific samples by their indices from a HuggingFace dataset.
# This function allows for targeted sampling by specifying exact row indices.
# Note: This requires loading the dataset in non-streaming mode.
# Args:
# dataset_id: HuggingFace dataset identifier
# indices: List of row indices to retrieve
# split: Dataset split to sample from (default: 'train')
# config_name: Optional configuration name for multi-config datasets
# Returns:
# Dictionary containing the requested samples and metadata
# Raises:
# ValueError: If inputs are invalid
# DatasetServiceError: If sampling fails
# """
# # Handle empty strings from Gradio (convert to None)
# if config_name == "":
# config_name = None
# # Input validation using centralized validation
# try:
# dataset_id = validate_dataset_id(dataset_id)
# config_name = validate_config_name(config_name)
# split = validate_split_name(split)
# indices = validate_indices(indices)
# except ValidationError as e:
# logger.error(f"Validation error: {format_validation_error(e)}")
# raise ValueError(format_validation_error(e))
# logger.info(f"Sampling {len(indices)} specific indices from dataset: {dataset_id}")
# try:
# from datasets import load_dataset
# # Load dataset without streaming to access by index
# dataset = load_dataset(
# dataset_id, name=config_name, split=split, streaming=False
# )
# # Validate indices are within bounds
# max_index = max(indices)
# if max_index >= len(dataset):
# raise ValueError(
# f"Index {max_index} is out of bounds for dataset with {len(dataset)} rows"
# )
# # Get samples by indices
# samples = [dataset[i] for i in indices]
# # Get dataset info for schema
# service = get_dataset_service(hf_api_token=hf_api_token)
# dataset_info = service.load_dataset_info(dataset_id, config_name)
# # Prepare response
# sample_data = {
# "dataset_id": dataset_id,
# "config_name": config_name,
# "split": split,
# "num_samples": len(samples),
# "requested_indices": indices,
# "data": samples,
# "schema": dataset_info.get("features", {}),
# "sample_info": {
# "sampling_strategy": "by_indices",
# "streaming_used": False,
# "indices_requested": len(indices),
# },
# }
# sample_data["summary"] = _generate_sample_summary(sample_data)
# return sample_data
# except Exception as e:
# logger.error(f"Failed to sample by indices from {dataset_id}: {str(e)}")
# raise DatasetServiceError(f"Failed to sample by indices: {str(e)}")
def _generate_sample_summary(sample_data: Dict[str, Any]) -> str:
"""Generate a human-readable summary of the sample data."""
summary_parts = []
# Basic info
summary_parts.append(f"Dataset: {sample_data.get('dataset_id', 'Unknown')}")
summary_parts.append(f"Split: {sample_data.get('split', 'Unknown')}")
if sample_data.get("config_name"):
summary_parts.append(f"Config: {sample_data['config_name']}")
# Sample info
num_samples = sample_data.get("num_samples", 0)
requested = sample_data.get("requested_samples", num_samples)
if num_samples == requested:
summary_parts.append(f"Samples: {num_samples}")
else:
summary_parts.append(f"Samples: {num_samples}/{requested} (truncated)")
# Schema info
schema = sample_data.get("schema", {})
if schema:
summary_parts.append(f"Columns: {len(schema)}")
# Sampling strategy
sample_info = sample_data.get("sample_info", {})
strategy = sample_info.get("sampling_strategy", "unknown")
if strategy == "by_indices":
summary_parts.append("Strategy: by indices")
elif strategy == "sequential_head":
summary_parts.append("Strategy: first N rows")
return " | ".join(summary_parts)