KhalilGuetari's picture
Add a search text in dataset tool
ca96eb9
import logging
import gradio as gr
from typing import Dict, Any
from hf_eda_mcp.services.dataset_service import (
DatasetServiceError,
DatasetNotParquetError,
NoTextColumnsError,
get_dataset_service
)
from hf_eda_mcp.integrations.hf_client import DatasetNotFoundError, AuthenticationError, NetworkError
from hf_eda_mcp.validation import (
validate_dataset_id,
validate_config_name,
validate_split_name,
ValidationError,
format_validation_error,
)
from hf_eda_mcp.error_handling import format_error_response, log_error_with_context
logger = logging.getLogger(__name__)
def search_text_in_dataset(
dataset_id: str,
config_name: str,
split: str,
query: str,
offset: int = 0,
length: int = 10,
hf_api_token: gr.Header = "",
) -> Dict[str, Any]:
"""
Search for text in text columns of a dataset using the Dataset Viewer API.
Only text columns are searched and only parquet datasets are supported (builder_name="parquet")
Useful for finding relevant examples or debugging issues.
Args:
dataset_id: HuggingFace full dataset identifier (e.g., 'stanfordnlp/imdb', 'rajpurkar/squad', 'nyu-mll/glue')
config_name: Configuration name
split: Split name
query: Search query
offset: Offset for pagination (default: 0)
length: Number of examples to return (default: 50). Means that we search in [offset, offset+length[
hf_api_token: Header parsed by Gradio when hf_api_token is provided in MCP configuration headers
Returns:
Dictionary containing search results including:
- features: List of features from the dataset, including column names and data types
- rows: List of slice of rows of a dataset and the content contained in each column of a specific row.
- num_rows_total: Total number of examples in the split
- num_rows_per_page: Number of examples in the current page
- partial: Whether the response is partial. If True, it means that the search couldn’t be run on the full dataset because it’s too big.
"""
# Handle empty strings from Gradio (convert to None)
if config_name == "":
config_name = None
# Input validation using centralized validation
try:
dataset_id = validate_dataset_id(dataset_id)
config_name = validate_config_name(config_name)
split = validate_split_name(split)
except ValidationError as e:
logger.error(f"Validation error: {format_validation_error(e)}")
raise ValueError(format_validation_error(e))
context = {
"dataset_id": dataset_id,
"config_name": config_name,
"split": split,
"query": query,
"offset": offset,
"length": length,
"operation": "search_text_in_dataset"
}
logger.info(
f"Searching text {query} in dataset: {dataset_id}, split: {split}, "
f"config: {config_name}, offset: {offset}, length: {length}"
)
try:
# Get dataset service
service = get_dataset_service(hf_api_token=hf_api_token)
# Search in dataset
search_results = service.search_text_in_dataset(
dataset_id=dataset_id,
config_name=config_name,
split_name=split,
query=query,
offset=offset,
length=length
)
return search_results
except DatasetNotParquetError as e:
log_error_with_context(e, context, level=logging.WARNING)
logger.info(f"Dataset is not in parquet format: {str(e)}")
raise ValueError(str(e)) from e
except NoTextColumnsError as e:
log_error_with_context(e, context, level=logging.WARNING)
logger.info(f"Dataset has no text columns: {str(e)}")
raise ValueError(str(e)) from e
except DatasetNotFoundError as e:
log_error_with_context(e, context, level=logging.WARNING)
error_response = format_error_response(e, context)
logger.info(f"Dataset/split not found suggestions: {error_response.get('suggestions', [])}")
raise
except AuthenticationError as e:
log_error_with_context(e, context, level=logging.WARNING)
error_response = format_error_response(e, context)
logger.info(f"Authentication error guidance: {error_response.get('suggestions', [])}")
raise
except NetworkError as e:
log_error_with_context(e, context)
error_response = format_error_response(e, context)
logger.info(f"Network error guidance: {error_response.get('suggestions', [])}")
raise
except Exception as e:
log_error_with_context(e, context)
raise DatasetServiceError(f"Failed to search in dataset: {str(e)}") from e