fkuyumcu's picture
Upload 4 files
ed99b9c verified
"""
OutfitTransformerCIR - Complementary Item Retrieval Model
==========================================================
Architecture based on Sarkar et al. with modifications:
- LaBSE instead of BERT for multilingual text encoding
- Set-wise Outfit Ranking Loss instead of InfoNCE
Usage:
from model import OutfitTransformerCIR
model = OutfitTransformerCIR()
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
# context_images: (B, S, 512) - ResNet-18 features
# context_texts: (B, S, 768) - LaBSE embeddings
predicted = model(context_images, context_texts)
# predicted: (B, 128) - Missing item embedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class OutfitTransformerCIR(nn.Module):
"""
Complementary Item Retrieval Transformer
Given context items (partial outfit), predicts the embedding of a missing item
that would complete the outfit harmoniously.
Architecture:
- Visual projection: 512 (ResNet-18) → 64
- Text projection: 768 (LaBSE) → 64
- Combined: 64 + 64 = 128 dim item embedding
- Transformer Encoder: 6 layers, 16 heads
- Learnable [QUERY] token for missing item prediction
Args:
embedding_dim (int): Final embedding dimension (default: 128)
nhead (int): Number of attention heads (default: 16)
num_layers (int): Number of transformer layers (default: 6)
use_projection (bool): Whether to apply projection layers.
- True: Input is raw features (512 + 768)
- False: Input is pre-projected features (64 + 64)
"""
def __init__(self, embedding_dim=128, nhead=16, num_layers=6, use_projection=True):
super(OutfitTransformerCIR, self).__init__()
self.use_projection = use_projection
self.embedding_dim = embedding_dim
# Projection layers (trained, not frozen)
self.visual_proj = nn.Linear(512, 64)
self.text_proj = nn.Linear(768, 64)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embedding_dim,
nhead=nhead,
dim_feedforward=512,
batch_first=True,
dropout=0.1
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Learnable query token (represents the missing item)
self.query_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
# Output projection with normalization
self.output_proj = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim),
nn.LayerNorm(embedding_dim)
)
def encode_items(self, images, texts):
"""
Encode multiple items (for context).
Args:
images: (B, S, D_img) where D_img=512 (raw) or 64 (projected)
texts: (B, S, D_txt) where D_txt=768 (raw) or 64 (projected)
Returns:
(B, S, 128) - Unified item embeddings
"""
if self.use_projection:
img_emb = self.visual_proj(images)
txt_emb = self.text_proj(texts)
else:
img_emb = images
txt_emb = texts
return torch.cat((img_emb, txt_emb), dim=-1)
def encode_single_item(self, image, text):
"""
Encode a single item (for candidate scoring).
Args:
image: (B, D_img)
text: (B, D_txt)
Returns:
(B, 128) - Item embedding
"""
if self.use_projection:
img_emb = self.visual_proj(image)
txt_emb = self.text_proj(text)
else:
img_emb = image
txt_emb = text
return torch.cat((img_emb, txt_emb), dim=-1)
def forward(self, context_images, context_texts, padding_mask=None):
"""
Predict the embedding of a missing item.
Args:
context_images: (B, S, 512) - ResNet-18 features of context items
context_texts: (B, S, 768) - LaBSE embeddings of context items
padding_mask: (B, S) - True indicates padding positions
Returns:
(B, 128) - Predicted embedding for the missing item
Example:
>>> model = OutfitTransformerCIR()
>>> # Outfit with 3 items: t-shirt, jeans, watch
>>> img_features = torch.randn(1, 3, 512) # ResNet-18 outputs
>>> txt_features = torch.randn(1, 3, 768) # LaBSE outputs
>>> predicted = model(img_features, txt_features)
>>> # predicted: (1, 128) - embedding for ideal 4th item (e.g., shoes)
"""
batch_size = context_images.size(0)
device = context_images.device
# 1. Encode context items
item_embeddings = self.encode_items(context_images, context_texts)
# 2. Prepend learnable query token
query = self.query_token.expand(batch_size, -1, -1)
x = torch.cat([query, item_embeddings], dim=1)
# 3. Build attention mask (query always attends, padding positions masked)
if padding_mask is not None:
query_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=device)
full_mask = torch.cat([query_mask, padding_mask], dim=1)
else:
full_mask = None
# 4. Transformer forward
out = self.transformer(x, src_key_padding_mask=full_mask)
# 5. Extract query output (first position)
query_out = out[:, 0, :]
# 6. Project and L2 normalize
predicted = self.output_proj(query_out)
predicted = F.normalize(predicted, p=2, dim=-1)
return predicted
# Convenience function for loading
def load_model(checkpoint_path, device="cpu"):
"""
Load a trained OutfitTransformerCIR model.
Args:
checkpoint_path: Path to pytorch_model.bin
device: "cpu" or "cuda"
Returns:
Loaded model in eval mode
"""
model = OutfitTransformerCIR(
embedding_dim=128,
nhead=16,
num_layers=6,
use_projection=True
)
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model