| """Hub utilities for uploading/downloading step data to HF Dataset repo.""" |
| import os |
| import logging |
| from pathlib import Path |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_tree |
|
|
| logger = logging.getLogger(__name__) |
|
|
| HF_DATASET_REPO_ID = "baenacoco/talking-head-avatar" |
|
|
|
|
| def _get_api(): |
| token = os.environ.get("HF_TOKEN") |
| if not token: |
| raise ValueError("HF_TOKEN no encontrado en variables de entorno") |
| api = HfApi(token=token) |
| api.create_repo(repo_id=HF_DATASET_REPO_ID, repo_type="dataset", exist_ok=True) |
| return api |
|
|
|
|
| def upload_step(name: str, step_folder: str, local_dir: str): |
| """Upload a local directory to {name}/{step_folder}/ in the dataset repo.""" |
| api = _get_api() |
| api.upload_folder( |
| folder_path=local_dir, |
| path_in_repo=f"{name}/{step_folder}", |
| repo_id=HF_DATASET_REPO_ID, |
| repo_type="dataset", |
| ) |
| logger.info(f"Uploaded {local_dir} -> {name}/{step_folder}") |
| return f"Subido a Hub: {name}/{step_folder}" |
|
|
|
|
| def download_step(name: str, step_folder: str, local_dir: str): |
| """Download {name}/{step_folder}/ from the dataset repo to a local directory.""" |
| from huggingface_hub import snapshot_download |
| token = os.environ.get("HF_TOKEN") |
| snapshot_download( |
| repo_id=HF_DATASET_REPO_ID, |
| repo_type="dataset", |
| local_dir=local_dir, |
| allow_patterns=[f"{name}/{step_folder}/**"], |
| token=token, |
| ) |
| logger.info(f"Downloaded {name}/{step_folder} -> {local_dir}") |
| return f"Descargado de Hub: {name}/{step_folder}" |
|
|
|
|
| def list_projects() -> list[str]: |
| """List project names (top-level folders) in the dataset repo.""" |
| token = os.environ.get("HF_TOKEN") |
| try: |
| api = HfApi(token=token) |
| entries = list(api.list_repo_tree( |
| repo_id=HF_DATASET_REPO_ID, repo_type="dataset", path_in_repo="", |
| )) |
| return sorted(set( |
| e.rfilename.split("/")[0] if hasattr(e, "rfilename") else e.path.split("/")[0] |
| for e in entries |
| if ("/" in getattr(e, "rfilename", "")) or hasattr(e, "path") |
| )) |
| except Exception as e: |
| logger.warning(f"Could not list projects: {e}") |
| return [] |
|
|