"""Hugging Face API client wrapper for dataset operations.""" import os from typing import Optional, List, Dict, Any from huggingface_hub import HfApi, list_datasets, DatasetCard from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names from dotenv import load_dotenv load_dotenv() class HFDatasetClient: """Client for interacting with Hugging Face datasets.""" def __init__(self, token: Optional[str] = None): self.token = token or os.getenv("HF_TOKEN") self.api = HfApi(token=self.token) def search_datasets( self, query: str, limit: int = 10, filter_task: Optional[str] = None, sort: str = "downloads" ) -> List[Dict[str, Any]]: """Search for datasets on Hugging Face Hub.""" datasets = list(list_datasets( search=query, limit=limit, sort=sort, task_categories=filter_task if filter_task else None )) return [ { "id": ds.id, "downloads": ds.downloads, "likes": ds.likes, "tags": ds.tags[:5] if ds.tags else [], "created_at": str(ds.created_at) if ds.created_at else None, } for ds in datasets ] def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]: """Get detailed information about a dataset.""" info = self.api.dataset_info(dataset_id) # Try to get the dataset card card_content = None try: card = DatasetCard.load(dataset_id) card_content = card.text[:2000] if card.text else None # Limit card size except Exception: pass return { "id": info.id, "author": info.author, "downloads": info.downloads, "likes": info.likes, "tags": info.tags, "license": getattr(info, 'license', None), "created_at": str(info.created_at) if info.created_at else None, "last_modified": str(info.last_modified) if info.last_modified else None, "card_summary": card_content, } def get_configs_and_splits(self, dataset_id: str) -> Dict[str, List[str]]: """Get available configs and splits for a dataset.""" try: configs = get_dataset_config_names(dataset_id, trust_remote_code=True) except Exception: configs = ["default"] result = {} for config in configs[:3]: # Limit to first 3 configs try: splits = get_dataset_split_names(dataset_id, config, trust_remote_code=True) result[config] = splits except Exception: result[config] = ["train"] return result def load_sample( self, dataset_id: str, config: Optional[str] = None, split: str = "train", n_rows: int = 5, streaming: bool = True ) -> List[Dict[str, Any]]: """Load a sample of rows from a dataset.""" try: ds = load_dataset( dataset_id, config, split=split, streaming=streaming, trust_remote_code=True ) if streaming: samples = [] for i, row in enumerate(ds): if i >= n_rows: break # Convert row to serializable format samples.append(self._serialize_row(row)) return samples else: return [self._serialize_row(row) for row in ds.select(range(min(n_rows, len(ds))))] except Exception as e: return [{"error": str(e)}] def get_schema(self, dataset_id: str, config: Optional[str] = None, split: str = "train") -> Dict[str, Any]: """Get the schema/features of a dataset.""" try: ds = load_dataset( dataset_id, config, split=split, streaming=True, trust_remote_code=True ) features = ds.features schema = {} for name, feature in features.items(): schema[name] = str(feature) return { "columns": list(features.keys()), "features": schema, "num_columns": len(features) } except Exception as e: return {"error": str(e)} def _serialize_row(self, row: Dict[str, Any]) -> Dict[str, Any]: """Convert a row to JSON-serializable format.""" result = {} for key, value in row.items(): if hasattr(value, 'tolist'): # numpy array result[key] = value.tolist() elif hasattr(value, '__dict__'): # PIL Image or similar result[key] = f"<{type(value).__name__}>" elif isinstance(value, bytes): result[key] = f"" else: result[key] = value return result # Singleton instance _client = None def get_client() -> HFDatasetClient: """Get or create the HF client singleton.""" global _client if _client is None: _client = HFDatasetClient() return _client