hf-eda-mcp / src /hf_eda_mcp /validation.py
KhalilGuetari's picture
fix validation when string is empty for config
1b746fa
"""
Input validation utilities for HF EDA MCP Server.
This module provides centralized validation functions for all tool inputs,
ensuring consistent error messages and validation logic across the application.
"""
import re
from typing import Optional, List
from hf_eda_mcp.config import get_config
class ValidationError(ValueError):
"""Custom exception for validation errors with helpful messages."""
def __init__(self, message: str, suggestions: Optional[List[str]] = None):
super().__init__(message)
self.suggestions = suggestions or []
def validate_dataset_id(dataset_id: str) -> str:
"""
Validate and normalize a HuggingFace dataset identifier.
Args:
dataset_id: Dataset identifier to validate
Returns:
Normalized dataset_id (stripped of whitespace)
Raises:
ValidationError: If dataset_id is invalid with helpful error message
"""
if not dataset_id:
raise ValidationError(
"dataset_id is required and cannot be empty",
suggestions=[
"Provide a valid HuggingFace dataset identifier",
"Examples: 'imdb', 'squad', 'glue', 'username/dataset-name'",
],
)
if not isinstance(dataset_id, str):
raise ValidationError(
f"dataset_id must be a string, got {type(dataset_id).__name__}",
suggestions=["Ensure dataset_id is passed as a string value"],
)
dataset_id = dataset_id.strip()
if not dataset_id:
raise ValidationError(
"dataset_id cannot be empty or contain only whitespace",
suggestions=["Provide a non-empty dataset identifier"],
)
# Validate format: alphanumeric, hyphens, underscores, slashes, dots, @
# Pattern: optional username/ followed by dataset name
pattern = r"^[a-zA-Z0-9][\w\-\.@]*(/[\w\-\.]+)?$"
if not re.match(pattern, dataset_id):
raise ValidationError(
f"Invalid dataset_id format: '{dataset_id}'",
suggestions=[
"Dataset IDs should contain only letters, numbers, hyphens, underscores, dots, and slashes",
"Valid formats: 'dataset-name' or 'username/dataset-name'",
"Examples: 'imdb', 'squad', 'huggingface/dataset-name'",
],
)
# Check for common mistakes
if dataset_id.startswith("/") or dataset_id.endswith("/"):
raise ValidationError(
f"Invalid dataset_id: '{dataset_id}' - cannot start or end with '/'",
suggestions=["Remove leading or trailing slashes from the dataset_id"],
)
if "//" in dataset_id:
raise ValidationError(
f"Invalid dataset_id: '{dataset_id}' - contains consecutive slashes",
suggestions=["Use single slashes to separate username from dataset name"],
)
# Warn about very long dataset IDs (likely an error)
if len(dataset_id) > 100:
raise ValidationError(
f"dataset_id is unusually long ({len(dataset_id)} characters)",
suggestions=[
"Check if the dataset_id is correct",
"Dataset IDs are typically shorter than 100 characters",
],
)
return dataset_id
def validate_config_name(config_name: Optional[str]) -> Optional[str]:
"""
Validate and normalize a dataset configuration name.
Args:
config_name: Configuration name to validate (can be None)
Returns:
Normalized config_name or None
Raises:
ValidationError: If config_name is invalid
"""
if config_name is None:
return None
if not isinstance(config_name, str):
raise ValidationError(
f"config_name must be a string or None, got {type(config_name).__name__}",
suggestions=["Pass config_name as a string or omit it for default configuration"],
)
config_name = config_name.strip()
if len(config_name) == 0:
return None
# Validate format: alphanumeric, hyphens, underscores, dots
pattern = r"^[a-zA-Z0-9][\w\-\.]*$"
if not re.match(pattern, config_name):
raise ValidationError(
f"Invalid config_name format: '{config_name}'",
suggestions=[
"Configuration names should contain only letters, numbers, hyphens, underscores, and dots",
"Examples: 'cola', 'sst2', 'plain_text'",
],
)
if len(config_name) > 50:
raise ValidationError(
f"config_name is unusually long ({len(config_name)} characters)",
suggestions=[
"Check if the config_name is correct",
"Configuration names are typically shorter than 50 characters",
],
)
return config_name
def validate_split_name(split: str) -> str:
"""
Validate and normalize a dataset split name.
Args:
split: Split name to validate
Returns:
Normalized split name (lowercase, stripped)
Raises:
ValidationError: If split is invalid
"""
if not split:
raise ValidationError(
"split is required and cannot be empty",
suggestions=[
"Provide a valid split name",
"Common splits: 'train', 'validation', 'test'",
],
)
if not isinstance(split, str):
raise ValidationError(
f"split must be a string, got {type(split).__name__}",
suggestions=["Ensure split is passed as a string value"],
)
split = split.strip().lower()
if not split:
raise ValidationError(
"split cannot be empty or contain only whitespace",
suggestions=["Provide a non-empty split name"],
)
# Validate format: alphanumeric, hyphens, underscores
pattern = r"^[a-zA-Z0-9][\w\-]*$"
if not re.match(pattern, split):
raise ValidationError(
f"Invalid split name format: '{split}'",
suggestions=[
"Split names should contain only letters, numbers, hyphens, and underscores",
"Common splits: 'train', 'validation', 'test', 'dev'",
],
)
# Note: We don't enforce a specific set of split names as datasets can have custom splits
# Common splits for reference
common_splits = {"train", "validation", "test", "dev", "val"}
if split not in common_splits and len(split) > 20:
raise ValidationError(
f"Unusual split name: '{split}' (length: {len(split)})",
suggestions=[
"Check if the split name is correct",
f"Common splits are: {', '.join(sorted(common_splits))}",
"Some datasets may have custom split names",
],
)
return split
def validate_sample_size(num_samples: int, parameter_name: str = "num_samples") -> int:
"""
Validate sample size parameter.
Args:
num_samples: Number of samples to validate
parameter_name: Name of the parameter (for error messages)
Returns:
Validated num_samples
Raises:
ValidationError: If num_samples is invalid
"""
if not isinstance(num_samples, int):
# Check if it's a float that's actually an integer
if isinstance(num_samples, float) and num_samples.is_integer():
num_samples = int(num_samples)
else:
raise ValidationError(
f"{parameter_name} must be an integer, got {type(num_samples).__name__}",
suggestions=[
f"Provide {parameter_name} as an integer value",
"Example: num_samples=100",
],
)
if num_samples <= 0:
raise ValidationError(
f"{parameter_name} must be positive, got {num_samples}",
suggestions=[
f"Provide a positive integer for {parameter_name}",
"Example: num_samples=10 or num_samples=1000",
],
)
# Get max sample size from config
config = get_config()
max_sample_size = config.max_sample_size
if num_samples > max_sample_size:
raise ValidationError(
f"{parameter_name} ({num_samples}) exceeds maximum allowed ({max_sample_size})",
suggestions=[
f"Reduce {parameter_name} to {max_sample_size} or less",
f"Current maximum is configured as {max_sample_size}",
"For larger samples, consider using streaming or batch processing",
],
)
# Warn about very small samples (might not be useful)
if num_samples < 5:
# This is just a soft warning, not an error
pass
return num_samples
def validate_indices(indices: List[int]) -> List[int]:
"""
Validate a list of indices for sampling.
Args:
indices: List of indices to validate
Returns:
Validated indices list
Raises:
ValidationError: If indices are invalid
"""
if not indices:
raise ValidationError(
"indices list is required and cannot be empty",
suggestions=[
"Provide a non-empty list of indices",
"Example: indices=[0, 1, 2, 10, 20]",
],
)
if not isinstance(indices, list):
raise ValidationError(
f"indices must be a list, got {type(indices).__name__}",
suggestions=[
"Provide indices as a list of integers",
"Example: indices=[0, 1, 2]",
],
)
# Validate each index
for i, idx in enumerate(indices):
if not isinstance(idx, int):
raise ValidationError(
f"All indices must be integers, got {type(idx).__name__} at position {i}",
suggestions=[
"Ensure all indices are integer values",
"Example: indices=[0, 1, 2] (not [0.5, 1.2])",
],
)
if idx < 0:
raise ValidationError(
f"All indices must be non-negative, got {idx} at position {i}",
suggestions=[
"Provide only non-negative indices (0 or greater)",
"Example: indices=[0, 1, 2, 10]",
],
)
# Check for reasonable list size
config = get_config()
max_sample_size = config.max_sample_size
if len(indices) > max_sample_size:
raise ValidationError(
f"Too many indices requested ({len(indices)}), maximum is {max_sample_size}",
suggestions=[
f"Reduce the number of indices to {max_sample_size} or less",
"Consider using regular sampling instead of specific indices",
],
)
return indices
def format_validation_error(error: ValidationError) -> str:
"""
Format a validation error with suggestions into a user-friendly message.
Args:
error: ValidationError to format
Returns:
Formatted error message with suggestions
"""
message = str(error)
if error.suggestions:
message += "\n\nSuggestions:"
for suggestion in error.suggestions:
message += f"\n - {suggestion}"
return message