|
|
""" |
|
|
OutfitTransformerCIR - Complementary Item Retrieval Model |
|
|
========================================================== |
|
|
|
|
|
Architecture based on Sarkar et al. with modifications: |
|
|
- LaBSE instead of BERT for multilingual text encoding |
|
|
- Set-wise Outfit Ranking Loss instead of InfoNCE |
|
|
|
|
|
Usage: |
|
|
from model import OutfitTransformerCIR |
|
|
|
|
|
model = OutfitTransformerCIR() |
|
|
model.load_state_dict(torch.load("pytorch_model.bin")) |
|
|
model.eval() |
|
|
|
|
|
# context_images: (B, S, 512) - ResNet-18 features |
|
|
# context_texts: (B, S, 768) - LaBSE embeddings |
|
|
predicted = model(context_images, context_texts) |
|
|
# predicted: (B, 128) - Missing item embedding |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class OutfitTransformerCIR(nn.Module): |
|
|
""" |
|
|
Complementary Item Retrieval Transformer |
|
|
|
|
|
Given context items (partial outfit), predicts the embedding of a missing item |
|
|
that would complete the outfit harmoniously. |
|
|
|
|
|
Architecture: |
|
|
- Visual projection: 512 (ResNet-18) → 64 |
|
|
- Text projection: 768 (LaBSE) → 64 |
|
|
- Combined: 64 + 64 = 128 dim item embedding |
|
|
- Transformer Encoder: 6 layers, 16 heads |
|
|
- Learnable [QUERY] token for missing item prediction |
|
|
|
|
|
Args: |
|
|
embedding_dim (int): Final embedding dimension (default: 128) |
|
|
nhead (int): Number of attention heads (default: 16) |
|
|
num_layers (int): Number of transformer layers (default: 6) |
|
|
use_projection (bool): Whether to apply projection layers. |
|
|
- True: Input is raw features (512 + 768) |
|
|
- False: Input is pre-projected features (64 + 64) |
|
|
""" |
|
|
|
|
|
def __init__(self, embedding_dim=128, nhead=16, num_layers=6, use_projection=True): |
|
|
super(OutfitTransformerCIR, self).__init__() |
|
|
|
|
|
self.use_projection = use_projection |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
|
|
|
self.visual_proj = nn.Linear(512, 64) |
|
|
self.text_proj = nn.Linear(768, 64) |
|
|
|
|
|
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=embedding_dim, |
|
|
nhead=nhead, |
|
|
dim_feedforward=512, |
|
|
batch_first=True, |
|
|
dropout=0.1 |
|
|
) |
|
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
self.query_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) |
|
|
|
|
|
|
|
|
self.output_proj = nn.Sequential( |
|
|
nn.Linear(embedding_dim, embedding_dim), |
|
|
nn.LayerNorm(embedding_dim) |
|
|
) |
|
|
|
|
|
def encode_items(self, images, texts): |
|
|
""" |
|
|
Encode multiple items (for context). |
|
|
|
|
|
Args: |
|
|
images: (B, S, D_img) where D_img=512 (raw) or 64 (projected) |
|
|
texts: (B, S, D_txt) where D_txt=768 (raw) or 64 (projected) |
|
|
|
|
|
Returns: |
|
|
(B, S, 128) - Unified item embeddings |
|
|
""" |
|
|
if self.use_projection: |
|
|
img_emb = self.visual_proj(images) |
|
|
txt_emb = self.text_proj(texts) |
|
|
else: |
|
|
img_emb = images |
|
|
txt_emb = texts |
|
|
|
|
|
return torch.cat((img_emb, txt_emb), dim=-1) |
|
|
|
|
|
def encode_single_item(self, image, text): |
|
|
""" |
|
|
Encode a single item (for candidate scoring). |
|
|
|
|
|
Args: |
|
|
image: (B, D_img) |
|
|
text: (B, D_txt) |
|
|
|
|
|
Returns: |
|
|
(B, 128) - Item embedding |
|
|
""" |
|
|
if self.use_projection: |
|
|
img_emb = self.visual_proj(image) |
|
|
txt_emb = self.text_proj(text) |
|
|
else: |
|
|
img_emb = image |
|
|
txt_emb = text |
|
|
|
|
|
return torch.cat((img_emb, txt_emb), dim=-1) |
|
|
|
|
|
def forward(self, context_images, context_texts, padding_mask=None): |
|
|
""" |
|
|
Predict the embedding of a missing item. |
|
|
|
|
|
Args: |
|
|
context_images: (B, S, 512) - ResNet-18 features of context items |
|
|
context_texts: (B, S, 768) - LaBSE embeddings of context items |
|
|
padding_mask: (B, S) - True indicates padding positions |
|
|
|
|
|
Returns: |
|
|
(B, 128) - Predicted embedding for the missing item |
|
|
|
|
|
Example: |
|
|
>>> model = OutfitTransformerCIR() |
|
|
>>> # Outfit with 3 items: t-shirt, jeans, watch |
|
|
>>> img_features = torch.randn(1, 3, 512) # ResNet-18 outputs |
|
|
>>> txt_features = torch.randn(1, 3, 768) # LaBSE outputs |
|
|
>>> predicted = model(img_features, txt_features) |
|
|
>>> # predicted: (1, 128) - embedding for ideal 4th item (e.g., shoes) |
|
|
""" |
|
|
batch_size = context_images.size(0) |
|
|
device = context_images.device |
|
|
|
|
|
|
|
|
item_embeddings = self.encode_items(context_images, context_texts) |
|
|
|
|
|
|
|
|
query = self.query_token.expand(batch_size, -1, -1) |
|
|
x = torch.cat([query, item_embeddings], dim=1) |
|
|
|
|
|
|
|
|
if padding_mask is not None: |
|
|
query_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=device) |
|
|
full_mask = torch.cat([query_mask, padding_mask], dim=1) |
|
|
else: |
|
|
full_mask = None |
|
|
|
|
|
|
|
|
out = self.transformer(x, src_key_padding_mask=full_mask) |
|
|
|
|
|
|
|
|
query_out = out[:, 0, :] |
|
|
|
|
|
|
|
|
predicted = self.output_proj(query_out) |
|
|
predicted = F.normalize(predicted, p=2, dim=-1) |
|
|
|
|
|
return predicted |
|
|
|
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, device="cpu"): |
|
|
""" |
|
|
Load a trained OutfitTransformerCIR model. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to pytorch_model.bin |
|
|
device: "cpu" or "cuda" |
|
|
|
|
|
Returns: |
|
|
Loaded model in eval mode |
|
|
""" |
|
|
model = OutfitTransformerCIR( |
|
|
embedding_dim=128, |
|
|
nhead=16, |
|
|
num_layers=6, |
|
|
use_projection=True |
|
|
) |
|
|
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|