mohsin-devs's picture
Boot safely when HF token secret is missing
8605269
"""Hugging Face Hub storage backend for DocVault."""
from __future__ import annotations
import os
from datetime import datetime, timezone
from typing import Any, Dict, List
from huggingface_hub import (
CommitOperationAdd,
CommitOperationCopy,
CommitOperationDelete,
HfApi,
hf_hub_url,
)
from server import config
from server.storage.interface import StorageInterface
from server.utils.logger import setup_logger
from server.utils.validators import PathValidator, sanitize_filename
logger = setup_logger(__name__)
class HuggingFaceStorageManager(StorageInterface):
"""Stores all files and folders in a Hugging Face dataset repository."""
def __init__(self) -> None:
self.api = HfApi(token=config.HF_TOKEN) if config.HF_TOKEN else HfApi()
self._repo_ready = False
if config.HF_TOKEN:
self._ensure_repo_exists()
def _ensure_token(self) -> None:
if not config.HF_TOKEN:
raise RuntimeError(
"HF_TOKEN is required for write operations. Set HF_TOKEN or HUGGING_FACE_HUB_TOKEN in the Space secrets."
)
def _ensure_repo_exists(self) -> None:
self._ensure_token()
try:
self.api.repo_info(repo_id=config.HF_REPO_ID, repo_type=config.HF_REPO_TYPE)
self._repo_ready = True
except Exception:
self.api.create_repo(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
private=True,
exist_ok=True,
)
self._repo_ready = True
def _timestamp(self) -> str:
return datetime.now(timezone.utc).isoformat()
def _validate_relative_path(self, path: str, label: str = "path") -> str:
normalized = PathValidator._normalize_relative_path(path)
if not PathValidator.is_valid_path(normalized):
raise ValueError(f"Invalid {label}: {path}")
return normalized
def _user_repo_path(self, user_id: str, path: str = "") -> str:
normalized = self._validate_relative_path(path)
base = self._validate_relative_path(user_id, "user_id")
return "/".join(part for part in [base, normalized] if part)
def _folder_marker_path(self, user_id: str, folder_path: str) -> str:
repo_folder = self._user_repo_path(user_id, folder_path)
return "/".join([repo_folder, config.FOLDER_MARKER]) if repo_folder else config.FOLDER_MARKER
def _list_repo_files(self) -> List[str]:
try:
return self.api.list_repo_files(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
)
except Exception:
return []
def _file_exists(self, repo_path: str, repo_files: List[str] | None = None) -> bool:
repo_files = repo_files if repo_files is not None else self._list_repo_files()
return repo_path in repo_files
def _folder_exists(self, repo_folder_path: str, repo_files: List[str] | None = None) -> bool:
repo_files = repo_files if repo_files is not None else self._list_repo_files()
marker = f"{repo_folder_path}/{config.FOLDER_MARKER}"
prefix = f"{repo_folder_path}/"
return marker in repo_files or any(item.startswith(prefix) for item in repo_files)
def _split_parent(self, path: str) -> tuple[str, str]:
normalized = self._validate_relative_path(path)
parent = os.path.dirname(normalized).replace("\\", "/").strip(".")
parent = parent.strip("/")
return parent, os.path.basename(normalized)
def get_item_type(self, user_id: str, path: str) -> str | None:
normalized = self._validate_relative_path(path)
repo_path = self._user_repo_path(user_id, normalized)
repo_files = self._list_repo_files()
if repo_path in repo_files:
return "file"
if any(item.startswith(f"{repo_path}/") for item in repo_files):
return "folder"
return None
def _next_available_file_path(
self, user_id: str, folder_path: str, filename: str
) -> tuple[str, str]:
folder_path = self._validate_relative_path(folder_path)
filename = sanitize_filename(filename)
if not PathValidator.is_valid_filename(filename):
raise ValueError("Invalid filename.")
repo_files = self._list_repo_files()
name, ext = os.path.splitext(filename)
counter = 0
while True:
candidate_name = filename if counter == 0 else f"{name}_{counter}{ext}"
candidate_repo_path = self._user_repo_path(
user_id,
"/".join(part for part in [folder_path, candidate_name] if part),
)
if candidate_repo_path not in repo_files:
relative_path = "/".join(part for part in [folder_path, candidate_name] if part)
return relative_path, candidate_repo_path
counter += 1
def create_folder(self, user_id: str, folder_path: str) -> Dict[str, Any]:
self._ensure_token()
folder_path = self._validate_relative_path(folder_path, "folder_path")
if not folder_path:
return {"success": False, "error": "folder_path is required", "code": "INVALID_FOLDER"}
repo_folder_path = self._user_repo_path(user_id, folder_path)
repo_files = self._list_repo_files()
if self._folder_exists(repo_folder_path, repo_files):
return {"success": False, "error": "Folder already exists", "code": "FOLDER_EXISTS"}
marker_path = self._folder_marker_path(user_id, folder_path)
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=[CommitOperationAdd(path_in_repo=marker_path, path_or_fileobj=b"")],
commit_message=f"Create folder {repo_folder_path}",
)
folder_name = folder_path.split("/")[-1]
return {
"success": True,
"message": f"Folder created: {folder_path}",
"folder": self.standardize_folder(
name=folder_name,
path=folder_path,
created_at=self._timestamp(),
storage_type="hf",
),
}
def upload_file(
self, user_id: str, folder_path: str, filename: str, file_obj: Any
) -> Dict[str, Any]:
self._ensure_token()
folder_path = self._validate_relative_path(folder_path, "folder_path")
relative_path, repo_path = self._next_available_file_path(user_id, folder_path, filename)
final_filename = relative_path.split("/")[-1]
file_data = file_obj.read() if hasattr(file_obj, "read") else file_obj
if not isinstance(file_data, (bytes, bytearray)):
raise TypeError("Uploaded content must be bytes.")
operations = []
if folder_path:
marker_path = self._folder_marker_path(user_id, folder_path)
if not self._file_exists(marker_path):
operations.append(
CommitOperationAdd(path_in_repo=marker_path, path_or_fileobj=b"")
)
operations.append(CommitOperationAdd(path_in_repo=repo_path, path_or_fileobj=file_data))
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=operations,
commit_message=f"Upload {repo_path}",
)
return {
"success": True,
"message": f"Uploaded file: {final_filename}",
"file": self.standardize_file(
name=final_filename,
path=relative_path,
size=len(file_data),
created_at=self._timestamp(),
storage_type="hf",
),
}
def delete_file(self, user_id: str, file_path: str) -> Dict[str, Any]:
self._ensure_token()
file_path = self._validate_relative_path(file_path, "file_path")
repo_path = self._user_repo_path(user_id, file_path)
if not self._file_exists(repo_path):
return {"success": False, "error": "File not found", "code": "FILE_NOT_FOUND"}
self.api.delete_file(
path_in_repo=repo_path,
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
commit_message=f"Delete file {repo_path}",
)
return {"success": True, "message": f"Deleted file: {file_path}"}
def rename_file(self, user_id: str, file_path: str, new_name: str) -> Dict[str, Any]:
self._ensure_token()
file_path = self._validate_relative_path(file_path, "file_path")
new_name = sanitize_filename(new_name)
if not PathValidator.is_valid_filename(new_name):
return {"success": False, "error": "Invalid characters in name", "code": "INVALID_NAME"}
parent, _ = self._split_parent(file_path)
new_relative_path = "/".join(part for part in [parent, new_name] if part)
source_repo_path = self._user_repo_path(user_id, file_path)
target_repo_path = self._user_repo_path(user_id, new_relative_path)
repo_files = self._list_repo_files()
if source_repo_path not in repo_files:
return {"success": False, "error": "File not found", "code": "NOT_FOUND"}
if target_repo_path in repo_files:
return {"success": False, "error": "An item with this name already exists", "code": "CONFLICT"}
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=[
CommitOperationCopy(
src_path_in_repo=source_repo_path, path_in_repo=target_repo_path
),
CommitOperationDelete(path_in_repo=source_repo_path),
],
commit_message=f"Rename file {source_repo_path} -> {target_repo_path}",
)
return {
"success": True,
"message": f"Renamed to {new_name}",
"item": {"name": new_name, "type": "file", "path": new_relative_path},
}
def delete_folder(self, user_id: str, folder_path: str) -> Dict[str, Any]:
self._ensure_token()
folder_path = self._validate_relative_path(folder_path, "folder_path")
repo_folder_path = self._user_repo_path(user_id, folder_path)
repo_files = self._list_repo_files()
prefix = f"{repo_folder_path}/"
matches = [item for item in repo_files if item.startswith(prefix)]
if not matches:
return {
"success": False,
"error": "Folder not found",
"code": "FOLDER_NOT_FOUND",
}
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=[CommitOperationDelete(path_in_repo=item) for item in matches],
commit_message=f"Delete folder {repo_folder_path}",
)
return {"success": True, "message": f"Deleted folder: {folder_path}"}
def rename_folder(
self, user_id: str, folder_path: str, new_name: str
) -> Dict[str, Any]:
self._ensure_token()
folder_path = self._validate_relative_path(folder_path, "folder_path")
new_name = sanitize_filename(new_name)
if not PathValidator.is_valid_filename(new_name):
return {"success": False, "error": "Invalid characters in name", "code": "INVALID_NAME"}
parent, _ = self._split_parent(folder_path)
new_relative_path = "/".join(part for part in [parent, new_name] if part)
source_repo_path = self._user_repo_path(user_id, folder_path)
target_repo_path = self._user_repo_path(user_id, new_relative_path)
repo_files = self._list_repo_files()
source_prefix = f"{source_repo_path}/"
target_prefix = f"{target_repo_path}/"
matches = [item for item in repo_files if item.startswith(source_prefix)]
if not matches:
return {"success": False, "error": "Folder not found", "code": "NOT_FOUND"}
if any(item.startswith(target_prefix) for item in repo_files):
return {"success": False, "error": "An item with this name already exists", "code": "CONFLICT"}
operations = []
for item in matches:
suffix = item[len(source_prefix) :]
operations.append(
CommitOperationCopy(
src_path_in_repo=item, path_in_repo=f"{target_prefix}{suffix}"
)
)
operations.append(CommitOperationDelete(path_in_repo=item))
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=operations,
commit_message=f"Rename folder {source_repo_path} -> {target_repo_path}",
)
return {
"success": True,
"message": f"Renamed to {new_name}",
"item": {"name": new_name, "type": "folder", "path": new_relative_path},
}
def download(self, user_id: str, file_path: str) -> Dict[str, Any]:
file_path = self._validate_relative_path(file_path, "file_path")
repo_path = self._user_repo_path(user_id, file_path)
if not self._file_exists(repo_path):
raise FileNotFoundError(f"File not found: {file_path}")
return {
"url": hf_hub_url(
repo_id=config.HF_REPO_ID,
filename=repo_path,
repo_type=config.HF_REPO_TYPE,
),
"path": file_path,
}
def list(self, user_id: str, prefix: str = "") -> Dict[str, List[Dict[str, Any]]]:
prefix = self._validate_relative_path(prefix, "folder_path")
repo_prefix = self._user_repo_path(user_id, prefix)
search_prefix = f"{repo_prefix}/" if repo_prefix else f"{self._user_repo_path(user_id)}/"
repo_files = self._list_repo_files()
folders_map: Dict[str, Dict[str, Any]] = {}
files: List[Dict[str, Any]] = []
for repo_item in repo_files:
if not repo_item.startswith(search_prefix):
continue
relative = repo_item[len(search_prefix) :]
if not relative:
continue
parts = relative.split("/")
if len(parts) > 1:
folder_name = parts[0]
folder_path = "/".join(part for part in [prefix, folder_name] if part)
folders_map.setdefault(
folder_name,
self.standardize_folder(
name=folder_name,
path=folder_path,
created_at=self._timestamp(),
storage_type="hf",
),
)
continue
if parts[0] == config.FOLDER_MARKER:
continue
file_path = "/".join(part for part in [prefix, parts[0]] if part)
files.append(
self.standardize_file(
name=parts[0],
path=file_path,
size=0,
created_at=self._timestamp(),
storage_type="hf",
)
)
folders = sorted(folders_map.values(), key=lambda item: item["name"].lower())
files.sort(key=lambda item: item["name"].lower())
return {
"success": True,
"folders": folders,
"files": files,
"total_folders": len(folders),
"total_files": len(files),
}
def exists(self, user_id: str, path: str) -> bool:
try:
normalized = self._validate_relative_path(path)
repo_path = self._user_repo_path(user_id, normalized)
repo_files = self._list_repo_files()
return repo_path in repo_files or any(
item.startswith(f"{repo_path}/") for item in repo_files
)
except Exception:
return False
def get_stats(self, user_id: str) -> Dict[str, Any]:
repo_prefix = f"{self._user_repo_path(user_id)}/"
repo_files = self._list_repo_files()
file_count = 0
folder_names = set()
for repo_item in repo_files:
if not repo_item.startswith(repo_prefix):
continue
relative = repo_item[len(repo_prefix) :]
if relative.endswith(f"/{config.FOLDER_MARKER}"):
folder_names.add(relative[: -len(f'/{config.FOLDER_MARKER}')])
continue
if relative == config.FOLDER_MARKER:
continue
if "/" in relative:
parts = relative.split("/")[:-1]
for index in range(1, len(parts) + 1):
folder_names.add("/".join(parts[:index]))
file_count += 1
return {
"success": True,
"total_size": 0,
"total_size_formatted": "0 B",
"total_files": file_count,
"total_folders": len(folder_names),
"storage_type": "hf",
"repo_id": config.HF_REPO_ID,
}
def get_history(self, user_id: str, path: str) -> List[Dict[str, Any]]:
self._ensure_token()
repo_path = self._user_repo_path(user_id, path)
commits = self.api.list_repo_commits(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
)
history = []
for commit in commits:
title = getattr(commit, "title", "") or getattr(commit, "message", "")
if repo_path not in title:
continue
history.append(
{
"id": commit.commit_id,
"message": title,
"timestamp": commit.created_at.isoformat(),
"author": commit.authors[0] if getattr(commit, "authors", None) else "unknown",
}
)
return history
def restore(
self, user_id: str, path: str, revision: str, as_copy: bool = False
) -> Dict[str, Any]:
self._ensure_token()
path = self._validate_relative_path(path)
source_repo_path = self._user_repo_path(user_id, path)
destination_repo_path = source_repo_path
destination_relative_path = path
if as_copy:
parent, filename = self._split_parent(path)
stem, ext = os.path.splitext(filename)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
copy_name = f"{stem}_{timestamp}{ext}"
destination_relative_path = "/".join(
part for part in [parent, copy_name] if part
)
destination_repo_path = self._user_repo_path(user_id, destination_relative_path)
self.api.create_commit(
repo_id=config.HF_REPO_ID,
repo_type=config.HF_REPO_TYPE,
operations=[
CommitOperationCopy(
src_path_in_repo=source_repo_path,
path_in_repo=destination_repo_path,
src_revision=revision,
)
],
commit_message=f"Restore {source_repo_path} from {revision}",
)
return {
"success": True,
"message": f"Restored {path}",
"item": {
"name": destination_relative_path.split("/")[-1],
"path": destination_relative_path,
},
}