lukhsaankumar's picture
Deploy DeepFake Detector API - 2026-03-07 09:12:00
df4a21a
"""
Hugging Face Hub service for downloading model repositories.
"""
import os
from pathlib import Path
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError
from app.core.config import settings
from app.core.errors import HuggingFaceDownloadError
from app.core.logging import get_logger
logger = get_logger(__name__)
# Disable symlink warnings on Windows
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
class HFHubService:
"""
Service for interacting with Hugging Face Hub.
Handles downloading model repositories and caching them locally.
"""
def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None):
"""
Initialize the HF Hub service.
Args:
cache_dir: Local directory for caching downloads.
Defaults to settings.HF_CACHE_DIR
token: Hugging Face API token for private repos.
Defaults to settings.HF_TOKEN
"""
self.cache_dir = cache_dir or settings.HF_CACHE_DIR
self.token = token or settings.HF_TOKEN
# Ensure cache directory exists
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
logger.info(f"HF Hub service initialized with cache dir: {self.cache_dir}")
def download_repo(
self,
repo_id: str,
revision: Optional[str] = None,
force_download: bool = False
) -> str:
"""
Download a repository from Hugging Face Hub.
Uses snapshot_download which handles caching automatically.
If the repo is already cached and not stale, it returns the cached path.
Args:
repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
revision: Git revision (branch, tag, or commit hash). Defaults to "main"
force_download: If True, re-download even if cached
Returns:
Local path to the downloaded repository
Raises:
HuggingFaceDownloadError: If download fails
"""
logger.info(f"Downloading repo: {repo_id} (revision={revision}, force={force_download})")
try:
# Use local_dir instead of cache_dir to avoid symlink issues on Windows
repo_name = repo_id.replace("/", "--")
local_dir = Path(self.cache_dir) / repo_name
local_path = snapshot_download(
repo_id=repo_id,
revision=revision or "main",
local_dir=str(local_dir),
token=self.token,
force_download=force_download,
local_files_only=False
)
logger.info(f"Downloaded {repo_id} to {local_path}")
return local_path
except HfHubHTTPError as e:
logger.error(f"HTTP error downloading {repo_id}: {e}")
raise HuggingFaceDownloadError(
message=f"Failed to download repository: {repo_id}",
details={"repo_id": repo_id, "error": str(e)}
)
except Exception as e:
logger.error(f"Error downloading {repo_id}: {e}")
raise HuggingFaceDownloadError(
message=f"Failed to download repository: {repo_id}",
details={"repo_id": repo_id, "error": str(e)}
)
def get_cached_path(self, repo_id: str) -> Optional[str]:
"""
Get the cached path for a repository if it exists.
Args:
repo_id: Hugging Face repository ID
Returns:
Local path if cached, None otherwise
"""
# Check local_dir path format (used to avoid symlinks on Windows)
repo_name = repo_id.replace("/", "--")
local_dir = Path(self.cache_dir) / repo_name
if local_dir.exists() and any(local_dir.iterdir()):
return str(local_dir)
return None
def is_cached(self, repo_id: str) -> bool:
"""
Check if a repository is already cached.
Args:
repo_id: Hugging Face repository ID
Returns:
True if cached, False otherwise
"""
return self.get_cached_path(repo_id) is not None
# Global singleton instance
_hf_hub_service: Optional[HFHubService] = None
def get_hf_hub_service() -> HFHubService:
"""
Get the global HF Hub service instance.
Returns:
HFHubService instance
"""
global _hf_hub_service
if _hf_hub_service is None:
_hf_hub_service = HFHubService()
return _hf_hub_service