import os import json from typing import List, Dict, Any, Tuple import torch from torch.utils.data import Dataset from PIL import Image from utils.transforms import build_train_transforms from pathlib import Path class PolyvoreTripletDataset(Dataset): """ Creates (anchor, positive, negative) image triplets for training the ResNet embedder. Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths. Expected structure (customize as needed): root/ images/.jpg splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...] """ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: super().__init__() self.root = root self.split = split self.transforms = build_train_transforms(image_size=image_size) triplet_path = os.path.join(root, "splits", f"{split}.json") if not os.path.exists(triplet_path): raise FileNotFoundError(f"Triplet file not found: {triplet_path}") with open(triplet_path, "r") as f: self.samples: List[Dict[str, Any]] = json.load(f) def _find_image_path(self, item_id: str) -> str: base = os.path.join(self.root, "images") # direct common extensions for ext in (".jpg", ".jpeg", ".png", ".webp"): p = os.path.join(base, f"{item_id}{ext}") if os.path.isfile(p): return p # recursive fuzzy search for p in Path(base).rglob(f"*{item_id}*"): if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"): return str(p) raise FileNotFoundError(f"Image for item {item_id} not found under {base}") def _load_image(self, item_id: str) -> Image.Image: img_path = self._find_image_path(item_id) return Image.open(img_path).convert("RGB") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): s = self.samples[idx] a = self._load_image(str(s["anchor"])) p = self._load_image(str(s["positive"])) n = self._load_image(str(s["negative"])) return self.transforms(a), self.transforms(p), self.transforms(n) class PolyvoreOutfitDataset(Dataset): """ Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream. For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline. Expected structure: root/ images/.jpg splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...] """ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: super().__init__() self.root = root self.split = split self.transforms = build_train_transforms(image_size=image_size) outfit_path = os.path.join(root, "splits", f"outfits_{split}.json") if not os.path.exists(outfit_path): raise FileNotFoundError(f"Outfit file not found: {outfit_path}") with open(outfit_path, "r") as f: self.samples: List[Dict[str, Any]] = json.load(f) # enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is. def _load_image(self, item_id: str) -> Image.Image: img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic return Image.open(img_path).convert("RGB") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): s = self.samples[idx] imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]] label = torch.tensor(int(s.get("label", 1)), dtype=torch.long) # Returns list of tensors; training loop can embed then pack to (N,D) return imgs, label class PolyvoreOutfitTripletDataset(Dataset): """ Outfit-level triplets for ViT triplet training: (good1, good2, bad). Expects file `outfit_triplets_.json` with entries: {"good_a": [id...], "good_b": [id...], "bad": [id...]} """ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None: super().__init__() self.root = root self.split = split self.transforms = build_train_transforms(image_size=image_size) trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json") if not os.path.exists(trip_path): raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}") with open(trip_path, "r") as f: self.samples: List[Dict[str, Any]] = json.load(f) def _load_image(self, item_id: str) -> Image.Image: img_path = PolyvoreTripletDataset._find_image_path(self, item_id) return Image.open(img_path).convert("RGB") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): s = self.samples[idx] ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]] gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]] bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]] return ga, gb, bd