VTON_TEST / storage.py
CI
deploy
6933b0e
import os
import uuid
import shutil
from pathlib import Path
from PIL import Image
import numpy as np
DATA_REPO = "aj406/vton-data"
REPO_TYPE = "dataset"
DATASET_HF_TOKEN = os.environ.get("DATASET_HF_TOKEN")
LOCAL_DATA = Path("data")
IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
def is_remote():
return DATASET_HF_TOKEN is not None
def _api():
from huggingface_hub import HfApi
return HfApi()
def _ensure_repo():
if not is_remote():
return
_api().create_repo(repo_id=DATA_REPO, repo_type=REPO_TYPE, exist_ok=True, token=DATASET_HF_TOKEN)
def save_image(img, local_path):
local_path = Path(local_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(local_path, "JPEG", quality=85)
def upload_image(local_path, remote_path):
if not is_remote():
return
_ensure_repo()
_api().upload_file(
path_or_fileobj=str(local_path),
path_in_repo=remote_path,
repo_id=DATA_REPO,
repo_type=REPO_TYPE,
token=DATASET_HF_TOKEN,
)
def delete_remote_file(remote_path):
if not is_remote():
return
_api().delete_file(
path_in_repo=remote_path,
repo_id=DATA_REPO,
repo_type=REPO_TYPE,
token=DATASET_HF_TOKEN,
)
def download_dir(remote_prefix):
if not is_remote():
return
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=DATA_REPO,
repo_type=REPO_TYPE,
allow_patterns=f"{remote_prefix}/**",
local_dir=str(LOCAL_DATA),
token=DATASET_HF_TOKEN,
)
def generate_id():
return uuid.uuid4().hex[:8]
def make_filename(item_id, item_type):
return f"{item_id}_{item_type}.jpg"
def parse_filename(filename):
stem = Path(filename).stem
parts = stem.rsplit("_", 1)
if len(parts) != 2:
return None
return {"id": parts[0], "type": parts[1]}
def make_result_filename(portrait_id, garment_id):
return f"{portrait_id}_{garment_id}_result.jpg"
def parse_result_filename(filename):
stem = Path(filename).stem
parts = stem.rsplit("_", 2)
if len(parts) != 3 or parts[2] != "result":
return None
if "-" in parts[1]:
return None
return {"portrait_id": parts[0], "garment_id": parts[1]}
def make_multi_result_filename(portrait_id, garment_ids):
"""Build result filename encoding per-person garment assignments.
garment_ids: list of garment_id (str) or None per person.
Example: portrait_id=abc123, garment_ids=["ef12ab34", None, "gh56cd78"]
→ "abc123_ef12ab34-x-gh56cd78_result.jpg"
"""
slots = [gid if gid else "x" for gid in garment_ids]
code = "-".join(slots)
return f"{portrait_id}_{code}_result.jpg"
def parse_multi_result_filename(filename):
stem = Path(filename).stem
if not stem.endswith("_result"):
return None
stem = stem[:-len("_result")]
parts = stem.split("_", 1)
if len(parts) != 2:
return None
portrait_id = parts[0]
code = parts[1]
slots = code.split("-")
garment_ids = [None if slot == "x" else slot for slot in slots]
return {"portrait_id": portrait_id, "garment_ids": garment_ids}
def list_local_images(directory):
d = Path(directory)
if not d.exists():
return []
return sorted([str(p) for p in d.iterdir() if p.suffix.lower() in IMG_EXTS])
def file_url(remote_path):
"""Return a direct HF URL for a file in the dataset repo (public repo)."""
return f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/{remote_path}"
def list_gallery_urls(prefix, subdir):
"""List files in dataset repo and return direct URLs for gallery display."""
if not is_remote():
return list_local_images(LOCAL_DATA / prefix / subdir)
try:
items = _api().list_repo_tree(
DATA_REPO, repo_type=REPO_TYPE, path_in_repo=f"{prefix}/{subdir}"
)
urls = []
for item in items:
if hasattr(item, "rfilename"):
name = item.rfilename
elif hasattr(item, "path"):
name = item.path
else:
continue
if Path(name).suffix.lower() in IMG_EXTS:
urls.append(file_url(name))
return sorted(urls)
except Exception:
return list_local_images(LOCAL_DATA / prefix / subdir)
HF_URL_PREFIX = f"https://huggingface.co/datasets/{DATA_REPO}/resolve/main/"
def is_dataset_url(url):
"""Check if a URL points to our HF dataset repo."""
return isinstance(url, str) and url.startswith(HF_URL_PREFIX)
def download_to_local(path_or_url):
"""Download a URL to local path. HF dataset URLs use hf_hub, other URLs use requests."""
if not isinstance(path_or_url, str):
return path_or_url
if is_dataset_url(path_or_url):
remote_path = path_or_url[len(HF_URL_PREFIX):]
from huggingface_hub import hf_hub_download
local = hf_hub_download(
repo_id=DATA_REPO,
repo_type=REPO_TYPE,
filename=remote_path,
token=DATASET_HF_TOKEN,
)
return local
if path_or_url.startswith(("http://", "https://")):
import requests
from io import BytesIO
resp = requests.get(path_or_url, timeout=30)
resp.raise_for_status()
img = Image.open(BytesIO(resp.content))
tmp_path = LOCAL_DATA / "tmp" / f"{generate_id()}.jpg"
tmp_path.parent.mkdir(parents=True, exist_ok=True)
save_image(img, tmp_path)
return str(tmp_path)
return path_or_url
def load_image_sets(prefix):
"""Scan {prefix}/portraits/ dir, parse filenames, return list of dicts with matched files."""
local_prefix = LOCAL_DATA / prefix
if is_remote():
download_dir(prefix)
portraits_dir = local_prefix / "portraits"
if not portraits_dir.exists():
return []
sets = {}
for p in portraits_dir.iterdir():
if p.suffix.lower() not in IMG_EXTS:
continue
parsed = parse_filename(p.name)
if not parsed:
continue
item_id = parsed["id"]
sets[item_id] = {
"id": item_id,
"portrait": str(p),
}
garments_dir = local_prefix / "garments"
results_dir = local_prefix / "results"
for item_id, entry in sets.items():
garment = garments_dir / f"{item_id}_garment.jpg"
result = results_dir / f"{item_id}_result.jpg"
entry["garment"] = str(garment) if garment.exists() else None
entry["result"] = str(result) if result.exists() else None
return [v for v in sets.values() if v["garment"] is not None]
def save_image_set(prefix, img_portrait, img_garment, img_result=None):
"""Save a set of images (portrait + garment + optional result) with consistent naming."""
item_id = generate_id()
local_prefix = LOCAL_DATA / prefix
portrait_name = make_filename(item_id, "portrait")
garment_name = make_filename(item_id, "garment")
portrait_path = local_prefix / "portraits" / portrait_name
garment_path = local_prefix / "garments" / garment_name
save_image(img_portrait, portrait_path)
save_image(img_garment, garment_path)
upload_image(portrait_path, f"{prefix}/portraits/{portrait_name}")
upload_image(garment_path, f"{prefix}/garments/{garment_name}")
result_path = None
if img_result is not None:
result_name = make_result_filename(item_id, item_id)
result_path = local_prefix / "results" / result_name
save_image(img_result, result_path)
upload_image(result_path, f"{prefix}/results/{result_name}")
return item_id, str(portrait_path), str(garment_path), str(result_path) if result_path else None
def save_result(prefix, portrait_id, garment_id, img_result):
"""Save a result image encoding both portrait and garment IDs."""
local_prefix = LOCAL_DATA / prefix
result_name = make_result_filename(portrait_id, garment_id)
result_path = local_prefix / "results" / result_name
save_image(img_result, result_path)
upload_image(result_path, f"{prefix}/results/{result_name}")
return str(result_path)
def save_multi_result(prefix, portrait_id, assignments, img_result):
"""Save a multi-garment result image with assignment-encoded filename."""
local_prefix = LOCAL_DATA / prefix
result_name = make_multi_result_filename(portrait_id, assignments)
result_path = local_prefix / "results" / result_name
save_image(img_result, result_path)
upload_image(result_path, f"{prefix}/results/{result_name}")
return str(result_path)
def delete_image_set(prefix, item_id):
"""Delete all files for an image set (scans for ID prefix to catch multi-garment files)."""
local_prefix = LOCAL_DATA / prefix
for subdir in ("portraits", "garments", "results"):
d = local_prefix / subdir
if not d.exists():
continue
for f in d.iterdir():
if f.stem.startswith(item_id):
f.unlink()
if is_remote():
try:
delete_remote_file(f"{prefix}/{subdir}/{f.name}")
except Exception:
pass
def promote_to_example(result_path):
"""Copy a result file to examples, preserving its filename for resolution."""
src = Path(result_path)
dest = LOCAL_DATA / "examples" / "results" / src.name
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(str(src), str(dest))
upload_image(dest, f"examples/results/{src.name}")
return src.stem