import json import os import tempfile from typing import List, Dict, Any, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError # Configuration - uses same ledger repo as subscriptions LEDGER_REPO = os.getenv("LEDGER_DATASET_ID", "") REGISTRY_FILE = "datasets.json" HF_TOKEN = os.getenv("HF_TOKEN") # Fallback to local file if LEDGER_DATASET_ID not set (for local dev) LOCAL_REGISTRY_FILE = "datasets.json" # Initialize HF API api = HfApi(token=HF_TOKEN) if HF_TOKEN else None def _use_hf_storage() -> bool: """Check if we should use HF Dataset storage.""" return bool(LEDGER_REPO and HF_TOKEN and api) def _download_registry() -> Optional[str]: """Download current registry from HF Dataset.""" if not _use_hf_storage(): return None try: path = hf_hub_download( repo_id=LEDGER_REPO, filename=REGISTRY_FILE, repo_type="dataset", token=HF_TOKEN ) return path except EntryNotFoundError: # File doesn't exist yet in the dataset return None except Exception as e: print(f"Error downloading registry: {e}") return None def _upload_registry(local_path: str) -> bool: """Upload registry to HF Dataset.""" if not _use_hf_storage(): return False try: api.upload_file( path_or_fileobj=local_path, path_in_repo=REGISTRY_FILE, repo_id=LEDGER_REPO, repo_type="dataset", token=HF_TOKEN, commit_message=f"Update dataset registry" ) return True except Exception as e: print(f"Error uploading registry: {e}") return False def load_registry() -> List[Dict[str, Any]]: """Loads the dataset registry from HF Dataset or local file.""" if _use_hf_storage(): hf_path = _download_registry() if hf_path: try: with open(hf_path, "r") as f: return json.load(f) except json.JSONDecodeError: print(f"Error decoding {hf_path}") return [] # Fallback to local file if not os.path.exists(LOCAL_REGISTRY_FILE): return [] try: with open(LOCAL_REGISTRY_FILE, "r") as f: registry = json.load(f) return registry except json.JSONDecodeError: print(f"Error decoding {LOCAL_REGISTRY_FILE}") return [] def save_registry(registry: List[Dict[str, Any]]) -> bool: """Saves the dataset registry to HF Dataset or local file.""" if _use_hf_storage(): # Create temp file with registry content with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: tmp_path = tmp.name json.dump(registry, tmp, indent=2) # Upload to HF success = _upload_registry(tmp_path) # Clean up temp file try: os.unlink(tmp_path) except: pass return success else: # Local file storage with open(LOCAL_REGISTRY_FILE, "w") as f: json.dump(registry, f, indent=2) return True def get_dataset_by_id(dataset_id: str) -> Optional[Dict[str, Any]]: """Finds a dataset by its ID.""" registry = load_registry() for dataset in registry: if dataset.get("dataset_id") == dataset_id: return dataset return None def get_dataset_by_slug(slug: str) -> Optional[Dict[str, Any]]: """Finds a dataset by its slug.""" registry = load_registry() for dataset in registry: if dataset.get("slug") == slug: return dataset return None def get_plan_by_price_id(price_id: str) -> Optional[Dict[str, Any]]: """Finds a plan and its dataset by Stripe price ID.""" registry = load_registry() for dataset in registry: for plan in dataset.get("plans", []): if plan.get("stripe_price_id") == price_id: return {"dataset": dataset, "plan": plan} return None def get_free_plan(dataset_id: str) -> Optional[Dict[str, Any]]: """ Securely finds a free plan for a dataset. Returns the plan dict if found, None otherwise. """ dataset = get_dataset_by_id(dataset_id) if not dataset: return None # Explicitly check for free markers for plan in dataset.get("plans", []): if plan.get("stripe_price_id") in ["free", "0", 0]: return plan return None def detect_dataset_format(dataset_id: str) -> Dict[str, Any]: """ Detects the format and parquet path for a HuggingFace dataset. Returns info about the dataset including the correct parquet URL pattern. """ if not api: return { "dataset_id": dataset_id, "error": "HF API not initialized (HF_TOKEN not set)", "parquet_url_pattern": None } try: # Get dataset info from main branch info = api.dataset_info(dataset_id, token=HF_TOKEN) # Check for native parquet files in main branch parquet_paths = [] has_native_parquet = False for sibling in info.siblings or []: filename = sibling.rfilename if filename.endswith('.parquet'): parquet_paths.append(filename) has_native_parquet = True # Check for auto-converted parquet in refs/convert/parquet has_converted_parquet = False converted_parquet_paths = [] try: convert_info = api.dataset_info(dataset_id, token=HF_TOKEN, revision='refs/convert/parquet') for sibling in convert_info.siblings or []: filename = sibling.rfilename if filename.endswith('.parquet'): converted_parquet_paths.append(filename) has_converted_parquet = True except Exception: # refs/convert/parquet doesn't exist for this dataset pass # Determine the best parquet URL pattern if has_native_parquet: # Dataset has native parquet files in main branch parquet_url_pattern = f"hf://datasets/{dataset_id}/**/*.parquet" parquet_count = len(parquet_paths) elif has_converted_parquet: # Dataset was auto-converted, use refs/convert/parquet # Note: The revision path must be URL-encoded for DuckDB parquet_url_pattern = f"hf://datasets/{dataset_id}@refs%2Fconvert%2Fparquet/**/*.parquet" parquet_count = len(converted_parquet_paths) else: # No parquet files found parquet_url_pattern = None parquet_count = 0 return { "dataset_id": dataset_id, "has_native_parquet": has_native_parquet, "has_converted_parquet": has_converted_parquet, "parquet_url_pattern": parquet_url_pattern, "parquet_files_count": parquet_count, "card_data": info.card_data.__dict__ if info.card_data else None, } except Exception as e: return { "dataset_id": dataset_id, "error": str(e), "parquet_url_pattern": None } def get_parquet_url(dataset_id: str) -> str: """ Gets the best parquet URL pattern for a dataset. Checks registry first, then tries to detect automatically. """ # Check if dataset has a stored parquet_url_pattern in registry dataset = get_dataset_by_id(dataset_id) if dataset and dataset.get("parquet_url_pattern"): return dataset["parquet_url_pattern"] # Try to detect the format format_info = detect_dataset_format(dataset_id) if format_info.get("parquet_url_pattern"): return format_info["parquet_url_pattern"] # Fallback to standard pattern return f"hf://datasets/{dataset_id}/**/*.parquet"