Spaces:
Paused
Paused
File size: 5,402 Bytes
4716563 fac18b7 4716563 fac18b7 4716563 fac18b7 4716563 fac18b7 4716563 55c158e 4716563 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import json
from typing import List, Dict, Any, Tuple
import torch
from torch.utils.data import Dataset
from PIL import Image
from utils.transforms import build_train_transforms
from pathlib import Path
class PolyvoreTripletDataset(Dataset):
"""
Creates (anchor, positive, negative) image triplets for training the ResNet embedder.
Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths.
Expected structure (customize as needed):
root/
images/<item_id>.jpg
splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...]
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
triplet_path = os.path.join(root, "splits", f"{split}.json")
if not os.path.exists(triplet_path):
raise FileNotFoundError(f"Triplet file not found: {triplet_path}")
with open(triplet_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
def _find_image_path(self, item_id: str) -> str:
base = os.path.join(self.root, "images")
# direct common extensions
for ext in (".jpg", ".jpeg", ".png", ".webp"):
p = os.path.join(base, f"{item_id}{ext}")
if os.path.isfile(p):
return p
# recursive fuzzy search
for p in Path(base).rglob(f"*{item_id}*"):
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
return str(p)
raise FileNotFoundError(f"Image for item {item_id} not found under {base}")
def _load_image(self, item_id: str) -> Image.Image:
img_path = self._find_image_path(item_id)
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
a = self._load_image(str(s["anchor"]))
p = self._load_image(str(s["positive"]))
n = self._load_image(str(s["negative"]))
return self.transforms(a), self.transforms(p), self.transforms(n)
class PolyvoreOutfitDataset(Dataset):
"""
Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream.
For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline.
Expected structure:
root/
images/<item_id>.jpg
splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...]
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
outfit_path = os.path.join(root, "splits", f"outfits_{split}.json")
if not os.path.exists(outfit_path):
raise FileNotFoundError(f"Outfit file not found: {outfit_path}")
with open(outfit_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
# enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available
# If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
def _load_image(self, item_id: str) -> Image.Image:
img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]]
label = torch.tensor(int(s.get("label", 1)), dtype=torch.long)
# Returns list of tensors; training loop can embed then pack to (N,D)
return imgs, label
class PolyvoreOutfitTripletDataset(Dataset):
"""
Outfit-level triplets for ViT triplet training: (good1, good2, bad).
Expects file `outfit_triplets_<split>.json` with entries:
{"good_a": [id...], "good_b": [id...], "bad": [id...]}
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json")
if not os.path.exists(trip_path):
raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}")
with open(trip_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
def _load_image(self, item_id: str) -> Image.Image:
img_path = PolyvoreTripletDataset._find_image_path(self, item_id)
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]]
gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]]
bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]]
return ga, gb, bd
|