| import os |
| import uuid |
| import shutil |
| from pathlib import Path |
| from PIL import Image |
| import numpy as np |
|
|
| DATA_REPO = "aj406/vton-data" |
| REPO_TYPE = "dataset" |
| DATASET_HF_TOKEN = os.environ.get("DATASET_HF_TOKEN") |
| LOCAL_DATA = Path("data") |
|
|
| IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp"} |
|
|
|
|
| def is_remote(): |
| return DATASET_HF_TOKEN is not None |
|
|
|
|
| def _api(): |
| from huggingface_hub import HfApi |
| return HfApi() |
|
|
|
|
| def _ensure_repo(): |
| if not is_remote(): |
| return |
| _api().create_repo(repo_id=DATA_REPO, repo_type=REPO_TYPE, exist_ok=True, token=DATASET_HF_TOKEN) |
|
|
|
|
| def save_image(img, local_path): |
| local_path = Path(local_path) |
| local_path.parent.mkdir(parents=True, exist_ok=True) |
| if isinstance(img, np.ndarray): |
| img = Image.fromarray(img) |
| img.save(local_path, "JPEG", quality=85) |
|
|
|
|
| def upload_image(local_path, remote_path): |
| if not is_remote(): |
| return |
| _ensure_repo() |
| _api().upload_file( |
| path_or_fileobj=str(local_path), |
| path_in_repo=remote_path, |
| repo_id=DATA_REPO, |
| repo_type=REPO_TYPE, |
| token=DATASET_HF_TOKEN, |
| ) |
|
|
|
|
| def delete_remote_file(remote_path): |
| if not is_remote(): |
| return |
| _api().delete_file( |
| path_in_repo=remote_path, |
| repo_id=DATA_REPO, |
| repo_type=REPO_TYPE, |
| token=DATASET_HF_TOKEN, |
| ) |
|
|
|
|
| def download_dir(remote_prefix): |
| if not is_remote(): |
| return |
| from huggingface_hub import snapshot_download |
| snapshot_download( |
| repo_id=DATA_REPO, |
| repo_type=REPO_TYPE, |
| allow_patterns=f"{remote_prefix}/**", |
| local_dir=str(LOCAL_DATA), |
| token=DATASET_HF_TOKEN, |
| ) |
|
|
|
|
| def generate_id(): |
| return uuid.uuid4().hex[:8] |
|
|
|
|
| def make_filename(item_id, item_type): |
| return f"{item_id}_{item_type}.jpg" |
|
|
|
|
| def parse_filename(filename): |
| stem = Path(filename).stem |
| parts = stem.rsplit("_", 1) |
| if len(parts) != 2: |
| return None |
| return {"id": parts[0], "type": parts[1]} |
|
|
|
|
| def make_result_filename(portrait_id, garment_id): |
| return f"{portrait_id}_{garment_id}_result.jpg" |
|
|
|
|
| def parse_result_filename(filename): |
| stem = Path(filename).stem |
| parts = stem.rsplit("_", 2) |
| if len(parts) != 3 or parts[2] != "result": |
| return None |
| if "-" in parts[1]: |
| return None |
| return {"portrait_id": parts[0], "garment_id": parts[1]} |
|
|
|
|
| def make_multi_result_filename(portrait_id, garment_ids): |
| """Build result filename encoding per-person garment assignments. |
| |
| garment_ids: list of garment_id (str) or None per person. |
| Example: portrait_id=abc123, garment_ids=["ef12ab34", None, "gh56cd78"] |
| → "abc123_ef12ab34-x-gh56cd78_result.jpg" |
| """ |
| slots = [gid if gid else "x" for gid in garment_ids] |
| code = "-".join(slots) |
| return f"{portrait_id}_{code}_result.jpg" |
|
|
|
|
| def parse_multi_result_filename(filename): |
| stem = Path(filename).stem |
| if not stem.endswith("_result"): |
| return None |
| stem = stem[:-len("_result")] |
| parts = stem.split("_", 1) |
| if len(parts) != 2: |
| return None |
| portrait_id = parts[0] |
| code = parts[1] |
| slots = code.split("-") |
| garment_ids = [None if slot == "x" else slot for slot in slots] |
| return {"portrait_id": portrait_id, "garment_ids": garment_ids} |
|
|
|
|
| def list_local_images(directory): |
| d = Path(directory) |
| if not d.exists(): |
| return [] |
| return sorted([str(p) for p in d.iterdir() if p.suffix.lower() in IMG_EXTS]) |
|
|
|
|
| def file_url(remote_path): |
| """Return a direct HF URL for a file in the dataset repo (public repo).""" |
| return f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/{remote_path}" |
|
|
|
|
| def list_gallery_urls(prefix, subdir): |
| """List files in dataset repo and return direct URLs for gallery display.""" |
| if not is_remote(): |
| return list_local_images(LOCAL_DATA / prefix / subdir) |
| try: |
| items = _api().list_repo_tree( |
| DATA_REPO, repo_type=REPO_TYPE, path_in_repo=f"{prefix}/{subdir}" |
| ) |
| urls = [] |
| for item in items: |
| if hasattr(item, "rfilename"): |
| name = item.rfilename |
| elif hasattr(item, "path"): |
| name = item.path |
| else: |
| continue |
| if Path(name).suffix.lower() in IMG_EXTS: |
| urls.append(file_url(name)) |
| return sorted(urls) |
| except Exception: |
| return list_local_images(LOCAL_DATA / prefix / subdir) |
|
|
|
|
| HF_URL_PREFIX = f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/" |
|
|
|
|
| def is_dataset_url(url): |
| """Check if a URL points to our HF dataset repo.""" |
| return isinstance(url, str) and url.startswith(HF_URL_PREFIX) |
|
|
|
|
| def download_to_local(path_or_url): |
| """Download a URL to local path. HF dataset URLs use hf_hub, other URLs use requests.""" |
| if not isinstance(path_or_url, str): |
| return path_or_url |
| if is_dataset_url(path_or_url): |
| remote_path = path_or_url[len(HF_URL_PREFIX):] |
| from huggingface_hub import hf_hub_download |
| local = hf_hub_download( |
| repo_id=DATA_REPO, |
| repo_type=REPO_TYPE, |
| filename=remote_path, |
| token=DATASET_HF_TOKEN, |
| ) |
| return local |
| if path_or_url.startswith(("http://", "https://")): |
| import requests |
| from io import BytesIO |
| resp = requests.get(path_or_url, timeout=30) |
| resp.raise_for_status() |
| img = Image.open(BytesIO(resp.content)) |
| tmp_path = LOCAL_DATA / "tmp" / f"{generate_id()}.jpg" |
| tmp_path.parent.mkdir(parents=True, exist_ok=True) |
| save_image(img, tmp_path) |
| return str(tmp_path) |
| return path_or_url |
|
|
|
|
| def load_image_sets(prefix): |
| """Scan {prefix}/portraits/ dir, parse filenames, return list of dicts with matched files.""" |
| local_prefix = LOCAL_DATA / prefix |
| if is_remote(): |
| download_dir(prefix) |
| portraits_dir = local_prefix / "portraits" |
| if not portraits_dir.exists(): |
| return [] |
| sets = {} |
| for p in portraits_dir.iterdir(): |
| if p.suffix.lower() not in IMG_EXTS: |
| continue |
| parsed = parse_filename(p.name) |
| if not parsed: |
| continue |
| item_id = parsed["id"] |
| sets[item_id] = { |
| "id": item_id, |
| "portrait": str(p), |
| } |
| garments_dir = local_prefix / "garments" |
| results_dir = local_prefix / "results" |
| for item_id, entry in sets.items(): |
| garment = garments_dir / f"{item_id}_garment.jpg" |
| result = results_dir / f"{item_id}_result.jpg" |
| entry["garment"] = str(garment) if garment.exists() else None |
| entry["result"] = str(result) if result.exists() else None |
| return [v for v in sets.values() if v["garment"] is not None] |
|
|
|
|
| def save_image_set(prefix, img_portrait, img_garment, img_result=None): |
| """Save a set of images (portrait + garment + optional result) with consistent naming.""" |
| item_id = generate_id() |
| local_prefix = LOCAL_DATA / prefix |
|
|
| portrait_name = make_filename(item_id, "portrait") |
| garment_name = make_filename(item_id, "garment") |
|
|
| portrait_path = local_prefix / "portraits" / portrait_name |
| garment_path = local_prefix / "garments" / garment_name |
|
|
| save_image(img_portrait, portrait_path) |
| save_image(img_garment, garment_path) |
| upload_image(portrait_path, f"{prefix}/portraits/{portrait_name}") |
| upload_image(garment_path, f"{prefix}/garments/{garment_name}") |
|
|
| result_path = None |
| if img_result is not None: |
| result_name = make_result_filename(item_id, item_id) |
| result_path = local_prefix / "results" / result_name |
| save_image(img_result, result_path) |
| upload_image(result_path, f"{prefix}/results/{result_name}") |
|
|
| return item_id, str(portrait_path), str(garment_path), str(result_path) if result_path else None |
|
|
|
|
| def save_result(prefix, portrait_id, garment_id, img_result): |
| """Save a result image encoding both portrait and garment IDs.""" |
| local_prefix = LOCAL_DATA / prefix |
| result_name = make_result_filename(portrait_id, garment_id) |
| result_path = local_prefix / "results" / result_name |
| save_image(img_result, result_path) |
| upload_image(result_path, f"{prefix}/results/{result_name}") |
| return str(result_path) |
|
|
|
|
| def save_multi_result(prefix, portrait_id, assignments, img_result): |
| """Save a multi-garment result image with assignment-encoded filename.""" |
| local_prefix = LOCAL_DATA / prefix |
| result_name = make_multi_result_filename(portrait_id, assignments) |
| result_path = local_prefix / "results" / result_name |
| save_image(img_result, result_path) |
| upload_image(result_path, f"{prefix}/results/{result_name}") |
| return str(result_path) |
|
|
|
|
| def delete_image_set(prefix, item_id): |
| """Delete all files for an image set (scans for ID prefix to catch multi-garment files).""" |
| local_prefix = LOCAL_DATA / prefix |
| for subdir in ("portraits", "garments", "results"): |
| d = local_prefix / subdir |
| if not d.exists(): |
| continue |
| for f in d.iterdir(): |
| if f.stem.startswith(item_id): |
| f.unlink() |
| if is_remote(): |
| try: |
| delete_remote_file(f"{prefix}/{subdir}/{f.name}") |
| except Exception: |
| pass |
|
|
|
|
| def promote_to_example(result_path): |
| """Copy a result file to examples, preserving its filename for resolution.""" |
| src = Path(result_path) |
| dest = LOCAL_DATA / "examples" / "results" / src.name |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(str(src), str(dest)) |
| upload_image(dest, f"examples/results/{src.name}") |
| return src.stem |
|
|