File size: 4,771 Bytes
df4a21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Hugging Face Hub service for downloading model repositories.
"""

import os
from pathlib import Path
from typing import Optional

from huggingface_hub import snapshot_download
from huggingface_hub.utils import HfHubHTTPError

from app.core.config import settings
from app.core.errors import HuggingFaceDownloadError
from app.core.logging import get_logger

logger = get_logger(__name__)

# Disable symlink warnings on Windows
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"


class HFHubService:
    """
    Service for interacting with Hugging Face Hub.
    
    Handles downloading model repositories and caching them locally.
    """
    
    def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None):
        """
        Initialize the HF Hub service.
        
        Args:
            cache_dir: Local directory for caching downloads. 
                      Defaults to settings.HF_CACHE_DIR
            token: Hugging Face API token for private repos.
                  Defaults to settings.HF_TOKEN
        """
        self.cache_dir = cache_dir or settings.HF_CACHE_DIR
        self.token = token or settings.HF_TOKEN
        
        # Ensure cache directory exists
        Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
        logger.info(f"HF Hub service initialized with cache dir: {self.cache_dir}")
    
    def download_repo(
        self,
        repo_id: str,
        revision: Optional[str] = None,
        force_download: bool = False
    ) -> str:
        """
        Download a repository from Hugging Face Hub.
        
        Uses snapshot_download which handles caching automatically.
        If the repo is already cached and not stale, it returns the cached path.
        
        Args:
            repo_id: Hugging Face repository ID (e.g., "DeepFakeDetector/test-random-a")
            revision: Git revision (branch, tag, or commit hash). Defaults to "main"
            force_download: If True, re-download even if cached
            
        Returns:
            Local path to the downloaded repository
            
        Raises:
            HuggingFaceDownloadError: If download fails
        """
        logger.info(f"Downloading repo: {repo_id} (revision={revision}, force={force_download})")
        
        try:
            # Use local_dir instead of cache_dir to avoid symlink issues on Windows
            repo_name = repo_id.replace("/", "--")
            local_dir = Path(self.cache_dir) / repo_name
            
            local_path = snapshot_download(
                repo_id=repo_id,
                revision=revision or "main",
                local_dir=str(local_dir),
                token=self.token,
                force_download=force_download,
                local_files_only=False
            )
            
            logger.info(f"Downloaded {repo_id} to {local_path}")
            return local_path
            
        except HfHubHTTPError as e:
            logger.error(f"HTTP error downloading {repo_id}: {e}")
            raise HuggingFaceDownloadError(
                message=f"Failed to download repository: {repo_id}",
                details={"repo_id": repo_id, "error": str(e)}
            )
        except Exception as e:
            logger.error(f"Error downloading {repo_id}: {e}")
            raise HuggingFaceDownloadError(
                message=f"Failed to download repository: {repo_id}",
                details={"repo_id": repo_id, "error": str(e)}
            )
    
    def get_cached_path(self, repo_id: str) -> Optional[str]:
        """
        Get the cached path for a repository if it exists.
        
        Args:
            repo_id: Hugging Face repository ID
            
        Returns:
            Local path if cached, None otherwise
        """
        # Check local_dir path format (used to avoid symlinks on Windows)
        repo_name = repo_id.replace("/", "--")
        local_dir = Path(self.cache_dir) / repo_name
        
        if local_dir.exists() and any(local_dir.iterdir()):
            return str(local_dir)
        return None
    
    def is_cached(self, repo_id: str) -> bool:
        """
        Check if a repository is already cached.
        
        Args:
            repo_id: Hugging Face repository ID
            
        Returns:
            True if cached, False otherwise
        """
        return self.get_cached_path(repo_id) is not None


# Global singleton instance
_hf_hub_service: Optional[HFHubService] = None


def get_hf_hub_service() -> HFHubService:
    """
    Get the global HF Hub service instance.
    
    Returns:
        HFHubService instance
    """
    global _hf_hub_service
    if _hf_hub_service is None:
        _hf_hub_service = HFHubService()
    return _hf_hub_service