dataview-mcp / utils /hf_client.py
efecelik's picture
Initial release: DataView MCP - HuggingFace Dataset Explorer
b67578f
"""Hugging Face API client wrapper for dataset operations."""
import os
from typing import Optional, List, Dict, Any
from huggingface_hub import HfApi, list_datasets, DatasetCard
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
from dotenv import load_dotenv
load_dotenv()
class HFDatasetClient:
"""Client for interacting with Hugging Face datasets."""
def __init__(self, token: Optional[str] = None):
self.token = token or os.getenv("HF_TOKEN")
self.api = HfApi(token=self.token)
def search_datasets(
self,
query: str,
limit: int = 10,
filter_task: Optional[str] = None,
sort: str = "downloads"
) -> List[Dict[str, Any]]:
"""Search for datasets on Hugging Face Hub."""
datasets = list(list_datasets(
search=query,
limit=limit,
sort=sort,
task_categories=filter_task if filter_task else None
))
return [
{
"id": ds.id,
"downloads": ds.downloads,
"likes": ds.likes,
"tags": ds.tags[:5] if ds.tags else [],
"created_at": str(ds.created_at) if ds.created_at else None,
}
for ds in datasets
]
def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
"""Get detailed information about a dataset."""
info = self.api.dataset_info(dataset_id)
# Try to get the dataset card
card_content = None
try:
card = DatasetCard.load(dataset_id)
card_content = card.text[:2000] if card.text else None # Limit card size
except Exception:
pass
return {
"id": info.id,
"author": info.author,
"downloads": info.downloads,
"likes": info.likes,
"tags": info.tags,
"license": getattr(info, 'license', None),
"created_at": str(info.created_at) if info.created_at else None,
"last_modified": str(info.last_modified) if info.last_modified else None,
"card_summary": card_content,
}
def get_configs_and_splits(self, dataset_id: str) -> Dict[str, List[str]]:
"""Get available configs and splits for a dataset."""
try:
configs = get_dataset_config_names(dataset_id, trust_remote_code=True)
except Exception:
configs = ["default"]
result = {}
for config in configs[:3]: # Limit to first 3 configs
try:
splits = get_dataset_split_names(dataset_id, config, trust_remote_code=True)
result[config] = splits
except Exception:
result[config] = ["train"]
return result
def load_sample(
self,
dataset_id: str,
config: Optional[str] = None,
split: str = "train",
n_rows: int = 5,
streaming: bool = True
) -> List[Dict[str, Any]]:
"""Load a sample of rows from a dataset."""
try:
ds = load_dataset(
dataset_id,
config,
split=split,
streaming=streaming,
trust_remote_code=True
)
if streaming:
samples = []
for i, row in enumerate(ds):
if i >= n_rows:
break
# Convert row to serializable format
samples.append(self._serialize_row(row))
return samples
else:
return [self._serialize_row(row) for row in ds.select(range(min(n_rows, len(ds))))]
except Exception as e:
return [{"error": str(e)}]
def get_schema(self, dataset_id: str, config: Optional[str] = None, split: str = "train") -> Dict[str, Any]:
"""Get the schema/features of a dataset."""
try:
ds = load_dataset(
dataset_id,
config,
split=split,
streaming=True,
trust_remote_code=True
)
features = ds.features
schema = {}
for name, feature in features.items():
schema[name] = str(feature)
return {
"columns": list(features.keys()),
"features": schema,
"num_columns": len(features)
}
except Exception as e:
return {"error": str(e)}
def _serialize_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a row to JSON-serializable format."""
result = {}
for key, value in row.items():
if hasattr(value, 'tolist'): # numpy array
result[key] = value.tolist()
elif hasattr(value, '__dict__'): # PIL Image or similar
result[key] = f"<{type(value).__name__}>"
elif isinstance(value, bytes):
result[key] = f"<bytes: {len(value)} bytes>"
else:
result[key] = value
return result
# Singleton instance
_client = None
def get_client() -> HFDatasetClient:
"""Get or create the HF client singleton."""
global _client
if _client is None:
_client = HFDatasetClient()
return _client