File size: 4,771 Bytes
df4a21a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
Hugging Face Hub service for downloading model repositories.
"""
import os
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from app.core.config import settings
from app.core.errors import HuggingFaceDownloadError
from app.core.logging import get_logger
logger = get_logger(__name__)
# Disable symlink warnings on Windows
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
class HFHubService:
"""
Service for interacting with Hugging Face Hub.
Handles downloading model repositories and caching them locally.
"""
def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None):
"""
Initialize the HF Hub service.
Args:
cache_dir: Local directory for caching downloads.
Defaults to settings.HF_CACHE_DIR
token: Hugging Face API token for private repos.
Defaults to settings.HF_TOKEN
"""
self.cache_dir = cache_dir or settings.HF_CACHE_DIR
self.token = token or settings.HF_TOKEN
# Ensure cache directory exists
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
logger.info(f"HF Hub service initialized with cache dir: {self.cache_dir}")
def download_repo(
self,
repo_id: str,
revision: Optional[str] = None,
force_download: bool = False
) -> str:
"""
Download a repository from Hugging Face Hub.
Uses snapshot_download which handles caching automatically.
If the repo is already cached and not stale, it returns the cached path.
Args:
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
revision: Git revision (branch, tag, or commit hash). Defaults to "main"
force_download: If True, re-download even if cached
Returns:
Local path to the downloaded repository
Raises:
HuggingFaceDownloadError: If download fails
"""
logger.info(f"Downloading repo: {repo_id} (revision={revision}, force={force_download})")
try:
# Use local_dir instead of cache_dir to avoid symlink issues on Windows
repo_name = repo_id.replace("/", "--")
local_dir = Path(self.cache_dir) / repo_name
local_path = snapshot_download(
repo_id=repo_id,
revision=revision or "main",
local_dir=str(local_dir),
token=self.token,
force_download=force_download,
local_files_only=False
)
logger.info(f"Downloaded {repo_id} to {local_path}")
return local_path
except HfHubHTTPError as e:
logger.error(f"HTTP error downloading {repo_id}: {e}")
raise HuggingFaceDownloadError(
message=f"Failed to download repository: {repo_id}",
details={"repo_id": repo_id, "error": str(e)}
)
except Exception as e:
logger.error(f"Error downloading {repo_id}: {e}")
raise HuggingFaceDownloadError(
message=f"Failed to download repository: {repo_id}",
details={"repo_id": repo_id, "error": str(e)}
)
def get_cached_path(self, repo_id: str) -> Optional[str]:
"""
Get the cached path for a repository if it exists.
Args:
repo_id: Hugging Face repository ID
Returns:
Local path if cached, None otherwise
"""
# Check local_dir path format (used to avoid symlinks on Windows)
repo_name = repo_id.replace("/", "--")
local_dir = Path(self.cache_dir) / repo_name
if local_dir.exists() and any(local_dir.iterdir()):
return str(local_dir)
return None
def is_cached(self, repo_id: str) -> bool:
"""
Check if a repository is already cached.
Args:
repo_id: Hugging Face repository ID
Returns:
True if cached, False otherwise
"""
return self.get_cached_path(repo_id) is not None
# Global singleton instance
_hf_hub_service: Optional[HFHubService] = None
def get_hf_hub_service() -> HFHubService:
"""
Get the global HF Hub service instance.
Returns:
HFHubService instance
"""
global _hf_hub_service
if _hf_hub_service is None:
_hf_hub_service = HFHubService()
return _hf_hub_service
|