recomendation / data /polyvore.py
Ali Mohsin
more fix
55c158e
import os
import json
from typing import List, Dict, Any, Tuple
import torch
from torch.utils.data import Dataset
from PIL import Image
from utils.transforms import build_train_transforms
from pathlib import Path
class PolyvoreTripletDataset(Dataset):
"""
Creates (anchor, positive, negative) image triplets for training the ResNet embedder.
Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths.
Expected structure (customize as needed):
root/
images/<item_id>.jpg
splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...]
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
triplet_path = os.path.join(root, "splits", f"{split}.json")
if not os.path.exists(triplet_path):
raise FileNotFoundError(f"Triplet file not found: {triplet_path}")
with open(triplet_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
def _find_image_path(self, item_id: str) -> str:
base = os.path.join(self.root, "images")
# direct common extensions
for ext in (".jpg", ".jpeg", ".png", ".webp"):
p = os.path.join(base, f"{item_id}{ext}")
if os.path.isfile(p):
return p
# recursive fuzzy search
for p in Path(base).rglob(f"*{item_id}*"):
if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
return str(p)
raise FileNotFoundError(f"Image for item {item_id} not found under {base}")
def _load_image(self, item_id: str) -> Image.Image:
img_path = self._find_image_path(item_id)
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
a = self._load_image(str(s["anchor"]))
p = self._load_image(str(s["positive"]))
n = self._load_image(str(s["negative"]))
return self.transforms(a), self.transforms(p), self.transforms(n)
class PolyvoreOutfitDataset(Dataset):
"""
Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream.
For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline.
Expected structure:
root/
images/<item_id>.jpg
splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...]
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
outfit_path = os.path.join(root, "splits", f"outfits_{split}.json")
if not os.path.exists(outfit_path):
raise FileNotFoundError(f"Outfit file not found: {outfit_path}")
with open(outfit_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
# enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available
# If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
def _load_image(self, item_id: str) -> Image.Image:
img_path = PolyvoreTripletDataset._find_image_path(self, item_id) # reuse logic
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]]
label = torch.tensor(int(s.get("label", 1)), dtype=torch.long)
# Returns list of tensors; training loop can embed then pack to (N,D)
return imgs, label
class PolyvoreOutfitTripletDataset(Dataset):
"""
Outfit-level triplets for ViT triplet training: (good1, good2, bad).
Expects file `outfit_triplets_<split>.json` with entries:
{"good_a": [id...], "good_b": [id...], "bad": [id...]}
"""
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
super().__init__()
self.root = root
self.split = split
self.transforms = build_train_transforms(image_size=image_size)
trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json")
if not os.path.exists(trip_path):
raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}")
with open(trip_path, "r") as f:
self.samples: List[Dict[str, Any]] = json.load(f)
def _load_image(self, item_id: str) -> Image.Image:
img_path = PolyvoreTripletDataset._find_image_path(self, item_id)
return Image.open(img_path).convert("RGB")
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int):
s = self.samples[idx]
ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]]
gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]]
bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]]
return ga, gb, bd