Spaces:
Paused
Paused
| import os | |
| import json | |
| from typing import List, Dict, Any, Tuple | |
| import torch | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from utils.transforms import build_train_transforms | |
| from pathlib import Path | |
| class PolyvoreTripletDataset(Dataset): | |
| """ | |
| Creates (anchor, positive, negative) image triplets for training the ResNet embedder. | |
| Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths. | |
| Expected structure (customize as needed): | |
| root/ | |
| images/<item_id>.jpg | |
| splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...] | |
| """ | |
| def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: | |
| super().__init__() | |
| self.root = root | |
| self.split = split | |
| self.transforms = build_train_transforms(image_size=image_size) | |
| triplet_path = os.path.join(root, "splits", f"{split}.json") | |
| if not os.path.exists(triplet_path): | |
| raise FileNotFoundError(f"Triplet file not found: {triplet_path}") | |
| with open(triplet_path, "r") as f: | |
| self.samples: List[Dict[str, Any]] = json.load(f) | |
| def _find_image_path(self, item_id: str) -> str: | |
| base = os.path.join(self.root, "images") | |
| # direct common extensions | |
| for ext in (".jpg", ".jpeg", ".png", ".webp"): | |
| p = os.path.join(base, f"{item_id}{ext}") | |
| if os.path.isfile(p): | |
| return p | |
| # recursive fuzzy search | |
| for p in Path(base).rglob(f"*{item_id}*"): | |
| if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"): | |
| return str(p) | |
| raise FileNotFoundError(f"Image for item {item_id} not found under {base}") | |
| def _load_image(self, item_id: str) -> Image.Image: | |
| img_path = self._find_image_path(item_id) | |
| return Image.open(img_path).convert("RGB") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int): | |
| s = self.samples[idx] | |
| a = self._load_image(str(s["anchor"])) | |
| p = self._load_image(str(s["positive"])) | |
| n = self._load_image(str(s["negative"])) | |
| return self.transforms(a), self.transforms(p), self.transforms(n) | |
| class PolyvoreOutfitDataset(Dataset): | |
| """ | |
| Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream. | |
| For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline. | |
| Expected structure: | |
| root/ | |
| images/<item_id>.jpg | |
| splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...] | |
| """ | |
| def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: | |
| super().__init__() | |
| self.root = root | |
| self.split = split | |
| self.transforms = build_train_transforms(image_size=image_size) | |
| outfit_path = os.path.join(root, "splits", f"outfits_{split}.json") | |
| if not os.path.exists(outfit_path): | |
| raise FileNotFoundError(f"Outfit file not found: {outfit_path}") | |
| with open(outfit_path, "r") as f: | |
| self.samples: List[Dict[str, Any]] = json.load(f) | |
| # enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available | |
| # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is. | |
| def _load_image(self, item_id: str) -> Image.Image: | |
| img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic | |
| return Image.open(img_path).convert("RGB") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int): | |
| s = self.samples[idx] | |
| imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]] | |
| label = torch.tensor(int(s.get("label", 1)), dtype=torch.long) | |
| # Returns list of tensors; training loop can embed then pack to (N,D) | |
| return imgs, label | |
| class PolyvoreOutfitTripletDataset(Dataset): | |
| """ | |
| Outfit-level triplets for ViT triplet training: (good1, good2, bad). | |
| Expects file `outfit_triplets_<split>.json` with entries: | |
| {"good_a": [id...], "good_b": [id...], "bad": [id...]} | |
| """ | |
| def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: | |
| super().__init__() | |
| self.root = root | |
| self.split = split | |
| self.transforms = build_train_transforms(image_size=image_size) | |
| trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json") | |
| if not os.path.exists(trip_path): | |
| raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}") | |
| with open(trip_path, "r") as f: | |
| self.samples: List[Dict[str, Any]] = json.load(f) | |
| def _load_image(self, item_id: str) -> Image.Image: | |
| img_path = PolyvoreTripletDataset._find_image_path(self, item_id) | |
| return Image.open(img_path).convert("RGB") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int): | |
| s = self.samples[idx] | |
| ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]] | |
| gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]] | |
| bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]] | |
| return ga, gb, bd | |