File size: 2,224 Bytes
b64777f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Hub utilities for uploading/downloading step data to HF Dataset repo."""
import os
import logging
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_tree

logger = logging.getLogger(__name__)

HF_DATASET_REPO_ID = "baenacoco/talking-head-avatar"


def _get_api():
    token = os.environ.get("HF_TOKEN")
    if not token:
        raise ValueError("HF_TOKEN no encontrado en variables de entorno")
    api = HfApi(token=token)
    api.create_repo(repo_id=HF_DATASET_REPO_ID, repo_type="dataset", exist_ok=True)
    return api


def upload_step(name: str, step_folder: str, local_dir: str):
    """Upload a local directory to {name}/{step_folder}/ in the dataset repo."""
    api = _get_api()
    api.upload_folder(
        folder_path=local_dir,
        path_in_repo=f"{name}/{step_folder}",
        repo_id=HF_DATASET_REPO_ID,
        repo_type="dataset",
    )
    logger.info(f"Uploaded {local_dir} -> {name}/{step_folder}")
    return f"Subido a Hub: {name}/{step_folder}"


def download_step(name: str, step_folder: str, local_dir: str):
    """Download {name}/{step_folder}/ from the dataset repo to a local directory."""
    from huggingface_hub import snapshot_download
    token = os.environ.get("HF_TOKEN")
    snapshot_download(
        repo_id=HF_DATASET_REPO_ID,
        repo_type="dataset",
        local_dir=local_dir,
        allow_patterns=[f"{name}/{step_folder}/**"],
        token=token,
    )
    logger.info(f"Downloaded {name}/{step_folder} -> {local_dir}")
    return f"Descargado de Hub: {name}/{step_folder}"


def list_projects() -> list[str]:
    """List project names (top-level folders) in the dataset repo."""
    token = os.environ.get("HF_TOKEN")
    try:
        api = HfApi(token=token)
        entries = list(api.list_repo_tree(
            repo_id=HF_DATASET_REPO_ID, repo_type="dataset", path_in_repo="",
        ))
        return sorted(set(
            e.rfilename.split("/")[0] if hasattr(e, "rfilename") else e.path.split("/")[0]
            for e in entries
            if ("/" in getattr(e, "rfilename", "")) or hasattr(e, "path")
        ))
    except Exception as e:
        logger.warning(f"Could not list projects: {e}")
        return []