""" Dataset sampling tool for retrieving samples from HuggingFace datasets. This module provides tools for efficiently sampling data from HuggingFace datasets with support for different splits, configurable sample sizes, and streaming for large datasets. """ import logging import gradio as gr from typing import Optional, Dict, Any from hf_eda_mcp.config import get_config from hf_eda_mcp.services.dataset_service import get_dataset_service, DatasetServiceError from hf_eda_mcp.integrations.hf_client import DatasetNotFoundError, AuthenticationError, NetworkError from hf_eda_mcp.validation import ( validate_dataset_id, validate_config_name, validate_split_name, validate_sample_size, ValidationError, format_validation_error, ) from hf_eda_mcp.error_handling import format_error_response, log_error_with_context logger = logging.getLogger(__name__) # Default constants (can be overridden by config) DEFAULT_SAMPLE_SIZE = 10 VALID_SPLITS = {"train", "validation", "test", "dev", "val"} def get_dataset_sample( dataset_id: str, split: str = "train", num_samples: int = DEFAULT_SAMPLE_SIZE, config_name: Optional[str] = None, streaming: bool = True, hf_api_token: gr.Header = "", ) -> Dict[str, Any]: """ Retrieve a sample of rows from a HuggingFace dataset. This function efficiently samples data from datasets with support for different splits and configurable sample sizes. It uses streaming by default for large datasets to minimize memory usage and loading time. Args: dataset_id: HuggingFace dataset identifier (e.g., 'imdb', 'squad', 'glue') split: Dataset split to sample from (default: 'train') num_samples: Number of samples to retrieve (default: 10, max: 10000) config_name: Optional configuration name for multi-config datasets streaming: Whether to use streaming mode for efficient loading (default: True) hf_api_token: Header parsed by Gradio when hf_api_token is provided in MCP configuration headers Returns: Dictionary containing sampled data and metadata: - dataset_id: Original dataset identifier - config_name: Configuration name used (if any) - split: Split name sampled from - num_samples: Actual number of samples returned - requested_samples: Number of samples originally requested - data: List of sample dictionaries - schema: Dictionary describing the dataset features/columns - sample_info: Additional information about the sampling process Raises: ValueError: If inputs are invalid (empty dataset_id, invalid split, etc.) DatasetNotFoundError: If dataset or split doesn't exist AuthenticationError: If dataset is private and authentication fails DatasetServiceError: If sampling fails for other reasons Example: >>> # Basic sampling >>> sample = get_dataset_sample("imdb", split="train", num_samples=5) >>> print(f"Got {sample['num_samples']} samples from {sample['dataset_id']}") >>> for i, row in enumerate(sample['data']): ... print(f"Sample {i+1}: {list(row.keys())}") >>> # Multi-config dataset sampling >>> sample = get_dataset_sample("glue", split="validation", ... num_samples=3, config_name="cola") >>> print(f"Schema: {sample['schema']}") """ # Handle empty strings from Gradio (convert to None) if config_name == "": config_name = None # Input validation using centralized validation try: dataset_id = validate_dataset_id(dataset_id) config_name = validate_config_name(config_name) split = validate_split_name(split) num_samples = validate_sample_size(num_samples, "num_samples") except ValidationError as e: logger.error(f"Validation error: {format_validation_error(e)}") raise ValueError(format_validation_error(e)) context = { "dataset_id": dataset_id, "split": split, "num_samples": num_samples, "config_name": config_name, "operation": "get_dataset_sample" } logger.info( f"Sampling {num_samples} rows from dataset: {dataset_id}, " f"split: {split}" + (f", config: {config_name}" if config_name else "") ) try: # Get dataset service and load sample service = get_dataset_service(hf_api_token=hf_api_token) sample_data = service.load_dataset_sample( dataset_id=dataset_id, split=split, num_samples=num_samples, config_name=config_name, streaming=streaming, ) # Enhance the response with additional metadata config = get_config() sample_data["sample_info"] = { "streaming_used": streaming, "sampling_strategy": "sequential_head", # We take first N samples "max_sample_size": config.max_sample_size, "truncated": sample_data["num_samples"] < sample_data["requested_samples"], } # Add data preview information if sample_data["data"]: first_sample = sample_data["data"][0] sample_data["sample_info"]["preview"] = { "columns": list(first_sample.keys()) if isinstance(first_sample, dict) else [], "first_sample_types": { k: type(v).__name__ for k, v in first_sample.items() } if isinstance(first_sample, dict) else {}, } # Add summary sample_data["summary"] = _generate_sample_summary(sample_data) logger.info( f"Successfully sampled {sample_data['num_samples']} rows from {dataset_id}" ) return sample_data except DatasetNotFoundError as e: log_error_with_context(e, context, level=logging.WARNING) error_response = format_error_response(e, context) logger.info(f"Dataset/split not found suggestions: {error_response.get('suggestions', [])}") raise except AuthenticationError as e: log_error_with_context(e, context, level=logging.WARNING) error_response = format_error_response(e, context) logger.info(f"Authentication error guidance: {error_response.get('suggestions', [])}") raise except NetworkError as e: log_error_with_context(e, context) error_response = format_error_response(e, context) logger.info(f"Network error guidance: {error_response.get('suggestions', [])}") raise except Exception as e: log_error_with_context(e, context) raise DatasetServiceError(f"Failed to sample dataset: {str(e)}") from e # def get_dataset_sample_with_indices( # dataset_id: str, # indices: List[int], # split: str = "train", # config_name: Optional[str] = None, # ) -> Dict[str, Any]: # """ # Retrieve specific samples by their indices from a HuggingFace dataset. # This function allows for targeted sampling by specifying exact row indices. # Note: This requires loading the dataset in non-streaming mode. # Args: # dataset_id: HuggingFace dataset identifier # indices: List of row indices to retrieve # split: Dataset split to sample from (default: 'train') # config_name: Optional configuration name for multi-config datasets # Returns: # Dictionary containing the requested samples and metadata # Raises: # ValueError: If inputs are invalid # DatasetServiceError: If sampling fails # """ # # Handle empty strings from Gradio (convert to None) # if config_name == "": # config_name = None # # Input validation using centralized validation # try: # dataset_id = validate_dataset_id(dataset_id) # config_name = validate_config_name(config_name) # split = validate_split_name(split) # indices = validate_indices(indices) # except ValidationError as e: # logger.error(f"Validation error: {format_validation_error(e)}") # raise ValueError(format_validation_error(e)) # logger.info(f"Sampling {len(indices)} specific indices from dataset: {dataset_id}") # try: # from datasets import load_dataset # # Load dataset without streaming to access by index # dataset = load_dataset( # dataset_id, name=config_name, split=split, streaming=False # ) # # Validate indices are within bounds # max_index = max(indices) # if max_index >= len(dataset): # raise ValueError( # f"Index {max_index} is out of bounds for dataset with {len(dataset)} rows" # ) # # Get samples by indices # samples = [dataset[i] for i in indices] # # Get dataset info for schema # service = get_dataset_service(hf_api_token=hf_api_token) # dataset_info = service.load_dataset_info(dataset_id, config_name) # # Prepare response # sample_data = { # "dataset_id": dataset_id, # "config_name": config_name, # "split": split, # "num_samples": len(samples), # "requested_indices": indices, # "data": samples, # "schema": dataset_info.get("features", {}), # "sample_info": { # "sampling_strategy": "by_indices", # "streaming_used": False, # "indices_requested": len(indices), # }, # } # sample_data["summary"] = _generate_sample_summary(sample_data) # return sample_data # except Exception as e: # logger.error(f"Failed to sample by indices from {dataset_id}: {str(e)}") # raise DatasetServiceError(f"Failed to sample by indices: {str(e)}") def _generate_sample_summary(sample_data: Dict[str, Any]) -> str: """Generate a human-readable summary of the sample data.""" summary_parts = [] # Basic info summary_parts.append(f"Dataset: {sample_data.get('dataset_id', 'Unknown')}") summary_parts.append(f"Split: {sample_data.get('split', 'Unknown')}") if sample_data.get("config_name"): summary_parts.append(f"Config: {sample_data['config_name']}") # Sample info num_samples = sample_data.get("num_samples", 0) requested = sample_data.get("requested_samples", num_samples) if num_samples == requested: summary_parts.append(f"Samples: {num_samples}") else: summary_parts.append(f"Samples: {num_samples}/{requested} (truncated)") # Schema info schema = sample_data.get("schema", {}) if schema: summary_parts.append(f"Columns: {len(schema)}") # Sampling strategy sample_info = sample_data.get("sample_info", {}) strategy = sample_info.get("sampling_strategy", "unknown") if strategy == "by_indices": summary_parts.append("Strategy: by indices") elif strategy == "sequential_head": summary_parts.append("Strategy: first N rows") return " | ".join(summary_parts)