File size: 5,491 Bytes
c4a0359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import json
import logging
from typing import Any, Dict, List, Optional

from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete, create_commit, hf_hub_url

logger = logging.getLogger(__name__)


GALLERY_FILE_PATH = "gallery/gallery.json"


def build_dataset_resolve_url(repo_id: str, path_in_repo: str, revision: str = "main") -> str:
    """
    Build a CDN-resolved URL for a file stored in a Hugging Face dataset repo.
    """
    return hf_hub_url(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", revision=revision)


class HFStorageClient:
    """
    Simple helper around huggingface_hub for storing run artifacts and gallery metadata
    in a Dataset repository.

    Repo format:
      - runs/YYYY/MM/DD/<job_id>/content.jpg
      - runs/YYYY/MM/DD/<job_id>/style.jpg
      - runs/YYYY/MM/DD/<job_id>/result.jpg
      - gallery/gallery.json
    """

    def __init__(self, dataset_repo: str, hf_token: Optional[str] = None, revision: str = "main"):
        if not dataset_repo:
            raise ValueError("HF_DATASET_REPO is not set. Please configure the dataset repository id.")
        self.dataset_repo = dataset_repo
        self.revision = revision
        self.api = HfApi(token=hf_token) if hf_token else HfApi()

    def load_gallery(self) -> List[Dict[str, Any]]:
        """
        Download and parse gallery.json from the dataset. If missing, return [].
        """
        try:
            # Try to get the raw file content via the hub URL
            url = build_dataset_resolve_url(self.dataset_repo, GALLERY_FILE_PATH, self.revision)
            import requests  # local import to avoid hard dependency elsewhere
            headers = {}
            if self.api.token:
                headers["Authorization"] = f"Bearer {self.api.token}"
            resp = requests.get(url, timeout=10, headers=headers)
            if resp.status_code == 200:
                return resp.json()
            logger.info("Gallery not found at %s (status %s). Initializing empty gallery.", url, resp.status_code)
            return []
        except Exception as e:
            logger.error("Failed to load gallery from HF: %s", str(e))
            return []

    def save_gallery(self, gallery: List[Dict[str, Any]]) -> None:
        """
        Commit a new version of gallery.json to the dataset repo.
        """
        try:
            payload = json.dumps(gallery, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
            operations = [
                CommitOperationAdd(path_in_repo=GALLERY_FILE_PATH, path_or_fileobj=payload)
            ]
            create_commit(
                repo_id=self.dataset_repo,
                repo_type="dataset",
                operations=operations,
                commit_message="Update gallery.json",
                revision=self.revision,
                token=self.api.token,
            )
        except Exception as e:
            logger.error("Failed to save gallery to HF: %s", str(e))
            raise

    def upload_file(self, local_path: str, dst_path: str) -> str:
        """
        Upload a local file to the dataset repo at dst_path. Returns the path_in_repo.
        """
        if not os.path.exists(local_path):
            raise FileNotFoundError(local_path)

        try:
            with open(local_path, "rb") as f:
                operations = [
                    CommitOperationAdd(path_in_repo=dst_path, path_or_fileobj=f)
                ]
                create_commit(
                    repo_id=self.dataset_repo,
                    repo_type="dataset",
                    operations=operations,
                    commit_message=f"Upload {dst_path}",
                    revision=self.revision,
                    token=self.api.token,
                )
            return dst_path
        except Exception as e:
            logger.error("Failed to upload %s to HF at %s: %s", local_path, dst_path, str(e))
            raise

    def delete_run_artifacts(self, gallery_item: Dict[str, Any]) -> None:
        """
        Attempt to delete the three image artifacts associated with a run.
        This parses resolve URLs to determine paths in repo.
        """
        def extract_path(url: Optional[str]) -> Optional[str]:
            if not url:
                return None
            marker = "/resolve/"
            if marker in url:
                try:
                    # url ends with .../resolve/<rev>/<path_in_repo>
                    parts = url.split(marker, 1)[1].split("/", 1)
                    if len(parts) == 2:
                        return parts[1]
                except Exception:
                    return None
            return None

        paths: List[str] = []
        for key in ("contentImageUrl", "styleImageUrl", "resultImageUrl"):
            p = extract_path(gallery_item.get(key))
            if p:
                paths.append(p)

        if not paths:
            return

        try:
            operations = [CommitOperationDelete(path) for path in paths]
            create_commit(
                repo_id=self.dataset_repo,
                repo_type="dataset",
                operations=operations,
                commit_message=f"Delete artifacts for run {gallery_item.get('id', '')}",
                revision=self.revision,
                token=self.api.token,
            )
        except Exception as e:
            logger.error("Failed to delete artifacts %s: %s", paths, str(e))