File size: 9,699 Bytes
6933b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
import os
import uuid
import shutil
from pathlib import Path
from PIL import Image
import numpy as np

DATA_REPO = "aj406/vton-data"
REPO_TYPE = "dataset"
DATASET_HF_TOKEN = os.environ.get("DATASET_HF_TOKEN")
LOCAL_DATA = Path("data")

IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp"}


def is_remote():
    return DATASET_HF_TOKEN is not None


def _api():
    from huggingface_hub import HfApi
    return HfApi()


def _ensure_repo():
    if not is_remote():
        return
    _api().create_repo(repo_id=DATA_REPO, repo_type=REPO_TYPE, exist_ok=True, token=DATASET_HF_TOKEN)


def save_image(img, local_path):
    local_path = Path(local_path)
    local_path.parent.mkdir(parents=True, exist_ok=True)
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    img.save(local_path, "JPEG", quality=85)


def upload_image(local_path, remote_path):
    if not is_remote():
        return
    _ensure_repo()
    _api().upload_file(
        path_or_fileobj=str(local_path),
        path_in_repo=remote_path,
        repo_id=DATA_REPO,
        repo_type=REPO_TYPE,
        token=DATASET_HF_TOKEN,
    )


def delete_remote_file(remote_path):
    if not is_remote():
        return
    _api().delete_file(
        path_in_repo=remote_path,
        repo_id=DATA_REPO,
        repo_type=REPO_TYPE,
        token=DATASET_HF_TOKEN,
    )


def download_dir(remote_prefix):
    if not is_remote():
        return
    from huggingface_hub import snapshot_download
    snapshot_download(
        repo_id=DATA_REPO,
        repo_type=REPO_TYPE,
        allow_patterns=f"{remote_prefix}/**",
        local_dir=str(LOCAL_DATA),
        token=DATASET_HF_TOKEN,
    )


def generate_id():
    return uuid.uuid4().hex[:8]


def make_filename(item_id, item_type):
    return f"{item_id}_{item_type}.jpg"


def parse_filename(filename):
    stem = Path(filename).stem
    parts = stem.rsplit("_", 1)
    if len(parts) != 2:
        return None
    return {"id": parts[0], "type": parts[1]}


def make_result_filename(portrait_id, garment_id):
    return f"{portrait_id}_{garment_id}_result.jpg"


def parse_result_filename(filename):
    stem = Path(filename).stem
    parts = stem.rsplit("_", 2)
    if len(parts) != 3 or parts[2] != "result":
        return None
    if "-" in parts[1]:
        return None
    return {"portrait_id": parts[0], "garment_id": parts[1]}


def make_multi_result_filename(portrait_id, garment_ids):
    """Build result filename encoding per-person garment assignments.

    garment_ids: list of garment_id (str) or None per person.
    Example: portrait_id=abc123, garment_ids=["ef12ab34", None, "gh56cd78"]
    → "abc123_ef12ab34-x-gh56cd78_result.jpg"
    """
    slots = [gid if gid else "x" for gid in garment_ids]
    code = "-".join(slots)
    return f"{portrait_id}_{code}_result.jpg"


def parse_multi_result_filename(filename):
    stem = Path(filename).stem
    if not stem.endswith("_result"):
        return None
    stem = stem[:-len("_result")]
    parts = stem.split("_", 1)
    if len(parts) != 2:
        return None
    portrait_id = parts[0]
    code = parts[1]
    slots = code.split("-")
    garment_ids = [None if slot == "x" else slot for slot in slots]
    return {"portrait_id": portrait_id, "garment_ids": garment_ids}


def list_local_images(directory):
    d = Path(directory)
    if not d.exists():
        return []
    return sorted([str(p) for p in d.iterdir() if p.suffix.lower() in IMG_EXTS])


def file_url(remote_path):
    """Return a direct HF URL for a file in the dataset repo (public repo)."""
    return f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/{remote_path}"


def list_gallery_urls(prefix, subdir):
    """List files in dataset repo and return direct URLs for gallery display."""
    if not is_remote():
        return list_local_images(LOCAL_DATA / prefix / subdir)
    try:
        items = _api().list_repo_tree(
            DATA_REPO, repo_type=REPO_TYPE, path_in_repo=f"{prefix}/{subdir}"
        )
        urls = []
        for item in items:
            if hasattr(item, "rfilename"):
                name = item.rfilename
            elif hasattr(item, "path"):
                name = item.path
            else:
                continue
            if Path(name).suffix.lower() in IMG_EXTS:
                urls.append(file_url(name))
        return sorted(urls)
    except Exception:
        return list_local_images(LOCAL_DATA / prefix / subdir)


HF_URL_PREFIX = f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/"


def is_dataset_url(url):
    """Check if a URL points to our HF dataset repo."""
    return isinstance(url, str) and url.startswith(HF_URL_PREFIX)


def download_to_local(path_or_url):
    """Download a URL to local path. HF dataset URLs use hf_hub, other URLs use requests."""
    if not isinstance(path_or_url, str):
        return path_or_url
    if is_dataset_url(path_or_url):
        remote_path = path_or_url[len(HF_URL_PREFIX):]
        from huggingface_hub import hf_hub_download
        local = hf_hub_download(
            repo_id=DATA_REPO,
            repo_type=REPO_TYPE,
            filename=remote_path,
            token=DATASET_HF_TOKEN,
        )
        return local
    if path_or_url.startswith(("http://", "https://")):
        import requests
        from io import BytesIO
        resp = requests.get(path_or_url, timeout=30)
        resp.raise_for_status()
        img = Image.open(BytesIO(resp.content))
        tmp_path = LOCAL_DATA / "tmp" / f"{generate_id()}.jpg"
        tmp_path.parent.mkdir(parents=True, exist_ok=True)
        save_image(img, tmp_path)
        return str(tmp_path)
    return path_or_url


def load_image_sets(prefix):
    """Scan {prefix}/portraits/ dir, parse filenames, return list of dicts with matched files."""
    local_prefix = LOCAL_DATA / prefix
    if is_remote():
        download_dir(prefix)
    portraits_dir = local_prefix / "portraits"
    if not portraits_dir.exists():
        return []
    sets = {}
    for p in portraits_dir.iterdir():
        if p.suffix.lower() not in IMG_EXTS:
            continue
        parsed = parse_filename(p.name)
        if not parsed:
            continue
        item_id = parsed["id"]
        sets[item_id] = {
            "id": item_id,
            "portrait": str(p),
        }
    garments_dir = local_prefix / "garments"
    results_dir = local_prefix / "results"
    for item_id, entry in sets.items():
        garment = garments_dir / f"{item_id}_garment.jpg"
        result = results_dir / f"{item_id}_result.jpg"
        entry["garment"] = str(garment) if garment.exists() else None
        entry["result"] = str(result) if result.exists() else None
    return [v for v in sets.values() if v["garment"] is not None]


def save_image_set(prefix, img_portrait, img_garment, img_result=None):
    """Save a set of images (portrait + garment + optional result) with consistent naming."""
    item_id = generate_id()
    local_prefix = LOCAL_DATA / prefix

    portrait_name = make_filename(item_id, "portrait")
    garment_name = make_filename(item_id, "garment")

    portrait_path = local_prefix / "portraits" / portrait_name
    garment_path = local_prefix / "garments" / garment_name

    save_image(img_portrait, portrait_path)
    save_image(img_garment, garment_path)
    upload_image(portrait_path, f"{prefix}/portraits/{portrait_name}")
    upload_image(garment_path, f"{prefix}/garments/{garment_name}")

    result_path = None
    if img_result is not None:
        result_name = make_result_filename(item_id, item_id)
        result_path = local_prefix / "results" / result_name
        save_image(img_result, result_path)
        upload_image(result_path, f"{prefix}/results/{result_name}")

    return item_id, str(portrait_path), str(garment_path), str(result_path) if result_path else None


def save_result(prefix, portrait_id, garment_id, img_result):
    """Save a result image encoding both portrait and garment IDs."""
    local_prefix = LOCAL_DATA / prefix
    result_name = make_result_filename(portrait_id, garment_id)
    result_path = local_prefix / "results" / result_name
    save_image(img_result, result_path)
    upload_image(result_path, f"{prefix}/results/{result_name}")
    return str(result_path)


def save_multi_result(prefix, portrait_id, assignments, img_result):
    """Save a multi-garment result image with assignment-encoded filename."""
    local_prefix = LOCAL_DATA / prefix
    result_name = make_multi_result_filename(portrait_id, assignments)
    result_path = local_prefix / "results" / result_name
    save_image(img_result, result_path)
    upload_image(result_path, f"{prefix}/results/{result_name}")
    return str(result_path)


def delete_image_set(prefix, item_id):
    """Delete all files for an image set (scans for ID prefix to catch multi-garment files)."""
    local_prefix = LOCAL_DATA / prefix
    for subdir in ("portraits", "garments", "results"):
        d = local_prefix / subdir
        if not d.exists():
            continue
        for f in d.iterdir():
            if f.stem.startswith(item_id):
                f.unlink()
                if is_remote():
                    try:
                        delete_remote_file(f"{prefix}/{subdir}/{f.name}")
                    except Exception:
                        pass


def promote_to_example(result_path):
    """Copy a result file to examples, preserving its filename for resolution."""
    src = Path(result_path)
    dest = LOCAL_DATA / "examples" / "results" / src.name
    dest.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(str(src), str(dest))
    upload_image(dest, f"examples/results/{src.name}")
    return src.stem