Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import logging | |
| from typing import Any, Dict, List, Optional | |
| from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete, create_commit, hf_hub_url | |
| logger = logging.getLogger(__name__) | |
| GALLERY_FILE_PATH = "gallery/gallery.json" | |
| def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str: | |
| """ | |
| Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo. | |
| """ | |
| return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision) | |
| class HFStorageClient: | |
| """ | |
| Simple helper around huggingface_hub for storing run artifacts and gallery metadata | |
| in a Dataset repository. | |
| Repo format: | |
| - runs/YYYY/MM/DD/<job_id>/content.jpg | |
| - runs/YYYY/MM/DD/<job_id>/style.jpg | |
| - runs/YYYY/MM/DD/<job_id>/result.jpg | |
| - gallery/gallery.json | |
| """ | |
| def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"): | |
| if not dataset_repo: | |
| raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.") | |
| self.dataset_repo = dataset_repo | |
| self.revision = revision | |
| self.api = HfApi(token=hf_token) if hf_token else HfApi() | |
| def load_gallery(self) -> List[Dict[str, Any]]: | |
| """ | |
| Download and parse gallery.json from the dataset. If missing, return []. | |
| """ | |
| try: | |
| # Try to get the raw file content via the hub URL | |
| url = build_dataset_resolve_url(self.dataset_repo, GALLERY_FILE_PATH, self.revision) | |
| import requests # local import to avoid hard dependency elsewhere | |
| headers = {} | |
| if self.api.token: | |
| headers["Authorization"] = f"Bearer {self.api.token}" | |
| resp = requests.get(url, timeout=10, headers=headers) | |
| if resp.status_code == 200: | |
| return resp.json() | |
| logger.info("Gallery not found at %s (status %s). Initializing empty gallery.", url, resp.status_code) | |
| return [] | |
| except Exception as e: | |
| logger.error("Failed to load gallery from HF: %s", str(e)) | |
| return [] | |
| def save_gallery(self, gallery: List[Dict[str, Any]]) -> None: | |
| """ | |
| Commit a new version of gallery.json to the dataset repo. | |
| """ | |
| try: | |
| payload = json.dumps(gallery, ensure_ascii=False, separators=(",", ":")).encode("utf-8") | |
| operations = [ | |
| CommitOperationAdd(path_in_repo=GALLERY_FILE_PATH, path_or_fileobj=payload) | |
| ] | |
| create_commit( | |
| repo_id=self.dataset_repo, | |
| repo_type="dataset", | |
| operations=operations, | |
| commit_message="Update gallery.json", | |
| revision=self.revision, | |
| token=self.api.token, | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to save gallery to HF: %s", str(e)) | |
| raise | |
| def upload_file(self, local_path: str, dst_path: str) -> str: | |
| """ | |
| Upload a local file to the dataset repo at dst_path. Returns the path_in_repo. | |
| """ | |
| if not os.path.exists(local_path): | |
| raise FileNotFoundError(local_path) | |
| try: | |
| with open(local_path, "rb") as f: | |
| operations = [ | |
| CommitOperationAdd(path_in_repo=dst_path, path_or_fileobj=f) | |
| ] | |
| create_commit( | |
| repo_id=self.dataset_repo, | |
| repo_type="dataset", | |
| operations=operations, | |
| commit_message=f"Upload {dst_path}", | |
| revision=self.revision, | |
| token=self.api.token, | |
| ) | |
| return dst_path | |
| except Exception as e: | |
| logger.error("Failed to upload %s to HF at %s: %s", local_path, dst_path, str(e)) | |
| raise | |
| def delete_run_artifacts(self, gallery_item: Dict[str, Any]) -> None: | |
| """ | |
| Attempt to delete the three image artifacts associated with a run. | |
| This parses resolve URLs to determine paths in repo. | |
| """ | |
| def extract_path(url: Optional[str]) -> Optional[str]: | |
| if not url: | |
| return None | |
| marker = "/resolve/" | |
| if marker in url: | |
| try: | |
| # url ends with .../resolve/<rev>/<path_in_repo> | |
| parts = url.split(marker, 1)[1].split("/", 1) | |
| if len(parts) == 2: | |
| return parts[1] | |
| except Exception: | |
| return None | |
| return None | |
| paths: List[str] = [] | |
| for key in ("contentImageUrl", "styleImageUrl", "resultImageUrl"): | |
| p = extract_path(gallery_item.get(key)) | |
| if p: | |
| paths.append(p) | |
| if not paths: | |
| return | |
| try: | |
| operations = [CommitOperationDelete(path) for path in paths] | |
| create_commit( | |
| repo_id=self.dataset_repo, | |
| repo_type="dataset", | |
| operations=operations, | |
| commit_message=f"Delete artifacts for run {gallery_item.get('id', '')}", | |
| revision=self.revision, | |
| token=self.api.token, | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to delete artifacts %s: %s", paths, str(e)) | |