File size: 5,402 Bytes
4716563
 
 
 
 
 
 
 
 
fac18b7
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac18b7
 
 
 
 
 
 
 
 
 
 
 
 
4716563
fac18b7
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fac18b7
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55c158e
4716563
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import json
from typing import List, Dict, Any, Tuple

import torch
from torch.utils.data import Dataset
from PIL import Image

from utils.transforms import build_train_transforms
from pathlib import Path


class PolyvoreTripletDataset(Dataset):
    """
    Creates (anchor, positive, negative) image triplets for training the ResNet embedder.
    Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths.

    Expected structure (customize as needed):
      root/
        images/<item_id>.jpg
        splits/train.json  # [{"anchor": id, "positive": id, "negative": id}, ...]
    """

    def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
        super().__init__()
        self.root = root
        self.split = split
        self.transforms = build_train_transforms(image_size=image_size)
        triplet_path = os.path.join(root, "splits", f"{split}.json")
        if not os.path.exists(triplet_path):
            raise FileNotFoundError(f"Triplet file not found: {triplet_path}")
        with open(triplet_path, "r") as f:
            self.samples: List[Dict[str, Any]] = json.load(f)

    def _find_image_path(self, item_id: str) -> str:
        base = os.path.join(self.root, "images")
        # direct common extensions
        for ext in (".jpg", ".jpeg", ".png", ".webp"):
            p = os.path.join(base, f"{item_id}{ext}")
            if os.path.isfile(p):
                return p
        # recursive fuzzy search
        for p in Path(base).rglob(f"*{item_id}*"):
            if p.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
                return str(p)
        raise FileNotFoundError(f"Image for item {item_id} not found under {base}")

    def _load_image(self, item_id: str) -> Image.Image:
        img_path = self._find_image_path(item_id)
        return Image.open(img_path).convert("RGB")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        a = self._load_image(str(s["anchor"]))
        p = self._load_image(str(s["positive"]))
        n = self._load_image(str(s["negative"]))
        return self.transforms(a), self.transforms(p), self.transforms(n)


class PolyvoreOutfitDataset(Dataset):
    """
    Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream.
    For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline.

    Expected structure:
      root/
        images/<item_id>.jpg
        splits/outfits_train.json  # [{"items": [id1,id2,...], "label": 1}, ...]
    """

    def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
        super().__init__()
        self.root = root
        self.split = split
        self.transforms = build_train_transforms(image_size=image_size)
        outfit_path = os.path.join(root, "splits", f"outfits_{split}.json")
        if not os.path.exists(outfit_path):
            raise FileNotFoundError(f"Outfit file not found: {outfit_path}")
        with open(outfit_path, "r") as f:
            self.samples: List[Dict[str, Any]] = json.load(f)
        # enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available
        # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.

    def _load_image(self, item_id: str) -> Image.Image:
        img_path = PolyvoreTripletDataset._find_image_path(self, item_id)  # reuse logic
        return Image.open(img_path).convert("RGB")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]]
        label = torch.tensor(int(s.get("label", 1)), dtype=torch.long)
        # Returns list of tensors; training loop can embed then pack to (N,D)
        return imgs, label


class PolyvoreOutfitTripletDataset(Dataset):
    """
    Outfit-level triplets for ViT triplet training: (good1, good2, bad).
    Expects file `outfit_triplets_<split>.json` with entries:
      {"good_a": [id...], "good_b": [id...], "bad": [id...]}
    """

    def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
        super().__init__()
        self.root = root
        self.split = split
        self.transforms = build_train_transforms(image_size=image_size)
        trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json")
        if not os.path.exists(trip_path):
            raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}")
        with open(trip_path, "r") as f:
            self.samples: List[Dict[str, Any]] = json.load(f)

    def _load_image(self, item_id: str) -> Image.Image:
        img_path = PolyvoreTripletDataset._find_image_path(self, item_id)
        return Image.open(img_path).convert("RGB")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]]
        gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]]
        bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]]
        return ga, gb, bd