Outfit Transformer CIR (Complementary Item Retrieval)
A multimodal Transformer model for fashion outfit completion and complementary item retrieval. Given a partial outfit (e.g., a t-shirt and jeans), the model predicts the ideal embedding for a missing item (e.g., shoes) that would complete the outfit harmoniously.
Model Description
This model is based on the architecture proposed by Sarkar et al. in their paper on outfit recommendation, with several key modifications:
Differences from Original Paper
| Aspect | Original (Sarkar et al.) | This Implementation |
|---|---|---|
| Text Encoder | BERT (768-dim) | LaBSE (768-dim) |
| Text Language | English only | Multilingual (109 languages) |
| Negative Sampling | Random | Hard Negative Mining (same category) |
Why LaBSE instead of BERT?
LaBSE (Language-agnostic BERT Sentence Embedding) was chosen because:
- Multilingual Support: Works with 109 languages, enabling Turkish/English fashion descriptions
- Cross-lingual Alignment: "Mavi tiΕΓΆrt" and "blue t-shirt" produce similar embeddings
- Same Dimensionality: Still outputs 768-dim vectors, compatible with the original architecture
- Production Ready: Better suited for real-world e-commerce applications
Loss Function: Set-wise Outfit Ranking Loss
Following the original paper, we use the Set-wise Outfit Ranking Loss (Section 3.2.2):
L_set = L_all + L_hard
Where:
- L_all: Margin-based ranking over all negatives
- L_hard: Extra penalty on the hardest negative (closest wrong answer)
# L_ALL: General ranking loss
diff_all = pos_dist - neg_dist + margin # margin = 2.0
loss_all = ReLU(diff_all).mean()
# L_HARD: Hardest negative focus
min_neg_dist = neg_dist.min(dim=1)
diff_hard = pos_dist - min_neg_dist + margin
loss_hard = ReLU(diff_hard).mean()
total_loss = loss_all + loss_hard
Why this helps:
- InfoNCE treats all negatives equally via softmax
- Set-wise loss explicitly penalizes the hardest negative
- Reduces hubness problem where popular items dominate retrieval
Architecture
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OutfitTransformerCIR β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β ββββββββββββββββ ββββββββββββββββ β
β β ResNet-18 β β LaBSE β β
β β (Frozen) β β (Frozen) β β
β β 512-dim β β 768-dim β β
β ββββββββ¬ββββββββ ββββββββ¬ββββββββ β
β β β β
β ββββββββΌββββββββ ββββββββΌββββββββ β
β β Visual Proj β β Text Proj β β Trained β
β β 512 β 64 β β 768 β 64 β β
β ββββββββ¬ββββββββ ββββββββ¬ββββββββ β
β β β β
β ββββββββββ¬βββββββββββ β
β β β
β ββββββββΌβββββββ β
β β Concat β β
β β 64+64 = 128 β β
β ββββββββ¬βββββββ β
β β β
β βββββββββββββββΌββββββββββββββ β
β β [QUERY] + Item Embeddings β β
β β (Learnable Token) β β
β βββββββββββββββ¬ββββββββββββββ β
β β β
β βββββββββββββββΌββββββββββββββ β
β β Transformer Encoder β β
β β 6 layers, 16 heads β β
β β d_model=128, ff=512 β β
β βββββββββββββββ¬ββββββββββββββ β
β β β
β βββββββββββββββΌββββββββββββββ β
β β Output Projection β β
β β + LayerNorm + L2 Norm β β
β βββββββββββββββ¬ββββββββββββββ β
β β β
β ββββββββΌβββββββ β
β β 128-dim β β
β β Predicted β β
β β Embedding β β
β βββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Benchmark Results
Evaluated on Polyvore Outfits dataset (disjoint split):
| Metric | Score |
|---|---|
| FITB Accuracy | 56.39% |
| MRR | 0.7447 |
| Recall@1 | 56.39% |
| Recall@2 | 80.86% |
| Recall@3 | 93.56% |
| NDCG@3 | 0.7818 |
| NDCG@5 | 0.8095 |
Comparison with Baselines
| Model | FITB Accuracy | Notes |
|---|---|---|
| Random | 25.00% | 4-choice task |
| Type-Aware (Vasileva 2018) | ~53% | Category-specific spaces |
| Ours (LaBSE + SetWise) | 56.39% | Multilingual, margin-based |
| Sarkar et al. (reported) | ~57% | English BERT, InfoNCE |
Usage
Installation
pip install torch torchvision transformers
Loading the Model
import torch
from model import OutfitTransformerCIR
# Load model
model = OutfitTransformerCIR(embedding_dim=128, nhead=16, num_layers=6)
model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
model.eval()
Inference Example
# Assume you have pre-extracted features:
# context_images: (1, num_items, 512) - ResNet-18 features
# context_texts: (1, num_items, 768) - LaBSE embeddings
with torch.no_grad():
# Predict missing item embedding
predicted_embedding = model(context_images, context_texts)
# predicted_embedding: (1, 128)
# Use cosine similarity to find closest items in your database
similarities = torch.cosine_similarity(predicted_embedding, item_database)
top_matches = similarities.argsort(descending=True)[:10]
Feature Extraction (for your own items)
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import torch.nn as nn
# Image encoder (ResNet-18)
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet.eval()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Text encoder (LaBSE)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/LaBSE")
labse = AutoModel.from_pretrained("sentence-transformers/LaBSE")
labse.eval()
def extract_features(image_path, text_description):
# Image: 512-dim
image = Image.open(image_path).convert('RGB')
img_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
img_features = resnet(img_tensor).flatten(1) # (1, 512)
# Text: 768-dim
inputs = tokenizer(text_description, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
txt_features = labse(**inputs).pooler_output # (1, 768)
return img_features, txt_features
Training Details
| Hyperparameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning Rate | 1e-5 |
| Weight Decay | 0.01 |
| Batch Size | 64 |
| Epochs | 50 |
| Warmup Epochs | 2 |
| LR Scheduler | StepLR (step=10, gamma=0.5) |
| Margin (loss) | 2.0 |
| Num Negatives | 10 |
| Hard Negative Ratio | 50% (same category) |
Training Data
- Dataset: Polyvore Outfits (Maryland split, disjoint)
- Train: ~17K outfits, ~250K items
- Validation: ~2K outfits
- Test: ~3K outfits
Limitations
- Fixed Item Length: Model expects max 8 items per outfit (padding applied)
- Frozen Encoders: ResNet-18 and LaBSE are frozen during training
- Hubness: Some popular items may dominate retrieval (mitigated with CSLS)
- Fashion Domain: Trained on Polyvore data, may not generalize to other domains
Citation
If you use this model, please cite:
@misc{outfit-cir-transformer,
author = {Kuyumcu, Furkan},
title = {Outfit Transformer CIR: Multilingual Complementary Item Retrieval},
year = {2026},
publisher = {Hugging Face},
url = {https://huggingface.co/fkuyumcu/outfit-cir-transformer}
}
Original Paper Reference
@inproceedings{sarkar2022outfitbert,
title={OutfitTransformer: Learning Outfit Representations for Fashion Recommendation},
author={Sarkar, Rohan and others},
booktitle={CVPR Workshop on Computer Vision for Fashion, Art, and Design},
year={2022}
}
License
MIT License
- Downloads last month
- 11