File size: 5,406 Bytes
f60a6c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from logging import getLogger
from pathlib import Path
from asyncio import Semaphore, gather, to_thread
from huggingface_hub import HfApi
from scorevision.utils.settings import get_settings
logger = getLogger(__name__)
def get_huggingface_repo_name() -> str:
settings = get_settings()
# nickname = settings.BITTENSOR_WALLET_HOT
return f"{settings.HUGGINGFACE_USERNAME}/ScoreVision" # -{nickname}"
def verify_huggingface_repo_name_exists(hf_api: HfApi) -> None:
name = get_huggingface_repo_name()
try:
info = hf_api.repo_info(repo_id=name, repo_type="model")
except Exception as e:
raise ValueError(
f"{e}.\n\nIf this is your first time, specify the path to the model to upload"
)
def verify_huggingface_repo_revision_exists(revision: str, hf_api: HfApi) -> None:
name = get_huggingface_repo_name()
info = hf_api.repo_info(repo_id=name, repo_type="model", revision=revision)
logger.info(f"Repo Info:{info}")
revision_ = getattr(info, "sha", None) or getattr(info, "oid", None)
if revision != revision_:
raise ValueError(
f"HF revision not accessible (gated/missing?): {revision_} != {revision}"
)
def get_paths_in_directory(path_dir: Path) -> list[Path]:
def is_hidden(path: Path) -> bool:
return any(
part.startswith(".") and part not in (".", "..") for part in path.parts
)
def is_lock(path: Path) -> bool:
return path.name.startswith(".") or path.name.endswith(".lock")
paths = []
for path in path_dir.rglob("*"):
if not path.is_file():
continue
if is_hidden(path=path):
continue
if is_lock(path=path):
continue
paths.append(path)
logger.info(f"{len(paths)} files found")
return paths
async def upload_file_to_huggingface_repo(
name: str, path_file: Path, path_dir: Path, semaphore: Semaphore, hf_api: HfApi
) -> None:
async with semaphore:
await to_thread(
lambda: hf_api.upload_file(
path_or_fileobj=str(path_file),
path_in_repo=str(path_file.relative_to(path_dir)),
repo_id=name,
repo_type="model",
commit_message="scorevision: push artifact",
)
)
async def upload_directory_to_huggingface_repo(path_dir: Path, hf_api: HfApi) -> None:
logger.info(f"Uploading {path_dir}")
settings = get_settings()
semaphore = Semaphore(settings.HUGGINGFACE_CONCURRENCY)
repo_name = get_huggingface_repo_name()
paths = get_paths_in_directory(path_dir=path_dir)
await gather(
*(
upload_file_to_huggingface_repo(
name=repo_name,
path_file=path,
path_dir=path_dir,
semaphore=semaphore,
hf_api=hf_api,
)
for path in paths
)
)
async def create_or_update_huggingface_repo(model_path: Path, hf_api: HfApi) -> None:
name = get_huggingface_repo_name()
hf_api.create_repo(repo_id=name, repo_type="model", private=True, exist_ok=True)
try:
hf_api.update_repo_visibility(repo_id=name, private=True)
except Exception as e:
logger.error(f"Error making hf repo private: {e}")
await upload_directory_to_huggingface_repo(path_dir=model_path, hf_api=hf_api)
async def get_huggingface_repo_revision(hf_api: HfApi) -> str:
name = get_huggingface_repo_name()
info = hf_api.repo_info(repo_id=name, repo_type="model")
revision = getattr(info, "sha", getattr(info, "oid", "")) or ""
logger.info(f"Detected revision: {revision}")
return revision
async def create_update_or_verify_huggingface_repo(
model_path: Path | None, hf_revision: str | None
) -> str:
"""
if model_path is provided, the huggingface repo will be created or updated (if it already exists)
if hf_revision is provided, the huggingface repo revision with be verified but not updated
if model_path and hf_revision are both not provided,
if a repo exists for the user, the latest revision will be used
otherwise: an error will be thrown asking the user to specify a path to a model for upload
"""
settings = get_settings()
if (
not settings.HUGGINGFACE_USERNAME
and not settings.HUGGINGFACE_API_KEY.get_secret_value()
):
raise ValueError("HUGGINGFACE_USERNAME/HUGGINGFACE_API_KEY required")
hf_api = HfApi(token=settings.HUGGINGFACE_API_KEY.get_secret_value())
if model_path:
logger.info(f"Creating/Updating repo")
await create_or_update_huggingface_repo(model_path=model_path, hf_api=hf_api)
else:
verify_huggingface_repo_name_exists(hf_api=hf_api)
logger.info(f"Using existing repo")
if hf_revision:
verify_huggingface_repo_revision_exists(revision=hf_revision, hf_api=hf_api)
logger.info(f"Using provided revision: {hf_revision}")
else:
hf_revision = await get_huggingface_repo_revision(hf_api=hf_api)
logger.info(f"Hf revision: {hf_revision}")
try:
hf_api.update_repo_settings(
repo_id=get_huggingface_repo_name(), repo_type="model", private=False
)
except Exception as e:
logger.error(f"Error making hf repo public: {e}")
pass
return hf_revision
|