File size: 19,807 Bytes
f6e3d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import os
from huggingface_hub import HfApi, HfFolder, Repository, create_repo, get_full_repo_name
from huggingface_hub.utils import HfHubHTTPError
import logging
from pathlib import Path
import json # Added for example usage

logger = logging.getLogger(__name__)

class HuggingFaceWrapper:
    """
    A wrapper for interacting with the Hugging Face Hub API.
    Handles authentication, model and dataset uploads/downloads,
    and repository creation.
    """
    def __init__(self, token: str | None = None, default_repo_prefix: str = "museum-sexoskop"):
        """
        Initializes the HuggingFaceWrapper.

        Args:
            token: Your Hugging Face API token. If None, it will try to use
                   a token saved locally via `huggingface-cli login`.
            default_repo_prefix: A default prefix for repository names.
        """
        self.api = HfApi()
        if token:
            self.token = token
            # Note: HfApi uses the token from HfFolder by default if logged in.
            # To explicitly use a provided token for all operations,
            # some HfApi methods accept it directly.
            # For operations like Repository, ensure the environment or HfFolder is set.
            HfFolder.save_token(token)
            logger.info("Hugging Face token saved for the session.")
        else:
            self.token = HfFolder.get_token()
            if not self.token:
                logger.warning("No Hugging Face token provided or found locally. "
                               "Please login using `huggingface-cli login` or provide a token.")
            else:
                logger.info("Using locally saved Hugging Face token.")
        
        self.default_repo_prefix = default_repo_prefix

    def _get_full_repo_id(self, repo_name: str, repo_type: str | None = None) -> str:
        """Helper to construct full repo ID, ensuring it includes the username/org."""
        # If repo_name already contains a slash, it's likely a full ID (user/repo or org/repo)
        if "/" in repo_name:
            # Further check: if it doesn't have the prefix and prefix is defined,
            # this might be an attempt to use a non-prefixed name directly.
            # For simplicity, we assume if '/' is present, it's a deliberate full ID.
            return repo_name

        user_or_org = self.api.whoami(token=self.token).get("name") if self.token else None
        if not user_or_org:
            raise ValueError("Could not determine Hugging Face username/org. Ensure you are logged in or token is valid.")
        
        effective_repo_name = repo_name
        if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix):
            effective_repo_name = f"{self.default_repo_prefix}-{repo_name}"
            
        return f"{user_or_org}/{effective_repo_name}"

    def create_repository(self, repo_name: str, repo_type: str | None = None, private: bool = True, organization: str | None = None) -> str:
        """
        Creates a new repository on the Hugging Face Hub.
        If organization is provided, repo_name should be the base name.
        If organization is None, repo_name can be a base name (username will be prepended) 
        or a full name like 'username/repo_name'.

        Args:
            repo_name: The name of the repository. Can be 'my-repo' or 'username/my-repo'.
            repo_type: Type of the repository ('model', 'dataset', 'space').
            private: Whether the repository should be private.
            organization: Optional organization name to create the repo under. If provided,
                          repo_name should be the base name for the repo within that org.

        Returns:
            The full repository ID (e.g., "username/repo_name" or "orgname/repo_name").
        """
        if organization:
            # If org is specified, repo_name must be the base name for that org
            if "/" in repo_name:
                raise ValueError("When organization is specified, repo_name should be a base name, not 'org/repo'.")
            full_repo_id = f"{organization}/{repo_name}"
        elif "/" in repo_name:
            # User provided a full name like "username/repo_name"
            full_repo_id = repo_name
        else:
            # User provided a base name, prepend current user
            user = self.api.whoami(token=self.token).get("name")
            if not user:
                raise ConnectionError("Could not determine Hugging Face username. Ensure token is valid and you are logged in.")
            full_repo_id = f"{user}/{repo_name}"
        
        try:
            url = create_repo(repo_id=full_repo_id, token=self.token, private=private, repo_type=repo_type, exist_ok=True)
            logger.info(f"Repository '{full_repo_id}' ensured to exist. URL: {url}")
            return full_repo_id
        except HfHubHTTPError as e:
            logger.error(f"Error creating repository '{full_repo_id}': {e}")
            # If error indicates it's because it's a user repo and trying to use org logic or vice-versa
            # it might be complex to auto-fix, so better to raise.
            raise

    def upload_file_or_folder(self, local_path: str | Path, repo_id: str, path_in_repo: str | None = None, repo_type: str | None = None, commit_message: str = "Upload content"):
        """Helper to upload a single file or an entire folder."""
        local_path_obj = Path(local_path)
        
        if not path_in_repo and local_path_obj.is_file():
            path_in_repo = local_path_obj.name
        elif not path_in_repo and local_path_obj.is_dir():
             # For folders, path_in_repo is relative to the repo root.
             # If None, files will be uploaded to the root.
             # If you want to upload contents of 'my_folder' into 'target_folder_in_repo/',
             # then path_in_repo should be 'target_folder_in_repo'
             # For simplicity here, if path_in_repo is None for a folder, we upload its contents to the root.
             pass


        if local_path_obj.is_file():
            self.api.upload_file(
                path_or_fileobj=str(local_path_obj),
                path_in_repo=path_in_repo if path_in_repo else local_path_obj.name,
                repo_id=repo_id,
                repo_type=repo_type,
                token=self.token,
                commit_message=commit_message,
            )
            logger.info(f"File '{local_path_obj}' uploaded to '{repo_id}/{path_in_repo if path_in_repo else local_path_obj.name}'.")
        elif local_path_obj.is_dir():
            # upload_folder uploads the *contents* of folder_path into the repo_id,
            # optionally under a path_in_repo.
            self.api.upload_folder(
                folder_path=str(local_path_obj),
                path_in_repo=path_in_repo if path_in_repo else ".", # Upload to root if no path_in_repo
                repo_id=repo_id,
                repo_type=repo_type,
                token=self.token,
                commit_message=commit_message,
                ignore_patterns=["*.git*", ".gitattributes"],
            )
            logger.info(f"Folder '{local_path_obj}' contents uploaded to '{repo_id}{'/' + path_in_repo if path_in_repo and path_in_repo != '.' else ''}'.")
        else:
            raise FileNotFoundError(f"Local path '{local_path}' not found or is not a file/directory.")

    def upload_model(self, model_path: str | Path, repo_name: str, private: bool = True, commit_message: str = "Upload model", organization: str | None = None) -> str:
        """
        Uploads a model to the Hugging Face Hub.

        Args:
            model_path: Path to the local model directory or file.
            repo_name: Base name of the repository (e.g., "my-lora-model").
                       The prefix from __init__ and username/org will be added.
            private: Whether the repository should be private.
            commit_message: Commit message for the upload.
            organization: Optional organization to host this model. If None, uses the logged-in user.


        Returns:
            The URL of the uploaded model repository.
        """
        # Construct the effective repo name, possibly prefixed
        effective_repo_name = repo_name
        if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix):
             effective_repo_name = f"{self.default_repo_prefix}-{repo_name}"

        # Create the repository
        target_repo_id = self.create_repository(repo_name=effective_repo_name, repo_type="model", private=private, organization=organization)
        
        logger.info(f"Uploading model from '{model_path}' to '{target_repo_id}'...")
        self.upload_file_or_folder(local_path=model_path, repo_id=target_repo_id, repo_type="model", commit_message=commit_message)
            
        repo_url = f"https://huggingface.co/{target_repo_id}"
        logger.info(f"Model uploaded to {repo_url}")
        return repo_url

    def download_model(self, repo_name: str, local_dir: str | Path, revision: str | None = None, organization: str | None = None) -> str:
        """
        Downloads a model from the Hugging Face Hub.

        Args:
            repo_name: Name of the repository. Can be a base name (e.g., "my-lora-model")
                       or a full ID (e.g., "username/my-lora-model").
                       If base name and no organization, prefix and username are added.
                       If base name and organization, prefix is added.
            local_dir: Local directory to save the model.
            revision: Optional model revision (branch, tag, commit hash).
            organization: Optional organization if repo_name is a base name under an org.

        Returns:
            Path to the downloaded model.
        """
        if "/" in repo_name: # User provided full ID like "user/repo" or "org/repo"
            target_repo_id = repo_name
        else: # User provided base name
            effective_repo_name = repo_name
            if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix):
                effective_repo_name = f"{self.default_repo_prefix}-{repo_name}"
            
            if organization:
                target_repo_id = f"{organization}/{effective_repo_name}"
            else:
                user = self.api.whoami(token=self.token).get("name")
                if not user:
                    raise ConnectionError("Could not determine Hugging Face username for downloading.")
                target_repo_id = f"{user}/{effective_repo_name}"

        logger.info(f"Downloading model '{target_repo_id}' to '{local_dir}'...")
        
        downloaded_path = self.api.snapshot_download(
            repo_id=target_repo_id,
            repo_type="model", # Can be omitted, snapshot_download infers if possible
            local_dir=str(local_dir),
            token=self.token,
            revision=revision,
        )
        logger.info(f"Model '{target_repo_id}' downloaded to '{downloaded_path}'.")
        return downloaded_path

    def upload_dataset(self, dataset_path: str | Path, repo_name: str, private: bool = True, commit_message: str = "Upload dataset", organization: str | None = None) -> str:
        """
        Uploads a dataset to the Hugging Face Hub. (Similar to upload_model)
        """
        effective_repo_name = repo_name
        if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix):
             effective_repo_name = f"{self.default_repo_prefix}-{repo_name}"

        target_repo_id = self.create_repository(repo_name=effective_repo_name, repo_type="dataset", private=private, organization=organization)

        logger.info(f"Uploading dataset from '{dataset_path}' to '{target_repo_id}'...")
        self.upload_file_or_folder(local_path=dataset_path, repo_id=target_repo_id, repo_type="dataset", commit_message=commit_message)

        repo_url = f"https://huggingface.co/{target_repo_id}"
        logger.info(f"Dataset uploaded to {repo_url}")
        return repo_url

    def download_dataset(self, repo_name: str, local_dir: str | Path, revision: str | None = None, organization: str | None = None) -> str:
        """
        Downloads a dataset from the Hugging Face Hub. (Similar to download_model)
        """
        if "/" in repo_name:
            target_repo_id = repo_name
        else:
            effective_repo_name = repo_name
            if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix):
                effective_repo_name = f"{self.default_repo_prefix}-{repo_name}"
            if organization:
                target_repo_id = f"{organization}/{effective_repo_name}"
            else:
                user = self.api.whoami(token=self.token).get("name")
                if not user:
                    raise ConnectionError("Could not determine Hugging Face username for downloading.")
                target_repo_id = f"{user}/{effective_repo_name}"
                
        logger.info(f"Downloading dataset '{target_repo_id}' to '{local_dir}'...")
        
        downloaded_path = self.api.snapshot_download(
            repo_id=target_repo_id,
            repo_type="dataset", # Can be omitted
            local_dir=str(local_dir),
            token=self.token,
            revision=revision,
        )
        logger.info(f"Dataset '{target_repo_id}' downloaded to '{downloaded_path}'.")
        return downloaded_path

    def initiate_training(self, model_repo_id: str, dataset_repo_id: str, training_params: dict):
        logger.warning("initiate_training is a placeholder and not fully implemented.")
        logger.info(f"Would attempt to train model {model_repo_id} with dataset {dataset_repo_id} using params: {training_params}")
        pass

# Example Usage
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    
    hf_token = os.environ.get("HF_TOKEN") 
    if not hf_token:
        logger.warning("HF_TOKEN environment variable not set. Please set it or log in via huggingface-cli.")
        logger.warning("Skipping example usage.")
    else:
        # Use a different prefix for examples to avoid conflict with actual app prefix
        hf_wrapper = HuggingFaceWrapper(token=hf_token, default_repo_prefix="hf-wrapper-test")
        
        # Determine current Hugging Face username for constructing repo IDs in tests
        try:
            current_hf_user = hf_wrapper.api.whoami(token=hf_wrapper.token).get("name")
            if not current_hf_user:
                raise ValueError("Could not retrieve HuggingFace username.")
        except Exception as e:
            logger.error(f"Failed to get HuggingFace username for tests: {e}. Skipping examples.")
            current_hf_user = None

        if current_hf_user:
            # --- Test Repository Creation ---
            test_model_repo_basename = "my-test-model"
            test_dataset_repo_basename = "my-test-dataset"
            
            # These will be prefixed like "hf-wrapper-test-my-test-model"
            # And the full ID will be "username/hf-wrapper-test-my-test-model"
            
            try:
                logger.info("\\n--- Testing Model Repository Creation ---")
                model_repo_id = hf_wrapper.create_repository(repo_name=test_model_repo_basename, repo_type="model", private=True)
                logger.info(f"Model repository created/ensured: {model_repo_id}")

                logger.info("\\n--- Testing Dataset Repository Creation ---")
                dataset_repo_id = hf_wrapper.create_repository(repo_name=test_dataset_repo_basename, repo_type="dataset", private=True)
                logger.info(f"Dataset repository created/ensured: {dataset_repo_id}")
                
                # --- Test File/Folder Upload & Download ---
                dummy_model_dir = Path("dummy_model_for_hf_upload")
                dummy_model_dir.mkdir(exist_ok=True)
                dummy_dataset_file = Path("dummy_dataset_for_hf_upload.jsonl")

                with open(dummy_model_dir / "config.json", "w") as f:
                    json.dump({"model_type": "dummy", "_comment": "Test model config"}, f, indent=2)
                with open(dummy_model_dir / "model.safetensors", "w") as f:
                    f.write("This is a dummy safetensors file content.")
                
                with open(dummy_dataset_file, "w") as f:
                    f.write(json.dumps({"text": "example line 1 for hf dataset"}) + "\\n")
                    f.write(json.dumps({"text": "example line 2 for hf dataset"}) + "\\n")

                logger.info(f"\\n--- Testing Model Upload (folder to {test_model_repo_basename}) ---")
                # upload_model uses the base name, prefixing and user/org is handled internally
                hf_wrapper.upload_model(model_path=dummy_model_dir, repo_name=test_model_repo_basename, private=True)
                
                logger.info(f"\\n--- Testing Dataset Upload (file to {test_dataset_repo_basename}) ---")
                hf_wrapper.upload_dataset(dataset_path=dummy_dataset_file, repo_name=test_dataset_repo_basename, private=True)

                # For download, construct the full repo ID as it would be on the Hub
                # The upload methods return the Hub URL, but download needs repo_id.
                # The create_repository returned the full ID, e.g. current_hf_user/hf-wrapper-test-my-test-model
                
                downloaded_model_path_base = Path("downloaded_hf_models")
                downloaded_model_path_base.mkdir(exist_ok=True)
                # model_repo_id is already the full ID from create_repository
                # e.g. "username/hf-wrapper-test-my-test-model"
                
                logger.info(f"\\n--- Testing Model Download (from {model_repo_id}) ---")
                # Use the repo_id returned by create_repository or constructed with _get_full_repo_id
                # For download, repo_name can be the full ID.
                hf_wrapper.download_model(repo_name=model_repo_id, local_dir=downloaded_model_path_base / test_model_repo_basename)
                logger.info(f"Model downloaded to: {downloaded_model_path_base / test_model_repo_basename}")

                downloaded_dataset_path_base = Path("downloaded_hf_datasets")
                downloaded_dataset_path_base.mkdir(exist_ok=True)
                # dataset_repo_id is e.g. "username/hf-wrapper-test-my-test-dataset"
                
                logger.info(f"\\n--- Testing Dataset Download (from {dataset_repo_id}) ---")
                hf_wrapper.download_dataset(repo_name=dataset_repo_id, local_dir=downloaded_dataset_path_base / test_dataset_repo_basename)
                logger.info(f"Dataset downloaded to: {downloaded_dataset_path_base / test_dataset_repo_basename}")

                logger.info("\\nExample usage complete. Check your Hugging Face account for new repositories.")
                logger.info(f"Consider deleting test repositories: {model_repo_id}, {dataset_repo_id}")
                
                # Clean up local dummy files/folders
                import shutil
                shutil.rmtree(dummy_model_dir)
                dummy_dataset_file.unlink()
                # You might want to manually inspect downloaded folders before deleting
                # shutil.rmtree(downloaded_model_path_base)
                # shutil.rmtree(downloaded_dataset_path_base)
                logger.info("Local dummy files and folders cleaned up. Downloaded content remains for inspection.")

            except Exception as e:
                logger.error(f"An error occurred during example usage: {e}", exc_info=True)