File size: 5,496 Bytes
0549051 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
from .. import constants
from ..file_download import repo_folder_name
from .sha import git_hash, sha_fileobj
if TYPE_CHECKING:
from ..hf_api import RepoFile, RepoFolder
# using fullmatch for clarity and strictness
_REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
# Typed structure describing a checksum mismatch
class Mismatch(TypedDict):
path: str
expected: str
actual: str
algorithm: str
HashAlgo = Literal["sha256", "git-sha1"]
@dataclass(frozen=True)
class FolderVerification:
revision: str
checked_count: int
mismatches: list[Mismatch]
missing_paths: list[str]
extra_paths: list[str]
verified_path: Path
def collect_local_files(root: Path) -> dict[str, Path]:
"""
Return a mapping of repo-relative path -> absolute path for all files under `root`.
"""
return {p.relative_to(root).as_posix(): p for p in root.rglob("*") if p.is_file()}
def _resolve_commit_hash_from_cache(storage_folder: Path, revision: Optional[str]) -> str:
"""
Resolve a commit hash from a cache repo folder and an optional revision.
"""
if revision and _REGEX_COMMIT_HASH.fullmatch(revision):
return revision
refs_dir = storage_folder / "refs"
snapshots_dir = storage_folder / "snapshots"
if revision:
ref_path = refs_dir / revision
if ref_path.is_file():
return ref_path.read_text(encoding="utf-8").strip()
raise ValueError(f"Revision '{revision}' could not be resolved in cache (expected file '{ref_path}').")
# No revision provided: try common defaults
main_ref = refs_dir / "main"
if main_ref.is_file():
return main_ref.read_text(encoding="utf-8").strip()
if not snapshots_dir.is_dir():
raise ValueError(f"Cache repo is missing snapshots directory: {snapshots_dir}. Provide --revision explicitly.")
candidates = [p.name for p in snapshots_dir.iterdir() if p.is_dir() and _REGEX_COMMIT_HASH.fullmatch(p.name)]
if len(candidates) == 1:
return candidates[0]
raise ValueError(
"Ambiguous cached revision: multiple snapshots found and no refs to disambiguate. Please pass --revision."
)
def compute_file_hash(path: Path, algorithm: HashAlgo) -> str:
"""
Compute the checksum of a local file using the requested algorithm.
"""
with path.open("rb") as stream:
if algorithm == "sha256":
return sha_fileobj(stream).hex()
if algorithm == "git-sha1":
return git_hash(stream.read())
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
def verify_maps(
*,
remote_by_path: dict[str, Union["RepoFile", "RepoFolder"]],
local_by_path: dict[str, Path],
revision: str,
verified_path: Path,
) -> FolderVerification:
"""Compare remote entries and local files and return a verification result."""
remote_paths = set(remote_by_path)
local_paths = set(local_by_path)
missing = sorted(remote_paths - local_paths)
extra = sorted(local_paths - remote_paths)
both = sorted(remote_paths & local_paths)
mismatches: list[Mismatch] = []
for rel_path in both:
remote_entry = remote_by_path[rel_path]
local_path = local_by_path[rel_path]
lfs = getattr(remote_entry, "lfs", None)
lfs_sha = getattr(lfs, "sha256", None) if lfs is not None else None
if lfs_sha is None and isinstance(lfs, dict):
lfs_sha = lfs.get("sha256")
if lfs_sha:
algorithm: HashAlgo = "sha256"
expected = str(lfs_sha).lower()
else:
blob_id = remote_entry.blob_id # type: ignore
algorithm = "git-sha1"
expected = str(blob_id).lower()
actual = compute_file_hash(local_path, algorithm)
if actual != expected:
mismatches.append(Mismatch(path=rel_path, expected=expected, actual=actual, algorithm=algorithm))
return FolderVerification(
revision=revision,
checked_count=len(both),
mismatches=mismatches,
missing_paths=missing,
extra_paths=extra,
verified_path=verified_path,
)
def resolve_local_root(
*,
repo_id: str,
repo_type: str,
revision: Optional[str],
cache_dir: Optional[Path],
local_dir: Optional[Path],
) -> tuple[Path, str]:
"""
Resolve the root directory to scan locally and the remote revision to verify.
"""
if local_dir is not None:
root = Path(local_dir).expanduser().resolve()
if not root.is_dir():
raise ValueError(f"Local directory does not exist or is not a directory: {root}")
return root, (revision or constants.DEFAULT_REVISION)
cache_root = Path(cache_dir or constants.HF_HUB_CACHE).expanduser().resolve()
storage_folder = cache_root / repo_folder_name(repo_id=repo_id, repo_type=repo_type)
if not storage_folder.exists():
raise ValueError(
f"Repo is not present in cache: {storage_folder}. Use 'hf download' first or pass --local-dir."
)
commit = _resolve_commit_hash_from_cache(storage_folder, revision)
snapshot_dir = storage_folder / "snapshots" / commit
if not snapshot_dir.is_dir():
raise ValueError(f"Snapshot directory does not exist for revision '{commit}': {snapshot_dir}.")
return snapshot_dir, commit
|