Spaces:
Running
Running
| """ | |
| Input validation utilities for HF EDA MCP Server. | |
| This module provides centralized validation functions for all tool inputs, | |
| ensuring consistent error messages and validation logic across the application. | |
| """ | |
| import re | |
| from typing import Optional, List | |
| from hf_eda_mcp.config import get_config | |
| class ValidationError(ValueError): | |
| """Custom exception for validation errors with helpful messages.""" | |
| def __init__(self, message: str, suggestions: Optional[List[str]] = None): | |
| super().__init__(message) | |
| self.suggestions = suggestions or [] | |
| def validate_dataset_id(dataset_id: str) -> str: | |
| """ | |
| Validate and normalize a HuggingFace dataset identifier. | |
| Args: | |
| dataset_id: Dataset identifier to validate | |
| Returns: | |
| Normalized dataset_id (stripped of whitespace) | |
| Raises: | |
| ValidationError: If dataset_id is invalid with helpful error message | |
| """ | |
| if not dataset_id: | |
| raise ValidationError( | |
| "dataset_id is required and cannot be empty", | |
| suggestions=[ | |
| "Provide a valid HuggingFace dataset identifier", | |
| "Examples: 'imdb', 'squad', 'glue', 'username/dataset-name'", | |
| ], | |
| ) | |
| if not isinstance(dataset_id, str): | |
| raise ValidationError( | |
| f"dataset_id must be a string, got {type(dataset_id).__name__}", | |
| suggestions=["Ensure dataset_id is passed as a string value"], | |
| ) | |
| dataset_id = dataset_id.strip() | |
| if not dataset_id: | |
| raise ValidationError( | |
| "dataset_id cannot be empty or contain only whitespace", | |
| suggestions=["Provide a non-empty dataset identifier"], | |
| ) | |
| # Validate format: alphanumeric, hyphens, underscores, slashes, dots, @ | |
| # Pattern: optional username/ followed by dataset name | |
| pattern = r"^[a-zA-Z0-9][\w\-\.@]*(/[\w\-\.]+)?$" | |
| if not re.match(pattern, dataset_id): | |
| raise ValidationError( | |
| f"Invalid dataset_id format: '{dataset_id}'", | |
| suggestions=[ | |
| "Dataset IDs should contain only letters, numbers, hyphens, underscores, dots, and slashes", | |
| "Valid formats: 'dataset-name' or 'username/dataset-name'", | |
| "Examples: 'imdb', 'squad', 'huggingface/dataset-name'", | |
| ], | |
| ) | |
| # Check for common mistakes | |
| if dataset_id.startswith("/") or dataset_id.endswith("/"): | |
| raise ValidationError( | |
| f"Invalid dataset_id: '{dataset_id}' - cannot start or end with '/'", | |
| suggestions=["Remove leading or trailing slashes from the dataset_id"], | |
| ) | |
| if "//" in dataset_id: | |
| raise ValidationError( | |
| f"Invalid dataset_id: '{dataset_id}' - contains consecutive slashes", | |
| suggestions=["Use single slashes to separate username from dataset name"], | |
| ) | |
| # Warn about very long dataset IDs (likely an error) | |
| if len(dataset_id) > 100: | |
| raise ValidationError( | |
| f"dataset_id is unusually long ({len(dataset_id)} characters)", | |
| suggestions=[ | |
| "Check if the dataset_id is correct", | |
| "Dataset IDs are typically shorter than 100 characters", | |
| ], | |
| ) | |
| return dataset_id | |
| def validate_config_name(config_name: Optional[str]) -> Optional[str]: | |
| """ | |
| Validate and normalize a dataset configuration name. | |
| Args: | |
| config_name: Configuration name to validate (can be None) | |
| Returns: | |
| Normalized config_name or None | |
| Raises: | |
| ValidationError: If config_name is invalid | |
| """ | |
| if config_name is None: | |
| return None | |
| if not isinstance(config_name, str): | |
| raise ValidationError( | |
| f"config_name must be a string or None, got {type(config_name).__name__}", | |
| suggestions=["Pass config_name as a string or omit it for default configuration"], | |
| ) | |
| config_name = config_name.strip() | |
| if len(config_name) == 0: | |
| return None | |
| # Validate format: alphanumeric, hyphens, underscores, dots | |
| pattern = r"^[a-zA-Z0-9][\w\-\.]*$" | |
| if not re.match(pattern, config_name): | |
| raise ValidationError( | |
| f"Invalid config_name format: '{config_name}'", | |
| suggestions=[ | |
| "Configuration names should contain only letters, numbers, hyphens, underscores, and dots", | |
| "Examples: 'cola', 'sst2', 'plain_text'", | |
| ], | |
| ) | |
| if len(config_name) > 50: | |
| raise ValidationError( | |
| f"config_name is unusually long ({len(config_name)} characters)", | |
| suggestions=[ | |
| "Check if the config_name is correct", | |
| "Configuration names are typically shorter than 50 characters", | |
| ], | |
| ) | |
| return config_name | |
| def validate_split_name(split: str) -> str: | |
| """ | |
| Validate and normalize a dataset split name. | |
| Args: | |
| split: Split name to validate | |
| Returns: | |
| Normalized split name (lowercase, stripped) | |
| Raises: | |
| ValidationError: If split is invalid | |
| """ | |
| if not split: | |
| raise ValidationError( | |
| "split is required and cannot be empty", | |
| suggestions=[ | |
| "Provide a valid split name", | |
| "Common splits: 'train', 'validation', 'test'", | |
| ], | |
| ) | |
| if not isinstance(split, str): | |
| raise ValidationError( | |
| f"split must be a string, got {type(split).__name__}", | |
| suggestions=["Ensure split is passed as a string value"], | |
| ) | |
| split = split.strip().lower() | |
| if not split: | |
| raise ValidationError( | |
| "split cannot be empty or contain only whitespace", | |
| suggestions=["Provide a non-empty split name"], | |
| ) | |
| # Validate format: alphanumeric, hyphens, underscores | |
| pattern = r"^[a-zA-Z0-9][\w\-]*$" | |
| if not re.match(pattern, split): | |
| raise ValidationError( | |
| f"Invalid split name format: '{split}'", | |
| suggestions=[ | |
| "Split names should contain only letters, numbers, hyphens, and underscores", | |
| "Common splits: 'train', 'validation', 'test', 'dev'", | |
| ], | |
| ) | |
| # Note: We don't enforce a specific set of split names as datasets can have custom splits | |
| # Common splits for reference | |
| common_splits = {"train", "validation", "test", "dev", "val"} | |
| if split not in common_splits and len(split) > 20: | |
| raise ValidationError( | |
| f"Unusual split name: '{split}' (length: {len(split)})", | |
| suggestions=[ | |
| "Check if the split name is correct", | |
| f"Common splits are: {', '.join(sorted(common_splits))}", | |
| "Some datasets may have custom split names", | |
| ], | |
| ) | |
| return split | |
| def validate_sample_size(num_samples: int, parameter_name: str = "num_samples") -> int: | |
| """ | |
| Validate sample size parameter. | |
| Args: | |
| num_samples: Number of samples to validate | |
| parameter_name: Name of the parameter (for error messages) | |
| Returns: | |
| Validated num_samples | |
| Raises: | |
| ValidationError: If num_samples is invalid | |
| """ | |
| if not isinstance(num_samples, int): | |
| # Check if it's a float that's actually an integer | |
| if isinstance(num_samples, float) and num_samples.is_integer(): | |
| num_samples = int(num_samples) | |
| else: | |
| raise ValidationError( | |
| f"{parameter_name} must be an integer, got {type(num_samples).__name__}", | |
| suggestions=[ | |
| f"Provide {parameter_name} as an integer value", | |
| "Example: num_samples=100", | |
| ], | |
| ) | |
| if num_samples <= 0: | |
| raise ValidationError( | |
| f"{parameter_name} must be positive, got {num_samples}", | |
| suggestions=[ | |
| f"Provide a positive integer for {parameter_name}", | |
| "Example: num_samples=10 or num_samples=1000", | |
| ], | |
| ) | |
| # Get max sample size from config | |
| config = get_config() | |
| max_sample_size = config.max_sample_size | |
| if num_samples > max_sample_size: | |
| raise ValidationError( | |
| f"{parameter_name} ({num_samples}) exceeds maximum allowed ({max_sample_size})", | |
| suggestions=[ | |
| f"Reduce {parameter_name} to {max_sample_size} or less", | |
| f"Current maximum is configured as {max_sample_size}", | |
| "For larger samples, consider using streaming or batch processing", | |
| ], | |
| ) | |
| # Warn about very small samples (might not be useful) | |
| if num_samples < 5: | |
| # This is just a soft warning, not an error | |
| pass | |
| return num_samples | |
| def validate_indices(indices: List[int]) -> List[int]: | |
| """ | |
| Validate a list of indices for sampling. | |
| Args: | |
| indices: List of indices to validate | |
| Returns: | |
| Validated indices list | |
| Raises: | |
| ValidationError: If indices are invalid | |
| """ | |
| if not indices: | |
| raise ValidationError( | |
| "indices list is required and cannot be empty", | |
| suggestions=[ | |
| "Provide a non-empty list of indices", | |
| "Example: indices=[0, 1, 2, 10, 20]", | |
| ], | |
| ) | |
| if not isinstance(indices, list): | |
| raise ValidationError( | |
| f"indices must be a list, got {type(indices).__name__}", | |
| suggestions=[ | |
| "Provide indices as a list of integers", | |
| "Example: indices=[0, 1, 2]", | |
| ], | |
| ) | |
| # Validate each index | |
| for i, idx in enumerate(indices): | |
| if not isinstance(idx, int): | |
| raise ValidationError( | |
| f"All indices must be integers, got {type(idx).__name__} at position {i}", | |
| suggestions=[ | |
| "Ensure all indices are integer values", | |
| "Example: indices=[0, 1, 2] (not [0.5, 1.2])", | |
| ], | |
| ) | |
| if idx < 0: | |
| raise ValidationError( | |
| f"All indices must be non-negative, got {idx} at position {i}", | |
| suggestions=[ | |
| "Provide only non-negative indices (0 or greater)", | |
| "Example: indices=[0, 1, 2, 10]", | |
| ], | |
| ) | |
| # Check for reasonable list size | |
| config = get_config() | |
| max_sample_size = config.max_sample_size | |
| if len(indices) > max_sample_size: | |
| raise ValidationError( | |
| f"Too many indices requested ({len(indices)}), maximum is {max_sample_size}", | |
| suggestions=[ | |
| f"Reduce the number of indices to {max_sample_size} or less", | |
| "Consider using regular sampling instead of specific indices", | |
| ], | |
| ) | |
| return indices | |
| def format_validation_error(error: ValidationError) -> str: | |
| """ | |
| Format a validation error with suggestions into a user-friendly message. | |
| Args: | |
| error: ValidationError to format | |
| Returns: | |
| Formatted error message with suggestions | |
| """ | |
| message = str(error) | |
| if error.suggestions: | |
| message += "\n\nSuggestions:" | |
| for suggestion in error.suggestions: | |
| message += f"\n - {suggestion}" | |
| return message | |