datapass-server / datasets_registry.py
waroca's picture
Upload folder using huggingface_hub
f1b8a40 verified
import json
import os
import tempfile
from typing import List, Dict, Any, Optional
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
# Configuration - uses same ledger repo as subscriptions
LEDGER_REPO = os.getenv("LEDGER_DATASET_ID", "")
REGISTRY_FILE = "datasets.json"
HF_TOKEN = os.getenv("HF_TOKEN")
# Fallback to local file if LEDGER_DATASET_ID not set (for local dev)
LOCAL_REGISTRY_FILE = "datasets.json"
# Initialize HF API
api = HfApi(token=HF_TOKEN) if HF_TOKEN else None
def _use_hf_storage() -> bool:
"""Check if we should use HF Dataset storage."""
return bool(LEDGER_REPO and HF_TOKEN and api)
def _download_registry() -> Optional[str]:
"""Download current registry from HF Dataset."""
if not _use_hf_storage():
return None
try:
path = hf_hub_download(
repo_id=LEDGER_REPO,
filename=REGISTRY_FILE,
repo_type="dataset",
token=HF_TOKEN
)
return path
except EntryNotFoundError:
# File doesn't exist yet in the dataset
return None
except Exception as e:
print(f"Error downloading registry: {e}")
return None
def _upload_registry(local_path: str) -> bool:
"""Upload registry to HF Dataset."""
if not _use_hf_storage():
return False
try:
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=REGISTRY_FILE,
repo_id=LEDGER_REPO,
repo_type="dataset",
token=HF_TOKEN,
commit_message=f"Update dataset registry"
)
return True
except Exception as e:
print(f"Error uploading registry: {e}")
return False
def load_registry() -> List[Dict[str, Any]]:
"""Loads the dataset registry from HF Dataset or local file."""
if _use_hf_storage():
hf_path = _download_registry()
if hf_path:
try:
with open(hf_path, "r") as f:
return json.load(f)
except json.JSONDecodeError:
print(f"Error decoding {hf_path}")
return []
# Fallback to local file
if not os.path.exists(LOCAL_REGISTRY_FILE):
return []
try:
with open(LOCAL_REGISTRY_FILE, "r") as f:
registry = json.load(f)
return registry
except json.JSONDecodeError:
print(f"Error decoding {LOCAL_REGISTRY_FILE}")
return []
def save_registry(registry: List[Dict[str, Any]]) -> bool:
"""Saves the dataset registry to HF Dataset or local file."""
if _use_hf_storage():
# Create temp file with registry content
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
tmp_path = tmp.name
json.dump(registry, tmp, indent=2)
# Upload to HF
success = _upload_registry(tmp_path)
# Clean up temp file
try:
os.unlink(tmp_path)
except:
pass
return success
else:
# Local file storage
with open(LOCAL_REGISTRY_FILE, "w") as f:
json.dump(registry, f, indent=2)
return True
def get_dataset_by_id(dataset_id: str) -> Optional[Dict[str, Any]]:
"""Finds a dataset by its ID."""
registry = load_registry()
for dataset in registry:
if dataset.get("dataset_id") == dataset_id:
return dataset
return None
def get_dataset_by_slug(slug: str) -> Optional[Dict[str, Any]]:
"""Finds a dataset by its slug."""
registry = load_registry()
for dataset in registry:
if dataset.get("slug") == slug:
return dataset
return None
def get_plan_by_price_id(price_id: str) -> Optional[Dict[str, Any]]:
"""Finds a plan and its dataset by Stripe price ID."""
registry = load_registry()
for dataset in registry:
for plan in dataset.get("plans", []):
if plan.get("stripe_price_id") == price_id:
return {"dataset": dataset, "plan": plan}
return None
def get_free_plan(dataset_id: str) -> Optional[Dict[str, Any]]:
"""
Securely finds a free plan for a dataset.
Returns the plan dict if found, None otherwise.
"""
dataset = get_dataset_by_id(dataset_id)
if not dataset:
return None
# Explicitly check for free markers
for plan in dataset.get("plans", []):
if plan.get("stripe_price_id") in ["free", "0", 0]:
return plan
return None
def detect_dataset_format(dataset_id: str) -> Dict[str, Any]:
"""
Detects the format and parquet path for a HuggingFace dataset.
Returns info about the dataset including the correct parquet URL pattern.
"""
if not api:
return {
"dataset_id": dataset_id,
"error": "HF API not initialized (HF_TOKEN not set)",
"parquet_url_pattern": None
}
try:
# Get dataset info from main branch
info = api.dataset_info(dataset_id, token=HF_TOKEN)
# Check for native parquet files in main branch
parquet_paths = []
has_native_parquet = False
for sibling in info.siblings or []:
filename = sibling.rfilename
if filename.endswith('.parquet'):
parquet_paths.append(filename)
has_native_parquet = True
# Check for auto-converted parquet in refs/convert/parquet
has_converted_parquet = False
converted_parquet_paths = []
try:
convert_info = api.dataset_info(dataset_id, token=HF_TOKEN, revision='refs/convert/parquet')
for sibling in convert_info.siblings or []:
filename = sibling.rfilename
if filename.endswith('.parquet'):
converted_parquet_paths.append(filename)
has_converted_parquet = True
except Exception:
# refs/convert/parquet doesn't exist for this dataset
pass
# Determine the best parquet URL pattern
if has_native_parquet:
# Dataset has native parquet files in main branch
parquet_url_pattern = f"hf://datasets/{dataset_id}/**/*.parquet"
parquet_count = len(parquet_paths)
elif has_converted_parquet:
# Dataset was auto-converted, use refs/convert/parquet
# Note: The revision path must be URL-encoded for DuckDB
parquet_url_pattern = f"hf://datasets/{dataset_id}@refs%2Fconvert%2Fparquet/**/*.parquet"
parquet_count = len(converted_parquet_paths)
else:
# No parquet files found
parquet_url_pattern = None
parquet_count = 0
return {
"dataset_id": dataset_id,
"has_native_parquet": has_native_parquet,
"has_converted_parquet": has_converted_parquet,
"parquet_url_pattern": parquet_url_pattern,
"parquet_files_count": parquet_count,
"card_data": info.card_data.__dict__ if info.card_data else None,
}
except Exception as e:
return {
"dataset_id": dataset_id,
"error": str(e),
"parquet_url_pattern": None
}
def get_parquet_url(dataset_id: str) -> str:
"""
Gets the best parquet URL pattern for a dataset.
Checks registry first, then tries to detect automatically.
"""
# Check if dataset has a stored parquet_url_pattern in registry
dataset = get_dataset_by_id(dataset_id)
if dataset and dataset.get("parquet_url_pattern"):
return dataset["parquet_url_pattern"]
# Try to detect the format
format_info = detect_dataset_format(dataset_id)
if format_info.get("parquet_url_pattern"):
return format_info["parquet_url_pattern"]
# Fallback to standard pattern
return f"hf://datasets/{dataset_id}/**/*.parquet"