""" Input validation utilities for HF EDA MCP Server. This module provides centralized validation functions for all tool inputs, ensuring consistent error messages and validation logic across the application. """ import re from typing import Optional, List from hf_eda_mcp.config import get_config class ValidationError(ValueError): """Custom exception for validation errors with helpful messages.""" def __init__(self, message: str, suggestions: Optional[List[str]] = None): super().__init__(message) self.suggestions = suggestions or [] def validate_dataset_id(dataset_id: str) -> str: """ Validate and normalize a HuggingFace dataset identifier. Args: dataset_id: Dataset identifier to validate Returns: Normalized dataset_id (stripped of whitespace) Raises: ValidationError: If dataset_id is invalid with helpful error message """ if not dataset_id: raise ValidationError( "dataset_id is required and cannot be empty", suggestions=[ "Provide a valid HuggingFace dataset identifier", "Examples: 'imdb', 'squad', 'glue', 'username/dataset-name'", ], ) if not isinstance(dataset_id, str): raise ValidationError( f"dataset_id must be a string, got {type(dataset_id).__name__}", suggestions=["Ensure dataset_id is passed as a string value"], ) dataset_id = dataset_id.strip() if not dataset_id: raise ValidationError( "dataset_id cannot be empty or contain only whitespace", suggestions=["Provide a non-empty dataset identifier"], ) # Validate format: alphanumeric, hyphens, underscores, slashes, dots, @ # Pattern: optional username/ followed by dataset name pattern = r"^[a-zA-Z0-9][\w\-\.@]*(/[\w\-\.]+)?$" if not re.match(pattern, dataset_id): raise ValidationError( f"Invalid dataset_id format: '{dataset_id}'", suggestions=[ "Dataset IDs should contain only letters, numbers, hyphens, underscores, dots, and slashes", "Valid formats: 'dataset-name' or 'username/dataset-name'", "Examples: 'imdb', 'squad', 'huggingface/dataset-name'", ], ) # Check for common mistakes if dataset_id.startswith("/") or dataset_id.endswith("/"): raise ValidationError( f"Invalid dataset_id: '{dataset_id}' - cannot start or end with '/'", suggestions=["Remove leading or trailing slashes from the dataset_id"], ) if "//" in dataset_id: raise ValidationError( f"Invalid dataset_id: '{dataset_id}' - contains consecutive slashes", suggestions=["Use single slashes to separate username from dataset name"], ) # Warn about very long dataset IDs (likely an error) if len(dataset_id) > 100: raise ValidationError( f"dataset_id is unusually long ({len(dataset_id)} characters)", suggestions=[ "Check if the dataset_id is correct", "Dataset IDs are typically shorter than 100 characters", ], ) return dataset_id def validate_config_name(config_name: Optional[str]) -> Optional[str]: """ Validate and normalize a dataset configuration name. Args: config_name: Configuration name to validate (can be None) Returns: Normalized config_name or None Raises: ValidationError: If config_name is invalid """ if config_name is None: return None if not isinstance(config_name, str): raise ValidationError( f"config_name must be a string or None, got {type(config_name).__name__}", suggestions=["Pass config_name as a string or omit it for default configuration"], ) config_name = config_name.strip() if len(config_name) == 0: return None # Validate format: alphanumeric, hyphens, underscores, dots pattern = r"^[a-zA-Z0-9][\w\-\.]*$" if not re.match(pattern, config_name): raise ValidationError( f"Invalid config_name format: '{config_name}'", suggestions=[ "Configuration names should contain only letters, numbers, hyphens, underscores, and dots", "Examples: 'cola', 'sst2', 'plain_text'", ], ) if len(config_name) > 50: raise ValidationError( f"config_name is unusually long ({len(config_name)} characters)", suggestions=[ "Check if the config_name is correct", "Configuration names are typically shorter than 50 characters", ], ) return config_name def validate_split_name(split: str) -> str: """ Validate and normalize a dataset split name. Args: split: Split name to validate Returns: Normalized split name (lowercase, stripped) Raises: ValidationError: If split is invalid """ if not split: raise ValidationError( "split is required and cannot be empty", suggestions=[ "Provide a valid split name", "Common splits: 'train', 'validation', 'test'", ], ) if not isinstance(split, str): raise ValidationError( f"split must be a string, got {type(split).__name__}", suggestions=["Ensure split is passed as a string value"], ) split = split.strip().lower() if not split: raise ValidationError( "split cannot be empty or contain only whitespace", suggestions=["Provide a non-empty split name"], ) # Validate format: alphanumeric, hyphens, underscores pattern = r"^[a-zA-Z0-9][\w\-]*$" if not re.match(pattern, split): raise ValidationError( f"Invalid split name format: '{split}'", suggestions=[ "Split names should contain only letters, numbers, hyphens, and underscores", "Common splits: 'train', 'validation', 'test', 'dev'", ], ) # Note: We don't enforce a specific set of split names as datasets can have custom splits # Common splits for reference common_splits = {"train", "validation", "test", "dev", "val"} if split not in common_splits and len(split) > 20: raise ValidationError( f"Unusual split name: '{split}' (length: {len(split)})", suggestions=[ "Check if the split name is correct", f"Common splits are: {', '.join(sorted(common_splits))}", "Some datasets may have custom split names", ], ) return split def validate_sample_size(num_samples: int, parameter_name: str = "num_samples") -> int: """ Validate sample size parameter. Args: num_samples: Number of samples to validate parameter_name: Name of the parameter (for error messages) Returns: Validated num_samples Raises: ValidationError: If num_samples is invalid """ if not isinstance(num_samples, int): # Check if it's a float that's actually an integer if isinstance(num_samples, float) and num_samples.is_integer(): num_samples = int(num_samples) else: raise ValidationError( f"{parameter_name} must be an integer, got {type(num_samples).__name__}", suggestions=[ f"Provide {parameter_name} as an integer value", "Example: num_samples=100", ], ) if num_samples <= 0: raise ValidationError( f"{parameter_name} must be positive, got {num_samples}", suggestions=[ f"Provide a positive integer for {parameter_name}", "Example: num_samples=10 or num_samples=1000", ], ) # Get max sample size from config config = get_config() max_sample_size = config.max_sample_size if num_samples > max_sample_size: raise ValidationError( f"{parameter_name} ({num_samples}) exceeds maximum allowed ({max_sample_size})", suggestions=[ f"Reduce {parameter_name} to {max_sample_size} or less", f"Current maximum is configured as {max_sample_size}", "For larger samples, consider using streaming or batch processing", ], ) # Warn about very small samples (might not be useful) if num_samples < 5: # This is just a soft warning, not an error pass return num_samples def validate_indices(indices: List[int]) -> List[int]: """ Validate a list of indices for sampling. Args: indices: List of indices to validate Returns: Validated indices list Raises: ValidationError: If indices are invalid """ if not indices: raise ValidationError( "indices list is required and cannot be empty", suggestions=[ "Provide a non-empty list of indices", "Example: indices=[0, 1, 2, 10, 20]", ], ) if not isinstance(indices, list): raise ValidationError( f"indices must be a list, got {type(indices).__name__}", suggestions=[ "Provide indices as a list of integers", "Example: indices=[0, 1, 2]", ], ) # Validate each index for i, idx in enumerate(indices): if not isinstance(idx, int): raise ValidationError( f"All indices must be integers, got {type(idx).__name__} at position {i}", suggestions=[ "Ensure all indices are integer values", "Example: indices=[0, 1, 2] (not [0.5, 1.2])", ], ) if idx < 0: raise ValidationError( f"All indices must be non-negative, got {idx} at position {i}", suggestions=[ "Provide only non-negative indices (0 or greater)", "Example: indices=[0, 1, 2, 10]", ], ) # Check for reasonable list size config = get_config() max_sample_size = config.max_sample_size if len(indices) > max_sample_size: raise ValidationError( f"Too many indices requested ({len(indices)}), maximum is {max_sample_size}", suggestions=[ f"Reduce the number of indices to {max_sample_size} or less", "Consider using regular sampling instead of specific indices", ], ) return indices def format_validation_error(error: ValidationError) -> str: """ Format a validation error with suggestions into a user-friendly message. Args: error: ValidationError to format Returns: Formatted error message with suggestions """ message = str(error) if error.suggestions: message += "\n\nSuggestions:" for suggestion in error.suggestions: message += f"\n - {suggestion}" return message