Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
4716563
1
Parent(s):
4ef6571
intial starter
Browse files- Dockerfile +31 -0
- README.md +1 -5
- app.py +263 -0
- configs/default.yaml +33 -0
- data/__init__.py +5 -0
- data/polyvore.py +127 -0
- inference.py +156 -0
- models/__init__.py +5 -0
- models/resnet_embedder.py +35 -0
- models/vit_outfit.py +40 -0
- requirements.txt +17 -0
- scripts/prepare_polyvore.py +187 -0
- train_resnet.py +80 -0
- train_vit_triplet.py +145 -0
- utils/__init__.py +5 -0
- utils/data_fetch.py +148 -0
- utils/export.py +37 -0
- utils/transforms.py +31 -0
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
HF_HUB_ENABLE_HF_TRANSFER=1
|
| 7 |
+
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
curl \
|
| 12 |
+
ca-certificates \
|
| 13 |
+
libgomp1 \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
COPY recommendation/requirements.txt /app/requirements.txt
|
| 19 |
+
RUN pip install --upgrade pip && pip install -r /app/requirements.txt
|
| 20 |
+
|
| 21 |
+
COPY recommendation /app/
|
| 22 |
+
|
| 23 |
+
EXPOSE 8000
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
README.md
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
---
|
| 2 |
title: Recomendation
|
| 3 |
emoji: 🏆
|
| 4 |
colorFrom: purple
|
|
@@ -6,7 +5,4 @@ colorTo: green
|
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.1
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
| 1 |
title: Recomendation
|
| 2 |
emoji: 🏆
|
| 3 |
colorFrom: purple
|
|
|
|
| 5 |
sdk: gradio
|
| 6 |
sdk_version: 5.44.1
|
| 7 |
app_file: app.py
|
| 8 |
+
pinned: false
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
+
from typing import List, Optional, Any, Dict
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import requests
|
| 9 |
+
import torch
|
| 10 |
+
from fastapi import FastAPI, Header, HTTPException
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from starlette.staticfiles import StaticFiles
|
| 14 |
+
import threading
|
| 15 |
+
|
| 16 |
+
from inference import InferenceService
|
| 17 |
+
from utils.data_fetch import ensure_dataset_ready
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
AI_API_KEY = os.getenv("AI_API_KEY")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def require_api_key(x_api_key: Optional[str]):
|
| 24 |
+
if AI_API_KEY and x_api_key != AI_API_KEY:
|
| 25 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class EmbedRequest(BaseModel):
|
| 29 |
+
image_urls: Optional[List[str]] = None
|
| 30 |
+
images_base64: Optional[List[str]] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Item(BaseModel):
|
| 34 |
+
id: str
|
| 35 |
+
embedding: Optional[List[float]] = None
|
| 36 |
+
category: Optional[str] = None
|
| 37 |
+
image_url: Optional[str] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ComposeRequest(BaseModel):
|
| 41 |
+
items: List[Item]
|
| 42 |
+
context: Optional[Dict[str, Any]] = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
app = FastAPI(title="Dressify Recommendation Service")
|
| 46 |
+
service = InferenceService()
|
| 47 |
+
|
| 48 |
+
# Non-blocking bootstrap: fetch data, prepare splits, and train if needed in background
|
| 49 |
+
BOOT_STATUS = "idle"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _background_bootstrap():
|
| 53 |
+
global BOOT_STATUS
|
| 54 |
+
try:
|
| 55 |
+
BOOT_STATUS = "preparing-dataset"
|
| 56 |
+
ds_root = ensure_dataset_ready()
|
| 57 |
+
if not ds_root:
|
| 58 |
+
BOOT_STATUS = "dataset-not-prepared"
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
# Prepare 70/10/10 splits if missing
|
| 62 |
+
splits_dir = os.path.join(ds_root, "splits")
|
| 63 |
+
need_prepare = not (
|
| 64 |
+
os.path.isfile(os.path.join(splits_dir, "train.json")) or
|
| 65 |
+
os.path.isfile(os.path.join(splits_dir, "outfit_triplets_train.json"))
|
| 66 |
+
)
|
| 67 |
+
if need_prepare:
|
| 68 |
+
BOOT_STATUS = "creating-splits"
|
| 69 |
+
os.makedirs(splits_dir, exist_ok=True)
|
| 70 |
+
from scripts.prepare_polyvore import main as prepare_main
|
| 71 |
+
os.environ.setdefault("PYTHONWARNINGS", "ignore")
|
| 72 |
+
import sys
|
| 73 |
+
argv_bak = sys.argv
|
| 74 |
+
try:
|
| 75 |
+
sys.argv = ["prepare_polyvore.py", "--root", ds_root, "--random_split"]
|
| 76 |
+
prepare_main()
|
| 77 |
+
finally:
|
| 78 |
+
sys.argv = argv_bak
|
| 79 |
+
|
| 80 |
+
# Train if checkpoints are absent
|
| 81 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 82 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 83 |
+
resnet_ckpt = os.path.join(export_dir, "resnet_item_embedder_best.pth")
|
| 84 |
+
vit_ckpt = os.path.join(export_dir, "vit_outfit_model_best.pth")
|
| 85 |
+
import subprocess
|
| 86 |
+
if not os.path.exists(resnet_ckpt):
|
| 87 |
+
BOOT_STATUS = "training-resnet"
|
| 88 |
+
subprocess.run([
|
| 89 |
+
"python", "train_resnet.py", "--data_root", ds_root, "--epochs", "3",
|
| 90 |
+
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
|
| 91 |
+
], check=False)
|
| 92 |
+
if not os.path.exists(vit_ckpt):
|
| 93 |
+
BOOT_STATUS = "training-vit"
|
| 94 |
+
subprocess.run([
|
| 95 |
+
"python", "train_vit_triplet.py", "--data_root", ds_root, "--epochs", "3",
|
| 96 |
+
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
|
| 97 |
+
], check=False)
|
| 98 |
+
service.reload_models()
|
| 99 |
+
BOOT_STATUS = "ready"
|
| 100 |
+
except Exception as e:
|
| 101 |
+
BOOT_STATUS = f"error: {e}"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
threading.Thread(target=_background_bootstrap, daemon=True).start()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.get("/health")
|
| 108 |
+
def health() -> dict:
|
| 109 |
+
return {"status": "ok", "device": service.device, "resnet": service.resnet_version, "vit": service.vit_version}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@app.post("/embed")
|
| 113 |
+
def embed(req: EmbedRequest, x_api_key: Optional[str] = Header(None)) -> dict:
|
| 114 |
+
require_api_key(x_api_key)
|
| 115 |
+
images: List[Image.Image] = []
|
| 116 |
+
if req.image_urls:
|
| 117 |
+
for url in req.image_urls:
|
| 118 |
+
resp = requests.get(url, timeout=20)
|
| 119 |
+
resp.raise_for_status()
|
| 120 |
+
images.append(Image.open(io.BytesIO(resp.content)).convert("RGB"))
|
| 121 |
+
if req.images_base64:
|
| 122 |
+
for b64 in req.images_base64:
|
| 123 |
+
images.append(Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB"))
|
| 124 |
+
if not images:
|
| 125 |
+
raise HTTPException(status_code=400, detail="No images provided")
|
| 126 |
+
embs = service.embed_images(images)
|
| 127 |
+
return {"embeddings": [e.tolist() for e in embs], "model_version": service.resnet_version}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@app.post("/compose")
|
| 131 |
+
def compose(req: ComposeRequest, x_api_key: Optional[str] = Header(None)) -> dict:
|
| 132 |
+
require_api_key(x_api_key)
|
| 133 |
+
items = [
|
| 134 |
+
{
|
| 135 |
+
"id": it.id,
|
| 136 |
+
"embedding": np.array(it.embedding, dtype=np.float32) if it.embedding is not None else None,
|
| 137 |
+
"category": it.category,
|
| 138 |
+
"image_url": it.image_url,
|
| 139 |
+
}
|
| 140 |
+
for it in req.items
|
| 141 |
+
]
|
| 142 |
+
outfits = service.compose_outfits(items, context=req.context or {})
|
| 143 |
+
return {"outfits": outfits, "version": service.vit_version}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@app.get("/artifacts")
|
| 147 |
+
def artifacts() -> dict:
|
| 148 |
+
# list exported model artifacts for download
|
| 149 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 150 |
+
files = []
|
| 151 |
+
if os.path.isdir(export_dir):
|
| 152 |
+
for fn in os.listdir(export_dir):
|
| 153 |
+
if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
|
| 154 |
+
files.append({
|
| 155 |
+
"name": fn,
|
| 156 |
+
"path": f"{export_dir}/{fn}",
|
| 157 |
+
"url": f"/files/{fn}",
|
| 158 |
+
})
|
| 159 |
+
return {"artifacts": files}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# --------- Gradio UI ---------
|
| 163 |
+
|
| 164 |
+
def _load_images_from_files(files: List[str]) -> List[Image.Image]:
|
| 165 |
+
images: List[Image.Image] = []
|
| 166 |
+
for fp in files:
|
| 167 |
+
try:
|
| 168 |
+
with Image.open(fp) as im:
|
| 169 |
+
images.append(im.convert("RGB"))
|
| 170 |
+
except Exception:
|
| 171 |
+
continue
|
| 172 |
+
return images
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def gradio_embed(files: List[str]):
|
| 176 |
+
if not files:
|
| 177 |
+
return "[]"
|
| 178 |
+
images = _load_images_from_files(files)
|
| 179 |
+
if not images:
|
| 180 |
+
return "[]"
|
| 181 |
+
embs = service.embed_images(images)
|
| 182 |
+
return str([e.tolist() for e in embs])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def gradio_compose(files: List[str], occasion: str, weather: str, num_outfits: int):
|
| 186 |
+
if not files:
|
| 187 |
+
return []
|
| 188 |
+
images = _load_images_from_files(files)
|
| 189 |
+
if not images:
|
| 190 |
+
return []
|
| 191 |
+
embs = service.embed_images(images)
|
| 192 |
+
items = [
|
| 193 |
+
{"id": f"item_{i}", "embedding": embs[i], "category": None, "image_url": None}
|
| 194 |
+
for i in range(len(embs))
|
| 195 |
+
]
|
| 196 |
+
results = service.compose_outfits(items, context={"occasion": occasion, "weather": weather, "num_outfits": int(num_outfits)})
|
| 197 |
+
# Render as a simple markdown summary
|
| 198 |
+
lines = []
|
| 199 |
+
for r in results:
|
| 200 |
+
lines.append(f"score={r['score']:.3f}, items={r['item_ids']}")
|
| 201 |
+
return "\n".join(lines)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
with gr.Blocks() as demo:
|
| 205 |
+
gr.Markdown("# Dressify Recommendations – HF Test UI")
|
| 206 |
+
with gr.Tab("Embed"):
|
| 207 |
+
inp = gr.Files(label="Upload Items (multiple images)")
|
| 208 |
+
out = gr.Textbox(label="Embeddings (JSON)")
|
| 209 |
+
btn = gr.Button("Compute Embeddings")
|
| 210 |
+
btn.click(fn=gradio_embed, inputs=inp, outputs=out)
|
| 211 |
+
with gr.Tab("Compose"):
|
| 212 |
+
inp2 = gr.Files(label="Upload Wardrobe (multiple images)")
|
| 213 |
+
occasion = gr.Dropdown(choices=["casual", "business", "formal", "sport"], value="casual", label="Occasion")
|
| 214 |
+
weather = gr.Dropdown(choices=["any", "hot", "mild", "cold", "rain"], value="any", label="Weather")
|
| 215 |
+
num_outfits = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Num outfits")
|
| 216 |
+
out2 = gr.Textbox(label="Recommendations")
|
| 217 |
+
btn2 = gr.Button("Generate")
|
| 218 |
+
btn2.click(fn=gradio_compose, inputs=[inp2, occasion, weather, num_outfits], outputs=out2)
|
| 219 |
+
with gr.Tab("Downloads"):
|
| 220 |
+
gr.Markdown("Download trained artifacts from models/exports")
|
| 221 |
+
file_list = gr.JSON(label="Artifacts JSON")
|
| 222 |
+
def list_artifacts_for_ui():
|
| 223 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 224 |
+
files = []
|
| 225 |
+
if os.path.isdir(export_dir):
|
| 226 |
+
for fn in os.listdir(export_dir):
|
| 227 |
+
if fn.endswith((".pth", ".pt", ".onnx", ".ts", ".json")):
|
| 228 |
+
files.append({
|
| 229 |
+
"name": fn,
|
| 230 |
+
"path": f"{export_dir}/{fn}",
|
| 231 |
+
"url": f"/files/{fn}",
|
| 232 |
+
})
|
| 233 |
+
return {"artifacts": files}
|
| 234 |
+
refresh = gr.Button("Refresh")
|
| 235 |
+
refresh.click(fn=lambda: list_artifacts_for_ui(), inputs=[], outputs=file_list)
|
| 236 |
+
with gr.Tab("Status"):
|
| 237 |
+
gr.Markdown("Startup & training status")
|
| 238 |
+
status = gr.Textbox(label="Status", value=lambda: BOOT_STATUS)
|
| 239 |
+
refresh_status = gr.Button("Refresh Status")
|
| 240 |
+
refresh_status.click(fn=lambda: BOOT_STATUS, inputs=[], outputs=status)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
try:
|
| 244 |
+
# Mount Gradio onto FastAPI root path
|
| 245 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 246 |
+
except Exception:
|
| 247 |
+
# In case mounting fails in certain runners, we still want FastAPI to be available
|
| 248 |
+
pass
|
| 249 |
+
|
| 250 |
+
# Mount static files for direct artifact download
|
| 251 |
+
export_dir = os.getenv("EXPORT_DIR", "models/exports")
|
| 252 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 253 |
+
try:
|
| 254 |
+
app.mount("/files", StaticFiles(directory=export_dir), name="files")
|
| 255 |
+
except Exception:
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
# Local testing
|
| 261 |
+
demo.launch()
|
| 262 |
+
|
| 263 |
+
|
configs/default.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default training/eval configuration
|
| 2 |
+
|
| 3 |
+
seed: 42
|
| 4 |
+
image_size: 224
|
| 5 |
+
|
| 6 |
+
resnet:
|
| 7 |
+
backbone: resnet50
|
| 8 |
+
embedding_dim: 512
|
| 9 |
+
batch_size: 64
|
| 10 |
+
epochs: 20
|
| 11 |
+
lr: 0.001
|
| 12 |
+
weight_decay: 0.0001
|
| 13 |
+
use_imagenet_pretrained: true
|
| 14 |
+
|
| 15 |
+
vit:
|
| 16 |
+
embedding_dim: 512
|
| 17 |
+
num_layers: 6
|
| 18 |
+
num_heads: 8
|
| 19 |
+
ff_multiplier: 4
|
| 20 |
+
dropout: 0.1
|
| 21 |
+
batch_size: 64
|
| 22 |
+
epochs: 30
|
| 23 |
+
lr: 0.0005
|
| 24 |
+
weight_decay: 0.05
|
| 25 |
+
|
| 26 |
+
paths:
|
| 27 |
+
polyvore_root: /kaggle/input/polyvore-outfits
|
| 28 |
+
resnet_ckpt: models/exports/resnet_item_embedder.pth
|
| 29 |
+
vit_ckpt: models/exports/vit_outfit_model.pth
|
| 30 |
+
export_dir: models/exports
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
data/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
data/polyvore.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import List, Dict, Any, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from utils.transforms import build_train_transforms
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PolyvoreTripletDataset(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
Creates (anchor, positive, negative) image triplets for training the ResNet embedder.
|
| 15 |
+
Assumes a JSON list or multiple files that describe compatible pairs/sets and item image paths.
|
| 16 |
+
|
| 17 |
+
Expected structure (customize as needed):
|
| 18 |
+
root/
|
| 19 |
+
images/<item_id>.jpg
|
| 20 |
+
splits/train.json # [{"anchor": id, "positive": id, "negative": id}, ...]
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.root = root
|
| 26 |
+
self.split = split
|
| 27 |
+
self.transforms = build_train_transforms(image_size=image_size)
|
| 28 |
+
triplet_path = os.path.join(root, "splits", f"{split}.json")
|
| 29 |
+
if not os.path.exists(triplet_path):
|
| 30 |
+
raise FileNotFoundError(f"Triplet file not found: {triplet_path}")
|
| 31 |
+
with open(triplet_path, "r") as f:
|
| 32 |
+
self.samples: List[Dict[str, Any]] = json.load(f)
|
| 33 |
+
|
| 34 |
+
def _load_image(self, item_id: str) -> Image.Image:
|
| 35 |
+
# Customize if images are arranged differently
|
| 36 |
+
img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
|
| 37 |
+
if not os.path.exists(img_path):
|
| 38 |
+
raise FileNotFoundError(img_path)
|
| 39 |
+
return Image.open(img_path).convert("RGB")
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.samples)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx: int):
|
| 45 |
+
s = self.samples[idx]
|
| 46 |
+
a = self._load_image(str(s["anchor"]))
|
| 47 |
+
p = self._load_image(str(s["positive"]))
|
| 48 |
+
n = self._load_image(str(s["negative"]))
|
| 49 |
+
return self.transforms(a), self.transforms(p), self.transforms(n)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PolyvoreOutfitDataset(Dataset):
|
| 53 |
+
"""
|
| 54 |
+
Produces (tokens, label) where tokens is a sequence of item embeddings or images preprocessed downstream.
|
| 55 |
+
For simplicity here we return a list of image tensors to be embedded externally or pre-embedded offline.
|
| 56 |
+
|
| 57 |
+
Expected structure:
|
| 58 |
+
root/
|
| 59 |
+
images/<item_id>.jpg
|
| 60 |
+
splits/outfits_train.json # [{"items": [id1,id2,...], "label": 1}, ...]
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.root = root
|
| 66 |
+
self.split = split
|
| 67 |
+
self.transforms = build_train_transforms(image_size=image_size)
|
| 68 |
+
outfit_path = os.path.join(root, "splits", f"outfits_{split}.json")
|
| 69 |
+
if not os.path.exists(outfit_path):
|
| 70 |
+
raise FileNotFoundError(f"Outfit file not found: {outfit_path}")
|
| 71 |
+
with open(outfit_path, "r") as f:
|
| 72 |
+
self.samples: List[Dict[str, Any]] = json.load(f)
|
| 73 |
+
# enforce outfit slot constraints: require at least upper, bottom, shoes, accessory if metadata available
|
| 74 |
+
# If metadata isn't available, we will rely on count >= 3 and let model learn; here, keep as-is.
|
| 75 |
+
|
| 76 |
+
def _load_image(self, item_id: str) -> Image.Image:
|
| 77 |
+
img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
|
| 78 |
+
if not os.path.exists(img_path):
|
| 79 |
+
raise FileNotFoundError(img_path)
|
| 80 |
+
return Image.open(img_path).convert("RGB")
|
| 81 |
+
|
| 82 |
+
def __len__(self) -> int:
|
| 83 |
+
return len(self.samples)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx: int):
|
| 86 |
+
s = self.samples[idx]
|
| 87 |
+
imgs = [self.transforms(self._load_image(str(i))) for i in s["items"]]
|
| 88 |
+
label = torch.tensor(int(s.get("label", 1)), dtype=torch.long)
|
| 89 |
+
# Returns list of tensors; training loop can embed then pack to (N,D)
|
| 90 |
+
return imgs, label
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class PolyvoreOutfitTripletDataset(Dataset):
|
| 94 |
+
"""
|
| 95 |
+
Outfit-level triplets for ViT triplet training: (good1, good2, bad).
|
| 96 |
+
Expects file `outfit_triplets_<split>.json` with entries:
|
| 97 |
+
{"good_a": [id...], "good_b": [id...], "bad": [id...]}
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, root: str, split: str = "train", image_size: int = 224) -> None:
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.root = root
|
| 103 |
+
self.split = split
|
| 104 |
+
self.transforms = build_train_transforms(image_size=image_size)
|
| 105 |
+
trip_path = os.path.join(root, "splits", f"outfit_triplets_{split}.json")
|
| 106 |
+
if not os.path.exists(trip_path):
|
| 107 |
+
raise FileNotFoundError(f"Outfit triplet file not found: {trip_path}")
|
| 108 |
+
with open(trip_path, "r") as f:
|
| 109 |
+
self.samples: List[Dict[str, Any]] = json.load(f)
|
| 110 |
+
|
| 111 |
+
def _load_image(self, item_id: str) -> Image.Image:
|
| 112 |
+
img_path = os.path.join(self.root, "images", f"{item_id}.jpg")
|
| 113 |
+
if not os.path.exists(img_path):
|
| 114 |
+
raise FileNotFoundError(img_path)
|
| 115 |
+
return Image.open(img_path).convert("RGB")
|
| 116 |
+
|
| 117 |
+
def __len__(self) -> int:
|
| 118 |
+
return len(self.samples)
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, idx: int):
|
| 121 |
+
s = self.samples[idx]
|
| 122 |
+
ga = [self.transforms(self._load_image(str(i))) for i in s["good_a"]]
|
| 123 |
+
gb = [self.transforms(self._load_image(str(i))) for i in s["good_b"]]
|
| 124 |
+
bd = [self.transforms(self._load_image(str(i))) for i in s["bad"]]
|
| 125 |
+
return ga, gb, bd
|
| 126 |
+
|
| 127 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from utils.transforms import build_inference_transform
|
| 10 |
+
from models.resnet_embedder import ResNetItemEmbedder
|
| 11 |
+
from models.vit_outfit import OutfitCompatibilityModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _get_device() -> str:
|
| 15 |
+
if torch.cuda.is_available():
|
| 16 |
+
return "cuda"
|
| 17 |
+
if torch.backends.mps.is_available():
|
| 18 |
+
return "mps"
|
| 19 |
+
return "cpu"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class InferenceService:
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
self.device = _get_device()
|
| 25 |
+
self.transform = build_inference_transform()
|
| 26 |
+
self.embed_dim = int(os.getenv("EMBED_DIM", "512"))
|
| 27 |
+
self.resnet_version = "resnet_v1"
|
| 28 |
+
self.vit_version = "vit_v1"
|
| 29 |
+
|
| 30 |
+
self.resnet = self._load_resnet().to(self.device).eval()
|
| 31 |
+
self.vit = self._load_vit().to(self.device).eval()
|
| 32 |
+
|
| 33 |
+
for m in [self.resnet, self.vit]:
|
| 34 |
+
for p in m.parameters():
|
| 35 |
+
p.requires_grad_(False)
|
| 36 |
+
|
| 37 |
+
def _load_resnet(self) -> nn.Module:
|
| 38 |
+
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 39 |
+
ckpt_path = os.getenv("RESNET_CHECKPOINT", "models/exports/resnet_item_embedder.pth")
|
| 40 |
+
model = ResNetItemEmbedder(embedding_dim=self.embed_dim)
|
| 41 |
+
if strategy == "random":
|
| 42 |
+
return model
|
| 43 |
+
# prefer best if present
|
| 44 |
+
best_path = os.path.join(os.path.dirname(ckpt_path), "resnet_item_embedder_best.pth")
|
| 45 |
+
if os.path.exists(best_path):
|
| 46 |
+
ckpt_to_use = best_path
|
| 47 |
+
else:
|
| 48 |
+
ckpt_to_use = ckpt_path
|
| 49 |
+
if os.path.exists(ckpt_to_use):
|
| 50 |
+
state = torch.load(ckpt_to_use, map_location="cpu")
|
| 51 |
+
# accept either full state_dict or {"state_dict": ...}
|
| 52 |
+
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 53 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 54 |
+
if len(unexpected) == 0:
|
| 55 |
+
return model
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
def _load_vit(self) -> nn.Module:
|
| 59 |
+
strategy = os.getenv("MODEL_LOAD_STRATEGY", "state_dict")
|
| 60 |
+
ckpt_path = os.getenv("VIT_CHECKPOINT", "models/exports/vit_outfit_model.pth")
|
| 61 |
+
model = OutfitCompatibilityModel(embedding_dim=self.embed_dim)
|
| 62 |
+
if strategy == "random":
|
| 63 |
+
return model
|
| 64 |
+
best_path = os.path.join(os.path.dirname(ckpt_path), "vit_outfit_model_best.pth")
|
| 65 |
+
ckpt_to_use = best_path if os.path.exists(best_path) else ckpt_path
|
| 66 |
+
if os.path.exists(ckpt_to_use):
|
| 67 |
+
state = torch.load(ckpt_to_use, map_location="cpu")
|
| 68 |
+
state_dict = state.get("state_dict", state) if isinstance(state, dict) else state
|
| 69 |
+
model.load_state_dict(state_dict, strict=False)
|
| 70 |
+
return model
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
def reload_models(self) -> None:
|
| 74 |
+
"""Reload weights from current checkpoint locations (used after background training)."""
|
| 75 |
+
self.resnet = self._load_resnet().to(self.device).eval()
|
| 76 |
+
self.vit = self._load_vit().to(self.device).eval()
|
| 77 |
+
for m in [self.resnet, self.vit]:
|
| 78 |
+
for p in m.parameters():
|
| 79 |
+
p.requires_grad_(False)
|
| 80 |
+
|
| 81 |
+
@torch.inference_mode()
|
| 82 |
+
def embed_images(self, images: List[Image.Image]) -> List[np.ndarray]:
|
| 83 |
+
if len(images) == 0:
|
| 84 |
+
return []
|
| 85 |
+
batch = torch.stack([self.transform(img) for img in images]).to(self.device)
|
| 86 |
+
emb = self.resnet(batch)
|
| 87 |
+
emb = nn.functional.normalize(emb, dim=-1)
|
| 88 |
+
return [e.detach().cpu().numpy().astype(np.float32) for e in emb]
|
| 89 |
+
|
| 90 |
+
@torch.inference_mode()
|
| 91 |
+
def compose_outfits(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 92 |
+
# Ensure embeddings
|
| 93 |
+
proc_items: List[Dict[str, Any]] = []
|
| 94 |
+
for it in items:
|
| 95 |
+
e = it.get("embedding")
|
| 96 |
+
if e is None and it.get("image") is not None:
|
| 97 |
+
# Not used in Gradio path, but kept for completeness
|
| 98 |
+
emb = self.embed_images([it["image"]])[0]
|
| 99 |
+
elif e is None:
|
| 100 |
+
# If missing embedding and image, skip
|
| 101 |
+
continue
|
| 102 |
+
else:
|
| 103 |
+
emb = np.asarray(e, dtype=np.float32)
|
| 104 |
+
proc_items.append({"id": it.get("id"), "embedding": emb, "category": it.get("category")})
|
| 105 |
+
|
| 106 |
+
if len(proc_items) < 2:
|
| 107 |
+
return []
|
| 108 |
+
|
| 109 |
+
# Candidate generation: enforce minimum slots (upper, bottom, shoes, accessory) if categories provided
|
| 110 |
+
rng = np.random.default_rng(42)
|
| 111 |
+
num_outfits = int(context.get("num_outfits", 3))
|
| 112 |
+
min_size, max_size = 3, 5
|
| 113 |
+
candidates: List[List[int]] = []
|
| 114 |
+
ids = list(range(len(proc_items)))
|
| 115 |
+
# slot-aware sampling if categories exist
|
| 116 |
+
def has_cat(i: int, cat_prefix: str) -> bool:
|
| 117 |
+
c = (proc_items[i].get("category") or "").lower()
|
| 118 |
+
return cat_prefix in c
|
| 119 |
+
|
| 120 |
+
uppers = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["top", "shirt", "tshirt", "blouse", "jacket", "hoodie"]) ]
|
| 121 |
+
bottoms = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["pant", "trouser", "jean", "skirt", "short"]) ]
|
| 122 |
+
shoes = [i for i in ids if "shoe" in (proc_items[i].get("category") or "").lower()]
|
| 123 |
+
accs = [i for i in ids if any(k in (proc_items[i].get("category") or "").lower() for k in ["watch", "belt", "ring", "bracelet", "accessor"]) ]
|
| 124 |
+
|
| 125 |
+
for _ in range(num_outfits * 10):
|
| 126 |
+
if uppers and bottoms and shoes and accs:
|
| 127 |
+
subset = [rng.choice(uppers).item(), rng.choice(bottoms).item(), rng.choice(shoes).item(), rng.choice(accs).item()]
|
| 128 |
+
# optional: add one more random
|
| 129 |
+
remain = list(set(ids) - set(subset))
|
| 130 |
+
if remain and rng.random() < 0.5:
|
| 131 |
+
subset.append(rng.choice(remain).item())
|
| 132 |
+
else:
|
| 133 |
+
k = rng.integers(min_size, max_size + 1)
|
| 134 |
+
subset = rng.choice(ids, size=int(k), replace=False).tolist()
|
| 135 |
+
candidates.append(subset)
|
| 136 |
+
|
| 137 |
+
# Score using ViT
|
| 138 |
+
def score_subset(idx_subset: List[int]) -> float:
|
| 139 |
+
embs = torch.tensor(np.stack([proc_items[i]["embedding"] for i in idx_subset]), dtype=torch.float32, device=self.device)
|
| 140 |
+
embs = embs.unsqueeze(0) # (1, N, D)
|
| 141 |
+
s = self.vit.score_compatibility(embs).item()
|
| 142 |
+
return float(s)
|
| 143 |
+
|
| 144 |
+
scored = [(subset, score_subset(subset)) for subset in candidates]
|
| 145 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 146 |
+
topk = scored[:num_outfits]
|
| 147 |
+
results = [
|
| 148 |
+
{
|
| 149 |
+
"item_ids": [proc_items[i]["id"] for i in subset],
|
| 150 |
+
"score": float(score),
|
| 151 |
+
}
|
| 152 |
+
for subset, score in topk
|
| 153 |
+
]
|
| 154 |
+
return results
|
| 155 |
+
|
| 156 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
models/resnet_embedder.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torchvision.models as tvm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResNetItemEmbedder(nn.Module):
|
| 9 |
+
def __init__(self, embedding_dim: int = 512, backbone: str = "resnet50", pretrained: bool = True) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
if backbone == "resnet50":
|
| 12 |
+
model = tvm.resnet50(weights=tvm.ResNet50_Weights.DEFAULT if pretrained else None)
|
| 13 |
+
feat_dim = 2048
|
| 14 |
+
elif backbone == "resnet101":
|
| 15 |
+
model = tvm.resnet101(weights=tvm.ResNet101_Weights.DEFAULT if pretrained else None)
|
| 16 |
+
feat_dim = 2048
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f"Unsupported backbone: {backbone}")
|
| 19 |
+
|
| 20 |
+
# Remove classifier, keep global average pooling output
|
| 21 |
+
modules = list(model.children())[:-1] # drop fc
|
| 22 |
+
self.backbone = nn.Sequential(*modules)
|
| 23 |
+
self.proj = nn.Linear(feat_dim, embedding_dim)
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
# x: (B, 3, H, W)
|
| 27 |
+
feats = self.backbone(x) # (B, C, 1, 1)
|
| 28 |
+
feats = feats.flatten(1) # (B, C)
|
| 29 |
+
emb = self.proj(feats) # (B, D)
|
| 30 |
+
return emb
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
models/vit_outfit.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class OutfitCompatibilityModel(nn.Module):
|
| 8 |
+
def __init__(self, embedding_dim: int = 512, num_layers: int = 4, num_heads: int = 8, ff_multiplier: int = 4, dropout: float = 0.1) -> None:
|
| 9 |
+
super().__init__()
|
| 10 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 11 |
+
d_model=embedding_dim,
|
| 12 |
+
nhead=num_heads,
|
| 13 |
+
dim_feedforward=ff_multiplier * embedding_dim,
|
| 14 |
+
dropout=dropout,
|
| 15 |
+
batch_first=True,
|
| 16 |
+
activation="gelu",
|
| 17 |
+
norm_first=True,
|
| 18 |
+
)
|
| 19 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 20 |
+
self.compatibility_head = nn.Sequential(
|
| 21 |
+
nn.LayerNorm(embedding_dim),
|
| 22 |
+
nn.Linear(embedding_dim, embedding_dim // 2),
|
| 23 |
+
nn.GELU(),
|
| 24 |
+
nn.Linear(embedding_dim // 2, 1),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
# tokens: (B, N, D) sequence of item embeddings
|
| 29 |
+
h = self.encoder(tokens) # (B, N, D)
|
| 30 |
+
pooled = h.mean(dim=1) # (B, D)
|
| 31 |
+
score = self.compatibility_head(pooled) # (B, 1)
|
| 32 |
+
return score.squeeze(-1) # (B,)
|
| 33 |
+
|
| 34 |
+
def score_compatibility(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
return self.forward(tokens)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.111.0
|
| 2 |
+
uvicorn[standard]==0.30.1
|
| 3 |
+
gradio==5.44.1
|
| 4 |
+
pydantic==2.8.2
|
| 5 |
+
python-multipart==0.0.9
|
| 6 |
+
Pillow==10.4.0
|
| 7 |
+
numpy==2.0.1
|
| 8 |
+
torch==2.3.1
|
| 9 |
+
torchvision==0.18.1
|
| 10 |
+
timm==1.0.8
|
| 11 |
+
requests==2.32.3
|
| 12 |
+
scikit-learn==1.5.1
|
| 13 |
+
faiss-cpu==1.8.0.post1
|
| 14 |
+
pyyaml==6.0.2
|
| 15 |
+
open-clip-torch==2.26.1
|
| 16 |
+
huggingface-hub==0.24.5
|
| 17 |
+
kagglehub[pandas-datasets]
|
scripts/prepare_polyvore.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import argparse
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, List, Set
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_outfits_json(root: str, split: str) -> List[Dict[str, Any]]:
|
| 10 |
+
candidates = [
|
| 11 |
+
os.path.join(root, f"{split}.json"),
|
| 12 |
+
os.path.join(root, f"{split}_no_dup.json"),
|
| 13 |
+
os.path.join(root, "splits", f"{split}.json"),
|
| 14 |
+
os.path.join(root, "splits", f"{split}_no_dup.json"),
|
| 15 |
+
]
|
| 16 |
+
for p in candidates:
|
| 17 |
+
if os.path.exists(p):
|
| 18 |
+
with open(p, "r") as f:
|
| 19 |
+
data = json.load(f)
|
| 20 |
+
# Expect list where each item has key "items" listing item ids
|
| 21 |
+
return data
|
| 22 |
+
raise FileNotFoundError(f"Could not find {split}.json in {root} or {root}/splits")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def try_load_any_outfits(root: str) -> List[Dict[str, Any]]:
|
| 26 |
+
candidates = [
|
| 27 |
+
os.path.join(root, "outfits.json"),
|
| 28 |
+
os.path.join(root, "all.json"),
|
| 29 |
+
os.path.join(root, "data.json"),
|
| 30 |
+
]
|
| 31 |
+
for p in candidates:
|
| 32 |
+
if os.path.exists(p):
|
| 33 |
+
with open(p, "r") as f:
|
| 34 |
+
return json.load(f)
|
| 35 |
+
# As a last resort, merge available splits
|
| 36 |
+
merged: List[Dict[str, Any]] = []
|
| 37 |
+
for sp in ["train", "valid", "test"]:
|
| 38 |
+
try:
|
| 39 |
+
merged.extend(load_outfits_json(root, sp))
|
| 40 |
+
except FileNotFoundError:
|
| 41 |
+
continue
|
| 42 |
+
return merged
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def collect_all_items(outfits: List[Dict[str, Any]]) -> List[str]:
|
| 46 |
+
s: Set[str] = set()
|
| 47 |
+
for o in outfits:
|
| 48 |
+
for it in o.get("items", []):
|
| 49 |
+
s.add(str(it))
|
| 50 |
+
return sorted(s)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def build_triplets(outfits: List[Dict[str, Any]], all_items: List[str], max_triplets: int = 200000) -> List[Dict[str, str]]:
|
| 54 |
+
rng = random.Random(42)
|
| 55 |
+
all_items_set = set(all_items)
|
| 56 |
+
triplets: List[Dict[str, str]] = []
|
| 57 |
+
for o in outfits:
|
| 58 |
+
items = [str(i) for i in o.get("items", [])]
|
| 59 |
+
if len(items) < 2:
|
| 60 |
+
continue
|
| 61 |
+
local_set = set(items)
|
| 62 |
+
for i in range(len(items) - 1):
|
| 63 |
+
a = items[i]
|
| 64 |
+
p = items[i + 1]
|
| 65 |
+
# pick a negative not in this outfit
|
| 66 |
+
negatives = list(all_items_set - local_set)
|
| 67 |
+
if not negatives:
|
| 68 |
+
continue
|
| 69 |
+
n = rng.choice(negatives)
|
| 70 |
+
triplets.append({"anchor": a, "positive": p, "negative": n})
|
| 71 |
+
if len(triplets) >= max_triplets:
|
| 72 |
+
return triplets
|
| 73 |
+
return triplets
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def build_outfit_pairs(outfits: List[Dict[str, Any]], num_negatives_per_pos: int = 1) -> List[Dict[str, Any]]:
|
| 77 |
+
rng = random.Random(123)
|
| 78 |
+
all_items = collect_all_items(outfits)
|
| 79 |
+
all_set = set(all_items)
|
| 80 |
+
pairs: List[Dict[str, Any]] = []
|
| 81 |
+
# Positive samples
|
| 82 |
+
for o in outfits:
|
| 83 |
+
items = [str(i) for i in o.get("items", [])]
|
| 84 |
+
if len(items) < 2:
|
| 85 |
+
continue
|
| 86 |
+
pairs.append({"items": items, "label": 1})
|
| 87 |
+
# Negative by corrupting one item
|
| 88 |
+
for _ in range(num_negatives_per_pos):
|
| 89 |
+
if not items:
|
| 90 |
+
continue
|
| 91 |
+
idx = rng.randrange(len(items))
|
| 92 |
+
neg_pool = list(all_set - set(items))
|
| 93 |
+
if not neg_pool:
|
| 94 |
+
continue
|
| 95 |
+
neg_item = rng.choice(neg_pool)
|
| 96 |
+
neg_items = items.copy()
|
| 97 |
+
neg_items[idx] = neg_item
|
| 98 |
+
pairs.append({"items": neg_items, "label": 0})
|
| 99 |
+
return pairs
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def build_outfit_triplets(outfits: List[Dict[str, Any]], num_triplets: int = 200000) -> List[Dict[str, Any]]:
|
| 103 |
+
rng = random.Random(999)
|
| 104 |
+
# Collect only valid positive outfits (len >= 3 or ideally slot-complete)
|
| 105 |
+
pos = [o for o in outfits if len(o.get("items", [])) >= 3]
|
| 106 |
+
all_items = collect_all_items(outfits)
|
| 107 |
+
all_set = set(all_items)
|
| 108 |
+
triplets: List[Dict[str, Any]] = []
|
| 109 |
+
for _ in range(num_triplets):
|
| 110 |
+
if len(pos) < 2:
|
| 111 |
+
break
|
| 112 |
+
ga = rng.choice(pos)
|
| 113 |
+
gb = rng.choice(pos)
|
| 114 |
+
# Ensure ga != gb
|
| 115 |
+
if ga is gb:
|
| 116 |
+
continue
|
| 117 |
+
# Create bad by corrupting one item in ga
|
| 118 |
+
items_ga = [str(i) for i in ga.get("items", [])]
|
| 119 |
+
if not items_ga:
|
| 120 |
+
continue
|
| 121 |
+
corrupt_idx = rng.randrange(len(items_ga))
|
| 122 |
+
neg_pool = list(all_set - set(items_ga))
|
| 123 |
+
if not neg_pool:
|
| 124 |
+
continue
|
| 125 |
+
neg_item = rng.choice(neg_pool)
|
| 126 |
+
bad = items_ga.copy()
|
| 127 |
+
bad[corrupt_idx] = neg_item
|
| 128 |
+
triplets.append({"good_a": items_ga, "good_b": [str(i) for i in gb.get("items", [])], "bad": bad})
|
| 129 |
+
return triplets
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main() -> None:
|
| 133 |
+
ap = argparse.ArgumentParser()
|
| 134 |
+
ap.add_argument("--root", type=str, required=True, help="Polyvore dataset root")
|
| 135 |
+
ap.add_argument("--out", type=str, default=None, help="Output directory for splits (default: <root>/splits)")
|
| 136 |
+
ap.add_argument("--max_triplets", type=int, default=200000)
|
| 137 |
+
ap.add_argument("--neg_per_pos", type=int, default=1)
|
| 138 |
+
ap.add_argument("--random_split", action="store_true", help="Create 70/10/10 random split if official splits are missing")
|
| 139 |
+
args = ap.parse_args()
|
| 140 |
+
|
| 141 |
+
out_dir = args.out or os.path.join(args.root, "splits")
|
| 142 |
+
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
| 143 |
+
|
| 144 |
+
if args.random_split:
|
| 145 |
+
all_outfits = try_load_any_outfits(args.root)
|
| 146 |
+
if not all_outfits:
|
| 147 |
+
raise FileNotFoundError("No outfits found to split. Provide official splits or an outfits.json file.")
|
| 148 |
+
rng = random.Random(2024)
|
| 149 |
+
rng.shuffle(all_outfits)
|
| 150 |
+
n = len(all_outfits)
|
| 151 |
+
n_train = int(0.7 * n)
|
| 152 |
+
n_valid = int(0.1 * n)
|
| 153 |
+
splits = {
|
| 154 |
+
"train": all_outfits[:n_train],
|
| 155 |
+
"valid": all_outfits[n_train:n_train + n_valid],
|
| 156 |
+
"test": all_outfits[n_train + n_valid:],
|
| 157 |
+
}
|
| 158 |
+
else:
|
| 159 |
+
splits = {}
|
| 160 |
+
for split in ["train", "valid", "test"]:
|
| 161 |
+
try:
|
| 162 |
+
splits[split] = load_outfits_json(args.root, split)
|
| 163 |
+
except FileNotFoundError as e:
|
| 164 |
+
print(f"Skipping {split}: {e}")
|
| 165 |
+
splits[split] = []
|
| 166 |
+
|
| 167 |
+
for split, outfits in splits.items():
|
| 168 |
+
if not outfits:
|
| 169 |
+
continue
|
| 170 |
+
all_items = collect_all_items(outfits)
|
| 171 |
+
triplets = build_triplets(outfits, all_items, max_triplets=args.max_triplets)
|
| 172 |
+
pairs = build_outfit_pairs(outfits, num_negatives_per_pos=args.neg_per_pos)
|
| 173 |
+
|
| 174 |
+
with open(os.path.join(out_dir, f"{split}.json"), "w") as f:
|
| 175 |
+
json.dump(triplets, f)
|
| 176 |
+
with open(os.path.join(out_dir, f"outfits_{split}.json"), "w") as f:
|
| 177 |
+
json.dump(pairs, f)
|
| 178 |
+
triplets_o = build_outfit_triplets(outfits)
|
| 179 |
+
with open(os.path.join(out_dir, f"outfit_triplets_{split}.json"), "w") as f:
|
| 180 |
+
json.dump(triplets_o, f)
|
| 181 |
+
print(f"Wrote {split}: {len(triplets)} item-triplets, {len(pairs)} outfit-pairs, {len(triplets_o)} outfit-triplets -> {out_dir}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
| 186 |
+
|
| 187 |
+
|
train_resnet.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
from data.polyvore import PolyvoreTripletDataset
|
| 11 |
+
from models.resnet_embedder import ResNetItemEmbedder
|
| 12 |
+
from utils.export import ensure_export_dir
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args() -> argparse.Namespace:
|
| 17 |
+
p = argparse.ArgumentParser()
|
| 18 |
+
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
|
| 19 |
+
p.add_argument("--epochs", type=int, default=20)
|
| 20 |
+
p.add_argument("--batch_size", type=int, default=64)
|
| 21 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 22 |
+
p.add_argument("--embedding_dim", type=int, default=512)
|
| 23 |
+
p.add_argument("--out", type=str, default="models/exports/resnet_item_embedder.pth")
|
| 24 |
+
return p.parse_args()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main() -> None:
|
| 28 |
+
args = parse_args()
|
| 29 |
+
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 30 |
+
|
| 31 |
+
dataset = PolyvoreTripletDataset(args.data_root, split="train")
|
| 32 |
+
|
| 33 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 34 |
+
model = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device)
|
| 35 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 36 |
+
criterion = nn.TripletMarginLoss(margin=0.2, p=2)
|
| 37 |
+
|
| 38 |
+
export_dir = ensure_export_dir(os.path.dirname(args.out) or "models/exports")
|
| 39 |
+
best_loss = float("inf")
|
| 40 |
+
history = []
|
| 41 |
+
for epoch in range(args.epochs):
|
| 42 |
+
model.train()
|
| 43 |
+
running = 0.0
|
| 44 |
+
steps = 0
|
| 45 |
+
for batch in loader:
|
| 46 |
+
# Expect batch as (anchor, positive, negative)
|
| 47 |
+
anchor, positive, negative = batch
|
| 48 |
+
anchor = anchor.to(device)
|
| 49 |
+
positive = positive.to(device)
|
| 50 |
+
negative = negative.to(device)
|
| 51 |
+
emb_a = model(anchor)
|
| 52 |
+
emb_p = model(positive)
|
| 53 |
+
emb_n = model(negative)
|
| 54 |
+
loss = criterion(emb_a, emb_p, emb_n)
|
| 55 |
+
optimizer.zero_grad(set_to_none=True)
|
| 56 |
+
loss.backward()
|
| 57 |
+
optimizer.step()
|
| 58 |
+
running += loss.item()
|
| 59 |
+
steps += 1
|
| 60 |
+
avg_loss = running / max(1, steps)
|
| 61 |
+
out_path = args.out
|
| 62 |
+
if not out_path.startswith("models/"):
|
| 63 |
+
out_path = os.path.join(export_dir, os.path.basename(args.out))
|
| 64 |
+
torch.save({"state_dict": model.state_dict()}, out_path)
|
| 65 |
+
print(f"Epoch {epoch+1}/{args.epochs} avg_triplet_loss={avg_loss:.4f} saved -> {out_path}")
|
| 66 |
+
history.append({"epoch": epoch + 1, "avg_triplet_loss": avg_loss})
|
| 67 |
+
if avg_loss < best_loss:
|
| 68 |
+
best_loss = avg_loss
|
| 69 |
+
torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "resnet_item_embedder_best.pth"))
|
| 70 |
+
|
| 71 |
+
# write metrics
|
| 72 |
+
metrics_path = os.path.join(export_dir, "resnet_metrics.json")
|
| 73 |
+
with open(metrics_path, "w") as f:
|
| 74 |
+
json.dump({"best_triplet_loss": best_loss, "history": history}, f)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
| 79 |
+
|
| 80 |
+
|
train_vit_triplet.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
from data.polyvore import PolyvoreOutfitTripletDataset
|
| 11 |
+
from models.vit_outfit import OutfitCompatibilityModel
|
| 12 |
+
from models.resnet_embedder import ResNetItemEmbedder
|
| 13 |
+
from utils.export import ensure_export_dir
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args() -> argparse.Namespace:
|
| 18 |
+
p = argparse.ArgumentParser()
|
| 19 |
+
p.add_argument("--data_root", type=str, default=os.getenv("POLYVORE_ROOT", "/kaggle/input/polyvore-outfits"))
|
| 20 |
+
p.add_argument("--epochs", type=int, default=30)
|
| 21 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 22 |
+
p.add_argument("--lr", type=float, default=5e-4)
|
| 23 |
+
p.add_argument("--embedding_dim", type=int, default=512)
|
| 24 |
+
p.add_argument("--triplet_margin", type=float, default=0.3)
|
| 25 |
+
p.add_argument("--export", type=str, default="models/exports/vit_outfit_model.pth")
|
| 26 |
+
p.add_argument("--eval_every", type=int, default=1)
|
| 27 |
+
return p.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def embed_outfit(imgs: List[torch.Tensor], embedder: ResNetItemEmbedder, device: str, max_len: int = 4) -> torch.Tensor:
|
| 31 |
+
if len(imgs) == 0:
|
| 32 |
+
return torch.zeros((max_len, embedder.proj.out_features), device=device)
|
| 33 |
+
k = min(len(imgs), max_len)
|
| 34 |
+
x = torch.stack(imgs[:k], dim=0).to(device)
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
e = embedder(x) # (k, D)
|
| 37 |
+
tokens = torch.zeros((max_len, e.shape[-1]), device=device)
|
| 38 |
+
tokens[:k] = e
|
| 39 |
+
return tokens
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main() -> None:
|
| 43 |
+
args = parse_args()
|
| 44 |
+
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 45 |
+
|
| 46 |
+
dataset = PolyvoreOutfitTripletDataset(args.data_root, split="train")
|
| 47 |
+
|
| 48 |
+
def collate(batch):
|
| 49 |
+
return batch # variable length handled inside training loop
|
| 50 |
+
|
| 51 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 52 |
+
|
| 53 |
+
model = OutfitCompatibilityModel(embedding_dim=args.embedding_dim).to(device)
|
| 54 |
+
embedder = ResNetItemEmbedder(embedding_dim=args.embedding_dim).to(device).eval()
|
| 55 |
+
for p in embedder.parameters():
|
| 56 |
+
p.requires_grad_(False)
|
| 57 |
+
|
| 58 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-2)
|
| 59 |
+
triplet = nn.TripletMarginWithDistanceLoss(distance_function=lambda x, y: 1 - nn.functional.cosine_similarity(x, y), margin=args.triplet_margin)
|
| 60 |
+
|
| 61 |
+
export_dir = ensure_export_dir(os.path.dirname(args.export) or "models/exports")
|
| 62 |
+
best_loss = float("inf")
|
| 63 |
+
hist = []
|
| 64 |
+
for epoch in range(args.epochs):
|
| 65 |
+
model.train()
|
| 66 |
+
for batch in loader:
|
| 67 |
+
# batch: List[(ga_imgs, gb_imgs, bd_imgs)]
|
| 68 |
+
anchor_tokens = []
|
| 69 |
+
positive_tokens = []
|
| 70 |
+
negative_tokens = []
|
| 71 |
+
for ga, gb, bd in batch:
|
| 72 |
+
ta = embed_outfit(ga, embedder, device)
|
| 73 |
+
tb = embed_outfit(gb, embedder, device)
|
| 74 |
+
tn = embed_outfit(bd, embedder, device)
|
| 75 |
+
anchor_tokens.append(ta.unsqueeze(0))
|
| 76 |
+
positive_tokens.append(tb.unsqueeze(0))
|
| 77 |
+
negative_tokens.append(tn.unsqueeze(0))
|
| 78 |
+
A = torch.cat(anchor_tokens, dim=0) # (B, N, D)
|
| 79 |
+
P = torch.cat(positive_tokens, dim=0)
|
| 80 |
+
N = torch.cat(negative_tokens, dim=0)
|
| 81 |
+
|
| 82 |
+
# get outfit-level embeddings via ViT encoder pooled output
|
| 83 |
+
ea = model.encoder(A).mean(dim=1)
|
| 84 |
+
ep = model.encoder(P).mean(dim=1)
|
| 85 |
+
en = model.encoder(N).mean(dim=1)
|
| 86 |
+
loss = triplet(ea, ep, en)
|
| 87 |
+
optimizer.zero_grad(set_to_none=True)
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
|
| 91 |
+
# Simple validation using a subset of training data as a proxy if no val split here
|
| 92 |
+
# For true 70/10/10, prepare_polyvore.py will create outfit_triplets_valid.json
|
| 93 |
+
val_path = os.path.join(args.data_root, "splits", "outfit_triplets_valid.json")
|
| 94 |
+
val_loss = None
|
| 95 |
+
if os.path.exists(val_path) and (epoch + 1) % args.eval_every == 0:
|
| 96 |
+
val_ds = PolyvoreOutfitTripletDataset(args.data_root, split="valid")
|
| 97 |
+
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, collate_fn=lambda x: x)
|
| 98 |
+
model.eval()
|
| 99 |
+
losses = []
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
for vbatch in val_loader:
|
| 102 |
+
anchor_tokens = []
|
| 103 |
+
positive_tokens = []
|
| 104 |
+
negative_tokens = []
|
| 105 |
+
for ga, gb, bd in vbatch:
|
| 106 |
+
ta = embed_outfit(ga, embedder, device)
|
| 107 |
+
tb = embed_outfit(gb, embedder, device)
|
| 108 |
+
tn = embed_outfit(bd, embedder, device)
|
| 109 |
+
anchor_tokens.append(ta.unsqueeze(0))
|
| 110 |
+
positive_tokens.append(tb.unsqueeze(0))
|
| 111 |
+
negative_tokens.append(tn.unsqueeze(0))
|
| 112 |
+
A = torch.cat(anchor_tokens, dim=0)
|
| 113 |
+
P = torch.cat(positive_tokens, dim=0)
|
| 114 |
+
N = torch.cat(negative_tokens, dim=0)
|
| 115 |
+
ea = model.encoder(A).mean(dim=1)
|
| 116 |
+
ep = model.encoder(P).mean(dim=1)
|
| 117 |
+
en = model.encoder(N).mean(dim=1)
|
| 118 |
+
l = triplet(ea, ep, en).item()
|
| 119 |
+
losses.append(l)
|
| 120 |
+
val_loss = sum(losses) / max(1, len(losses))
|
| 121 |
+
|
| 122 |
+
out_path = args.export
|
| 123 |
+
if not out_path.startswith("models/"):
|
| 124 |
+
out_path = os.path.join(export_dir, os.path.basename(args.export))
|
| 125 |
+
torch.save({"state_dict": model.state_dict()}, out_path)
|
| 126 |
+
if val_loss is not None:
|
| 127 |
+
print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} val_triplet_loss={val_loss:.4f} saved -> {out_path}")
|
| 128 |
+
hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item()), "val_triplet_loss": float(val_loss)})
|
| 129 |
+
if val_loss < best_loss:
|
| 130 |
+
best_loss = val_loss
|
| 131 |
+
torch.save({"state_dict": model.state_dict()}, os.path.join(export_dir, "vit_outfit_model_best.pth"))
|
| 132 |
+
else:
|
| 133 |
+
print(f"Epoch {epoch+1}/{args.epochs} triplet_loss={loss.item():.4f} saved -> {out_path}")
|
| 134 |
+
hist.append({"epoch": epoch + 1, "triplet_loss": float(loss.item())})
|
| 135 |
+
|
| 136 |
+
metrics_path = os.path.join(export_dir, "vit_metrics.json")
|
| 137 |
+
payload = {"best_val_triplet_loss": best_loss if best_loss != float("inf") else None, "history": hist}
|
| 138 |
+
with open(metrics_path, "w") as f:
|
| 139 |
+
json.dump(payload, f)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
| 144 |
+
|
| 145 |
+
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
utils/data_fetch.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import zipfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from huggingface_hub import snapshot_download # type: ignore
|
| 11 |
+
except Exception: # pragma: no cover
|
| 12 |
+
snapshot_download = None
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import kagglehub # type: ignore
|
| 16 |
+
from kagglehub import KaggleDatasetAdapter # type: ignore
|
| 17 |
+
except Exception: # pragma: no cover
|
| 18 |
+
kagglehub = None
|
| 19 |
+
KaggleDatasetAdapter = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _download_zip(url: str, dest_dir: str) -> str:
|
| 23 |
+
os.makedirs(dest_dir, exist_ok=True)
|
| 24 |
+
local_zip = os.path.join(dest_dir, "dataset.zip")
|
| 25 |
+
with requests.get(url, stream=True, timeout=60) as r:
|
| 26 |
+
r.raise_for_status()
|
| 27 |
+
with open(local_zip, "wb") as f:
|
| 28 |
+
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
| 29 |
+
if chunk:
|
| 30 |
+
f.write(chunk)
|
| 31 |
+
with zipfile.ZipFile(local_zip, "r") as zf:
|
| 32 |
+
zf.extractall(dest_dir)
|
| 33 |
+
os.remove(local_zip)
|
| 34 |
+
return dest_dir
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _unzip_inner_archives(root: str) -> None:
|
| 38 |
+
"""Find and extract any zip files inside root (e.g., images.zip)."""
|
| 39 |
+
for dirpath, _dirnames, filenames in os.walk(root):
|
| 40 |
+
for fn in filenames:
|
| 41 |
+
if fn.lower().endswith(".zip"):
|
| 42 |
+
zpath = os.path.join(dirpath, fn)
|
| 43 |
+
try:
|
| 44 |
+
with zipfile.ZipFile(zpath, "r") as zf:
|
| 45 |
+
zf.extractall(dirpath)
|
| 46 |
+
# keep original zip to avoid repeated work? remove to save disk
|
| 47 |
+
try:
|
| 48 |
+
os.remove(zpath)
|
| 49 |
+
except Exception:
|
| 50 |
+
pass
|
| 51 |
+
except Exception as e: # pragma: no cover
|
| 52 |
+
print(f"Failed to unzip inner archive {zpath}: {e}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _ensure_images_dir(root: str) -> None:
|
| 56 |
+
"""Ensure a stable images/ path exists under root. Create a symlink if needed."""
|
| 57 |
+
images_root = os.path.join(root, "images")
|
| 58 |
+
if os.path.isdir(images_root):
|
| 59 |
+
return
|
| 60 |
+
# Try to find a folder with many jpg/png files
|
| 61 |
+
candidate_dirs: List[str] = []
|
| 62 |
+
for dirpath, dirnames, filenames in os.walk(root):
|
| 63 |
+
if dirpath == root:
|
| 64 |
+
# skip root level files, look deeper
|
| 65 |
+
continue
|
| 66 |
+
img_files = [f for f in filenames if f.lower().endswith((".jpg", ".jpeg", ".png"))]
|
| 67 |
+
if len(img_files) > 1000: # heuristic: big image folder
|
| 68 |
+
candidate_dirs.append(dirpath)
|
| 69 |
+
# Prefer the shallowest candidate
|
| 70 |
+
candidate_dirs.sort(key=lambda p: len(Path(p).parts))
|
| 71 |
+
if candidate_dirs:
|
| 72 |
+
src = candidate_dirs[0]
|
| 73 |
+
try:
|
| 74 |
+
os.symlink(src, images_root)
|
| 75 |
+
print(f"Created images symlink: {images_root} -> {src}")
|
| 76 |
+
except Exception:
|
| 77 |
+
# fallback: create folder and leave it empty (training will fail fast if missing)
|
| 78 |
+
os.makedirs(images_root, exist_ok=True)
|
| 79 |
+
else:
|
| 80 |
+
os.makedirs(images_root, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def ensure_dataset_ready() -> Optional[str]:
|
| 84 |
+
"""
|
| 85 |
+
Ensure Polyvore dataset is present locally.
|
| 86 |
+
Priority:
|
| 87 |
+
1) If POLYVORE_ROOT exists and has splits, return it
|
| 88 |
+
2) Try Hugging Face dataset repo (defaults to Stylique/Polyvore if not set)
|
| 89 |
+
3) If DATA_ZIP_URL is set, download and unzip
|
| 90 |
+
4) Try KaggleHub (best-effort)
|
| 91 |
+
Returns resolved root path or None if nothing done.
|
| 92 |
+
"""
|
| 93 |
+
root = os.getenv("POLYVORE_ROOT", "./data/Polyvore")
|
| 94 |
+
auto_fetch = os.getenv("AUTO_FETCH_DATA", "true").lower() == "true"
|
| 95 |
+
Path(root).mkdir(parents=True, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# Already prepared?
|
| 98 |
+
if os.path.isdir(os.path.join(root, "splits")):
|
| 99 |
+
_unzip_inner_archives(root)
|
| 100 |
+
_ensure_images_dir(root)
|
| 101 |
+
return root
|
| 102 |
+
if not auto_fetch:
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
# Try HF dataset repo
|
| 106 |
+
repo = os.getenv("HF_DATASET_REPO", "Stylique/Polyvore")
|
| 107 |
+
if repo and snapshot_download is not None:
|
| 108 |
+
try:
|
| 109 |
+
snapshot_download(repo, repo_type="dataset", local_dir=root)
|
| 110 |
+
_unzip_inner_archives(root)
|
| 111 |
+
_ensure_images_dir(root)
|
| 112 |
+
# If splits not provided, they'll be prepared by the caller
|
| 113 |
+
return root
|
| 114 |
+
except Exception as e: # pragma: no cover
|
| 115 |
+
print(f"HF dataset download failed: {e}")
|
| 116 |
+
|
| 117 |
+
# Try ZIP URL
|
| 118 |
+
zip_url = os.getenv("DATA_ZIP_URL")
|
| 119 |
+
if zip_url:
|
| 120 |
+
try:
|
| 121 |
+
_download_zip(zip_url, root)
|
| 122 |
+
_unzip_inner_archives(root)
|
| 123 |
+
_ensure_images_dir(root)
|
| 124 |
+
except Exception as e: # pragma: no cover
|
| 125 |
+
print(f"ZIP download failed: {e}")
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
# Try KaggleHub (no Kaggle keys required for public datasets)
|
| 129 |
+
if kagglehub is not None and KaggleDatasetAdapter is not None:
|
| 130 |
+
try:
|
| 131 |
+
# Attempt to load core file to trigger dataset download locally
|
| 132 |
+
# User can override POLYVORE_FILE_PATH to select a specific CSV/JSON
|
| 133 |
+
file_path = os.getenv("POLYVORE_FILE_PATH", "")
|
| 134 |
+
kagglehub.load_dataset(
|
| 135 |
+
KaggleDatasetAdapter.PANDAS,
|
| 136 |
+
"dnepozitek/polyvore-outfits",
|
| 137 |
+
file_path,
|
| 138 |
+
)
|
| 139 |
+
# KaggleHub stores under ~/.cache/kagglehub/datasets/<slug>/...; copy to root if needed
|
| 140 |
+
# For simplicity, assume user will run prepare script using POLYVORE_ROOT pointing to extracted images
|
| 141 |
+
_unzip_inner_archives(root)
|
| 142 |
+
_ensure_images_dir(root)
|
| 143 |
+
except Exception as e: # pragma: no cover
|
| 144 |
+
print(f"KaggleHub download failed: {e}")
|
| 145 |
+
|
| 146 |
+
return root
|
| 147 |
+
|
| 148 |
+
|
utils/export.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def ensure_export_dir(path: str) -> str:
|
| 8 |
+
os.makedirs(path, exist_ok=True)
|
| 9 |
+
return path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
|
| 13 |
+
model.eval()
|
| 14 |
+
traced = torch.jit.trace(model, example_inputs)
|
| 15 |
+
torch.jit.save(traced, out_path)
|
| 16 |
+
return out_path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
|
| 20 |
+
model.eval()
|
| 21 |
+
torch.onnx.export(
|
| 22 |
+
model,
|
| 23 |
+
example_inputs,
|
| 24 |
+
out_path,
|
| 25 |
+
export_params=True,
|
| 26 |
+
opset_version=opset,
|
| 27 |
+
do_constant_folding=True,
|
| 28 |
+
input_names=["input"],
|
| 29 |
+
output_names=["output"],
|
| 30 |
+
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
|
| 31 |
+
)
|
| 32 |
+
return out_path
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
utils/transforms.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torchvision.transforms as T
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def build_inference_transform(image_size: int = 224) -> T.Compose:
|
| 7 |
+
return T.Compose(
|
| 8 |
+
[
|
| 9 |
+
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
|
| 10 |
+
T.ToTensor(),
|
| 11 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 12 |
+
]
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_train_transforms(image_size: int = 224) -> T.Compose:
|
| 17 |
+
return T.Compose(
|
| 18 |
+
[
|
| 19 |
+
T.Resize(int(image_size * 1.1), interpolation=T.InterpolationMode.BICUBIC),
|
| 20 |
+
T.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
|
| 21 |
+
T.RandomHorizontalFlip(),
|
| 22 |
+
T.ColorJitter(0.2, 0.2, 0.2, 0.1),
|
| 23 |
+
T.ToTensor(),
|
| 24 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 25 |
+
]
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|