Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import tempfile | |
| from typing import List, Dict, Any, Optional | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.utils import EntryNotFoundError | |
| # Configuration - uses same ledger repo as subscriptions | |
| LEDGER_REPO = os.getenv("LEDGER_DATASET_ID", "") | |
| REGISTRY_FILE = "datasets.json" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Fallback to local file if LEDGER_DATASET_ID not set (for local dev) | |
| LOCAL_REGISTRY_FILE = "datasets.json" | |
| # Initialize HF API | |
| api = HfApi(token=HF_TOKEN) if HF_TOKEN else None | |
| def _use_hf_storage() -> bool: | |
| """Check if we should use HF Dataset storage.""" | |
| return bool(LEDGER_REPO and HF_TOKEN and api) | |
| def _download_registry() -> Optional[str]: | |
| """Download current registry from HF Dataset.""" | |
| if not _use_hf_storage(): | |
| return None | |
| try: | |
| path = hf_hub_download( | |
| repo_id=LEDGER_REPO, | |
| filename=REGISTRY_FILE, | |
| repo_type="dataset", | |
| token=HF_TOKEN | |
| ) | |
| return path | |
| except EntryNotFoundError: | |
| # File doesn't exist yet in the dataset | |
| return None | |
| except Exception as e: | |
| print(f"Error downloading registry: {e}") | |
| return None | |
| def _upload_registry(local_path: str) -> bool: | |
| """Upload registry to HF Dataset.""" | |
| if not _use_hf_storage(): | |
| return False | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=REGISTRY_FILE, | |
| repo_id=LEDGER_REPO, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message=f"Update dataset registry" | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"Error uploading registry: {e}") | |
| return False | |
| def load_registry() -> List[Dict[str, Any]]: | |
| """Loads the dataset registry from HF Dataset or local file.""" | |
| if _use_hf_storage(): | |
| hf_path = _download_registry() | |
| if hf_path: | |
| try: | |
| with open(hf_path, "r") as f: | |
| return json.load(f) | |
| except json.JSONDecodeError: | |
| print(f"Error decoding {hf_path}") | |
| return [] | |
| # Fallback to local file | |
| if not os.path.exists(LOCAL_REGISTRY_FILE): | |
| return [] | |
| try: | |
| with open(LOCAL_REGISTRY_FILE, "r") as f: | |
| registry = json.load(f) | |
| return registry | |
| except json.JSONDecodeError: | |
| print(f"Error decoding {LOCAL_REGISTRY_FILE}") | |
| return [] | |
| def save_registry(registry: List[Dict[str, Any]]) -> bool: | |
| """Saves the dataset registry to HF Dataset or local file.""" | |
| if _use_hf_storage(): | |
| # Create temp file with registry content | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: | |
| tmp_path = tmp.name | |
| json.dump(registry, tmp, indent=2) | |
| # Upload to HF | |
| success = _upload_registry(tmp_path) | |
| # Clean up temp file | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| return success | |
| else: | |
| # Local file storage | |
| with open(LOCAL_REGISTRY_FILE, "w") as f: | |
| json.dump(registry, f, indent=2) | |
| return True | |
| def get_dataset_by_id(dataset_id: str) -> Optional[Dict[str, Any]]: | |
| """Finds a dataset by its ID.""" | |
| registry = load_registry() | |
| for dataset in registry: | |
| if dataset.get("dataset_id") == dataset_id: | |
| return dataset | |
| return None | |
| def get_dataset_by_slug(slug: str) -> Optional[Dict[str, Any]]: | |
| """Finds a dataset by its slug.""" | |
| registry = load_registry() | |
| for dataset in registry: | |
| if dataset.get("slug") == slug: | |
| return dataset | |
| return None | |
| def get_plan_by_price_id(price_id: str) -> Optional[Dict[str, Any]]: | |
| """Finds a plan and its dataset by Stripe price ID.""" | |
| registry = load_registry() | |
| for dataset in registry: | |
| for plan in dataset.get("plans", []): | |
| if plan.get("stripe_price_id") == price_id: | |
| return {"dataset": dataset, "plan": plan} | |
| return None | |
| def get_free_plan(dataset_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Securely finds a free plan for a dataset. | |
| Returns the plan dict if found, None otherwise. | |
| """ | |
| dataset = get_dataset_by_id(dataset_id) | |
| if not dataset: | |
| return None | |
| # Explicitly check for free markers | |
| for plan in dataset.get("plans", []): | |
| if plan.get("stripe_price_id") in ["free", "0", 0]: | |
| return plan | |
| return None | |
| def detect_dataset_format(dataset_id: str) -> Dict[str, Any]: | |
| """ | |
| Detects the format and parquet path for a HuggingFace dataset. | |
| Returns info about the dataset including the correct parquet URL pattern. | |
| """ | |
| if not api: | |
| return { | |
| "dataset_id": dataset_id, | |
| "error": "HF API not initialized (HF_TOKEN not set)", | |
| "parquet_url_pattern": None | |
| } | |
| try: | |
| # Get dataset info from main branch | |
| info = api.dataset_info(dataset_id, token=HF_TOKEN) | |
| # Check for native parquet files in main branch | |
| parquet_paths = [] | |
| has_native_parquet = False | |
| for sibling in info.siblings or []: | |
| filename = sibling.rfilename | |
| if filename.endswith('.parquet'): | |
| parquet_paths.append(filename) | |
| has_native_parquet = True | |
| # Check for auto-converted parquet in refs/convert/parquet | |
| has_converted_parquet = False | |
| converted_parquet_paths = [] | |
| try: | |
| convert_info = api.dataset_info(dataset_id, token=HF_TOKEN, revision='refs/convert/parquet') | |
| for sibling in convert_info.siblings or []: | |
| filename = sibling.rfilename | |
| if filename.endswith('.parquet'): | |
| converted_parquet_paths.append(filename) | |
| has_converted_parquet = True | |
| except Exception: | |
| # refs/convert/parquet doesn't exist for this dataset | |
| pass | |
| # Determine the best parquet URL pattern | |
| if has_native_parquet: | |
| # Dataset has native parquet files in main branch | |
| parquet_url_pattern = f"hf://datasets/{dataset_id}/**/*.parquet" | |
| parquet_count = len(parquet_paths) | |
| elif has_converted_parquet: | |
| # Dataset was auto-converted, use refs/convert/parquet | |
| # Note: The revision path must be URL-encoded for DuckDB | |
| parquet_url_pattern = f"hf://datasets/{dataset_id}@refs%2Fconvert%2Fparquet/**/*.parquet" | |
| parquet_count = len(converted_parquet_paths) | |
| else: | |
| # No parquet files found | |
| parquet_url_pattern = None | |
| parquet_count = 0 | |
| return { | |
| "dataset_id": dataset_id, | |
| "has_native_parquet": has_native_parquet, | |
| "has_converted_parquet": has_converted_parquet, | |
| "parquet_url_pattern": parquet_url_pattern, | |
| "parquet_files_count": parquet_count, | |
| "card_data": info.card_data.__dict__ if info.card_data else None, | |
| } | |
| except Exception as e: | |
| return { | |
| "dataset_id": dataset_id, | |
| "error": str(e), | |
| "parquet_url_pattern": None | |
| } | |
| def get_parquet_url(dataset_id: str) -> str: | |
| """ | |
| Gets the best parquet URL pattern for a dataset. | |
| Checks registry first, then tries to detect automatically. | |
| """ | |
| # Check if dataset has a stored parquet_url_pattern in registry | |
| dataset = get_dataset_by_id(dataset_id) | |
| if dataset and dataset.get("parquet_url_pattern"): | |
| return dataset["parquet_url_pattern"] | |
| # Try to detect the format | |
| format_info = detect_dataset_format(dataset_id) | |
| if format_info.get("parquet_url_pattern"): | |
| return format_info["parquet_url_pattern"] | |
| # Fallback to standard pattern | |
| return f"hf://datasets/{dataset_id}/**/*.parquet" | |