Spaces:
Running
Running
| """Hugging Face API client wrapper for dataset operations.""" | |
| import os | |
| from typing import Optional, List, Dict, Any | |
| from huggingface_hub import HfApi, list_datasets, DatasetCard | |
| from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class HFDatasetClient: | |
| """Client for interacting with Hugging Face datasets.""" | |
| def __init__(self, token: Optional[str] = None): | |
| self.token = token or os.getenv("HF_TOKEN") | |
| self.api = HfApi(token=self.token) | |
| def search_datasets( | |
| self, | |
| query: str, | |
| limit: int = 10, | |
| filter_task: Optional[str] = None, | |
| sort: str = "downloads" | |
| ) -> List[Dict[str, Any]]: | |
| """Search for datasets on Hugging Face Hub.""" | |
| datasets = list(list_datasets( | |
| search=query, | |
| limit=limit, | |
| sort=sort, | |
| task_categories=filter_task if filter_task else None | |
| )) | |
| return [ | |
| { | |
| "id": ds.id, | |
| "downloads": ds.downloads, | |
| "likes": ds.likes, | |
| "tags": ds.tags[:5] if ds.tags else [], | |
| "created_at": str(ds.created_at) if ds.created_at else None, | |
| } | |
| for ds in datasets | |
| ] | |
| def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]: | |
| """Get detailed information about a dataset.""" | |
| info = self.api.dataset_info(dataset_id) | |
| # Try to get the dataset card | |
| card_content = None | |
| try: | |
| card = DatasetCard.load(dataset_id) | |
| card_content = card.text[:2000] if card.text else None # Limit card size | |
| except Exception: | |
| pass | |
| return { | |
| "id": info.id, | |
| "author": info.author, | |
| "downloads": info.downloads, | |
| "likes": info.likes, | |
| "tags": info.tags, | |
| "license": getattr(info, 'license', None), | |
| "created_at": str(info.created_at) if info.created_at else None, | |
| "last_modified": str(info.last_modified) if info.last_modified else None, | |
| "card_summary": card_content, | |
| } | |
| def get_configs_and_splits(self, dataset_id: str) -> Dict[str, List[str]]: | |
| """Get available configs and splits for a dataset.""" | |
| try: | |
| configs = get_dataset_config_names(dataset_id, trust_remote_code=True) | |
| except Exception: | |
| configs = ["default"] | |
| result = {} | |
| for config in configs[:3]: # Limit to first 3 configs | |
| try: | |
| splits = get_dataset_split_names(dataset_id, config, trust_remote_code=True) | |
| result[config] = splits | |
| except Exception: | |
| result[config] = ["train"] | |
| return result | |
| def load_sample( | |
| self, | |
| dataset_id: str, | |
| config: Optional[str] = None, | |
| split: str = "train", | |
| n_rows: int = 5, | |
| streaming: bool = True | |
| ) -> List[Dict[str, Any]]: | |
| """Load a sample of rows from a dataset.""" | |
| try: | |
| ds = load_dataset( | |
| dataset_id, | |
| config, | |
| split=split, | |
| streaming=streaming, | |
| trust_remote_code=True | |
| ) | |
| if streaming: | |
| samples = [] | |
| for i, row in enumerate(ds): | |
| if i >= n_rows: | |
| break | |
| # Convert row to serializable format | |
| samples.append(self._serialize_row(row)) | |
| return samples | |
| else: | |
| return [self._serialize_row(row) for row in ds.select(range(min(n_rows, len(ds))))] | |
| except Exception as e: | |
| return [{"error": str(e)}] | |
| def get_schema(self, dataset_id: str, config: Optional[str] = None, split: str = "train") -> Dict[str, Any]: | |
| """Get the schema/features of a dataset.""" | |
| try: | |
| ds = load_dataset( | |
| dataset_id, | |
| config, | |
| split=split, | |
| streaming=True, | |
| trust_remote_code=True | |
| ) | |
| features = ds.features | |
| schema = {} | |
| for name, feature in features.items(): | |
| schema[name] = str(feature) | |
| return { | |
| "columns": list(features.keys()), | |
| "features": schema, | |
| "num_columns": len(features) | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def _serialize_row(self, row: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert a row to JSON-serializable format.""" | |
| result = {} | |
| for key, value in row.items(): | |
| if hasattr(value, 'tolist'): # numpy array | |
| result[key] = value.tolist() | |
| elif hasattr(value, '__dict__'): # PIL Image or similar | |
| result[key] = f"<{type(value).__name__}>" | |
| elif isinstance(value, bytes): | |
| result[key] = f"<bytes: {len(value)} bytes>" | |
| else: | |
| result[key] = value | |
| return result | |
| # Singleton instance | |
| _client = None | |
| def get_client() -> HFDatasetClient: | |
| """Get or create the HF client singleton.""" | |
| global _client | |
| if _client is None: | |
| _client = HFDatasetClient() | |
| return _client | |