Ali Mohsin commited on
Commit
4716563
·
1 Parent(s): 4ef6571

intial starter

Browse files
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ HF_HUB_ENABLE_HF_TRANSFER=1
7
+
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ git \
11
+ curl \
12
+ ca-certificates \
13
+ libgomp1 \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ WORKDIR /app
17
+
18
+ COPY recommendation/requirements.txt /app/requirements.txt
19
+ RUN pip install --upgrade pip && pip install -r /app/requirements.txt
20
+
21
+ COPY recommendation /app/
22
+
23
+ EXPOSE 8000
24
+ EXPOSE 7860
25
+
26
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
27
+
28
+
29
+
30
+
31
+
README.md CHANGED
@@ -1,4 +1,3 @@
1
- ---
2
  title: Recomendation
3
  emoji: 🏆
4
  colorFrom: purple
@@ -6,7 +5,4 @@ colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
  app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: Recomendation
2
  emoji: 🏆
3
  colorFrom: purple
 
5
  sdk: gradio
6
  sdk_version: 5.44.1
7
  app_file: app.py
8
+ pinned: false
 
 
 
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import io
4
+ from typing import List, Optional, Any, Dict
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel
12
+ from PIL import Image
13
+ from starlette.staticfiles import StaticFiles
14
+ import threading
15
+
16
+ from inference import InferenceService
17
+ from utils.data_fetch import ensure_dataset_ready
18
+
19
+
20
+ AI_API_KEY = os.getenv("AI_API_KEY")
21
+
22
+
23
+ def require_api_key(x_api_key: Optional[str]):
24
+ if AI_API_KEY and x_api_key != AI_API_KEY:
25
+ raise HTTPException(status_code=401, detail="Invalid API key")
26
+
27
+
28
+ class EmbedRequest(BaseModel):
29
+ image_urls: Optional[List[str]] = None
30
+ images_base64: Optional[List[str]] = None
31
+
32
+
33
+ class Item(BaseModel):
34
+ id: str
35
+ embedding: Optional[List[float]] = None
36
+ category: Optional[str] = None
37
+ image_url: Optional[str] = None
38
+
39
+
40
+ class ComposeRequest(BaseModel):
41
+ items: List[Item]
42
+ context: Optional[Dict[str, Any]] = None
43
+
44
+
45
+ app = FastAPI(title="Dressify Recommendation Service")
46
+ service = InferenceService()
47
+
48
+ # Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
49
+ BOOT_STATUS = "idle"
50
+
51
+
52
+ def _background_bootstrap():
53
+ global BOOT_STATUS
54
+ try:
55
+ BOOT_STATUS = "preparing-dataset"
56
+ ds_root = ensure_dataset_ready()
57
+ if not ds_root:
58
+ BOOT_STATUS = "dataset-not-prepared"
59
+ return
60
+
61
+ # Prepare 70/10/10 splits if missing
62
+ splits_dir = os.path.join(ds_root, "splits")
63
+ need_prepare = not (
64
+ os.path.isfile(os.path.join(splits_dir, "train.json")) or
65
+ os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json"))
66
+ )
67
+ if need_prepare:
68
+ BOOT_STATUS = "creating-splits"
69
+ os.makedirs(splits_dir, exist_ok=True)
70
+ from scripts.prepare_polyvore import main as prepare_main
71
+ os.environ.setdefault("PYTHONWARNINGS", "ignore")
72
+ import sys
73
+ argv_bak = sys.argv
74
+ try:
75
+ sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--random_split"]
76
+ prepare_main()
77
+ finally:
78
+ sys.argv = argv_bak
79
+
80
+ # Train if checkpoints are absent
81
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
82
+ os.makedirs(export_dir, exist_ok=True)
83
+ resnet_ckpt = os.path.join(export_dir, "resnet_item_embedder_best.pth")
84
+ vit_ckpt = os.path.join(export_dir, "vit_outfit_model_best.pth")
85
+ import subprocess
86
+ if not os.path.exists(resnet_ckpt):
87
+ BOOT_STATUS = "training-resnet"
88
+ subprocess.run([
89
+ "python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3",
90
+ "--out", os.path.join(export_dir, "resnet_item_embedder.pth")
91
+ ], check=False)
92
+ if not os.path.exists(vit_ckpt):
93
+ BOOT_STATUS = "training-vit"
94
+ subprocess.run([
95
+ "python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
96
+ "--export", os.path.join(export_dir, "vit_outfit_model.pth")
97
+ ], check=False)
98
+ service.reload_models()
99
+ BOOT_STATUS = "ready"
100
+ except Exception as e:
101
+ BOOT_STATUS = f"error: {e}"
102
+
103
+
104
+ threading.Thread(target=_background_bootstrap, daemon=True).start()
105
+
106
+
107
+ @app.get("/health")
108
+ def health() -> dict:
109
+ return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
110
+
111
+
112
+ @app.post("/embed")
113
+ def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
114
+ require_api_key(x_api_key)
115
+ images: List[Image.Image] = []
116
+ if req.image_urls:
117
+ for url in req.image_urls:
118
+ resp = requests.get(url, timeout=20)
119
+ resp.raise_for_status()
120
+ images.append(Image.open(io.BytesIO(resp.content)).convert("RGB"))
121
+ if req.images_base64:
122
+ for b64 in req.images_base64:
123
+ images.append(Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB"))
124
+ if not images:
125
+ raise HTTPException(status_code=400, detail="No images provided")
126
+ embs = service.embed_images(images)
127
+ return {"embeddings": [e.tolist() for e in embs], "model_version": service.resnet_version}
128
+
129
+
130
+ @app.post("/compose")
131
+ def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict:
132
+ require_api_key(x_api_key)
133
+ items = [
134
+ {
135
+ "id": it.id,
136
+ "embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None,
137
+ "category": it.category,
138
+ "image_url": it.image_url,
139
+ }
140
+ for it in req.items
141
+ ]
142
+ outfits = service.compose_outfits(items, context=req.context or {})
143
+ return {"outfits": outfits, "version": service.vit_version}
144
+
145
+
146
+ @app.get("/artifacts")
147
+ def artifacts() -> dict:
148
+ # list exported model artifacts for download
149
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
150
+ files = []
151
+ if os.path.isdir(export_dir):
152
+ for fn in os.listdir(export_dir):
153
+ if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
154
+ files.append({
155
+ "name": fn,
156
+ "path": f"{export_dir}/{fn}",
157
+ "url": f"/files/{fn}",
158
+ })
159
+ return {"artifacts": files}
160
+
161
+
162
+ # --------- Gradio UI ---------
163
+
164
+ def _load_images_from_files(files: List[str]) -> List[Image.Image]:
165
+ images: List[Image.Image] = []
166
+ for fp in files:
167
+ try:
168
+ with Image.open(fp) as im:
169
+ images.append(im.convert("RGB"))
170
+ except Exception:
171
+ continue
172
+ return images
173
+
174
+
175
+ def gradio_embed(files: List[str]):
176
+ if not files:
177
+ return "[]"
178
+ images = _load_images_from_files(files)
179
+ if not images:
180
+ return "[]"
181
+ embs = service.embed_images(images)
182
+ return str([e.tolist() for e in embs])
183
+
184
+
185
+ def gradio_compose(files: List[str], occasion: str, weather: str, num_outfits: int):
186
+ if not files:
187
+ return []
188
+ images = _load_images_from_files(files)
189
+ if not images:
190
+ return []
191
+ embs = service.embed_images(images)
192
+ items = [
193
+ {"id": f"item_{i}", "embedding": embs[i], "category": None, "image_url": None}
194
+ for i in range(len(embs))
195
+ ]
196
+ results = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
197
+ # Render as a simple markdown summary
198
+ lines = []
199
+ for r in results:
200
+ lines.append(f"score={r['score']:.3f}, items={r['item_ids']}")
201
+ return "\n".join(lines)
202
+
203
+
204
+ with gr.Blocks() as demo:
205
+ gr.Markdown("# Dressify Recommendations – HF Test UI")
206
+ with gr.Tab("Embed"):
207
+ inp = gr.Files(label="Upload Items (multiple images)")
208
+ out = gr.Textbox(label="Embeddings (JSON)")
209
+ btn = gr.Button("Compute Embeddings")
210
+ btn.click(fn=gradio_embed, inputs=inp, outputs=out)
211
+ with gr.Tab("Compose"):
212
+ inp2 = gr.Files(label="Upload Wardrobe (multiple images)")
213
+ occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
214
+ weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
215
+ num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Num outfits")
216
+ out2 = gr.Textbox(label="Recommendations")
217
+ btn2 = gr.Button("Generate")
218
+ btn2.click(fn=gradio_compose, inputs=[inp2, occasion, weather, num_outfits], outputs=out2)
219
+ with gr.Tab("Downloads"):
220
+ gr.Markdown("Download trained artifacts from models/exports")
221
+ file_list = gr.JSON(label="Artifacts JSON")
222
+ def list_artifacts_for_ui():
223
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
224
+ files = []
225
+ if os.path.isdir(export_dir):
226
+ for fn in os.listdir(export_dir):
227
+ if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
228
+ files.append({
229
+ "name": fn,
230
+ "path": f"{export_dir}/{fn}",
231
+ "url": f"/files/{fn}",
232
+ })
233
+ return {"artifacts": files}
234
+ refresh = gr.Button("Refresh")
235
+ refresh.click(fn=lambda: list_artifacts_for_ui(), inputs=[], outputs=file_list)
236
+ with gr.Tab("Status"):
237
+ gr.Markdown("Startup & training status")
238
+ status = gr.Textbox(label="Status", value=lambda: BOOT_STATUS)
239
+ refresh_status = gr.Button("Refresh Status")
240
+ refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
241
+
242
+
243
+ try:
244
+ # Mount Gradio onto FastAPI root path
245
+ app = gr.mount_gradio_app(app, demo, path="/")
246
+ except Exception:
247
+ # In case mounting fails in certain runners, we still want FastAPI to be available
248
+ pass
249
+
250
+ # Mount static files for direct artifact download
251
+ export_dir = os.getenv("EXPORT_DIR", "models/exports")
252
+ os.makedirs(export_dir, exist_ok=True)
253
+ try:
254
+ app.mount("/files", StaticFiles(directory=export_dir), name="files")
255
+ except Exception:
256
+ pass
257
+
258
+
259
+ if __name__ == "__main__":
260
+ # Local testing
261
+ demo.launch()
262
+
263
+
configs/default.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default training/eval configuration
2
+
3
+ seed: 42
4
+ image_size: 224
5
+
6
+ resnet:
7
+ backbone: resnet50
8
+ embedding_dim: 512
9
+ batch_size: 64
10
+ epochs: 20
11
+ lr: 0.001
12
+ weight_decay: 0.0001
13
+ use_imagenet_pretrained: true
14
+
15
+ vit:
16
+ embedding_dim: 512
17
+ num_layers: 6
18
+ num_heads: 8
19
+ ff_multiplier: 4
20
+ dropout: 0.1
21
+ batch_size: 64
22
+ epochs: 30
23
+ lr: 0.0005
24
+ weight_decay: 0.05
25
+
26
+ paths:
27
+ polyvore_root: /kaggle/input/polyvore-outfits
28
+ resnet_ckpt: models/exports/resnet_item_embedder.pth
29
+ vit_ckpt: models/exports/vit_outfit_model.pth
30
+ export_dir: models/exports
31
+
32
+
33
+
data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
data/polyvore.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List, Dict, Any, Tuple
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from PIL import Image
8
+
9
+ from utils.transforms import build_train_transforms
10
+
11
+
12
+ class PolyvoreTripletDataset(Dataset):
13
+ """
14
+ Creates (anchor, positive, negative) image triplets for training the ResNet embedder.
15
+ Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths.
16
+
17
+ Expected structure (customize as needed):
18
+ root/
19
+ images/<item_id>.jpg
20
+ splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...]
21
+ """
22
+
23
+ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
24
+ super().__init__()
25
+ self.root = root
26
+ self.split = split
27
+ self.transforms = build_train_transforms(image_size=image_size)
28
+ triplet_path = os.path.join(root, "splits", f"{split}.json")
29
+ if not os.path.exists(triplet_path):
30
+ raise FileNotFoundError(f"Triplet file not found: {triplet_path}")
31
+ with open(triplet_path, "r") as f:
32
+ self.samples: List[Dict[str, Any]] = json.load(f)
33
+
34
+ def _load_image(self, item_id: str) -> Image.Image:
35
+ # Customize if images are arranged differently
36
+ img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
37
+ if not os.path.exists(img_path):
38
+ raise FileNotFoundError(img_path)
39
+ return Image.open(img_path).convert("RGB")
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.samples)
43
+
44
+ def __getitem__(self, idx: int):
45
+ s = self.samples[idx]
46
+ a = self._load_image(str(s["anchor"]))
47
+ p = self._load_image(str(s["positive"]))
48
+ n = self._load_image(str(s["negative"]))
49
+ return self.transforms(a), self.transforms(p), self.transforms(n)
50
+
51
+
52
+ class PolyvoreOutfitDataset(Dataset):
53
+ """
54
+ Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream.
55
+ For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline.
56
+
57
+ Expected structure:
58
+ root/
59
+ images/<item_id>.jpg
60
+ splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...]
61
+ """
62
+
63
+ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
64
+ super().__init__()
65
+ self.root = root
66
+ self.split = split
67
+ self.transforms = build_train_transforms(image_size=image_size)
68
+ outfit_path = os.path.join(root, "splits", f"outfits_{split}.json")
69
+ if not os.path.exists(outfit_path):
70
+ raise FileNotFoundError(f"Outfit file not found: {outfit_path}")
71
+ with open(outfit_path, "r") as f:
72
+ self.samples: List[Dict[str, Any]] = json.load(f)
73
+ # enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available
74
+ # If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
75
+
76
+ def _load_image(self, item_id: str) -> Image.Image:
77
+ img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
78
+ if not os.path.exists(img_path):
79
+ raise FileNotFoundError(img_path)
80
+ return Image.open(img_path).convert("RGB")
81
+
82
+ def __len__(self) -> int:
83
+ return len(self.samples)
84
+
85
+ def __getitem__(self, idx: int):
86
+ s = self.samples[idx]
87
+ imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]]
88
+ label = torch.tensor(int(s.get("label", 1)), dtype=torch.long)
89
+ # Returns list of tensors; training loop can embed then pack to (N,D)
90
+ return imgs, label
91
+
92
+
93
+ class PolyvoreOutfitTripletDataset(Dataset):
94
+ """
95
+ Outfit-level triplets for ViT triplet training: (good1, good2, bad).
96
+ Expects file `outfit_triplets_<split>.json` with entries:
97
+ {"good_a": [id...], "good_b": [id...], "bad": [id...]}
98
+ """
99
+
100
+ def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
101
+ super().__init__()
102
+ self.root = root
103
+ self.split = split
104
+ self.transforms = build_train_transforms(image_size=image_size)
105
+ trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json")
106
+ if not os.path.exists(trip_path):
107
+ raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}")
108
+ with open(trip_path, "r") as f:
109
+ self.samples: List[Dict[str, Any]] = json.load(f)
110
+
111
+ def _load_image(self, item_id: str) -> Image.Image:
112
+ img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
113
+ if not os.path.exists(img_path):
114
+ raise FileNotFoundError(img_path)
115
+ return Image.open(img_path).convert("RGB")
116
+
117
+ def __len__(self) -> int:
118
+ return len(self.samples)
119
+
120
+ def __getitem__(self, idx: int):
121
+ s = self.samples[idx]
122
+ ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]]
123
+ gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]]
124
+ bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]]
125
+ return ga, gb, bd
126
+
127
+
inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from PIL import Image
8
+
9
+ from utils.transforms import build_inference_transform
10
+ from models.resnet_embedder import ResNetItemEmbedder
11
+ from models.vit_outfit import OutfitCompatibilityModel
12
+
13
+
14
+ def _get_device() -> str:
15
+ if torch.cuda.is_available():
16
+ return "cuda"
17
+ if torch.backends.mps.is_available():
18
+ return "mps"
19
+ return "cpu"
20
+
21
+
22
+ class InferenceService:
23
+ def __init__(self) -> None:
24
+ self.device = _get_device()
25
+ self.transform = build_inference_transform()
26
+ self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
27
+ self.resnet_version = "resnet_v1"
28
+ self.vit_version = "vit_v1"
29
+
30
+ self.resnet = self._load_resnet().to(self.device).eval()
31
+ self.vit = self._load_vit().to(self.device).eval()
32
+
33
+ for m in [self.resnet, self.vit]:
34
+ for p in m.parameters():
35
+ p.requires_grad_(False)
36
+
37
+ def _load_resnet(self) -> nn.Module:
38
+ strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
39
+ ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
40
+ model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
41
+ if strategy == "random":
42
+ return model
43
+ # prefer best if present
44
+ best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
45
+ if os.path.exists(best_path):
46
+ ckpt_to_use = best_path
47
+ else:
48
+ ckpt_to_use = ckpt_path
49
+ if os.path.exists(ckpt_to_use):
50
+ state = torch.load(ckpt_to_use, map_location="cpu")
51
+ # accept either full state_dict or {"state_dict": ...}
52
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
53
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
54
+ if len(unexpected) == 0:
55
+ return model
56
+ return model
57
+
58
+ def _load_vit(self) -> nn.Module:
59
+ strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
60
+ ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
61
+ model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
62
+ if strategy == "random":
63
+ return model
64
+ best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
65
+ ckpt_to_use = best_path if os.path.exists(best_path) else ckpt_path
66
+ if os.path.exists(ckpt_to_use):
67
+ state = torch.load(ckpt_to_use, map_location="cpu")
68
+ state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
69
+ model.load_state_dict(state_dict, strict=False)
70
+ return model
71
+ return model
72
+
73
+ def reload_models(self) -> None:
74
+ """Reload weights from current checkpoint locations (used after background training)."""
75
+ self.resnet = self._load_resnet().to(self.device).eval()
76
+ self.vit = self._load_vit().to(self.device).eval()
77
+ for m in [self.resnet, self.vit]:
78
+ for p in m.parameters():
79
+ p.requires_grad_(False)
80
+
81
+ @torch.inference_mode()
82
+ def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
83
+ if len(images) == 0:
84
+ return []
85
+ batch = torch.stack([self.transform(img) for img in images]).to(self.device)
86
+ emb = self.resnet(batch)
87
+ emb = nn.functional.normalize(emb, dim=-1)
88
+ return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
89
+
90
+ @torch.inference_mode()
91
+ def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
92
+ # Ensure embeddings
93
+ proc_items: List[Dict[str, Any]] = []
94
+ for it in items:
95
+ e = it.get("embedding")
96
+ if e is None and it.get("image") is not None:
97
+ # Not used in Gradio path, but kept for completeness
98
+ emb = self.embed_images([it["image"]])[0]
99
+ elif e is None:
100
+ # If missing embedding and image, skip
101
+ continue
102
+ else:
103
+ emb = np.asarray(e, dtype=np.float32)
104
+ proc_items.append({"id": it.get("id"), "embedding": emb, "category": it.get("category")})
105
+
106
+ if len(proc_items) < 2:
107
+ return []
108
+
109
+ # Candidate generation: enforce minimum slots (upper, bottom, shoes, accessory) if categories provided
110
+ rng = np.random.default_rng(42)
111
+ num_outfits = int(context.get("num_outfits", 3))
112
+ min_size, max_size = 3, 5
113
+ candidates: List[List[int]] = []
114
+ ids = list(range(len(proc_items)))
115
+ # slot-aware sampling if categories exist
116
+ def has_cat(i: int, cat_prefix: str) -> bool:
117
+ c = (proc_items[i].get("category") or "").lower()
118
+ return cat_prefix in c
119
+
120
+ uppers = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"]) ]
121
+ bottoms = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["pant", "trouser", "jean", "skirt", "short"]) ]
122
+ shoes = [i for i in ids if "shoe" in (proc_items[i].get("category") or "").lower()]
123
+ accs = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["watch", "belt", "ring", "bracelet", "accessor"]) ]
124
+
125
+ for _ in range(num_outfits * 10):
126
+ if uppers and bottoms and shoes and accs:
127
+ subset = [rng.choice(uppers).item(), rng.choice(bottoms).item(), rng.choice(shoes).item(), rng.choice(accs).item()]
128
+ # optional: add one more random
129
+ remain = list(set(ids) - set(subset))
130
+ if remain and rng.random() < 0.5:
131
+ subset.append(rng.choice(remain).item())
132
+ else:
133
+ k = rng.integers(min_size, max_size + 1)
134
+ subset = rng.choice(ids, size=int(k), replace=False).tolist()
135
+ candidates.append(subset)
136
+
137
+ # Score using ViT
138
+ def score_subset(idx_subset: List[int]) -> float:
139
+ embs = torch.tensor(np.stack([proc_items[i]["embedding"] for i in idx_subset]), dtype=torch.float32, device=self.device)
140
+ embs = embs.unsqueeze(0) # (1, N, D)
141
+ s = self.vit.score_compatibility(embs).item()
142
+ return float(s)
143
+
144
+ scored = [(subset, score_subset(subset)) for subset in candidates]
145
+ scored.sort(key=lambda x: x[1], reverse=True)
146
+ topk = scored[:num_outfits]
147
+ results = [
148
+ {
149
+ "item_ids": [proc_items[i]["id"] for i in subset],
150
+ "score": float(score),
151
+ }
152
+ for subset, score in topk
153
+ ]
154
+ return results
155
+
156
+
models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
models/resnet_embedder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as tvm
6
+
7
+
8
+ class ResNetItemEmbedder(nn.Module):
9
+ def __init__(self, embedding_dim: int = 512, backbone: str = "resnet50", pretrained: bool = True) -> None:
10
+ super().__init__()
11
+ if backbone == "resnet50":
12
+ model = tvm.resnet50(weights=tvm.ResNet50_Weights.DEFAULT if pretrained else None)
13
+ feat_dim = 2048
14
+ elif backbone == "resnet101":
15
+ model = tvm.resnet101(weights=tvm.ResNet101_Weights.DEFAULT if pretrained else None)
16
+ feat_dim = 2048
17
+ else:
18
+ raise ValueError(f"Unsupported backbone: {backbone}")
19
+
20
+ # Remove classifier, keep global average pooling output
21
+ modules = list(model.children())[:-1] # drop fc
22
+ self.backbone = nn.Sequential(*modules)
23
+ self.proj = nn.Linear(feat_dim, embedding_dim)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ # x: (B, 3, H, W)
27
+ feats = self.backbone(x) # (B, C, 1, 1)
28
+ feats = feats.flatten(1) # (B, C)
29
+ emb = self.proj(feats) # (B, D)
30
+ return emb
31
+
32
+
33
+
34
+
35
+
models/vit_outfit.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class OutfitCompatibilityModel(nn.Module):
8
+ def __init__(self, embedding_dim: int = 512, num_layers: int = 4, num_heads: int = 8, ff_multiplier: int = 4, dropout: float = 0.1) -> None:
9
+ super().__init__()
10
+ encoder_layer = nn.TransformerEncoderLayer(
11
+ d_model=embedding_dim,
12
+ nhead=num_heads,
13
+ dim_feedforward=ff_multiplier * embedding_dim,
14
+ dropout=dropout,
15
+ batch_first=True,
16
+ activation="gelu",
17
+ norm_first=True,
18
+ )
19
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
20
+ self.compatibility_head = nn.Sequential(
21
+ nn.LayerNorm(embedding_dim),
22
+ nn.Linear(embedding_dim, embedding_dim // 2),
23
+ nn.GELU(),
24
+ nn.Linear(embedding_dim // 2, 1),
25
+ )
26
+
27
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
28
+ # tokens: (B, N, D) sequence of item embeddings
29
+ h = self.encoder(tokens) # (B, N, D)
30
+ pooled = h.mean(dim=1) # (B, D)
31
+ score = self.compatibility_head(pooled) # (B, 1)
32
+ return score.squeeze(-1) # (B,)
33
+
34
+ def score_compatibility(self, tokens: torch.Tensor) -> torch.Tensor:
35
+ return self.forward(tokens)
36
+
37
+
38
+
39
+
40
+
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.30.1
3
+ gradio==5.44.1
4
+ pydantic==2.8.2
5
+ python-multipart==0.0.9
6
+ Pillow==10.4.0
7
+ numpy==2.0.1
8
+ torch==2.3.1
9
+ torchvision==0.18.1
10
+ timm==1.0.8
11
+ requests==2.32.3
12
+ scikit-learn==1.5.1
13
+ faiss-cpu==1.8.0.post1
14
+ pyyaml==6.0.2
15
+ open-clip-torch==2.26.1
16
+ huggingface-hub==0.24.5
17
+ kagglehub[pandas-datasets]
scripts/prepare_polyvore.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import argparse
5
+ from pathlib import Path
6
+ from typing import Dict, Any, List, Set
7
+
8
+
9
+ def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
10
+ candidates = [
11
+ os.path.join(root, f"{split}.json"),
12
+ os.path.join(root, f"{split}_no_dup.json"),
13
+ os.path.join(root, "splits", f"{split}.json"),
14
+ os.path.join(root, "splits", f"{split}_no_dup.json"),
15
+ ]
16
+ for p in candidates:
17
+ if os.path.exists(p):
18
+ with open(p, "r") as f:
19
+ data = json.load(f)
20
+ # Expect list where each item has key "items" listing item ids
21
+ return data
22
+ raise FileNotFoundError(f"Could not find {split}.json in {root} or {root}/splits")
23
+
24
+
25
+ def try_load_any_outfits(root: str) -> List[Dict[str, Any]]:
26
+ candidates = [
27
+ os.path.join(root, "outfits.json"),
28
+ os.path.join(root, "all.json"),
29
+ os.path.join(root, "data.json"),
30
+ ]
31
+ for p in candidates:
32
+ if os.path.exists(p):
33
+ with open(p, "r") as f:
34
+ return json.load(f)
35
+ # As a last resort, merge available splits
36
+ merged: List[Dict[str, Any]] = []
37
+ for sp in ["train", "valid", "test"]:
38
+ try:
39
+ merged.extend(load_outfits_json(root, sp))
40
+ except FileNotFoundError:
41
+ continue
42
+ return merged
43
+
44
+
45
+ def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]:
46
+ s: Set[str] = set()
47
+ for o in outfits:
48
+ for it in o.get("items", []):
49
+ s.add(str(it))
50
+ return sorted(s)
51
+
52
+
53
+ def build_triplets(outfits: List[Dict[str, Any]], all_items: List[str], max_triplets: int = 200000) -> List[Dict[str, str]]:
54
+ rng = random.Random(42)
55
+ all_items_set = set(all_items)
56
+ triplets: List[Dict[str, str]] = []
57
+ for o in outfits:
58
+ items = [str(i) for i in o.get("items", [])]
59
+ if len(items) < 2:
60
+ continue
61
+ local_set = set(items)
62
+ for i in range(len(items) - 1):
63
+ a = items[i]
64
+ p = items[i + 1]
65
+ # pick a negative not in this outfit
66
+ negatives = list(all_items_set - local_set)
67
+ if not negatives:
68
+ continue
69
+ n = rng.choice(negatives)
70
+ triplets.append({"anchor": a, "positive": p, "negative": n})
71
+ if len(triplets) >= max_triplets:
72
+ return triplets
73
+ return triplets
74
+
75
+
76
+ def build_outfit_pairs(outfits: List[Dict[str, Any]], num_negatives_per_pos: int = 1) -> List[Dict[str, Any]]:
77
+ rng = random.Random(123)
78
+ all_items = collect_all_items(outfits)
79
+ all_set = set(all_items)
80
+ pairs: List[Dict[str, Any]] = []
81
+ # Positive samples
82
+ for o in outfits:
83
+ items = [str(i) for i in o.get("items", [])]
84
+ if len(items) < 2:
85
+ continue
86
+ pairs.append({"items": items, "label": 1})
87
+ # Negative by corrupting one item
88
+ for _ in range(num_negatives_per_pos):
89
+ if not items:
90
+ continue
91
+ idx = rng.randrange(len(items))
92
+ neg_pool = list(all_set - set(items))
93
+ if not neg_pool:
94
+ continue
95
+ neg_item = rng.choice(neg_pool)
96
+ neg_items = items.copy()
97
+ neg_items[idx] = neg_item
98
+ pairs.append({"items": neg_items, "label": 0})
99
+ return pairs
100
+
101
+
102
+ def build_outfit_triplets(outfits: List[Dict[str, Any]], num_triplets: int = 200000) -> List[Dict[str, Any]]:
103
+ rng = random.Random(999)
104
+ # Collect only valid positive outfits (len >= 3 or ideally slot-complete)
105
+ pos = [o for o in outfits if len(o.get("items", [])) >= 3]
106
+ all_items = collect_all_items(outfits)
107
+ all_set = set(all_items)
108
+ triplets: List[Dict[str, Any]] = []
109
+ for _ in range(num_triplets):
110
+ if len(pos) < 2:
111
+ break
112
+ ga = rng.choice(pos)
113
+ gb = rng.choice(pos)
114
+ # Ensure ga != gb
115
+ if ga is gb:
116
+ continue
117
+ # Create bad by corrupting one item in ga
118
+ items_ga = [str(i) for i in ga.get("items", [])]
119
+ if not items_ga:
120
+ continue
121
+ corrupt_idx = rng.randrange(len(items_ga))
122
+ neg_pool = list(all_set - set(items_ga))
123
+ if not neg_pool:
124
+ continue
125
+ neg_item = rng.choice(neg_pool)
126
+ bad = items_ga.copy()
127
+ bad[corrupt_idx] = neg_item
128
+ triplets.append({"good_a": items_ga, "good_b": [str(i) for i in gb.get("items", [])], "bad": bad})
129
+ return triplets
130
+
131
+
132
+ def main() -> None:
133
+ ap = argparse.ArgumentParser()
134
+ ap.add_argument("--root", type=str, required=True, help="Polyvore dataset root")
135
+ ap.add_argument("--out", type=str, default=None, help="Output directory for splits (default: <root>/splits)")
136
+ ap.add_argument("--max_triplets", type=int, default=200000)
137
+ ap.add_argument("--neg_per_pos", type=int, default=1)
138
+ ap.add_argument("--random_split", action="store_true", help="Create 70/10/10 random split if official splits are missing")
139
+ args = ap.parse_args()
140
+
141
+ out_dir = args.out or os.path.join(args.root, "splits")
142
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
143
+
144
+ if args.random_split:
145
+ all_outfits = try_load_any_outfits(args.root)
146
+ if not all_outfits:
147
+ raise FileNotFoundError("No outfits found to split. Provide official splits or an outfits.json file.")
148
+ rng = random.Random(2024)
149
+ rng.shuffle(all_outfits)
150
+ n = len(all_outfits)
151
+ n_train = int(0.7 * n)
152
+ n_valid = int(0.1 * n)
153
+ splits = {
154
+ "train": all_outfits[:n_train],
155
+ "valid": all_outfits[n_train:n_train + n_valid],
156
+ "test": all_outfits[n_train + n_valid:],
157
+ }
158
+ else:
159
+ splits = {}
160
+ for split in ["train", "valid", "test"]:
161
+ try:
162
+ splits[split] = load_outfits_json(args.root, split)
163
+ except FileNotFoundError as e:
164
+ print(f"Skipping {split}: {e}")
165
+ splits[split] = []
166
+
167
+ for split, outfits in splits.items():
168
+ if not outfits:
169
+ continue
170
+ all_items = collect_all_items(outfits)
171
+ triplets = build_triplets(outfits, all_items, max_triplets=args.max_triplets)
172
+ pairs = build_outfit_pairs(outfits, num_negatives_per_pos=args.neg_per_pos)
173
+
174
+ with open(os.path.join(out_dir, f"{split}.json"), "w") as f:
175
+ json.dump(triplets, f)
176
+ with open(os.path.join(out_dir, f"outfits_{split}.json"), "w") as f:
177
+ json.dump(pairs, f)
178
+ triplets_o = build_outfit_triplets(outfits)
179
+ with open(os.path.join(out_dir, f"outfit_triplets_{split}.json"), "w") as f:
180
+ json.dump(triplets_o, f)
181
+ print(f"Wrote {split}: {len(triplets)} item-triplets, {len(pairs)} outfit-pairs, {len(triplets_o)} outfit-triplets -> {out_dir}")
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
186
+
187
+
train_resnet.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+
10
+ from data.polyvore import PolyvoreTripletDataset
11
+ from models.resnet_embedder import ResNetItemEmbedder
12
+ from utils.export import ensure_export_dir
13
+ import json
14
+
15
+
16
+ def parse_args() -> argparse.Namespace:
17
+ p = argparse.ArgumentParser()
18
+ p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
19
+ p.add_argument("--epochs", type=int, default=20)
20
+ p.add_argument("--batch_size", type=int, default=64)
21
+ p.add_argument("--lr", type=float, default=1e-3)
22
+ p.add_argument("--embedding_dim", type=int, default=512)
23
+ p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
24
+ return p.parse_args()
25
+
26
+
27
+ def main() -> None:
28
+ args = parse_args()
29
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
30
+
31
+ dataset = PolyvoreTripletDataset(args.data_root, split="train")
32
+
33
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
34
+ model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
35
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
36
+ criterion = nn.TripletMarginLoss(margin=0.2, p=2)
37
+
38
+ export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
39
+ best_loss = float("inf")
40
+ history = []
41
+ for epoch in range(args.epochs):
42
+ model.train()
43
+ running = 0.0
44
+ steps = 0
45
+ for batch in loader:
46
+ # Expect batch as (anchor, positive, negative)
47
+ anchor, positive, negative = batch
48
+ anchor = anchor.to(device)
49
+ positive = positive.to(device)
50
+ negative = negative.to(device)
51
+ emb_a = model(anchor)
52
+ emb_p = model(positive)
53
+ emb_n = model(negative)
54
+ loss = criterion(emb_a, emb_p, emb_n)
55
+ optimizer.zero_grad(set_to_none=True)
56
+ loss.backward()
57
+ optimizer.step()
58
+ running += loss.item()
59
+ steps += 1
60
+ avg_loss = running / max(1, steps)
61
+ out_path = args.out
62
+ if not out_path.startswith("models/"):
63
+ out_path = os.path.join(export_dir, os.path.basename(args.out))
64
+ torch.save({"state_dict": model.state_dict()}, out_path)
65
+ print(f"Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
66
+ history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
67
+ if avg_loss < best_loss:
68
+ best_loss = avg_loss
69
+ torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "resnet_item_embedder_best.pth"))
70
+
71
+ # write metrics
72
+ metrics_path = os.path.join(export_dir, "resnet_metrics.json")
73
+ with open(metrics_path, "w") as f:
74
+ json.dump({"best_triplet_loss": best_loss, "history": history}, f)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
79
+
80
+
train_vit_triplet.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from typing import List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+
10
+ from data.polyvore import PolyvoreOutfitTripletDataset
11
+ from models.vit_outfit import OutfitCompatibilityModel
12
+ from models.resnet_embedder import ResNetItemEmbedder
13
+ from utils.export import ensure_export_dir
14
+ import json
15
+
16
+
17
+ def parse_args() -> argparse.Namespace:
18
+ p = argparse.ArgumentParser()
19
+ p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
20
+ p.add_argument("--epochs", type=int, default=30)
21
+ p.add_argument("--batch_size", type=int, default=32)
22
+ p.add_argument("--lr", type=float, default=5e-4)
23
+ p.add_argument("--embedding_dim", type=int, default=512)
24
+ p.add_argument("--triplet_margin", type=float, default=0.3)
25
+ p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
26
+ p.add_argument("--eval_every", type=int, default=1)
27
+ return p.parse_args()
28
+
29
+
30
+ def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor:
31
+ if len(imgs) == 0:
32
+ return torch.zeros((max_len, embedder.proj.out_features), device=device)
33
+ k = min(len(imgs), max_len)
34
+ x = torch.stack(imgs[:k], dim=0).to(device)
35
+ with torch.no_grad():
36
+ e = embedder(x) # (k, D)
37
+ tokens = torch.zeros((max_len, e.shape[-1]), device=device)
38
+ tokens[:k] = e
39
+ return tokens
40
+
41
+
42
+ def main() -> None:
43
+ args = parse_args()
44
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
45
+
46
+ dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
47
+
48
+ def collate(batch):
49
+ return batch # variable length handled inside training loop
50
+
51
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
52
+
53
+ model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
54
+ embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
55
+ for p in embedder.parameters():
56
+ p.requires_grad_(False)
57
+
58
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
59
+ triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
60
+
61
+ export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
62
+ best_loss = float("inf")
63
+ hist = []
64
+ for epoch in range(args.epochs):
65
+ model.train()
66
+ for batch in loader:
67
+ # batch: List[(ga_imgs, gb_imgs, bd_imgs)]
68
+ anchor_tokens = []
69
+ positive_tokens = []
70
+ negative_tokens = []
71
+ for ga, gb, bd in batch:
72
+ ta = embed_outfit(ga, embedder, device)
73
+ tb = embed_outfit(gb, embedder, device)
74
+ tn = embed_outfit(bd, embedder, device)
75
+ anchor_tokens.append(ta.unsqueeze(0))
76
+ positive_tokens.append(tb.unsqueeze(0))
77
+ negative_tokens.append(tn.unsqueeze(0))
78
+ A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
79
+ P = torch.cat(positive_tokens, dim=0)
80
+ N = torch.cat(negative_tokens, dim=0)
81
+
82
+ # get outfit-level embeddings via ViT encoder pooled output
83
+ ea = model.encoder(A).mean(dim=1)
84
+ ep = model.encoder(P).mean(dim=1)
85
+ en = model.encoder(N).mean(dim=1)
86
+ loss = triplet(ea, ep, en)
87
+ optimizer.zero_grad(set_to_none=True)
88
+ loss.backward()
89
+ optimizer.step()
90
+
91
+ # Simple validation using a subset of training data as a proxy if no val split here
92
+ # For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
93
+ val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
94
+ val_loss = None
95
+ if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
96
+ val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
97
+ val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
98
+ model.eval()
99
+ losses = []
100
+ with torch.no_grad():
101
+ for vbatch in val_loader:
102
+ anchor_tokens = []
103
+ positive_tokens = []
104
+ negative_tokens = []
105
+ for ga, gb, bd in vbatch:
106
+ ta = embed_outfit(ga, embedder, device)
107
+ tb = embed_outfit(gb, embedder, device)
108
+ tn = embed_outfit(bd, embedder, device)
109
+ anchor_tokens.append(ta.unsqueeze(0))
110
+ positive_tokens.append(tb.unsqueeze(0))
111
+ negative_tokens.append(tn.unsqueeze(0))
112
+ A = torch.cat(anchor_tokens, dim=0)
113
+ P = torch.cat(positive_tokens, dim=0)
114
+ N = torch.cat(negative_tokens, dim=0)
115
+ ea = model.encoder(A).mean(dim=1)
116
+ ep = model.encoder(P).mean(dim=1)
117
+ en = model.encoder(N).mean(dim=1)
118
+ l = triplet(ea, ep, en).item()
119
+ losses.append(l)
120
+ val_loss = sum(losses) / max(1, len(losses))
121
+
122
+ out_path = args.export
123
+ if not out_path.startswith("models/"):
124
+ out_path = os.path.join(export_dir, os.path.basename(args.export))
125
+ torch.save({"state_dict": model.state_dict()}, out_path)
126
+ if val_loss is not None:
127
+ print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
128
+ hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item()), "val_triplet_loss": float(val_loss)})
129
+ if val_loss < best_loss:
130
+ best_loss = val_loss
131
+ torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "vit_outfit_model_best.pth"))
132
+ else:
133
+ print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} saved -> {out_path}")
134
+ hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item())})
135
+
136
+ metrics_path = os.path.join(export_dir, "vit_metrics.json")
137
+ payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
138
+ with open(metrics_path, "w") as f:
139
+ json.dump(payload, f)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
144
+
145
+
utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
utils/data_fetch.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import zipfile
4
+ from pathlib import Path
5
+ from typing import Optional, List
6
+
7
+ import requests
8
+
9
+ try:
10
+ from huggingface_hub import snapshot_download # type: ignore
11
+ except Exception: # pragma: no cover
12
+ snapshot_download = None
13
+
14
+ try:
15
+ import kagglehub # type: ignore
16
+ from kagglehub import KaggleDatasetAdapter # type: ignore
17
+ except Exception: # pragma: no cover
18
+ kagglehub = None
19
+ KaggleDatasetAdapter = None
20
+
21
+
22
+ def _download_zip(url: str, dest_dir: str) -> str:
23
+ os.makedirs(dest_dir, exist_ok=True)
24
+ local_zip = os.path.join(dest_dir, "dataset.zip")
25
+ with requests.get(url, stream=True, timeout=60) as r:
26
+ r.raise_for_status()
27
+ with open(local_zip, "wb") as f:
28
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
29
+ if chunk:
30
+ f.write(chunk)
31
+ with zipfile.ZipFile(local_zip, "r") as zf:
32
+ zf.extractall(dest_dir)
33
+ os.remove(local_zip)
34
+ return dest_dir
35
+
36
+
37
+ def _unzip_inner_archives(root: str) -> None:
38
+ """Find and extract any zip files inside root (e.g., images.zip)."""
39
+ for dirpath, _dirnames, filenames in os.walk(root):
40
+ for fn in filenames:
41
+ if fn.lower().endswith(".zip"):
42
+ zpath = os.path.join(dirpath, fn)
43
+ try:
44
+ with zipfile.ZipFile(zpath, "r") as zf:
45
+ zf.extractall(dirpath)
46
+ # keep original zip to avoid repeated work? remove to save disk
47
+ try:
48
+ os.remove(zpath)
49
+ except Exception:
50
+ pass
51
+ except Exception as e: # pragma: no cover
52
+ print(f"Failed to unzip inner archive {zpath}: {e}")
53
+
54
+
55
+ def _ensure_images_dir(root: str) -> None:
56
+ """Ensure a stable images/ path exists under root. Create a symlink if needed."""
57
+ images_root = os.path.join(root, "images")
58
+ if os.path.isdir(images_root):
59
+ return
60
+ # Try to find a folder with many jpg/png files
61
+ candidate_dirs: List[str] = []
62
+ for dirpath, dirnames, filenames in os.walk(root):
63
+ if dirpath == root:
64
+ # skip root level files, look deeper
65
+ continue
66
+ img_files = [f for f in filenames if f.lower().endswith((".jpg", ".jpeg", ".png"))]
67
+ if len(img_files) > 1000: # heuristic: big image folder
68
+ candidate_dirs.append(dirpath)
69
+ # Prefer the shallowest candidate
70
+ candidate_dirs.sort(key=lambda p: len(Path(p).parts))
71
+ if candidate_dirs:
72
+ src = candidate_dirs[0]
73
+ try:
74
+ os.symlink(src, images_root)
75
+ print(f"Created images symlink: {images_root} -> {src}")
76
+ except Exception:
77
+ # fallback: create folder and leave it empty (training will fail fast if missing)
78
+ os.makedirs(images_root, exist_ok=True)
79
+ else:
80
+ os.makedirs(images_root, exist_ok=True)
81
+
82
+
83
+ def ensure_dataset_ready() -> Optional[str]:
84
+ """
85
+ Ensure Polyvore dataset is present locally.
86
+ Priority:
87
+ 1) If POLYVORE_ROOT exists and has splits, return it
88
+ 2) Try Hugging Face dataset repo (defaults to Stylique/Polyvore if not set)
89
+ 3) If DATA_ZIP_URL is set, download and unzip
90
+ 4) Try KaggleHub (best-effort)
91
+ Returns resolved root path or None if nothing done.
92
+ """
93
+ root = os.getenv("POLYVORE_ROOT", "./data/Polyvore")
94
+ auto_fetch = os.getenv("AUTO_FETCH_DATA", "true").lower() == "true"
95
+ Path(root).mkdir(parents=True, exist_ok=True)
96
+
97
+ # Already prepared?
98
+ if os.path.isdir(os.path.join(root, "splits")):
99
+ _unzip_inner_archives(root)
100
+ _ensure_images_dir(root)
101
+ return root
102
+ if not auto_fetch:
103
+ return None
104
+
105
+ # Try HF dataset repo
106
+ repo = os.getenv("HF_DATASET_REPO", "Stylique/Polyvore")
107
+ if repo and snapshot_download is not None:
108
+ try:
109
+ snapshot_download(repo, repo_type="dataset", local_dir=root)
110
+ _unzip_inner_archives(root)
111
+ _ensure_images_dir(root)
112
+ # If splits not provided, they'll be prepared by the caller
113
+ return root
114
+ except Exception as e: # pragma: no cover
115
+ print(f"HF dataset download failed: {e}")
116
+
117
+ # Try ZIP URL
118
+ zip_url = os.getenv("DATA_ZIP_URL")
119
+ if zip_url:
120
+ try:
121
+ _download_zip(zip_url, root)
122
+ _unzip_inner_archives(root)
123
+ _ensure_images_dir(root)
124
+ except Exception as e: # pragma: no cover
125
+ print(f"ZIP download failed: {e}")
126
+ return None
127
+
128
+ # Try KaggleHub (no Kaggle keys required for public datasets)
129
+ if kagglehub is not None and KaggleDatasetAdapter is not None:
130
+ try:
131
+ # Attempt to load core file to trigger dataset download locally
132
+ # User can override POLYVORE_FILE_PATH to select a specific CSV/JSON
133
+ file_path = os.getenv("POLYVORE_FILE_PATH", "")
134
+ kagglehub.load_dataset(
135
+ KaggleDatasetAdapter.PANDAS,
136
+ "dnepozitek/polyvore-outfits",
137
+ file_path,
138
+ )
139
+ # KaggleHub stores under ~/.cache/kagglehub/datasets/<slug>/...; copy to root if needed
140
+ # For simplicity, assume user will run prepare script using POLYVORE_ROOT pointing to extracted images
141
+ _unzip_inner_archives(root)
142
+ _ensure_images_dir(root)
143
+ except Exception as e: # pragma: no cover
144
+ print(f"KaggleHub download failed: {e}")
145
+
146
+ return root
147
+
148
+
utils/export.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+
7
+ def ensure_export_dir(path: str) -> str:
8
+ os.makedirs(path, exist_ok=True)
9
+ return path
10
+
11
+
12
+ def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
13
+ model.eval()
14
+ traced = torch.jit.trace(model, example_inputs)
15
+ torch.jit.save(traced, out_path)
16
+ return out_path
17
+
18
+
19
+ def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
20
+ model.eval()
21
+ torch.onnx.export(
22
+ model,
23
+ example_inputs,
24
+ out_path,
25
+ export_params=True,
26
+ opset_version=opset,
27
+ do_constant_folding=True,
28
+ input_names=["input"],
29
+ output_names=["output"],
30
+ dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
31
+ )
32
+ return out_path
33
+
34
+
35
+
36
+
37
+
utils/transforms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torchvision.transforms as T
4
+
5
+
6
+ def build_inference_transform(image_size: int = 224) -> T.Compose:
7
+ return T.Compose(
8
+ [
9
+ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
10
+ T.ToTensor(),
11
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
12
+ ]
13
+ )
14
+
15
+
16
+ def build_train_transforms(image_size: int = 224) -> T.Compose:
17
+ return T.Compose(
18
+ [
19
+ T.Resize(int(image_size * 1.1), interpolation=T.InterpolationMode.BICUBIC),
20
+ T.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
21
+ T.RandomHorizontalFlip(),
22
+ T.ColorJitter(0.2, 0.2, 0.2, 0.1),
23
+ T.ToTensor(),
24
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
25
+ ]
26
+ )
27
+
28
+
29
+
30
+
31
+