import os from huggingface_hub import HfApi, HfFolder, Repository, create_repo, get_full_repo_name from huggingface_hub.utils import HfHubHTTPError import logging from pathlib import Path import json # Added for example usage logger = logging.getLogger(__name__) class HuggingFaceWrapper: """ A wrapper for interacting with the Hugging Face Hub API. Handles authentication, model and dataset uploads/downloads, and repository creation. """ def __init__(self, token: str | None = None, default_repo_prefix: str = "museum-sexoskop"): """ Initializes the HuggingFaceWrapper. Args: token: Your Hugging Face API token. If None, it will try to use a token saved locally via `huggingface-cli login`. default_repo_prefix: A default prefix for repository names. """ self.api = HfApi() if token: self.token = token # Note: HfApi uses the token from HfFolder by default if logged in. # To explicitly use a provided token for all operations, # some HfApi methods accept it directly. # For operations like Repository, ensure the environment or HfFolder is set. HfFolder.save_token(token) logger.info("Hugging Face token saved for the session.") else: self.token = HfFolder.get_token() if not self.token: logger.warning("No Hugging Face token provided or found locally. " "Please login using `huggingface-cli login` or provide a token.") else: logger.info("Using locally saved Hugging Face token.") self.default_repo_prefix = default_repo_prefix def _get_full_repo_id(self, repo_name: str, repo_type: str | None = None) -> str: """Helper to construct full repo ID, ensuring it includes the username/org.""" # If repo_name already contains a slash, it's likely a full ID (user/repo or org/repo) if "/" in repo_name: # Further check: if it doesn't have the prefix and prefix is defined, # this might be an attempt to use a non-prefixed name directly. # For simplicity, we assume if '/' is present, it's a deliberate full ID. return repo_name user_or_org = self.api.whoami(token=self.token).get("name") if self.token else None if not user_or_org: raise ValueError("Could not determine Hugging Face username/org. Ensure you are logged in or token is valid.") effective_repo_name = repo_name if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix): effective_repo_name = f"{self.default_repo_prefix}-{repo_name}" return f"{user_or_org}/{effective_repo_name}" def create_repository(self, repo_name: str, repo_type: str | None = None, private: bool = True, organization: str | None = None) -> str: """ Creates a new repository on the Hugging Face Hub. If organization is provided, repo_name should be the base name. If organization is None, repo_name can be a base name (username will be prepended) or a full name like 'username/repo_name'. Args: repo_name: The name of the repository. Can be 'my-repo' or 'username/my-repo'. repo_type: Type of the repository ('model', 'dataset', 'space'). private: Whether the repository should be private. organization: Optional organization name to create the repo under. If provided, repo_name should be the base name for the repo within that org. Returns: The full repository ID (e.g., "username/repo_name" or "orgname/repo_name"). """ if organization: # If org is specified, repo_name must be the base name for that org if "/" in repo_name: raise ValueError("When organization is specified, repo_name should be a base name, not 'org/repo'.") full_repo_id = f"{organization}/{repo_name}" elif "/" in repo_name: # User provided a full name like "username/repo_name" full_repo_id = repo_name else: # User provided a base name, prepend current user user = self.api.whoami(token=self.token).get("name") if not user: raise ConnectionError("Could not determine Hugging Face username. Ensure token is valid and you are logged in.") full_repo_id = f"{user}/{repo_name}" try: url = create_repo(repo_id=full_repo_id, token=self.token, private=private, repo_type=repo_type, exist_ok=True) logger.info(f"Repository '{full_repo_id}' ensured to exist. URL: {url}") return full_repo_id except HfHubHTTPError as e: logger.error(f"Error creating repository '{full_repo_id}': {e}") # If error indicates it's because it's a user repo and trying to use org logic or vice-versa # it might be complex to auto-fix, so better to raise. raise def upload_file_or_folder(self, local_path: str | Path, repo_id: str, path_in_repo: str | None = None, repo_type: str | None = None, commit_message: str = "Upload content"): """Helper to upload a single file or an entire folder.""" local_path_obj = Path(local_path) if not path_in_repo and local_path_obj.is_file(): path_in_repo = local_path_obj.name elif not path_in_repo and local_path_obj.is_dir(): # For folders, path_in_repo is relative to the repo root. # If None, files will be uploaded to the root. # If you want to upload contents of 'my_folder' into 'target_folder_in_repo/', # then path_in_repo should be 'target_folder_in_repo' # For simplicity here, if path_in_repo is None for a folder, we upload its contents to the root. pass if local_path_obj.is_file(): self.api.upload_file( path_or_fileobj=str(local_path_obj), path_in_repo=path_in_repo if path_in_repo else local_path_obj.name, repo_id=repo_id, repo_type=repo_type, token=self.token, commit_message=commit_message, ) logger.info(f"File '{local_path_obj}' uploaded to '{repo_id}/{path_in_repo if path_in_repo else local_path_obj.name}'.") elif local_path_obj.is_dir(): # upload_folder uploads the *contents* of folder_path into the repo_id, # optionally under a path_in_repo. self.api.upload_folder( folder_path=str(local_path_obj), path_in_repo=path_in_repo if path_in_repo else ".", # Upload to root if no path_in_repo repo_id=repo_id, repo_type=repo_type, token=self.token, commit_message=commit_message, ignore_patterns=["*.git*", ".gitattributes"], ) logger.info(f"Folder '{local_path_obj}' contents uploaded to '{repo_id}{'/' + path_in_repo if path_in_repo and path_in_repo != '.' else ''}'.") else: raise FileNotFoundError(f"Local path '{local_path}' not found or is not a file/directory.") def upload_model(self, model_path: str | Path, repo_name: str, private: bool = True, commit_message: str = "Upload model", organization: str | None = None) -> str: """ Uploads a model to the Hugging Face Hub. Args: model_path: Path to the local model directory or file. repo_name: Base name of the repository (e.g., "my-lora-model"). The prefix from __init__ and username/org will be added. private: Whether the repository should be private. commit_message: Commit message for the upload. organization: Optional organization to host this model. If None, uses the logged-in user. Returns: The URL of the uploaded model repository. """ # Construct the effective repo name, possibly prefixed effective_repo_name = repo_name if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix): effective_repo_name = f"{self.default_repo_prefix}-{repo_name}" # Create the repository target_repo_id = self.create_repository(repo_name=effective_repo_name, repo_type="model", private=private, organization=organization) logger.info(f"Uploading model from '{model_path}' to '{target_repo_id}'...") self.upload_file_or_folder(local_path=model_path, repo_id=target_repo_id, repo_type="model", commit_message=commit_message) repo_url = f"https://huggingface.co/{target_repo_id}" logger.info(f"Model uploaded to {repo_url}") return repo_url def download_model(self, repo_name: str, local_dir: str | Path, revision: str | None = None, organization: str | None = None) -> str: """ Downloads a model from the Hugging Face Hub. Args: repo_name: Name of the repository. Can be a base name (e.g., "my-lora-model") or a full ID (e.g., "username/my-lora-model"). If base name and no organization, prefix and username are added. If base name and organization, prefix is added. local_dir: Local directory to save the model. revision: Optional model revision (branch, tag, commit hash). organization: Optional organization if repo_name is a base name under an org. Returns: Path to the downloaded model. """ if "/" in repo_name: # User provided full ID like "user/repo" or "org/repo" target_repo_id = repo_name else: # User provided base name effective_repo_name = repo_name if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix): effective_repo_name = f"{self.default_repo_prefix}-{repo_name}" if organization: target_repo_id = f"{organization}/{effective_repo_name}" else: user = self.api.whoami(token=self.token).get("name") if not user: raise ConnectionError("Could not determine Hugging Face username for downloading.") target_repo_id = f"{user}/{effective_repo_name}" logger.info(f"Downloading model '{target_repo_id}' to '{local_dir}'...") downloaded_path = self.api.snapshot_download( repo_id=target_repo_id, repo_type="model", # Can be omitted, snapshot_download infers if possible local_dir=str(local_dir), token=self.token, revision=revision, ) logger.info(f"Model '{target_repo_id}' downloaded to '{downloaded_path}'.") return downloaded_path def upload_dataset(self, dataset_path: str | Path, repo_name: str, private: bool = True, commit_message: str = "Upload dataset", organization: str | None = None) -> str: """ Uploads a dataset to the Hugging Face Hub. (Similar to upload_model) """ effective_repo_name = repo_name if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix): effective_repo_name = f"{self.default_repo_prefix}-{repo_name}" target_repo_id = self.create_repository(repo_name=effective_repo_name, repo_type="dataset", private=private, organization=organization) logger.info(f"Uploading dataset from '{dataset_path}' to '{target_repo_id}'...") self.upload_file_or_folder(local_path=dataset_path, repo_id=target_repo_id, repo_type="dataset", commit_message=commit_message) repo_url = f"https://huggingface.co/{target_repo_id}" logger.info(f"Dataset uploaded to {repo_url}") return repo_url def download_dataset(self, repo_name: str, local_dir: str | Path, revision: str | None = None, organization: str | None = None) -> str: """ Downloads a dataset from the Hugging Face Hub. (Similar to download_model) """ if "/" in repo_name: target_repo_id = repo_name else: effective_repo_name = repo_name if self.default_repo_prefix and not repo_name.startswith(self.default_repo_prefix): effective_repo_name = f"{self.default_repo_prefix}-{repo_name}" if organization: target_repo_id = f"{organization}/{effective_repo_name}" else: user = self.api.whoami(token=self.token).get("name") if not user: raise ConnectionError("Could not determine Hugging Face username for downloading.") target_repo_id = f"{user}/{effective_repo_name}" logger.info(f"Downloading dataset '{target_repo_id}' to '{local_dir}'...") downloaded_path = self.api.snapshot_download( repo_id=target_repo_id, repo_type="dataset", # Can be omitted local_dir=str(local_dir), token=self.token, revision=revision, ) logger.info(f"Dataset '{target_repo_id}' downloaded to '{downloaded_path}'.") return downloaded_path def initiate_training(self, model_repo_id: str, dataset_repo_id: str, training_params: dict): logger.warning("initiate_training is a placeholder and not fully implemented.") logger.info(f"Would attempt to train model {model_repo_id} with dataset {dataset_repo_id} using params: {training_params}") pass # Example Usage if __name__ == "__main__": logging.basicConfig(level=logging.INFO) hf_token = os.environ.get("HF_TOKEN") if not hf_token: logger.warning("HF_TOKEN environment variable not set. Please set it or log in via huggingface-cli.") logger.warning("Skipping example usage.") else: # Use a different prefix for examples to avoid conflict with actual app prefix hf_wrapper = HuggingFaceWrapper(token=hf_token, default_repo_prefix="hf-wrapper-test") # Determine current Hugging Face username for constructing repo IDs in tests try: current_hf_user = hf_wrapper.api.whoami(token=hf_wrapper.token).get("name") if not current_hf_user: raise ValueError("Could not retrieve HuggingFace username.") except Exception as e: logger.error(f"Failed to get HuggingFace username for tests: {e}. Skipping examples.") current_hf_user = None if current_hf_user: # --- Test Repository Creation --- test_model_repo_basename = "my-test-model" test_dataset_repo_basename = "my-test-dataset" # These will be prefixed like "hf-wrapper-test-my-test-model" # And the full ID will be "username/hf-wrapper-test-my-test-model" try: logger.info("\\n--- Testing Model Repository Creation ---") model_repo_id = hf_wrapper.create_repository(repo_name=test_model_repo_basename, repo_type="model", private=True) logger.info(f"Model repository created/ensured: {model_repo_id}") logger.info("\\n--- Testing Dataset Repository Creation ---") dataset_repo_id = hf_wrapper.create_repository(repo_name=test_dataset_repo_basename, repo_type="dataset", private=True) logger.info(f"Dataset repository created/ensured: {dataset_repo_id}") # --- Test File/Folder Upload & Download --- dummy_model_dir = Path("dummy_model_for_hf_upload") dummy_model_dir.mkdir(exist_ok=True) dummy_dataset_file = Path("dummy_dataset_for_hf_upload.jsonl") with open(dummy_model_dir / "config.json", "w") as f: json.dump({"model_type": "dummy", "_comment": "Test model config"}, f, indent=2) with open(dummy_model_dir / "model.safetensors", "w") as f: f.write("This is a dummy safetensors file content.") with open(dummy_dataset_file, "w") as f: f.write(json.dumps({"text": "example line 1 for hf dataset"}) + "\\n") f.write(json.dumps({"text": "example line 2 for hf dataset"}) + "\\n") logger.info(f"\\n--- Testing Model Upload (folder to {test_model_repo_basename}) ---") # upload_model uses the base name, prefixing and user/org is handled internally hf_wrapper.upload_model(model_path=dummy_model_dir, repo_name=test_model_repo_basename, private=True) logger.info(f"\\n--- Testing Dataset Upload (file to {test_dataset_repo_basename}) ---") hf_wrapper.upload_dataset(dataset_path=dummy_dataset_file, repo_name=test_dataset_repo_basename, private=True) # For download, construct the full repo ID as it would be on the Hub # The upload methods return the Hub URL, but download needs repo_id. # The create_repository returned the full ID, e.g. current_hf_user/hf-wrapper-test-my-test-model downloaded_model_path_base = Path("downloaded_hf_models") downloaded_model_path_base.mkdir(exist_ok=True) # model_repo_id is already the full ID from create_repository # e.g. "username/hf-wrapper-test-my-test-model" logger.info(f"\\n--- Testing Model Download (from {model_repo_id}) ---") # Use the repo_id returned by create_repository or constructed with _get_full_repo_id # For download, repo_name can be the full ID. hf_wrapper.download_model(repo_name=model_repo_id, local_dir=downloaded_model_path_base / test_model_repo_basename) logger.info(f"Model downloaded to: {downloaded_model_path_base / test_model_repo_basename}") downloaded_dataset_path_base = Path("downloaded_hf_datasets") downloaded_dataset_path_base.mkdir(exist_ok=True) # dataset_repo_id is e.g. "username/hf-wrapper-test-my-test-dataset" logger.info(f"\\n--- Testing Dataset Download (from {dataset_repo_id}) ---") hf_wrapper.download_dataset(repo_name=dataset_repo_id, local_dir=downloaded_dataset_path_base / test_dataset_repo_basename) logger.info(f"Dataset downloaded to: {downloaded_dataset_path_base / test_dataset_repo_basename}") logger.info("\\nExample usage complete. Check your Hugging Face account for new repositories.") logger.info(f"Consider deleting test repositories: {model_repo_id}, {dataset_repo_id}") # Clean up local dummy files/folders import shutil shutil.rmtree(dummy_model_dir) dummy_dataset_file.unlink() # You might want to manually inspect downloaded folders before deleting # shutil.rmtree(downloaded_model_path_base) # shutil.rmtree(downloaded_dataset_path_base) logger.info("Local dummy files and folders cleaned up. Downloaded content remains for inspection.") except Exception as e: logger.error(f"An error occurred during example usage: {e}", exc_info=True)